20250301_代码笔记_函数class CVRPEnv: def step(self, selected)


前言

细读函数class CVRPEnv: def step(self, selected)

/home/tang/RL_exa/NCO_code-main/single_objective/LCH-Regret/Regret-POMO/CVRP/POMO/CVRPEnv.py

注:selectedstep()函数的输入参数。

优化版本:20250306-笔记-精读class CVRPEnv:step(self, selected)


1. 时间步数控制(time_step < 4)

if self.time_step < 4:
    self.time_step = self.time_step + 1
    self.selected_count = self.selected_count + 1
    self.at_the_depot = (selected == 0)

功能:

  • 判断当前的time_step是否小于4。如果是,则执行该时间段的逻辑。
    • self.time_step增加1,表示进入下一个时间步。
    • self.selected_count表示已经选择的节点数量,递增。
    • self.at_the_depot记录当前的节点是否为配送中心(depot)。如果selected == 0(即选择了配送中心),则设置为True

2. 时间步数3的特定操作

if self.time_step == 3:
    self.last_current_node = self.current_node.clone()
    self.last_load = self.load.clone()

功能:

  • time_step == 3时,记录当前节点和负载的状态,用于后续的操作。
    • self.last_current_nodeself.last_load保存当前状态,方便后续在步骤4时进行比较或更新。

3. 时间步数4的特定操作

if self.time_step == 4:
    self.last_current_node = self.current_node.clone()
    self.last_load = self.load.clone()
    self.visited_ninf_flag[:, :, self.problem_size+1][(~self.at_the_depot) & (self.last_current_node != 0)] = 0

功能:

  • time_step == 4时,保存当前节点和负载的状态。
    • 进一步处理visited_ninf_flag,标记已访问的节点,特别是那些不在配送中心并且之前没有访问过的节点。

4. 更新当前节点和已选择节点列表

self.current_node = selected
self.selected_node_list = torch.cat((self.selected_node_list, self.current_node[:, :, None]), dim=2)

功能:
selected(当前选择的节点)设置为self.current_node
更新self.selected_node_list,将当前选择的节点添加到已选择的节点列表中。


5. 更新负载

demand_list = self.depot_node_demand[:, None, :].expand(self.batch_size, self.pomo_size, -1)
gathering_index = selected[:, :, None]
selected_demand = demand_list.gather(dim=2, index=gathering_index).squeeze(dim=2)
self.load -= selected_demand
self.load[self.at_the_depot] = 1  # refill loaded at the depot

功能:

  • 获取当前选择节点的需求量,并更新self.load(负载):
    • demand_list:扩展depot_node_demand以便按批次、POMO索引进行操作。
    • gathering_index:用于从demand_list中获取对应于selected的需求量。
    • self.load -= selected_demand:根据选择的节点减少当前负载。
    • 如果当前位置在配送中心(at_the_depot),则将负载重置为1,表示负载已满。

6. 更新访问标记

self.visited_ninf_flag[self.BATCH_IDX, self.POMO_IDX, selected] = float('-inf')
self.visited_ninf_flag[:, :, 0][~self.at_the_depot] = 0  # depot is considered unvisited, unless you are AT the depot

功能:

  • 更新visited_ninf_flag以防止重复访问已选择的节点:
    • 对于已选择的节点,标记为负无穷(表示已访问)。
    • 对于配送中心(at_the_depot),设置为未访问(0),除非当前就在配送中心。

7. 更新负无穷掩码

self.ninf_mask = self.visited_ninf_flag.clone()
round_error_epsilon = 0.00001
demand_too_large = self.load[:, :, None] + round_error_epsilon < demand_list
_2 = torch.full((demand_too_large.shape[0], demand_too_large.shape[1], 1), False)
demand_too_large = torch.cat((demand_too_large, _2), dim=2)
self.ninf_mask[demand_too_large] = float('-inf')

功能:

  • 更新ninf_mask,用于掩盖负载不足以承载需求量的节点:
    • demand_too_large表示负载不足以承载节点的需求,如果某个节点的需求量大于当前负载,则将其标记为True
    • 使用self.ninf_mask将这些节点掩盖为负无穷(-inf)。

8. 更新步骤状态

self.step_state.selected_count = self.time_step
self.step_state.load = self.load
self.step_state.current_node = self.current_node
self.step_state.ninf_mask = self.ninf_mask

功能:

  • 更新self.step_state,将当前步骤的状态(如selected_countload、current_nodeninf_mask)同步到步骤状态中。

9. 时间步大于等于 4 的复杂操作

9.1. 动作模式分类(action0, action1, action2, action3)

