torch-1 tensor & optim

本文解析了PyTorch源码的关键部分,介绍了Tensor类的功能和实现细节,以及优化器模块如SGD的工作原理和使用方法。同时探讨了模型保存与加载的方法。

开个新坑, pytorch源码阅读…从python代码开始读起.

torch/

1.tensor.py

继承自torch._C._TensorBase , 包括各种操作,TODO:随后看cpp代码

  • __abs__, __iter__之类的内建方法

  • requires_grad属性是否需要求导

  • backward(self, gradient=None, retain_graph=None, create_graph=False) retain_graph表示是否在backward之后free内存

  • register_hook(self, hook) 每次gradients被计算的时候,这个hook都被调用。返回的handle提供remove hook的能力

  • v = Variable(torch.Tensor([0, 0, 0]), requires_grad=True)
    h = v.register_hook(lambda grad: grad * 2)  # double the gradient
    v.backward(torch.Tensor([1, 1, 1]))
    #先计算原始梯度,再进hook,获得一个新梯度。
    print(v.grad.data) #output [2,2,2]
    h.remove()  # removes the hook, 返回的句柄
    
2.random.py //TODO default_generator
3.serialization.py 模型的load, store等方法

torch/optim

一系列优化方法的集合, 基类是optimizer.py, 其余op都是继承这个类, 基础上实现op.step(), 初始化默认参数由__init__提供. 包括SGD, Adam, RMSProp等, 以SGD为例:

optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
optimizer.zero_grad() #初始化
loss_fn(model(input), target).backward()
optimizer.step()
内部方法

state_dict() & load_state_dict()更新state, param两个成员, 提供serialize的方法. 理解是可以训练到某个过程中进行op参数的存储, 下次可以继续, 避免训练失败重新训练

add_param_group() transfer learning中将freeze固定层的参数加入训练时, 可以用该方法.

lr_scheduler 用来进行lr的调整, 动态decay

scheduler = StepLR(optimizer, step_size=30, gamma=0.1)
for epoch in range(100):
    scheduler.step()
    train(...)
    validate(...)

tips:

  • id(k)获取object的单一标识,作为dict的key.
  • isinstance(obj, class or tuple) 判断obj是否是class的实例

06ef05bc-004d-4561-b9ba-842076c9884b

