Torch数据加载维度不一致报错

本文详细阐述了如何处理RuntimeError关于数据维度不一致的问题,通过数据预处理、检查并过滤不符合要求的样本,以及使用自定义collate_fn函数,确保PyTorch DataLoader的稳定运行。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

1、ERROR

RuntimeError: stack expects each tensor to be equal size, but got [238, 128] at entry 0 and [55, 128] at entry 128。

2、解决方法

  1. 方案一: 数据预处理时,最好排除特征数据维度不一致的样本;
  2. 方案二: 如果实在无法排除,又在数据加载时报错,可以尝试:
    1) 首先在get_item函数构建中判断数据维度,不符合目标维度的数据返回None类型。
         if X.shape == (238,128):
             return X,y
         else:
             return None,None
  1. 然后在数据加载的时候过滤掉None的样本
    注意: 需要import 默认的拼接方式
from torch.utils.data.dataloader import default_collate # 导入默认的拼接方式
def my_collate_fn(batch):
    '''
    batch中每个元素形如(data, label)
    '''
    # 过滤为None的数据
    batch = list(filter(lambda x:x[0] is not None, batch))
    if len(batch) == 0: return torch.Tensor()
    return default_collate(batch) # 用默认方式拼接过滤后的batch数据        

3)数据加载

data_loader = DataLoader(dataset=dataset,
                         batch_size=batch_size,
                         collate_fn=my_collate_fn,
                         sampler=sampler)

以上三步,即可顾虑掉维度不符合要求的样本。

3、参考文献

“pytorch 函数DataLoader“: https://blog.youkuaiyun.com/TH_NUM/article/details/80877687

<think>好的,我需要解决用户的问题,他们遇到了一个ValueError,错误提示是“not enough values to unpack (expected 3, got 2)”。根据之前的对话,问题出在数据维度符合预期,代码期望三维数据,但实际得到的是二维数据。现在用户询问如何输出数据维度的代码。 首先,我需要确认用户需要的是如何在Python中打印数据的形状。常见的情况是使用NumPy数组或Pandas DataFrame,所以应该以这两种数据结构为例。用户可能在处理时间序列数据,例如使用GAN生成数据,或者是其他机器学习任务,所以需要展示如何检查数据维度。 接下来,要考虑到用户可能对Python数据结构太熟悉,所以需要给出简单直接的代码示例。比如使用shape属性,这在NumPy和Pandas中都是标准的。同时,可能需要解释这些代码的作用,确保用户理解每个步骤。 然后,可能需要补充一些调试建议,比如在关键位置打印数据形状,帮助用户定位问题所在。例如在调用visualize函数之前打印real_dataset和generated_data的形状,或者在visualization_metrics.py文件中检查ori_data的形状。 另外,用户可能遇到的情况包括数据预处理时的错误reshape,或者生成器输出没有正确调整维度。因此,示例代码需要展示如何调整数据维度,比如使用reshape方法将二维数据转换为三维。 最后,要确保回答结构清晰,分步骤说明,并给出示例,这样用户能够逐步排查问题。同时,语言要简洁,避免使用过于专业的术语,让同水平的用户都能理解。</think>以下是用于检查数据维度Python代码示例,帮助定位类似`ValueError`的维度匹配问题: --- ### **1. 基础方法:打印数据维度** ```python # 检查任意Python变量(NumPy数组、Tensor、Pandas DataFrame等) print("原始数据维度:", real_dataset.shape) # 输出例如 (100, 10, 5) print("生成数据维度:", generated_data.shape) # 输出例如 (50, 5) → 发现维度匹配! ``` --- ### **2. 针对NumPy数组** ```python import numpy as np # 示例数据 real_data = np.random.rand(100, 10, 5) # 三维:(样本数, 时间步长, 特征) fake_data = np.random.rand(100, 5) # 二维 → 会引发错误 # 检查维度 print("真实数据维度:", real_data.shape) # 输出 (100, 10, 5) print("生成数据维度:", fake_data.shape) # 输出 (100, 5) → 问题根源! ``` --- ### **3. 针对PyTorch张量** ```python import torch tensor_data = torch.randn(32, 10, 3) # (batch_size, seq_len, features) print("张量维度:", tensor_data.shape) # 输出 torch.Size([32, 10, 3]) ``` --- ### **4. 针对Pandas DataFrame** ```python import pandas as pd df = pd.DataFrame({"A": [1,2,3], "B": [4,5,6]}) print("DataFrame维度(行, 列):", df.shape) # 输出 (3, 2) ``` --- ### **5. 调试场景示例** 在报错的代码附近添加检查: ```python def visualize(dataset, generator): # 添加检查点 print("[Debug] 输入数据维度:", dataset.shape) # 例如期望 (N, seq_len, features) generated_data = [] for _ in range(samples): temp = generator.predict(...) generated_data.append(temp) generated_data = np.array(generated_data) print("[Debug] 生成数据维度:", generated_data.shape) # 例如实际是 (N, features) visualization(dataset, generated_data, "tsne") # 此处可能触发错误 ``` --- ### **6. 维度调整示例** 若发现数据维度错误(如二维需转三维): ```python # 假设原始数据是二维 (N, features),需添加时间步长维度(例如时间步长=1) adjusted_data = real_dataset.reshape(real_dataset.shape[0], 1, real_dataset.shape[1]) print("调整后维度:", adjusted_data.shape) # 输出 (N, 1, features) ``` --- ### **关键点总结** 1. **何时检查**:在数据加载、预处理、生成器输出后、函数调用前等关键节点。 2. **期望维度**: - **时间序列**:`(样本数, 时间步长, 特征数)` - **图像数据**:`(样本数, 高度, 宽度, 通道数)` - **表格数据**:`(样本数, 特征数)` 3. **修复方法**:通过`reshape`、`np.expand_dims`或模型结构调整维度
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值