action0_bool_index = ((self.mode == 0) & (selected != self.problem_size + 1))
action1_bool_index = ((self.mode == 0) & (selected == self.problem_size + 1))  # regret
action2_bool_index = self.mode == 1
action3_bool_index = self.mode == 2

功能:

  • action0_bool_index:表示选择了正常节点且当前模式为0的情况。selected != self.problem_size + 1是用来排除选择了特殊的节点(即problem_size + 1,通常是用于后悔机制或终止标记)。
  • action1_bool_index:表示当前模式为0并且选择了后悔节点(selected == self.problem_size + 1)。在这种情况下,智能体执行后悔操作(Regret)。
  • action2_bool_index:表示当前模式为1,意味着智能体正在执行某种特定的操作。
  • action3_bool_index:表示当前模式为2,意味着智能体正在执行另一种特定操作。

9.2. 动作模式的索引提取

action1_index = torch.nonzero(action1_bool_index)
action2_index = torch.nonzero(action2_bool_index)
action4_index = torch.nonzero((action3_bool_index & (self.current_node != 0)))

功能:

  • action1_index:获取所有满足action1_bool_index的索引,即执行后悔操作的智能体。
  • action2_index:获取所有满足action2_bool_index的索引,执行操作模式1的智能体。
  • action4_index:获取所有满足action3_bool_index并且当前节点不为0(即不在配送中心)的智能体。

9.3. 更新选择计数

self.selected_count = self.selected_count + 1
self.selected_count[action1_bool_index] = self.selected_count[action1_bool_index] - 2

功能:

  • self.selected_count增加1,表示当前有一个新的节点被选择。
  • 对于执行后悔操作的智能体(action1_bool_index),减少选择计数2,因为它们撤回了之前的选择。

9.4. 节点更新

self.last_is_depot = (self.last_current_node == 0)
_ = self.last_current_node[action1_index[:, 0], action1_index[:, 1]].clone()
temp_last_current_node_action2 = self.last_current_node[action2_index[:, 0], action2_index[:, 1]].clone()
self.last_current_node = self.current_node.clone()
self.current_node = selected.clone()
self.current_node[action1_index[:, 0], action1_index[:, 1]] = _.clone()

功能:

  • self.last_is_depot:检查之前的节点是否为配送中心(0)。
  • 对于执行后悔操作的智能体(action1_index),保存它们上一步的current_node,然后将self.current_node更新为selected(当前选择的节点),并恢复后悔节点的选择。
  • self.last_current_node:保存当前的current_node,以便后悔操作或其他模式下的恢复。
  • temp_last_current_node_action2:用于存储执行操作模式2的智能体的节点,以便后续更新。

9.5. 更新已选择节点列表

self.selected_node_list = torch.cat((self.selected_node_list, selected[:, :, None]), dim=2)

功能:

  • 将当前选择的节点(selected)添加到已选择节点列表self.selected_node_list中。

9.6. 更新负载(load)

self.at_the_depot = (selected == 0)
demand_list = self.depot_node_demand[:, None, :].expand(self.batch_size, self.pomo_size, -1)
_3 = torch.full((demand_list.shape[0], demand_list.shape[1], 1), 0)
demand_list = torch.cat((demand_list, _3), dim=2)
gathering_index = selected[:, :, None]
selected_demand = demand_list.gather(dim=2, index=gathering_index).squeeze(dim=2)
_1 = self.last_load[action1_index[:, 0], action1_index[:, 1]].clone()
self.last_load = self.load.clone()
self.load -= selected_demand
self.load[action1_index[:, 0], action1_index[:, 1]] = _1.clone()
self.load[self.at_the_depot] = 1  # refill loaded at the depot

功能:

  • self.at_the_depot:判断是否选择了配送中心(selected == 0)。
  • demand_list:扩展depot_node_demand以便按批次、POMO索引进行操作。
  • 对于后悔操作的智能体(action1_index),负载被恢复为之前的值。
  • 通过gather方法获取当前选择节点的需求量,并更新self.load
  • 如果当前节点是配送中心,负载重置为1,表示重新加载。

9. 7. 更新访问标记(visited_ninf_flag)

self.visited_ninf_flag[:, :, self.problem_size + 1][self.last_is_depot] = 0
self.visited_ninf_flag[self.BATCH_IDX, self.POMO_IDX, selected] = float('-inf')
self.visited_ninf_flag[action2_index[:, 0], action2_index[:, 1], temp_last_current_node_action2] = float(0)
self.visited_ninf_flag[action4_index[:, 0], action4_index[:, 1], self.problem_size + 1] = float(0)
self.visited_ninf_flag[:, :, self.problem_size + 1][self.at_the_depot] = float('-inf')
self.visited_ninf_flag[:, :, 0][~self.at_the_depot] = 0