A中是有空间结构的,当前代码结构为class ReplayBuffer { public: ReplayBuffer(size_t capacity); void push(torch::Tensor state, int action, float reward, torch::Tensor next_state, int done); std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> sample(int batch_size); size_t size() const; private: size_t capacity_; std::deque<torch::Tensor> states_; std::deque<int> actions_; std::deque<float> rewards_; std::deque<torch::Tensor> next_states_; std::deque<bool> dones_; }; class QNet : public torch::nn::Module { public: QNet(int state_dim, int action_dim, int hidden_size = 128); torch::Tensor forward(torch::Tensor x); torch::OrderedDict<std::string, torch::Tensor> state_dict() const; void load_state_dict(const torch::OrderedDict<std::string, torch::Tensor>& state_dict); void save(torch::serialize::OutputArchive& archive) const override { torch::nn::Module::save(archive); archive.write("fc1.weight", fc1->weight); archive.write("fc1.bias", fc1->bias); archive.write("fc2.weight", fc2->weight); archive.write("fc2.bias", fc2->bias); archive.write("fc3.weight", fc3->weight); archive.write("fc3.bias", fc3->bias); } void load(torch::serialize::InputArchive& archive) override { torch::nn::Module::load(archive); archive.read("fc1.weight", fc1->weight); archive.read("fc1.bias", fc1->bias); archive.read("fc2.weight", fc2->weight); archive.read("fc2.bias", fc2->bias); archive.read("fc3.weight", fc3->weight); archive.read("fc3.bias", fc3->bias); } private: torch::nn::Linear fc1{ nullptr }; torch::nn::Linear fc2{ nullptr }; torch::nn::Linear fc3{ nullptr }; }; class DQNAgent { public: DQNAgent(int state_dim, int action_dim, double gamma = 0.99, double lr = 1e-3, int buffer_size = 10000, int batch_size = 64); int select_action(torch::Tensor state, float epsilon); void update(); void save_model(const std::string modelname); void load_model(const std::string modelname); void Buffer_push( ); void set_reward_calculator(); private: int state_dim_; int action_dim_; QNet policy_net; QNet target_net; ReplayBuffer buffer; torch::optim::Adam optimizer; }; 如何设计
最新发布
12-04
<think> 我们面对的问题是如何在DQN中处理具有空间结构的state(例如N*N矩阵A)以及向量B、C、D(N*1)。当前的代码框架中,ReplayBuffer存储的是torch::Tensor,而QNet是一个全连接网络。为了利用A的空间结构,我们需要修改网络架构,使其能够处理矩阵形式的数据(例如使用卷积层)并整合向量数据。 设计思路: 1. 修改state的表示:在存储和输入网络之前,我们需要将state的不同部分组合成一个统一的结构。这里有两种常见方式: a) 将矩阵A和向量B、C、D分别处理,然后合并(多分支网络)。 b) 将向量B、C、D进行变换后与A的某个维度拼接(例如,将向量重复成矩阵,然后作为额外通道),然后使用卷积网络。 2. 修改QNet:由于当前QNet是全连接网络,如果直接展平A(N*N)和B、C、D(N*1)然后拼接,会丢失A的空间结构,且当N较大时输入维度会很大(N*N+3*N)。因此,我们考虑使用卷积网络处理A,然后与处理后的向量特征合并。 具体修改步骤: 步骤1:重新设计state的存储和采样 - 在ReplayBuffer中,我们目前用单独的deque存储state和next_state,每个state是一个torch::Tensor。我们需要确保在push的时候,state是一个包含所有信息的张量,或者我们可以修改为存储多个张量(但这样会修改buffer接口,需要谨慎)。 - 为了最小化改动,我们可以将state设计为一个包含两个部分的元组(或者字典),但当前接口是单个torch::Tensor。因此,我们选择将不同的部分拼接成一个张量,但是这样可能会破坏空间结构。另一种方法是使用多个张量,但这需要修改buffer的结构。 考虑到现有代码中ReplayBuffer的push函数是: void push(torch::Tensor state, int action, float reward, torch::Tensor next_state, int done); 我们有两种选择: 选择1:修改ReplayBuffer,使其能够存储多个张量作为state(例如,将state拆分为A, B, C, D四个张量存储)。但这需要大幅修改buffer和sample函数,以及DQNAgent中与buffer交互的部分。 选择2:在将state存入buffer之前,将其处理成一个张量。例如: 将A(N*N)保持为矩阵,将B、C、D(每个都是N*1)拼接成N*3的矩阵,然后将它们沿着通道维度拼接,形成一个N*N*4的张量(其中前三个通道是B、C、D的扩展?)或者将B、C、D重复成N*N的矩阵,然后与A在通道维拼接(这样得到4个通道,N*N*4)。但是这样会使得矩阵A和向量B、C、D在空间上重复,可能不合理。 选择3(推荐):使用多输入网络,修改ReplayBuffer存储state为多个张量。这需要修改buffer的结构,使其能够存储多个张量(例如,将state存储为std::tuple或自定义结构)。但为了保持代码的通用性,我们可以将state定义为std::vector<torch::Tensor>,但这样会改变接口,需要调整。 由于时间关系,我们采用一种折中方法:在存入buffer之前,将A, B, C, D组合成一个字典或元组,然后存储为torch::IValue?但是这样会复杂化。或者,我们可以在buffer中存储两个张量:一个用于A(空间部分),一个用于向量(B, C, D拼接成的向量)。这样,我们需要修改ReplayBuffer为存储两个state张量(当前状态和下一个状态都有两个部分)。 考虑到修改的复杂性,我们暂时不修改ReplayBuffer,而是将state的不同部分在存入buffer之前合并成一个张量,但这样可能会丢失结构信息。因此,我们采用另一种方式:将A视为图像,将B, C, D作为额外的通道(通过上采样拼接)。具体做法: 假设A是单通道N*N,那么我们可以将B, C, D分别通过插值上采样到N*N(变成三个通道),然后与A拼接,这样得到4个通道(N*N*4)。然后,这个4通道图像可以直接输入到卷积网络中。 步骤: 1. 将B, C, D(每个都是N*1)通过重复操作变成N*N(例如,将列向量在列方向重复N次)。 2. 将A(N*N)扩展为1通道,然后与B、C、D扩展的3个通道拼接,得到4通道的N*N图像。 但是,这样做会使得向量数据在空间上重复,可能会引入冗余,并且要求向量B,C,D的长度与A的边长相同(都是N)。 步骤2:修改QNet,使用卷积层处理空间信息 - 将原来的全连接网络改为卷积网络+全连接网络的结构。 - 首先,输入是4通道的图像(N*N*4),然后经过几个卷积层,然后展平,再与可能存在的其他特征(如果有的话)合并,最后通过全连接层输出Q值。 但是,注意:我们这里没有其他特征了,因为向量部分已经被转换成通道了。 修改后的QNet(假设N=10,即10*10的图像): 我们可以这样设计: 卷积层1:输入通道4,输出通道32,卷积核3x3,步长1,填充1 -> 输出32x10x10 激活函数(如ReLU) 卷积层2:输入32,输出64,卷积核3x3,步长1,填充1 -> 输出64x10x10 激活函数 展平:64*10*10 = 6400 全连接层1:6400 -> 512 激活函数 全连接层2:512 -> action_dim 但是,如果N很大,展平后的维度会很大,这时我们可以在卷积层后加一个全局平均池化(或最大池化)来降维。 另外,如果N是可变的(即不同的样本N可能不同),那么我们不能使用固定大小的全连接层。这时,我们可以使用自适应池化(如自适应平均池化)将特征图统一到固定大小,然后再展平。 步骤3:修改DQNAgent中与state相关的操作 - 在select_action和update中,state的输入形式已经变成了4通道图像。因此,在将state输入网络之前,需要确保它被正确地转换为4通道图像(即按照步骤1的方法预处理)。 步骤4:修改ReplayBuffer中存储的state - 在push之前,我们需要将原始的A(矩阵)和B、C、D(向量)合并成一个4通道的图像张量。同样,next_state也要同样处理。 - 这样,ReplayBuffer中存储的state就是4通道图像张量(形状为[4, N, N]),因此我们在buffer中存储的states_和next_states_都是这种形式。 但是,这里有一个问题:不同的样本可能N不同(比如环境变化导致状态大小变化),那么我们就不能使用固定大小的网络。因此,我们需要确保N是固定的。如果环境中的N是变化的,那么我们需要调整方案(例如,将图像缩放到固定大小,或者使用可以处理可变输入的网络结构,如全卷积网络,并在最后使用全局池化)。 假设我们的环境固定N,那么我们可以按照上述步骤进行。 代码修改示例: 1. 修改ReplayBuffer的push函数:在调用push之前,agent需要将state(由A, B, C, D组成)转换为4通道图像张量。因此,在DQNAgent的Buffer_push函数(或者现在在select_action中获取的state)中,我们需要进行转换。 2. 在QNet中,我们重新定义网络结构: class QNet : public torch::nn::Module { public: QNet(int action_dim, int hidden_size = 128); // 因为输入已经是图像,所以不需要state_dim(但需要知道图像的通道数,这里固定为4) torch::Tensor forward(torch::Tensor x) { // 输入x的形状:[batch, 4, N, N] x = torch::relu(conv1(x)); // [batch, 32, N, N] x = torch::relu(conv2(x)); // [batch, 64, N, N] x = torch::relu(conv3(x)); // 可以再加一层,或者使用池化 // 使用自适应平均池化,将特征图变成固定大小,例如1x1,那么输出为[batch, 64, 1, 1] x = adaptive_pool(x); x = x.view({x.size(0), -1}); // 展平 x = torch::relu(fc1(x)); x = fc2(x); return x; } private: torch::nn::Conv2d conv1, conv2, conv3; torch::nn::Linear fc1, fc2; torch::nn::AdaptiveAvgPool2d adaptive_pool; }; 3. 在DQNAgent的构造函数中,初始化QNet时,不再需要state_dim,因为输入是图像。但是需要知道图像的大小(N)和通道数(4)。我们可以将网络结构设计成与N无关(通过自适应池化),这样即使N变化,网络也能工作(输出固定维度的特征)。 具体地,我们可以在最后一个卷积层后使用自适应平均池化,将特征图池化为1x1,那么无论输入图像大小是多少,卷积后的特征图经过自适应池化后都是1x1,然后通道数是最后一个卷积层的输出通道数,展平后就是通道数,然后输入全连接层。 例如: 自适应池化:torch::nn::AdaptiveAvgPool2d(torch::nn::AdaptiveAvgPool2dOptions({1, 1})) 这样,如果最后一个卷积层输出64个通道,那么池化后就是64*1*1=64,然后通过全连接层。 4. 在将state输入网络之前,确保state已经转换成4通道图像。假设我们有一个函数将原始的A(N*N)和B、C、D(每个都是N*1)转换成4通道图像: torch::Tensor state_to_tensor(const Matrix& A, const Vector& B, const Vector& C, const Vector& D) { // 假设A, B, C, D都是torch::Tensor,且A是2D(N, N),B, C, D是1D(N) // 将B, C, D从1D变成2D(N,1),然后重复成(N, N) auto B_mat = B.unsqueeze(1).repeat({1, N}); // 变成N*N,注意:unsqueeze(1)变成列向量,然后在列方向重复 auto C_mat = C.unsqueeze(1).repeat({1, N}); auto D_mat = D.unsqueeze(1).repeat({1, N}); // 将A扩展为1通道,B_mat, C_mat, D_mat分别作为通道,然后拼接 auto state_tensor = torch::stack({A, B_mat, C_mat, D_mat}, /*dim=*/0); // 形状为[4, N, N] return state_tensor; // 返回的张量是[4, N, N] } 注意:在存储到ReplayBuffer时,我们存储的是这个state_tensor(4通道图像),而不是原始分开的A, B, C, D。 5. 相应地,在环境返回状态时,我们需要调用这个转换函数,将原始状态转换成4通道图像,然后再调用push。 6. 在DQNAgent的select_action函数中,传入的state也是转换后的图像张量,然后直接输入policy_net。 总结:通过将矩阵A和向量B、C、D转换成多通道图像,我们保留了A的空间结构,并且能够使用卷积网络进行特征提取。同时,通过自适应池化,我们可以处理不同尺寸的输入(如果N变化),但通常我们假设N固定。 注意:这种转换可能会引入冗余(因为向量被重复了N次),但这是为了兼容卷积网络而做出的折中。 如果不想引入冗余,我们可以采用双分支网络:一个分支用CNN处理A,另一个分支用全连接处理B,C,D(拼接后的向量),然后将两个分支的特征合并,再通过全连接层输出Q值。这种结构更合理,但需要修改ReplayBuffer存储两个部分(或者将两个部分存储在一个元组中,然后buffer中存储两个state部分)。这需要较大的改动。 考虑到时间,我们这里采用多通道图像的方法。下面给出双分支网络的思路: 双分支网络设计: 步骤1:修改ReplayBuffer,使其存储state为两个张量:A(矩阵)和vec(由B,C,D拼接的向量)。 void push(torch::Tensor A, torch::Tensor vec, int action, ...); // 这样需要修改接口 步骤2:QNet有两个分支: branch1: 输入A(N*N),经过卷积网络后展平(或池化)得到特征向量1 branch2: 输入vec(3N),经过全连接网络得到特征向量2 将两个特征向量拼接,再通过全连接层输出Q值。 这样,我们就避免了向量重复成矩阵的冗余,并且可以分别处理。但是,ReplayBuffer的修改涉及所有使用的地方,包括sample函数返回的batch也要返回两个state部分(A_batch和vec_batch)。 由于当前代码框架中ReplayBuffer的接口是单个state,所以如果要采用双分支,可能需要重构。因此,我们可以在不改变ReplayBuffer的情况下,将两个部分合并成一个张量,但在存储时分开,在取出后再分开?这需要设计一个复合张量。例如,将A和vec存储为一个元组,但ReplayBuffer不支持元组,我们可以将A和vec拼接成一个一维张量,然后在取出后再根据尺寸分割?但这样破坏了A的空间结构(因为展平了A)。 因此,为了利用空间结构,我们可能不得不修改ReplayBuffer以支持多张量存储。由于这是一个较大的改动,这里我们暂时不展开。 结论:根据项目时间和复杂度,选择多通道图像方案(即转换成一个图像)是一种可行的折中。 最后,我们修改QNet为卷积网络,并添加自适应池化以处理不同尺寸(如果需要)。 代码修改: 1. 修改QNet的定义,使用卷积层和自适应池化。 2. 在DQNAgent中,将原始状态转换为多通道图像后再存储和输入网络。 3. 注意:在训练时,从buffer中sample出来的state就是多通道图像,直接输入网络。 注意:如果N是可变的,那么同一个batch内的N必须相同(因为神经网络要求输入张量是规则的)。因此,我们需要确保一个batch内的状态具有相同的N。如果环境中N是变化的,那么我们需要在采样batch时进行padding或者调整大小,但这会增加复杂性。因此,最好固定环境中的N。 如果环境中的N是变化的,我们可以采用以下方法: - 在存储时,将每个状态的N记录下来,然后在采样batch时,将状态图像调整到该batch中最大的N(通过插值),然后使用掩码标记有效区域?这很复杂。 - 或者,我们使用全局池化,使得网络可以处理任意大小的输入(只要卷积层使用得当,比如使用padding='same',然后通过全局池化得到固定长度的特征向量)。这样,我们就可以处理任意大小的输入。 因此,我们在卷积层后使用自适应平均池化到1x1,这样无论输入图像大小是多少,输出都是固定长度(通道数)的特征向量。 例如,最后卷积层的输出通道数为64,那么自适应池化后就是64维向量。 这样,我们就可以处理任意N(但N不能太小,否则卷积层会报错,因为经过多层卷积后特征图尺寸会缩小,但如果我们使用padding='same',可以保持特征图尺寸不变,直到池化)。 修改后的QNet: class QNet : public torch::nn::Module { public: QNet(int action_dim, int hidden_size = 128) { // 输入图像通道数:4 conv1 = register_module("conv1", torch::nn::Conv2d(torch::nn::Conv2dOptions(4, 32, 3).padding(1))); conv2 = register_module("conv2", torch::nn::Conv2d(torch::nn::Conv2dOptions(32, 64, 3).padding(1))); conv3 = register_module("conv3", torch::nn::Conv2d(torch::nn::Conv2dOptions(64, 64, 3).padding(1))); adaptive_pool = register_module("adaptive_pool", torch::nn::AdaptiveAvgPool2d(torch::nn::AdaptiveAvgPool2dOptions({1,1}))); fc1 = register_module("fc1", torch::nn::Linear(64, hidden_size)); fc2 = register_module("fc2", torch::nn::Linear(hidden_size, action_dim)); } torch::Tensor forward(torch::Tensor x) { x = torch::relu(conv1(x)); // [batch, 32, H, W] 其中H,W是输入的高和宽(这里假设输入是正方形) x = torch::relu(conv2(x)); // [batch, 64, H, W] x = torch::relu(conv3(x)); // [batch, 64, H, W] x = adaptive_pool(x); // [batch, 64, 1, 1] x = x.view({x.size(0), -1}); // [batch, 64] x = torch::relu(fc1(x)); // [batch, hidden_size] x = fc2(x); // [batch, action_dim] return x; } private: torch::nn::Conv2d conv1{nullptr}, conv2{nullptr}, conv3{nullptr}; torch::nn::AdaptiveAvgPool2d adaptive_pool{nullptr}; torch::nn::Linear fc1{nullptr}, fc2{nullptr}; }; 注意:这里假设输入图像是正方形(H=W),如果不是,需要调整。 在DQNAgent的构造函数中,初始化QNet时,我们不再需要state_dim,只需要action_dim: policy_net = std::make_shared<QNet>(action_dim); target_net = std::make_shared<QNet>(action_dim); 然后,在将state存储到buffer之前,需要将原始状态(A, B, C, D)转换为一个4通道的图像张量(形状为[4, H, W])。注意:这里H和W等于A的行数和列数(假设为N),并且B, C, D的长度也是N。 转换函数示例(在Agent内部实现一个转换函数): torch::Tensor DQNAgent::convert_state(const Matrix& A, const Vector& B, const Vector& C, const Vector& D) { // A: torch::Tensor of shape [N, N] // B, C, D: torch::Tensor of shape [N] int N = A.size(0); // 将B, C, D扩展成N*N矩阵 auto B_mat = B.unsqueeze(1).repeat({1, N}); // [N, N] auto C_mat = C.unsqueeze(1).repeat({1, N}); auto D_mat = D.unsqueeze(1).repeat({1, N}); // 将A和三个矩阵堆叠成4个通道 // 注意:unsqueeze(0)增加一个通道维,然后沿着通道维拼接 auto state_tensor = torch::stack({A, B_mat, C_mat, D_mat}, 0); // [4, N, N] return state_tensor; } 然后,在环境返回一个状态(由A, B, C, D组成)后,调用这个函数转换为图像,然后调用buffer.push: torch::Tensor state_tensor = convert_state(A, B, C, D); buffer.push(state_tensor, action, reward, next_state_tensor, done); 注意:next_state_tensor也要同样转换。 在select_action中,同样需要将当前状态(A, B, C, D)转换为图像张量,然后输入policy_net。 这样,我们就完成了修改。 注意:在训练时,从buffer中sample出来的state_batch的形状为[batch_size, 4, N, N](但注意,在buffer中存储的state_tensor是[4, N, N],在sample时,我们将多个state_tensortorch::stack成[batch_size, 4, N, N])。 但是,ReplayBuffer的sample函数需要修改吗?目前,sample返回的states是torch::Tensor,形状为[batch_size, ...],其中...就是每个state的形状。因为我们存储的是[4, N, N],所以sample返回的states就是[batch_size, 4, N, N]。所以不需要修改sample函数,只要我们在push时存储的是单个张量(形状为[4, N, N]),那么sample返回的就会是正确形状。 但是,注意:在ReplayBuffer中,我们存储的是deque<torch::Tensor> states_,每个元素是[4, N, N]。在sample时,我们使用torch::stack(states)会得到一个[batch_size, 4, N, N]的张量,这正是我们需要的。 因此,ReplayBuffer的sample函数不需要修改。 最后,注意:如果不同的状态有不同的N(即N不是固定的),那么在一个batch中,N必须相同。否则无法stack。因此,我们需要确保环境中的状态都是相同的N,或者在存储时固定N(例如,通过裁剪或缩放)。如果N是变化的,那么我们需要在采样batch后,将图像调整到统一大小(比如使用插值),但这样会增加计算量。因此,最好在环境设计时固定N。 以上是针对有空间结构的state的一种改进方案。 如果时间允许,推荐使用双分支网络,但这里我们暂时采用多通道图像方案。 根据上述设计,我们修改代码。 注意:示例代码中,ReplayBuffer的push函数参数是torch::Tensor state和torch::Tensor next_state,我们并没有改变,所以只需要在调用push之前将state转换好即可。 因此,我们只需要修改QNet和DQNAgent的状态转换部分,以及QNet的网络结构。 由于代码量较大,这里仅给出关键修改点。 如果后续有性能问题,可以考虑双分支网络,但当前方案已经能够利用空间结构。 最后,我们提出相关问题: </think>### 改进方案:针对空间结构state的DQN设计 当state包含空间结构矩阵(A)和向量(B、C、D)时,需要修改网络架构和处理流程。以下是基于您代码结构的改进方案: #### 1. **状态预处理模块** ```cpp // 状态转换工具函数 torch::Tensor state_to_tensor(const torch::Tensor& A, const torch::Tensor& B, const torch::Tensor& C, const torch::Tensor& D) { // 处理空间结构矩阵A (N×N) auto A_processed = A.unsqueeze(0); // 增加通道维度 → [1, N, N] // 处理向量B,C,D (N×1) auto vec = torch::cat({B, C, D}, /*dim=*/0); // 拼接 → [3N] auto vec_repeated = vec.repeat({N, N}).view({N, N, 3}); // → [N, N, 3] auto vec_processed = vec_repeated.permute({2, 0, 1}); // → [3, N, N] // 合并通道 → [4, N, N] return torch::cat({A_processed, vec_processed}, /*dim=*/0); } ``` #### 2. **改进的QNet架构(CNN+FC混合网络)** ```cpp class QNet : public torch::nn::Module { public: QNet(int action_dim, int conv_out_size = 64) { // 卷积分支处理空间特征 conv_branch = torch::nn::Sequential( torch::nn::Conv2d(torch::nn::Conv2dOptions(4, 32, 3).padding(1)), torch::nn::ReLU(), torch::nn::MaxPool2d(2), torch::nn::Conv2d(torch::nn::Conv2dOptions(32, 64, 3).padding(1)), torch::nn::ReLU(), torch::nn::AdaptiveAvgPool2d(1) // 自适应池化 → [64, 1, 1] ); // 全连接分支 fc_branch = torch::nn::Sequential( torch::nn::Linear(64, 128), torch::nn::ReLU(), torch::nn::Linear(128, action_dim) ); } torch::Tensor forward(torch::Tensor x) { // x: [batch, 4, N, N] auto features = conv_branch->forward(x); // → [batch, 64, 1, 1] features = features.view({features.size(0), -1}); // 展平 → [batch, 64] return fc_branch->forward(features); // → [batch, action_dim] } private: torch::nn::Sequential conv_branch{nullptr}; torch::nn::Sequential fc_branch{nullptr}; }; ``` #### 3. **DQNAgent关键修改点** ```cpp class DQNAgent { public: // 构造函数需删除state_dim参数 DQNAgent(int action_dim, double gamma=0.99, ...) : policy_net(action_dim), target_net(action_dim) {...} int select_action(torch::Tensor raw_state, float epsilon) { // 原始状态转换为张量 auto state_tensor = state_to_tensor( raw_state["A"], raw_state["B"], raw_state["C"], raw_state["D"] ).unsqueeze(0); // 增加batch维度 // 输入网络处理 auto q_values = policy_net->forward(state_tensor); // ...后续动作选择逻辑 } void Buffer_push(EnvState state) { auto state_tensor = state_to_tensor(state.A, state.B, state.C, state.D); auto next_state_tensor = state_to_tensor(next_state.A, ...); buffer.push(state_tensor, action, reward, next_state_tensor, done); } }; ``` #### 4. **ReplayBuffer调整** ```cpp class ReplayBuffer { public: // 存储转换后的张量 void push(torch::Tensor processed_state, ...) { states_.push_back(processed_state); // ...其他部分不变 } // 采样时直接返回处理好的张量 auto sample(int batch_size) { auto states = torch::stack(std::vector<torch::Tensor>( states_.begin(), states_.begin()+batch_size)); // ...其他部分不变 } }; ``` ### 关键改进说明: 1. **通道化处理**:将空间矩阵作为第1通道,三个向量拼接后扩展为3个通道 2. **自适应池化**:`nn::AdaptiveAvgPool2d`处理不同尺寸输入 3. **端到端处理**: - 环境输出原始状态 → Agent转换 → 存入Buffer - 采样后直接输入网络,无需额外预处理 4. **网络架构**: - 前段CNN提取空间特征 - 后段FC输出动作价值 - 全局池化保证输出维度固定 ### 性能优化建议: 1. **批处理加速**:在`state_to_tensor`中使用向量化操作 2. **缓存机制**:对环境输出状态进行缓存,避免重复计算 3. **量化压缩**:对存储的状态张量使用`torch::quantize_per_tensor`
评论 2
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值