import sys
sys.path.append('/data/coding') # 添加包所在的根目录到 Python 路径中,防止找不到 tensor2robot 和 robotics_transformer
import tensorflow as tf
import numpy as np
import tf_agents
from tf_agents.networks import sequential
from keras.layers import Dense
from tf_agents.agents.dqn import dqn_agent
from tf_agents.utils import common
from typing import Type
from tf_agents.networks import network
from tensor2robot.utils import tensorspec_utils
from tf_agents.specs import tensor_spec
from robotics_transformer import sequence_agent
from tf_agents.trajectories import time_step as ts
from tensorflow_datasets.core.data_sources import array_record
import tensorflow_datasets as tfds
from robotics_transformer import transformer_network
from robotics_transformer import transformer_network_test_set_up
import collections
from tf_agents.replay_buffers import reverb_replay_buffer
from tf_agents.replay_buffers import reverb_utils
import reverb
"""
加载数据集
"""
datasets_path = "/data/coding/moxing/colour4/19" # 预处理子数据集
#datasets_can_use_path = "/data/coding/moxing_use/colour4/19" # 后处理子数据集
datasets_val_path = "/data/coding/moxing_val/colour4/18" # 验证集保存路径
# 验证集数据标识列表
val_data_list = [-1, b'abdomen_109', b'abdomen_38', b'abdomen_74', b'abdomen_27', b'abdomen_14',
b'abdomen_122', b'abdomen_21', b'abdomen_82', b'abdomen_93', b'abdomen_54',
b'abdomen_80', b'abdomen_79', b'abdomen_105', b'abdomen_37', b'abdomen_7',
b'abdomen_32', b'abdomen_3', b'abdomen_46', b'abdomen_95', b'abdomen_30',
b'abdomen_55', b'abdomen_66', b'abdomen_10',
b'neck_14', b'neck_16', b'neck_80', b'neck_29', b'neck_82',
b'neck_47', b'neck_85', b'neck_43', b'neck_31', b'neck_66',
b'neck_57', b'neck_55', b'neck_9',b'neck_27',b'neck_46',]
load_dataset = tf.data.Dataset.load(datasets_path) # 加载数据集
load_dataset = load_dataset.batch(1, drop_remainder=True) # 按1步为批次
# 创建迭代器
load_iterator = iter(load_dataset)
# 初始化数据存储列表(移除了image_aligned_depth相关部分)
data_list = [[],[],[],[],[],[],[],[],[],[],[],[],[],[],[],[],[],[],[]]
# 数据集样本总数
num_samples_to_display = load_dataset.reduce(0, lambda x, _: x + 1).numpy()
index_1 = 0 # 第一阶段计数
for _ in range(num_samples_to_display):
sample_data = next(load_iterator)
# 使用 tf.squeeze() 消除第一个维度(移除了image_aligned_depth处理)
reshaped_image = tf.squeeze(sample_data['image'])
reshaped_natural_language_embedding = tf.squeeze(sample_data['natural_language_embedding'])
reshaped_natural_language_instruction = tf.squeeze(sample_data['natural_language_instruction'])
reshaped_base_displacement_vector = tf.squeeze(sample_data['base_displacement_vector'])
reshaped_base_displacement_vertical_rotation = tf.squeeze(sample_data['base_displacement_vertical_rotation'])
reshaped_gripper_closedness_action = tf.squeeze(sample_data['gripper_closedness_action'])
reshaped_rotation_delta = tf.squeeze(sample_data['rotation_delta'])
reshaped_terminate_episode = tf.squeeze(sample_data['terminate_episode'])
reshaped_world_vector = tf.squeeze(sample_data['world_vector'])
reshaped_discounted_return = tf.squeeze(sample_data['discounted_return'])
reshaped_return = tf.squeeze(sample_data['return'])
reshaped_reward = tf.squeeze(sample_data['reward'])
reshaped_step_id = tf.squeeze(sample_data['step_id'])
reshaped_element_index = tf.squeeze(sample_data['element_index'])
reshaped_num_steps = tf.squeeze(sample_data['num_steps'])
reshaped_is_first = tf.squeeze(sample_data['is_first'])
reshaped_is_last = tf.squeeze(sample_data['is_last'])
reshaped_step_type = tf.squeeze(sample_data['step_type'])
reshaped_next_step_type = tf.squeeze(sample_data['next_step_type'])
# 检查是否为验证集数据(跳过验证集)
if any(value == reshaped_element_index.numpy() for value in val_data_list):
print(f"{reshaped_element_index.numpy()} 与列表中的某个元素相等,需要剔除")
continue
# 将数据添加到列表(索引对应调整)
data_list[0].append(reshaped_image)
data_list[1].append(reshaped_natural_language_embedding)
data_list[2].append(reshaped_natural_language_instruction)
data_list[3].append(reshaped_base_displacement_vector)
data_list[4].append(reshaped_base_displacement_vertical_rotation)
data_list[5].append(reshaped_gripper_closedness_action)
data_list[6].append(reshaped_rotation_delta)
data_list[7].append(reshaped_terminate_episode)
data_list[8].append(reshaped_world_vector)
data_list[9].append(reshaped_discounted_return)
data_list[10].append(reshaped_return)
data_list[11].append(reshaped_reward)
data_list[12].append(reshaped_step_id)
data_list[13].append(reshaped_element_index)
data_list[14].append(reshaped_num_steps)
data_list[15].append(reshaped_is_first)
data_list[16].append(reshaped_is_last)
data_list[17].append(reshaped_step_type)
data_list[18].append(reshaped_next_step_type)
index_1 +=1
print(index_1)
print("第一阶段结束!")
print("------------------------------------------------------------------------------")
# 初始化分组列表(调整为19个字段)
dataset_list = [[],[],[],[],[],[],[],[],[],[],[],[],[],[],[],[],[],[],[]]
index_2 = 0 # 第二阶段计数
while index_2 < len(data_list[14]):
episode_num_steps = data_list[14][index_2] # 事件总步数
episode = [] # 存储单个事件的所有数据
# 提取事件数据(循环次数调整为19次)
i = 0
while i < 19:
start_index = index_2
end_index = index_2 + episode_num_steps
new_list = data_list[i][start_index:end_index]
episode.append(new_list)
i +=1
# 按6步一组分割(调整字段数)
j = 0
while j < 19:
grouped_lists = [episode[j][i:i+6] for i in range(0, len(episode[j]), 6)]
grouped_lists.pop() # 删除最后不足6步的部分
start_index = len(episode[j]) - 6
new_list = episode[j][start_index:]
grouped_lists.append(new_list) # 添加最后6步
# 填充分组数据
k=0
while k < len(grouped_lists):
dataset_list[j].append(grouped_lists[k])
k+=1
j +=1
index_2 += episode_num_steps
print("第二阶段结束!")
print("------------------------------------------------------------------------------")
# 合并序列数据(调整为19个字段)
secquence_list = [[],[],[],[],[],[],[],[],[],[],[],[],[],[],[],[],[],[],[]]
index_3 = 0 # 第三阶段计数
while index_3 < len(dataset_list[14]):
i = 0
while i < 19:
secquence_list[i].append(tf.stack(dataset_list[i][index_3], axis=0))
i +=1
index_3 +=1
print("第三阶段结束!")
print("------------------------------------------------------------------------------")
# 添加批次维度并构建最终数据集
batch_list = []
index_4 = 0
while index_4 < 19:
batch_list.append(tf.stack(secquence_list[index_4], axis=0))
index_4 +=1
# 构建数据字典(移除了image_aligned_depth)
data = {
"image" : batch_list[0],
"natural_language_embedding" : batch_list[1],
"natural_language_instruction" : batch_list[2],
"base_displacement_vector" : batch_list[3],
"base_displacement_vertical_rotation" : batch_list[4],
"gripper_closedness_action" : batch_list[5],
"rotation_delta" : batch_list[6],
"terminate_episode" : batch_list[7],
"world_vector" : batch_list[8],
"discounted_return" : batch_list[9],
"return" : batch_list[10],
"reward" : batch_list[11],
"step_id" : batch_list[12],
"element_index" : batch_list[13],
"num_steps" : batch_list[14],
"is_first" : batch_list[15],
"is_last" : batch_list[16],
"step_type" : batch_list[17],
"next_step_type" : batch_list[18],
}
# 创建并保存数据集
dataset = tf.data.Dataset.from_tensor_slices(data)
#dataset.save(datasets_can_use_path) # 训练集保存路径
dataset.save(datasets_val_path) # 验证集路径(按需启用)
最新发布