OK,果然容易忘记。
好记性不如烂笔头啊好记性不如烂笔头啊好记性不如烂笔头啊
1.tensorflow中get_shape()改为Pytorch对应的函数
①tensor.get_shape()本身获取tensor的维度信息并以元组的形式返回,由于元组内容不可更改,故该函数常常跟.as_list()连用,返回一个tensor维度信息的列表,以供后续操作使用。
②Numpy 里,V.shape 得到整型元组类型的 V 的维度。
③ PyTorch 里,V.size() 得到对象
torch.Size 对象为 Python 列表:
2.tf.concat()在pytorch对应
在pytorch中,常见的拼接函数主要是两个,分别是:
torch.stack() 函数stack()对序列数据内部的张量进行扩维拼接,指定维度由程序员选择、大小是生成后数据的维度区间。通常为了保留–[序列(先后)信息] 和 [张量的矩阵信息] 才会使用stack。
torch.cat()
3.tf.expand_dims()对应
unsqueeze()函数起升维的作用,参数表示在哪个地方加一个维度。。
squeeze:删减维度 与unsqueeze作用刚好相反。
4.tf.tile()
tile(a1,a2,…,an)
通过复制原张量中的第i轴的向量并填充将第i轴向量个数扩大为原本的ai倍
tf.tile和torch.repeat的使用是一样的. 在翻译项目的时候,直接替换即可.
5.tf.matmul()对应
①torch.mm() 当torch.mm用于大于二