功能:

  • self.visited_ninf_flag:更新节点的访问状态:
    • 对于执行后悔操作的智能体,恢复其节点访问状态。
    • 对于操作模式2的智能体,将其上一步的节点标记为未访问。
    • 对于操作模式3的智能体,确保结束时访问标记正确。
    • 如果节点为配送中心,设置为已访问。

9.8. 更新负无穷掩码(ninf_mask)

self.ninf_mask = self.visited_ninf_flag.clone()
round_error_epsilon = 0.00001
demand_too_large = self.load[:, :, None] + round_error_epsilon < demand_list
self.ninf_mask[demand_too_large] = float('-inf')

功能:

  • 更新ninf_mask,根据当前负载和节点需求量来掩盖不满足条件的节点(负载不足以承载需求的节点被标记为-inf)。

9.9. 更新完成状态(finished)

newly_finished = (self.visited_ninf_flag == float('-inf'))[:, :, :self.problem_size + 1].all(dim=2)
self.finished = self.finished + newly_finished

功能:

  • 检查是否所有节点都已被访问,对于完成的智能体,更新self.finished标志。

9.10. 更新模式(mode)

self.mode[action1_bool_index] = 1
self.mode[action2_bool_index] = 2
self.mode[action3_bool_index] = 0
self.mode[self.finished] = 4

功能:

  • 根据不同的动作模式更新self.mode,包括后悔操作、执行操作1、执行操作2和完成操作。

9.11. 更新完成后的掩码调整

self.ninf_mask[:, :, 0][self.finished] = 0
self.ninf_mask[:, :, self.problem_size + 1][self.finished] = float('-inf')

功能:

  • 如果智能体完成任务,调整掩码,确保完成的节点不会被选择。

10. 完成状态与奖励

done = self.finished.all()
if done:
    reward = -self._get_travel_distance()  # note the minus sign!
else:
    reward = None

功能:

  • 检查所有智能体是否已完成任务(即访问所有节点)。
  • 如果所有任务完成,则计算奖励(这里通过负的旅行距离来表示,越短的路径越好)。
  • 如果未完成任务,则奖励为None

11. 返回值

return self.step_state, reward, done

功能:

  • 返回当前步骤状态、奖励和完成标志。

附录

