iDP3复现代码模型训练全流程(五)——gr1_dex_dataset_3d.py

此脚本定义了 GR1DexDataset3D 的数据集类,用于处理3D数据,它继承自 BaseDataset 类,并整合了多个组件来实现数据加载、处理和规范化

此脚本主要用于加载、处理和规范化3D点云数据:

  • 通过重放缓冲区(ReplayBuffer)管理数据,并支持训练集和验证集的划分
  • 样本采样与规范化:利用 SequenceSampler 进行数据序列的采样,并使用 LinearNormalizer 对动作数据进行规范化
  • 最终的数据被转换为 PyTorch 张量,确保可以直接用于深度学习模型

详细分析如下:

1 库引用

from typing import Dict
import torch
import numpy as np
import copy
from diffusion_policy_3d.common.pytorch_util import dict_apply
from diffusion_policy_3d.common.replay_buffer import ReplayBuffer
from diffusion_policy_3d.common.sampler import (SequenceSampler, get_val_mask, downsample_mask)
from diffusion_policy_3d.model.common.normalizer import LinearNormalizer, SingleFieldLinearNormalizer, StringNormalizer
from diffusion_policy_3d.dataset.base_dataset import BaseDataset
import diffusion_policy_3d.model.vision_3d.point_process as point_process
from termcolor import cprint

2 类定义 GR1DexDataset3D 与 __init__  构造函数

该类继承自 BaseDataset,在其基础上增加了特定的数据预处理、采样和规范化功能。该数据集类的核心目的是管理和处理 3D 点云数据

class GR1DexDataset3D(BaseDataset):
    def __init__(self,
                 zarr_path, 
                 horizon=1,  # 序列长度
                 pad_before=0,  # 序列前的填充
                 pad_after=0,  # 序列后的填充
                 seed=42,  # 随机种子
                 val_ratio=0.0,  # 验证集比例
                 max_train_episodes=None,  # 最大训练集样本数
                 task_name=None,  # 任务名称
                 num_points=4096,  # 每个点云的采样点数
                 ):
        super().__init__()  # 调用父类的构造函数
        cprint(f'Loading GR1DexDataset from {zarr_path}', 'green')  # 输出加载数据的提示信息
        self.task_name = task_name
        self.num_points = num_points  # 设置点云采样点数
        
        # 定义缓冲区的键,包含状态、动作和点云数据
        buffer_keys = ['state', 'action']
        buffer_keys.append('point_cloud')  # 添加点云数据

__init__ 构造函数负责初始化数据集对象,并完成以下几项任务:

(1)加载数据

        # 从指定路径加载数据到重放缓冲区
        self.replay_buffer = ReplayBuffer.copy_from_path(zarr_path, keys=buffer_keys)

通过调用 ReplayBuffer.copy_from_path() 方法,从给定的路径 zarr_path 加载数据

ReplayBuffer 是一个缓存机制,用于存储和管理过程中的经验数据

buffer_keys 定义了要加载的数据字段(例如:状态 state、动作 action 和点云数据 point_cloud)

(2)训练与验证集掩码(mask)生成

        # 生成验证集掩码(val_mask)用于划分验证集
        val_mask = get_val_mask(
            n_episodes=self.replay_buffer.n_episodes, 
            val_ratio=val_ratio,  # 按比例选择验证集
            seed=seed)  # 设置随机种子
        train_mask = ~val_mask  # 训练集掩码是验证集掩码的反集

get_val_mask 用于生成验证集掩码,根据给定的验证比例 (val_ratio) 随机选择一部分样本作为验证集

train_mask 通过反转验证掩码来创建训练集掩码,确保训练和验证数据互不重叠

(3)样本采样器

        # 如果指定了最大训练集样本数,则对训练集掩码进行下采样
        train_mask = downsample_mask(
            mask=train_mask, 
            max_n=max_train_episodes,  # 最大训练样本数
            seed=seed)
        
        # 创建一个序列采样器,按给定的序列长度、填充等设置采样参数
        self.sampler = SequenceSampler(
            replay_buffer=self.replay_buffer, 
            sequence_length=horizon,  # 序列长度
            pad_before=pad_before,  # 序列前填充
            pad_after=pad_after,  # 序列后填充
            episode_mask=train_mask)  # 训练集掩码

SequenceSampler 是一个用于从重放缓冲区中提取数据序列的采样器。它会按照给定的 horizon(序列长度)来采样数据,并在序列的前后添加一定的填充(pad_before 和 pad_after),以保证输入序列的一致性

训练集掩码 (train_mask) 确保只从训练数据中进行采样

(4)其他参数

        # 保存训练集掩码、序列长度和填充参数
        self.train_mask = train_mask
        self.horizon = horizon
        self.pad_before = pad_before
        self.pad_after = pad_after

self.horizon、self.pad_before 和 self.pad_after 用于控制数据序列的长度和填充

