tensorflow转pytorch笔记;tf.gather_nd(x,y)转pytorch

这篇博客详细记录了将TensorFlow操作转换为PyTorch代码的过程,包括transpose、expand_dims、concat、tile、range、reduce_sum、clip_by_value、multinomial、equal、embedding_lookup和one_hot等函数的对应转换。还提供了一个自定义的gather_nd转换实现。内容对于熟悉TensorFlow并想过渡到PyTorch的开发者非常实用。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

记录了将tensorflow转pytorch时,一些常用的函数转换:

不能直接转换

  1. tf.transpose(input,[1, 0, 2]) -> input.permute([1, 0, 2]) 不能直接换成torch.transpose,因为操作不了多维
  2. tf.expand_dims(input), axis=1)->input.unsqueeze(1)
  3. tf.concat([content1,content2], axis=1->torch.cat((content1,content2), dim=1) 记得把axis换成dim
  4. tf.tile(input, [2, 1])-> input.repeat([2, 1])
  5. tf.range(10)->torch.arange(0)
  6. tf.reduce_sum(x, axis=1, keep_dims=True)-> torch.sum(x,dim=1,keepdim=True)
  7. tf.clip_by_value(x, min, max)->torch.clamp(x, min, max)
  8. tf.multinomial(logits=a, num_samples=1)->torch.multinomial(input=a, num_samples=1, replacement=False)
  9. tf.equal(x, y)->torch.eq(x, y)
  10. tf.nn.embedding_lookup(W_fe, Feature_input + 1)-> torch.index_select(W_fe, 0, Feature_input + 1)
  11. tf.one_hot()->functional.one_hot()

tf.gather_nd(x,y)转换

参考文章

    def gather_nd(self,params, indices):
        ''' 4D example params: tensor shaped [n_1, n_2, n_3, n_4] --> 4 dimensional indices: tensor shaped [m_1, m_2, m_3, m_4, 4] --> multidimensional list of 4D indices returns: tensor shaped [m_1, m_2, m_3, m_4] ND_example params: tensor shaped [n_1, ..., n_p] --> d-dimensional tensor indices: tensor shaped [m_1, ..., m_i, d] --> multidimensional list of d-dimensional indices returns: tensor shaped [m_1, ..., m_1] '''
        out_shape = indices.shape[:-1]
        indices = indices.unsqueeze(0).transpose(0, -1) # roll last axis to fring
        ndim = indices.shape[0]
        indices = indices.long()
        idx = torch.zeros_like(indices[0], device=indices.device).long()
        m = 1
        for i in range(ndim)[::-1]:
            idx += indices[i] * m
            m *= params.size(i)
        out = torch.take(params, idx)
        return out.view(out_shape)

可以直接转换

  1. tf.reshape()->torch.reshape()
  2. tf.log()
  3. tf.squeeze
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值