理解tf.squeeze()

本文详细介绍了TensorFlow中squeeze函数的使用方法及注意事项。通过具体示例解释了如何利用此函数来移除张量中维度为1的轴,包括默认移除所有单位维度的情况以及指定特定维度进行移除的操作。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

squeeze(
    input,
    axis=None,
    name=None,
    squeeze_dims=None
)

该函数返回一个张量,这个张量是将原始input中所有维度为1的那些维都删掉的结果
axis可以用来指定要删掉的为1的维度,此处要注意指定的维度必须确保其是1,否则会报错

>>>y = tf.squeeze(inputs, [0, 1], name='squeeze')
>>>ValueError: Can not squeeze dim[0], expected a dimension of 1, got 32 for 'squeeze' (op: 'Squeeze') with input shapes: [32,1,1,3].

例子:

#  't' 是一个维度是[1, 2, 1, 3, 1, 1]的张量
tf.shape(tf.squeeze(t))   # [2, 3], 默认删除所有为1的维度

# 't' 是一个维度[1, 2, 1, 3, 1, 1]的张量
tf.shape(tf.squeeze(t, [2, 4]))  # [1, 2, 3, 1],标号从零开始,只删掉了2和4维的1
import sys sys.path.append('/data/coding') # 添加包所在的根目录到 Python 路径中,防止找不到 tensor2robot 和 robotics_transformer import tensorflow as tf import numpy as np import tf_agents from tf_agents.networks import sequential from keras.layers import Dense from tf_agents.agents.dqn import dqn_agent from tf_agents.utils import common from typing import Type from tf_agents.networks import network from tensor2robot.utils import tensorspec_utils from tf_agents.specs import tensor_spec from robotics_transformer import sequence_agent from tf_agents.trajectories import time_step as ts from tensorflow_datasets.core.data_sources import array_record import tensorflow_datasets as tfds from robotics_transformer import transformer_network from robotics_transformer import transformer_network_test_set_up import collections from tf_agents.replay_buffers import reverb_replay_buffer from tf_agents.replay_buffers import reverb_utils import reverb """ 加载数据集 """ datasets_path = "/data/coding/moxing/colour4/19" # 预处理子数据集 #datasets_can_use_path = "/data/coding/moxing_use/colour4/19" # 后处理子数据集 datasets_val_path = "/data/coding/moxing_val/colour4/18" # 验证集保存路径 # 验证集数据标识列表 val_data_list = [-1, b'abdomen_109', b'abdomen_38', b'abdomen_74', b'abdomen_27', b'abdomen_14', b'abdomen_122', b'abdomen_21', b'abdomen_82', b'abdomen_93', b'abdomen_54', b'abdomen_80', b'abdomen_79', b'abdomen_105', b'abdomen_37', b'abdomen_7', b'abdomen_32', b'abdomen_3', b'abdomen_46', b'abdomen_95', b'abdomen_30', b'abdomen_55', b'abdomen_66', b'abdomen_10', b'neck_14', b'neck_16', b'neck_80', b'neck_29', b'neck_82', b'neck_47', b'neck_85', b'neck_43', b'neck_31', b'neck_66', b'neck_57', b'neck_55', b'neck_9',b'neck_27',b'neck_46',] load_dataset = tf.data.Dataset.load(datasets_path) # 加载数据集 load_dataset = load_dataset.batch(1, drop_remainder=True) # 按1步为批次 # 创建迭代器 load_iterator = iter(load_dataset) # 初始化数据存储列表(移除了image_aligned_depth相关部分) data_list = [[],[],[],[],[],[],[],[],[],[],[],[],[],[],[],[],[],[],[]] # 数据集样本总数 num_samples_to_display = load_dataset.reduce(0, lambda x, _: x + 1).numpy() index_1 = 0 # 第一阶段计数 for _ in range(num_samples_to_display): sample_data = next(load_iterator) # 使用 tf.squeeze() 消除第一个维度(移除了image_aligned_depth处理) reshaped_image = tf.squeeze(sample_data['image']) reshaped_natural_language_embedding = tf.squeeze(sample_data['natural_language_embedding']) reshaped_natural_language_instruction = tf.squeeze(sample_data['natural_language_instruction']) reshaped_base_displacement_vector = tf.squeeze(sample_data['base_displacement_vector']) reshaped_base_displacement_vertical_rotation = tf.squeeze(sample_data['base_displacement_vertical_rotation']) reshaped_gripper_closedness_action = tf.squeeze(sample_data['gripper_closedness_action']) reshaped_rotation_delta = tf.squeeze(sample_data['rotation_delta']) reshaped_terminate_episode = tf.squeeze(sample_data['terminate_episode']) reshaped_world_vector = tf.squeeze(sample_data['world_vector']) reshaped_discounted_return = tf.squeeze(sample_data['discounted_return']) reshaped_return = tf.squeeze(sample_data['return']) reshaped_reward = tf.squeeze(sample_data['reward']) reshaped_step_id = tf.squeeze(sample_data['step_id']) reshaped_element_index = tf.squeeze(sample_data['element_index']) reshaped_num_steps = tf.squeeze(sample_data['num_steps']) reshaped_is_first = tf.squeeze(sample_data['is_first']) reshaped_is_last = tf.squeeze(sample_data['is_last']) reshaped_step_type = tf.squeeze(sample_data['step_type']) reshaped_next_step_type = tf.squeeze(sample_data['next_step_type']) # 检查是否为验证集数据(跳过验证集) if any(value == reshaped_element_index.numpy() for value in val_data_list): print(f"{reshaped_element_index.numpy()} 与列表中的某个元素相等,需要剔除") continue # 将数据添加到列表(索引对应调整) data_list[0].append(reshaped_image) data_list[1].append(reshaped_natural_language_embedding) data_list[2].append(reshaped_natural_language_instruction) data_list[3].append(reshaped_base_displacement_vector) data_list[4].append(reshaped_base_displacement_vertical_rotation) data_list[5].append(reshaped_gripper_closedness_action) data_list[6].append(reshaped_rotation_delta) data_list[7].append(reshaped_terminate_episode) data_list[8].append(reshaped_world_vector) data_list[9].append(reshaped_discounted_return) data_list[10].append(reshaped_return) data_list[11].append(reshaped_reward) data_list[12].append(reshaped_step_id) data_list[13].append(reshaped_element_index) data_list[14].append(reshaped_num_steps) data_list[15].append(reshaped_is_first) data_list[16].append(reshaped_is_last) data_list[17].append(reshaped_step_type) data_list[18].append(reshaped_next_step_type) index_1 +=1 print(index_1) print("第一阶段结束!") print("------------------------------------------------------------------------------") # 初始化分组列表(调整为19个字段) dataset_list = [[],[],[],[],[],[],[],[],[],[],[],[],[],[],[],[],[],[],[]] index_2 = 0 # 第二阶段计数 while index_2 < len(data_list[14]): episode_num_steps = data_list[14][index_2] # 事件总步数 episode = [] # 存储单个事件的所有数据 # 提取事件数据(循环次数调整为19次) i = 0 while i < 19: start_index = index_2 end_index = index_2 + episode_num_steps new_list = data_list[i][start_index:end_index] episode.append(new_list) i +=1 # 按6步一组分割(调整字段数) j = 0 while j < 19: grouped_lists = [episode[j][i:i+6] for i in range(0, len(episode[j]), 6)] grouped_lists.pop() # 删除最后不足6步的部分 start_index = len(episode[j]) - 6 new_list = episode[j][start_index:] grouped_lists.append(new_list) # 添加最后6步 # 填充分组数据 k=0 while k < len(grouped_lists): dataset_list[j].append(grouped_lists[k]) k+=1 j +=1 index_2 += episode_num_steps print("第二阶段结束!") print("------------------------------------------------------------------------------") # 合并序列数据(调整为19个字段) secquence_list = [[],[],[],[],[],[],[],[],[],[],[],[],[],[],[],[],[],[],[]] index_3 = 0 # 第三阶段计数 while index_3 < len(dataset_list[14]): i = 0 while i < 19: secquence_list[i].append(tf.stack(dataset_list[i][index_3], axis=0)) i +=1 index_3 +=1 print("第三阶段结束!") print("------------------------------------------------------------------------------") # 添加批次维度并构建最终数据集 batch_list = [] index_4 = 0 while index_4 < 19: batch_list.append(tf.stack(secquence_list[index_4], axis=0)) index_4 +=1 # 构建数据字典(移除了image_aligned_depth) data = { "image" : batch_list[0], "natural_language_embedding" : batch_list[1], "natural_language_instruction" : batch_list[2], "base_displacement_vector" : batch_list[3], "base_displacement_vertical_rotation" : batch_list[4], "gripper_closedness_action" : batch_list[5], "rotation_delta" : batch_list[6], "terminate_episode" : batch_list[7], "world_vector" : batch_list[8], "discounted_return" : batch_list[9], "return" : batch_list[10], "reward" : batch_list[11], "step_id" : batch_list[12], "element_index" : batch_list[13], "num_steps" : batch_list[14], "is_first" : batch_list[15], "is_last" : batch_list[16], "step_type" : batch_list[17], "next_step_type" : batch_list[18], } # 创建并保存数据集 dataset = tf.data.Dataset.from_tensor_slices(data) #dataset.save(datasets_can_use_path) # 训练集保存路径 dataset.save(datasets_val_path) # 验证集路径(按需启用)
最新发布
05-31
def preprocess_train(example_batch, static_data_batch, labels_batch): # 处理图像数据并将每张图片切分成多个patches pixel_values = split_and_concat(example_batch["image"].convert("RGB")) pixel_values = tf.convert_to_tensor(pixel_values, dtype=tf.float32) static_data_batch = tf.convert_to_tensor(static_data_batch, dtype=tf.float32) labels_batch = tf.convert_to_tensor(labels_batch, dtype=tf.float32) return pixel_values, static_data_batch, labels_batch def create_dataset(dataset, labels, batch_size): image_data, static_data = dataset # 创建一个生成器函数 def generator(): pixel_values_list = [] static_data_list = [] labels_list = [] for example_batch, static_data_batch, labels_batch in zip(image_data, static_data, labels): pixel_values, static_data_batch, labels_batch = preprocess_train(example_batch, static_data_batch, labels_batch) # yield (pixel_values, static_data_batch), labels_batch labels_batch = tf.squeeze(labels_batch, axis=-1) # 将返回的样本逐一添加到列表 pixel_values_list.append(pixel_values) static_data_list.append(static_data_batch) labels_list.append(labels_batch) # 如果已经累积到 batch_size,进行拼接并返回 if len(pixel_values_list) == batch_size: # 拼接为 batch pixel_values_batch = tf.stack(pixel_values_list, axis=0) # (batch_size, 64, 64, 102) static_data_batch = tf.stack(static_data_list, axis=0) # (batch_size, static_data.shape[1]) labels_batch = tf.stack(labels_list, axis=0) # (batch_size,) pixel_values_list = [] static_data_list = [] labels_list = [] yield (pixel_values_batch, static_data_batch), labels_batch dataset = tf.data.Dataset.from_generator( generator, # 传递生成器函数 output_signature=( ( tf.TensorSpec(shape=(None, 64, 64, 102), dtype=tf.float32), # 图像数据形状 tf.TensorSpec(shape=(None, static_data.shape[1]), dtype=tf.float32) # 静态数据形状 ), tf.TensorSpec(shape=(None,), dtype=tf.float32) # 标签的 shape ) ) dataset = dataset.prefetch(tf.data.AUTOTUNE) # 自动预取,提升性能 return dataset对于一个二分类的图像数据集,且数据集很大并且正负极不均衡,上面代码是否需要进行修改
03-08
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值