3 get_validation_dataset 方法

    def get_validation_dataset(self):
        # 返回一个新的验证数据集对象,采用验证集掩码进行采样
        val_set = copy.copy(self)  # 创建数据集副本
        val_set.sampler = SequenceSampler(
            replay_buffer=self.replay_buffer, 
            sequence_length=self.horizon,
            pad_before=self.pad_before, 
            pad_after=self.pad_after,
            episode_mask=~self.train_mask)  # 使用验证集掩码
        val_set.train_mask = ~self.train_mask  # 设置训练集掩码为验证集掩码
        return val_set

该方法生成一个新的验证集对象,主要通过复制当前数据集对象,并修改其采样器,使其只使用验证集掩码 (~self.train_mask) 来采样数据

返回的新对象就是训练集和验证集之间的切换

4 get_normalizer 方法

该方法用于创建数据的规范化器(Normalizer)

(1)动作规范化器

    def get_normalizer(self, mode='limits', **kwargs):
        # 获取数据的规范化器
        data = {'action': self.replay_buffer['action']}  # 获取动作数据
        normalizer = LinearNormalizer()  # 创建一个线性规范化器
        normalizer.fit(data=data, last_n_dims=1, mode=mode, **kwargs)  # 拟合规范化器

LinearNormalizer 用于规范化动作数据。它将使用历史数据拟合规范化器,从而将动作数据转换到一个适当的尺度

(2)点云和代理位置规范化器

        # 对点云和代理位置不进行任何规范化,使用恒等变换
        normalizer['point_cloud'] = SingleFieldLinearNormalizer.create_identity()
        normalizer['agent_pos'] = SingleFieldLinearNormalizer.create_identity()
        
        return normalizer

对于 point_cloud 和 agent_pos,它们的规范化器被设置为“恒等”变换,即数据不进行任何变化

5 __len__ 方法

    def __len__(self) -> int:
        # 返回数据集的长度,即采样器中样本的数量
        return len(self.sampler)

该方法返回数据集的大小,即从采样器中获取样本的数量

6 _sample_to_data 方法

    def _sample_to_data(self, sample):
        # 将采样到的数据转换为所需的数据格式
        agent_pos = sample['state'][:,].astype(np.float32)  # 提取代理位置
        point_cloud = sample['point_cloud'][:,].astype(np.float32)  # 提取点云数据
        # 对点云数据进行均匀采样,确保点云的点数为指定数量
        point_cloud = point_process.uniform_sampling_numpy(point_cloud, self.num_points)
        # 将数据组织成一个字典,包含观察(代理位置和点云)和动作
        data = {
            'obs': {
                'agent_pos': agent_pos,
                'point_cloud': point_cloud,
                },
            'action': sample['action'].astype(np.float32)}  # 转换动作为float32
           
        return data

该方法将采样到的原始数据(如 state、point_cloud 和 action)处理成所需的数据格式:

  • agent_pos 由状态 (state) 提取并转换为 float32 类型
  • point_cloud 同样被转换为 float32,并且通过 uniform_sampling_numpy 函数进行均匀采样,使其大小固定为 self.num_points
  • 返回的数据字典包含两个主要部分:观察(obs)和动作(action)

6 __getitem__ 方法

    def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
        # 获取指定索引的样本数据,并转换为PyTorch张量
        sample = self.sampler.sample_sequence(idx)  # 从采样器中获取数据序列
        data = self._sample_to_data(sample)  # 将采样数据转换为所需格式
        # 定义一个转换函数,将numpy数组转换为PyTorch张量
        to_torch_function = lambda x: torch.from_numpy(x) if x.__class__.__name__ == 'ndarray' else x
        # 将数据字典中的所有数据项转换为PyTorch张量
        torch_data = dict_apply(data, to_torch_function)
        return torch_data  # 返回转换后的数据

该方法通过给定的索引 (idx) 从采样器中获取一个样本序列,并将其转换为 PyTorch 张量

dict_apply 用于将数据字典中的每个字段转换为 PyTorch 张量,保证所有数据可以直接用于深度学习模型训练

7 数据处理的关键部分总结

点云数据的均匀采样:point_cloud 数据通过 uniform_sampling_numpy 进行均匀采样,确保点云数据的尺寸一致,便于后续处理和训练

序列采样与填充:使用 SequenceSampler 对数据进行序列化采样,同时根据需要进行填充。这对于序列数据(如状态、动作序列)处理非常重要

数据规范化:数据在进入模型前会进行规范化处理,确保输入数据的尺度一致,避免不同特征尺度差异导致模型训练不稳定

(1)__init__ 构造函数

  • 加载数据并初始化参数
  • 通过 ReplayBuffer 加载数据,将数据划分为训练集和验证集
  • 配置一个序列采样器来从数据中提取带有填充的序列

(2)get_validation_dataset

  • 返回一个新的验证数据集对象,并设置其采样器使用验证集掩码

(3)get_normalizer

  • 创建和返回一个数据规范化器,用于规范化动作数据、点云数据和代理位置数据。

(4)__len__

  • 返回数据集的大小,即样本的数量

(5)_sample_to_data

  • 将采样到的数据(状态、点云和动作)转化为标准化的格式,准备进行后续处理

(6)__getitem__

  • 通过索引获取一个数据样本,并将其转换为 PyTorch 张量,方便模型训练
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值