libtorch,torch::nn::Module类改变parameters值和buffers值,以及深拷贝

#include <torch/torch.h>

// 定义一个继承自 torch::nn::Module 的 MLP 类
struct MLP : torch::nn::Module {
    // 构造函数
    MLP(int input_size, int output_size) {
        // 初始化线性层并注册为模块的参数
        fc = register_module("fc", torch::nn::Linear(input_size, output_size));
    }

    // 前向传播函数
    torch::Tensor forward(torch::Tensor x) {
        // 通过线性层传递输入张量
        x = fc->forward(x);
        return x;
    }

    // 线性层
    torch::nn::Linear fc{nullptr};
};

int main() {
    // 设置随机数种子以确保结果的可复现性
    torch::manual_seed(0);

    // 创建一个 MLP 实例,假设输入大小为 10,输出大小为 5
    MLP MyModule(10, 5);

    // 创建一个随机输入张量
    torch::Tensor input = torch::randn({1, 10});

    // 前向传播输入张量
    torch::Tensor output = MyModule.forward(input);

    // 打印输出张量
    std::cout << output << std::endl;

    // 打印模型的所有参数
    for (auto& parameter : MyModule.parameters()) {
        std::cout << parameter << std::endl;
    }

    return 0;
}

改变值

for (int i = 0; i < MyModule.parameters().size(); ++i) {
	MyModule.parameters()[i].data() = torch::rand({ 10, 5 });
}
for (int i = 0; i < MyModule.buffers().size(); ++i) {
	MyModule.buffers()[i].data() = torch::rand({ 10, 5 });
}

深拷贝如下

MLP MyModule_clone(10, 5);
for (int i = 0; i < MyModule.parameters().size(); ++i) {
	MyModule_clone.parameters()[i].data() = MyModule.parameters()[i].clone();
}
for (int i = 0; i < MyModule.buffers().size(); ++i) {
	MyModule_clone.buffers()[i].data() = MyModule.buffers()[i].clone();
}

根据提供的引用内容,torch.nn.modules.module.Module的__getattr__方法用于获取给定name的成员。它首先从self.__dict__['_parameters']、self.__dict__['_buffers']以及self.__dict__['_modules']中查找,找到后返回该成员;如果找不到,则会引发AttributeError异常,报错信息为"'{}' object has no attribute '{}'".format(type(self).__name__, name)。 在这种情况下,错误信息提示是"torch.nn.modules.module.ModuleAttributeError: 'DataParallel' object has no attribute 'get_params'"。这表示DataParallel对象没有名为'get_params'的属性。这个错误通常发生在尝试访问一个不存在的属性时。可能是代码中使用了错误的属性名或者没有正确设置属性。要解决这个错误,你可以检查代码中的属性名是否正确拼写,并确保属性已正确设置。<span class="em">1</span><span class="em">2</span><span class="em">3</span> #### 引用[.reference_title] - *1* *3* [in <module>报错_【PyTorchtorch.nn.Module 源码分析](https://blog.youkuaiyun.com/weixin_39851809/article/details/110105784)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v93^chatsearchT3_2"}}] [.reference_item style="max-width: 50%"] - *2* [undefined](undefined)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v93^chatsearchT3_2"}}] [.reference_item style="max-width: 50%"] [ .reference_list ]
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值