函数代码

    def step(self, selected):
        # selected.shape: (batch, pomo)

        #时间步数控制
        if self.time_step<4:

            # 控制时间步的递增
            self.time_step=self.time_step+1
            self.selectex_count = self.selected_count+1

            #判断是否在配送中心
            self.at_the_depot = (selected == 0)

            #特定时间步的操作
            if self.time_step==3:
                self.last_current_node = self.current_node.clone()
                self.last_load = self.load.clone()
            if self.time_step == 4:
                self.last_current_node = self.current_node.clone()
                self.last_load = self.load.clone()
                self.visited_ninf_flag[:, :, self.problem_size+1][(~self.at_the_depot)&(self.last_current_node!=0)] = 0
            
            #更新当前节点和已选择节点列表
            self.current_node = selected
            self.selected_node_list = torch.cat((self.selected_node_list, self.current_node[:, :, None]), dim=2)

            #更新需求和负载
            demand_list = self.depot_node_demand[:, None, :].expand(self.batch_size, self.pomo_size, -1)
            gathering_index = selected[:, :, None]
            selected_demand = demand_list.gather(dim=2, index=gathering_index).squeeze(dim=2)
            self.load -= selected_demand
            self.load[self.at_the_depot] = 1  # refill loaded at the depot

            #更新访问标记(防止重复选择已访问的节点)
            self.visited_ninf_flag[self.BATCH_IDX, self.POMO_IDX, selected] = float('-inf')
            self.visited_ninf_flag[:, :, 0][~self.at_the_depot] = 0  # depot is considered unvisited, unless you are AT the depot

            #更新负无穷掩码(屏蔽需求量超过当前负载的节点)
            self.ninf_mask = self.visited_ninf_flag.clone()
            round_error_epsilon = 0.00001
            demand_too_large = self.load[:, :, None] + round_error_epsilon < demand_list
            _2=torch.full((demand_too_large.shape[0],demand_too_large.shape[1],1),False)
            demand_too_large = torch.cat((demand_too_large, _2), dim=2)
            self.ninf_mask[demand_too_large] = float('-inf')

            #更新步骤状态,将更新后的状态同步到 self.step_state
            self.step_state.selected_count = self.time_step
            self.step_state.load = self.load
            self.step_state.current_node = self.current_node
            self.step_state.ninf_mask = self.ninf_mask


        #时间步大于等于 4 的复杂操作
        else:
            #动作模式分类
            action0_bool_index = ((self.mode == 0) & (selected != self.problem_size + 1))
            action1_bool_index = ((self.mode == 0) & (selected == self.problem_size + 1))  # regret
            action2_bool_index = self.mode == 1
            action3_bool_index = self.mode == 2
            
            action1_index = torch.nonzero(action1_bool_index)
            action2_index = torch.nonzero(action2_bool_index)

            action4_index = torch.nonzero((action3_bool_index & (self.current_node != 0)))

            #更新选择计数
            self.selected_count = self.selected_count+1
            #后悔模式
            self.selected_count[action1_bool_index] = self.selected_count[action1_bool_index] - 2

            #节点更新
            self.last_is_depot = (self.last_current_node == 0)

            _ = self.last_current_node[action1_index[:, 0], action1_index[:, 1]].clone()
            temp_last_current_node_action2 = self.last_current_node[action2_index[:, 0], action2_index[:, 1]].clone()
            self.last_current_node = self.current_node.clone()
            self.current_node = selected.clone()
            self.current_node[action1_index[:, 0], action1_index[:, 1]] = _.clone()

            #更新已选择节点列表
            self.selected_node_list = torch.cat((self.selected_node_list, selected[:, :, None]), dim=2)

            #更新负载
            self.at_the_depot = (selected == 0)
            demand_list = self.depot_node_demand[:, None, :].expand(self.batch_size, self.pomo_size, -1)
            # shape: (batch, pomo, problem+1)
            _3 = torch.full((demand_list.shape[0], demand_list.shape[1], 1), 0)
            #扩展需求列表 demand_list 
            demand_list = torch.cat((demand_list, _3), dim=2)
            gathering_index = selected[:, :, None]
            # shape: (batch, pomo, 1)
            selected_demand = demand_list.gather(dim=2, index=gathering_index).squeeze(dim=2)
            _1 = self.last_load[action1_index[:, 0], action1_index[:, 1]].clone()
            self.last_load= self.load.clone()
            # shape: (batch, pomo)
            self.load -= selected_demand
            self.load[action1_index[:, 0], action1_index[:, 1]] = _1.clone()
            self.load[self.at_the_depot] = 1  # refill loaded at the depot

            #更新访问标记
            self.visited_ninf_flag[:, :, self.problem_size+1][self.last_is_depot] = 0
            self.visited_ninf_flag[self.BATCH_IDX, self.POMO_IDX, selected] = float('-inf')
            self.visited_ninf_flag[action2_index[:, 0], action2_index[:, 1], temp_last_current_node_action2] = float(0)
            self.visited_ninf_flag[action4_index[:, 0], action4_index[:, 1], self.problem_size + 1] = float(0)
            self.visited_ninf_flag[:, :, self.problem_size+1][self.at_the_depot] = float('-inf')
            self.visited_ninf_flag[:, :, 0][~self.at_the_depot] = 0


            # 更新负无穷掩码
            self.ninf_mask = self.visited_ninf_flag.clone()
            round_error_epsilon = 0.00001
            demand_too_large = self.load[:, :, None] + round_error_epsilon < demand_list
            # shape: (batch, pomo, problem+1)
            self.ninf_mask[demand_too_large] = float('-inf')

            # 更新完成状态
            # 检查哪些智能体已经完成所有节点的访问。
            # 更新完成标记 self.finished。
            newly_finished = (self.visited_ninf_flag == float('-inf'))[:,:,:self.problem_size+1].all(dim=2)
            # shape: (batch, pomo)
            self.finished = self.finished + newly_finished
            # shape: (batch, pomo)

            #更新模式
            self.mode[action1_bool_index] = 1
            self.mode[action2_bool_index] = 2
            self.mode[action3_bool_index] = 0
            self.mode[self.finished] = 4

            # 更新完成后的掩码调整
            self.ninf_mask[:, :, 0][self.finished] = 0
            self.ninf_mask[:, :, self.problem_size+1][self.finished] = float('-inf')

            # 更新步骤状态
            self.step_state.selected_count = self.time_step
            self.step_state.load = self.load
            self.step_state.current_node = self.current_node
            self.step_state.ninf_mask = self.ninf_mask



        # returning values
        done = self.finished.all()
        if done:
            reward = -self._get_travel_distance()  # note the minus sign!
        else:
            reward = None

        return self.step_state, reward, done

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值