此脚本定义了 GR1DexDataset3D 的数据集类,用于处理3D数据,它继承自 BaseDataset 类,并整合了多个组件来实现数据加载、处理和规范化
此脚本主要用于加载、处理和规范化3D点云数据:
- 通过重放缓冲区(ReplayBuffer)管理数据,并支持训练集和验证集的划分
- 样本采样与规范化:利用 SequenceSampler 进行数据序列的采样,并使用 LinearNormalizer 对动作数据进行规范化
- 最终的数据被转换为 PyTorch 张量,确保可以直接用于深度学习模型
详细分析如下:
1 库引用
from typing import Dict
import torch
import numpy as np
import copy
from diffusion_policy_3d.common.pytorch_util import dict_apply
from diffusion_policy_3d.common.replay_buffer import ReplayBuffer
from diffusion_policy_3d.common.sampler import (SequenceSampler, get_val_mask, downsample_mask)
from diffusion_policy_3d.model.common.normalizer import LinearNormalizer, SingleFieldLinearNormalizer, StringNormalizer
from diffusion_policy_3d.dataset.base_dataset import BaseDataset
import diffusion_policy_3d.model.vision_3d.point_process as point_process
from termcolor import cprint
2 类定义 GR1DexDataset3D 与 __init__ 构造函数
该类继承自 BaseDataset,在其基础上增加了特定的数据预处理、采样和规范化功能。该数据集类的核心目的是管理和处理 3D 点云数据
class GR1DexDataset3D(BaseDataset):
def __init__(self,
zarr_path,
horizon=1, # 序列长度
pad_before=0, # 序列前的填充
pad_after=0, # 序列后的填充
seed=42, # 随机种子
val_ratio=0.0, # 验证集比例
max_train_episodes=None, # 最大训练集样本数
task_name=None, # 任务名称
num_points=4096, # 每个点云的采样点数
):
super().__init__() # 调用父类的构造函数
cprint(f'Loading GR1DexDataset from {zarr_path}', 'green') # 输出加载数据的提示信息
self.task_name = task_name
self.num_points = num_points # 设置点云采样点数
# 定义缓冲区的键,包含状态、动作和点云数据
buffer_keys = ['state', 'action']
buffer_keys.append('point_cloud') # 添加点云数据
__init__ 构造函数负责初始化数据集对象,并完成以下几项任务:
(1)加载数据
# 从指定路径加载数据到重放缓冲区
self.replay_buffer = ReplayBuffer.copy_from_path(zarr_path, keys=buffer_keys)
通过调用 ReplayBuffer.copy_from_path() 方法,从给定的路径 zarr_path 加载数据
ReplayBuffer 是一个缓存机制,用于存储和管理过程中的经验数据
buffer_keys 定义了要加载的数据字段(例如:状态 state、动作 action 和点云数据 point_cloud)
(2)训练与验证集掩码(mask)生成
# 生成验证集掩码(val_mask)用于划分验证集
val_mask = get_val_mask(
n_episodes=self.replay_buffer.n_episodes,
val_ratio=val_ratio, # 按比例选择验证集
seed=seed) # 设置随机种子
train_mask = ~val_mask # 训练集掩码是验证集掩码的反集
get_val_mask 用于生成验证集掩码,根据给定的验证比例 (val_ratio) 随机选择一部分样本作为验证集
train_mask 通过反转验证掩码来创建训练集掩码,确保训练和验证数据互不重叠
(3)样本采样器
# 如果指定了最大训练集样本数,则对训练集掩码进行下采样
train_mask = downsample_mask(
mask=train_mask,
max_n=max_train_episodes, # 最大训练样本数
seed=seed)
# 创建一个序列采样器,按给定的序列长度、填充等设置采样参数
self.sampler = SequenceSampler(
replay_buffer=self.replay_buffer,
sequence_length=horizon, # 序列长度
pad_before=pad_before, # 序列前填充
pad_after=pad_after, # 序列后填充
episode_mask=train_mask) # 训练集掩码
SequenceSampler 是一个用于从重放缓冲区中提取数据序列的采样器。它会按照给定的 horizon(序列长度)来采样数据,并在序列的前后添加一定的填充(pad_before 和 pad_after),以保证输入序列的一致性
训练集掩码 (train_mask) 确保只从训练数据中进行采样
(4)其他参数
# 保存训练集掩码、序列长度和填充参数
self.train_mask = train_mask
self.horizon = horizon
self.pad_before = pad_before
self.pad_after = pad_after
self.horizon、self.pad_before 和 self.pad_after 用于控制数据序列的长度和填充
3 get_validation_dataset 方法
def get_validation_dataset(self):
# 返回一个新的验证数据集对象,采用验证集掩码进行采样
val_set = copy.copy(self) # 创建数据集副本
val_set.sampler = SequenceSampler(
replay_buffer=self.replay_buffer,
sequence_length=self.horizon,
pad_before=self.pad_before,
pad_after=self.pad_after,
episode_mask=~self.train_mask) # 使用验证集掩码
val_set.train_mask = ~self.train_mask # 设置训练集掩码为验证集掩码
return val_set
该方法生成一个新的验证集对象,主要通过复制当前数据集对象,并修改其采样器,使其只使用验证集掩码 (~self.train_mask) 来采样数据
返回的新对象就是训练集和验证集之间的切换
4 get_normalizer 方法
该方法用于创建数据的规范化器(Normalizer)
(1)动作规范化器
def get_normalizer(self, mode='limits', **kwargs):
# 获取数据的规范化器
data = {'action': self.replay_buffer['action']} # 获取动作数据
normalizer = LinearNormalizer() # 创建一个线性规范化器
normalizer.fit(data=data, last_n_dims=1, mode=mode, **kwargs) # 拟合规范化器
LinearNormalizer 用于规范化动作数据。它将使用历史数据拟合规范化器,从而将动作数据转换到一个适当的尺度
(2)点云和代理位置规范化器
# 对点云和代理位置不进行任何规范化,使用恒等变换
normalizer['point_cloud'] = SingleFieldLinearNormalizer.create_identity()
normalizer['agent_pos'] = SingleFieldLinearNormalizer.create_identity()
return normalizer
对于 point_cloud 和 agent_pos,它们的规范化器被设置为“恒等”变换,即数据不进行任何变化
5 __len__ 方法
def __len__(self) -> int:
# 返回数据集的长度,即采样器中样本的数量
return len(self.sampler)
该方法返回数据集的大小,即从采样器中获取样本的数量
6 _sample_to_data 方法
def _sample_to_data(self, sample):
# 将采样到的数据转换为所需的数据格式
agent_pos = sample['state'][:,].astype(np.float32) # 提取代理位置
point_cloud = sample['point_cloud'][:,].astype(np.float32) # 提取点云数据
# 对点云数据进行均匀采样,确保点云的点数为指定数量
point_cloud = point_process.uniform_sampling_numpy(point_cloud, self.num_points)
# 将数据组织成一个字典,包含观察(代理位置和点云)和动作
data = {
'obs': {
'agent_pos': agent_pos,
'point_cloud': point_cloud,
},
'action': sample['action'].astype(np.float32)} # 转换动作为float32
return data
该方法将采样到的原始数据(如 state、point_cloud 和 action)处理成所需的数据格式:
- agent_pos 由状态 (state) 提取并转换为 float32 类型
- point_cloud 同样被转换为 float32,并且通过 uniform_sampling_numpy 函数进行均匀采样,使其大小固定为 self.num_points
- 返回的数据字典包含两个主要部分:观察(obs)和动作(action)
6 __getitem__ 方法
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
# 获取指定索引的样本数据,并转换为PyTorch张量
sample = self.sampler.sample_sequence(idx) # 从采样器中获取数据序列
data = self._sample_to_data(sample) # 将采样数据转换为所需格式
# 定义一个转换函数,将numpy数组转换为PyTorch张量
to_torch_function = lambda x: torch.from_numpy(x) if x.__class__.__name__ == 'ndarray' else x
# 将数据字典中的所有数据项转换为PyTorch张量
torch_data = dict_apply(data, to_torch_function)
return torch_data # 返回转换后的数据
该方法通过给定的索引 (idx) 从采样器中获取一个样本序列,并将其转换为 PyTorch 张量
dict_apply 用于将数据字典中的每个字段转换为 PyTorch 张量,保证所有数据可以直接用于深度学习模型训练
7 数据处理的关键部分总结
点云数据的均匀采样:point_cloud 数据通过 uniform_sampling_numpy 进行均匀采样,确保点云数据的尺寸一致,便于后续处理和训练
序列采样与填充:使用 SequenceSampler 对数据进行序列化采样,同时根据需要进行填充。这对于序列数据(如状态、动作序列)处理非常重要
数据规范化:数据在进入模型前会进行规范化处理,确保输入数据的尺度一致,避免不同特征尺度差异导致模型训练不稳定
(1)__init__ 构造函数
- 加载数据并初始化参数
- 通过 ReplayBuffer 加载数据,将数据划分为训练集和验证集
- 配置一个序列采样器来从数据中提取带有填充的序列
(2)get_validation_dataset
- 返回一个新的验证数据集对象,并设置其采样器使用验证集掩码
(3)get_normalizer
- 创建和返回一个数据规范化器,用于规范化动作数据、点云数据和代理位置数据。
(4)__len__
- 返回数据集的大小,即样本的数量
(5)_sample_to_data
- 将采样到的数据(状态、点云和动作)转化为标准化的格式,准备进行后续处理
(6)__getitem__
- 通过索引获取一个数据样本,并将其转换为 PyTorch 张量,方便模型训练