pytorch与tensorflow API速查表
| 方法名称 | pytroch | tensorflow | numpy |
|---|---|---|---|
| 裁剪 | torch.clamp(x, min, max) | tf.clip_by_value(x, min, max) | np.clip(x, min, max) |
| 取最大值 | torch.max(x, dim)[0] | tf.max(x, axis) | np.max(x, axis) |
| 取最小值 | torch.min(x, dim)[0] | tf.min(x, axis) | np.min(x , axis) |
| 取两个tensor的最大值 | torch.max(x, y) | tf.maximum(x, y) | np.maximum(x, y) |
| 取两个tensor的最小值 | torch.min(x, y) | torch.minimum(x, y) | np.minmum(x, y) |
| 取最大值索引 | torch.max(x, dim)[1] | tf.argmax(x, axis) | np.argmax(x, axis) |
| 取最小值索引 | torch.min(x, dim)[1] | tf.argmin(x, axis) | np.argmin(x, axis) |
| 比较(x > y) | torch.gt(x, y) | tf.greater(x, y) | np.greater(x, y) |
| 比较(x < y) | torch.le(x, y) | tf.less(x, y) | np.less(x, y) |
| 比较(x==y) | torch.eq(x, y) | tf.equal(x, y) | np.equal(x, y) |
| 比较(x!=y) | torch.ne(x, y) | tf.not_equal(x, y) | np.not_queal(x , y) |
| 取符合条件值的索引 | torch.nonzero(cond) | tf.where(cond) | np.where(cond) |
| 多个tensor聚合 | torch.cat([x, y], dim) | tf.concat([x,y], axis) | np.concatenate([x,y], axis) |
| 堆叠成一个tensor | torch.stack([x1, x2], dim) | tf.stack([x1, x2], axis) | np.stack([x, y], axis) |
| tensor切成多个tensor | torch.split(x1, split_size_or_sections, dim) | tf.split(x1, num_or_size_splits, axis) | np.split(x1, indices_or_sections, axis) |
| – | torch.unbind(x1, dim) | tf.unstack(x1,axis) | NULL |
| 随机扰乱 | torch.randperm(n)1 | tf.random_shuffle(x) | np.random.shuffle(x)2 np.random.permutation(x)3 |
| 前k个值 | torch.topk(x, n, sorted, dim) | tf.nn.top_k(x, n, sorted) | NULL |
本文提供了PyTorch与TensorFlow常用API的对比表格,涵盖了裁剪、取最大值、比较等操作,帮助读者快速掌握两种深度学习框架间的转换。
7531

被折叠的 条评论
为什么被折叠?



