#include <torch/torch.h>
// 定义一个继承自 torch::nn::Module 的 MLP 类
struct MLP : torch::nn::Module {
// 构造函数
MLP(int input_size, int output_size) {
// 初始化线性层并注册为模块的参数
fc = register_module("fc", torch::nn::Linear(input_size, output_size));
}
// 前向传播函数
torch::Tensor forward(torch::Tensor x) {
// 通过线性层传递输入张量
x = fc->forward(x);
return x;
}
// 线性层
torch::nn::Linear fc{nullptr};
};
int main() {
// 设置随机数种子以确保结果的可复现性
torch::manual_seed(0);
// 创建一个 MLP 实例,假设输入大小为 10,输出大小为 5
MLP MyModule(10, 5);
// 创建一个随机输入张量
torch::Tensor input = torch::randn({1, 10});
// 前向传播输入张量
torch::Tensor output = MyModule.forward(input);
// 打印输出张量
std::cout << output << std::endl;
// 打印模型的所有参数
for (auto& parameter : MyModule.parameters()) {
std::cout << parameter << std::endl;
}
return 0;
}
改变值
for (int i = 0; i < MyModule.parameters().size(); ++i) {
MyModule.parameters()[i].data() = torch::rand({ 10, 5 });
}
for (int i = 0; i < MyModule.buffers().size(); ++i) {
MyModule.buffers()[i].data() = torch::rand({ 10, 5 });
}
深拷贝如下
MLP MyModule_clone(10, 5);
for (int i = 0; i < MyModule.parameters().size(); ++i) {
MyModule_clone.parameters()[i].data() = MyModule.parameters()[i].clone();
}
for (int i = 0; i < MyModule.buffers().size(); ++i) {
MyModule_clone.buffers()[i].data() = MyModule.buffers()[i].clone();
}