MxNet系列——how_to——torch

本文介绍如何将MXNet用作Torch的前后端,包括使用MXNet.NDArray调用Torch张量数学函数及将Torch神经网络模块嵌入MXNet符号图的方法。此外还提供编译支持Torch的MXNet步骤,并给出具体代码示例。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

博客新址: http://blog.xuezhisd.top
邮箱:xuezhisd@126.com


如何将MXNet用作Torch的前后端

本章节描述了如何将MXNet用作Torch的两个主要功能(前端和后端):

  • 使用MXNet.NDArray来调用Torch的张量数学函数。

  • 将Torch的神经网络模块(层)嵌入到MXNet的符号图中。

编译支持Torch的MXNet

  • 参照 官方教程 来安装Torch
    • 如果还没有安装Torch,将配置文件 make/config.mk (Linux) 或 make/osx.mk (Mac) 复制到MXNet根目录中,并命名为 config.mk。取消文件 config.mk 中的两行注释: TORCH_PATH = $(HOME)/torchMXNET_PLUGINS += plugin/torch/torch.mk
    • 此处默认Torch安装在当前用户的主目录下(TORCH_PATH = $(HOME)/torch)。如果Torch没有安装在此目录,将参数 TORCH_PATH 修改成torch的安装目录。
  • 运行命令 make clean && make 来构建可以使用Torch的MXNet。

与张量相关的数学函数

mxnet.th模块支持调用Torch的张量数学函数和mxnet.nd.NDArray一起使用。查看 完整代码

import mxnet as mx
x = mx.th.randn(2, 2, ctx=mx.cpu(0))
print x.asnumpy()
y = mx.th.abs(x)
print y.asnumpy()

x = mx.th.randn(2, 2, ctx=mx.cpu(0))
print x.asnumpy()
mx.th.abs(x, x) # 原地计算
print x.asnumpy()

使用命令 help(mx.th) 获取更多帮助。

现在我们已经支持网页 Torch’s documentation page.上的最常用的函数。如果你发现你需要的函数还没有支持,你可以通过参考已经登记的函数,轻易地将它登记在页面 mxnet_root/plugin/torch/torch_function.cc 上。

Torch 模块 (网络层)

MXNet通过mxnet.symbol.TorchModule模块来支持Torch的神经网络模块。比如,下面的代码定义了一个对MNIST数据库进行分类的3层DNN网络。 (完整代码):

data = mx.symbol.Variable('data')
fc1 = mx.symbol.TorchModule(data_0=data, lua_string='nn.Linear(784, 128)', num_data=1, num_params=2, num_outputs=1, name='fc1')
act1 = mx.symbol.TorchModule(data_0=fc1, lua_string='nn.ReLU(false)', num_data=1, num_params=0, num_outputs=1, name='relu1')
fc2 = mx.symbol.TorchModule(data_0=act1, lua_string='nn.Linear(128, 64)', num_data=1, num_params=2, num_outputs=1, name='fc2')
act2 = mx.symbol.TorchModule(data_0=fc2, lua_string='nn.ReLU(false)', num_data=1, num_params=0, num_outputs=1, name='relu2')
fc3 = mx.symbol.TorchModule(data_0=act2, lua_string='nn.Linear(64, 10)', num_data=1, num_params=2, num_outputs=1, name='fc3')
mlp = mx.symbol.SoftmaxOutput(data=fc3, name='softmax')

下面,分析一下上述代码。首先 data = mx.symbol.Variable('data') 定义一个符号变量作为输入的占位符。然后,fc1 = mx.symbol.TorchModule(data_0=data, lua_string='nn.Linear(784, 128)', num_data=1, num_params=2, num_outputs=1, name='fc1') 将符号变量data传递给Torch的NN模块。如果你想使用Torch的Criterion作为损失函数,只需将最后一行替换成:

logsoftmax = mx.symbol.TorchModule(data_0=fc3, lua_string='nn.LogSoftMax()', num_data=1, num_params=0, num_outputs=1, name='logsoftmax')
# Torch的标签从1开始
label = mx.symbol.Variable('softmax_label') + 1
mlp = mx.symbol.TorchCriterion(data=logsoftmax, label=label, lua_string='nn.ClassNLLCriterion()', name='softmax')

nn模块的输入数据的命名估规则是 data_i,其中 i = 0 … num_data-1lua_string 是一个用来创建模块对象的单行Lua语句;对于Torch的内建模块,形式如nn.module_name(arguments) 所示。如果你要使用自定义模块,将它放在一个.lua脚本中,然后加载它:当你的脚本返回一个torch.nn对象时,使用命令 require 'module_file.lua 加载它;当你的脚本返回一个torch.nn类时,使用 (require 'module_file.lua')() 加载它。

`torch gather` 和 MXNet的`F.gather_nd` 都是用来从张量中选取特定索引元素的功能,但在使用上有一些细微差别。 1. **PyTorch (torch)**: `gather`函数主要用于沿着给定维度`dim`获取指定索引处的元素。它接受两个参数,第一个参数是源张量(`input`),第二个参数是一个长度匹配源张量该维度的整数切片(`index`)。例如,如果你有一个三维张量和一维索引,你可以选择每个索引对应的列。输出的张量将与输入的形状相同,除了指定的维度会变为1。 ```python input = torch.tensor([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]) index = torch.tensor([0, 1]) output = torch.gather(input, dim=1, index=index) # 输出: tensor([[1, 4], # [5, 8]]) ``` 2. **MXNet (F.gather_nd)**: `F.gather_nd`与`torch`的`gather`类似,但它可以处理更高维度的索引,允许你通过多个索引来取出多维张量的元素。这个函数需要一个数据张量`(data)`、一个形状匹配的数据张量`(indices)`作为索引以及一个轴`(axis)`。这个函数返回的是一个由索引指定元素组成的张量,形状是 `(indices.shape[:-1] + data.shape[axis+1:])`。 ```python import mxnet as mx data = mx.nd.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]) indices = mx.nd.array([[0, 1], [1, 0]]) # 二维索引 output = mx.nd.F.gather_nd(data, indices) # 输出: [[1, 3], [6, 5]] ``` **区别总结**: - PyTorch的`gather`更适用于单个或低维度索引的情况。 - MXNet的`F.gather_nd`支持高维度或多维索引,适合提取多位置的元素。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值