1、ERROR
RuntimeError: stack expects each tensor to be equal size, but got [238, 128] at entry 0 and [55, 128] at entry 128。
2、解决方法
- 方案一: 数据预处理时,最好排除特征数据维度不一致的样本;
- 方案二: 如果实在无法排除,又在数据加载时报错,可以尝试:
1) 首先在get_item函数构建中判断数据维度,不符合目标维度的数据返回None类型。
if X.shape == (238,128):
return X,y
else:
return None,None
- 然后在数据加载的时候过滤掉None的样本
注意: 需要import 默认的拼接方式
from torch.utils.data.dataloader import default_collate # 导入默认的拼接方式
def my_collate_fn(batch):
'''
batch中每个元素形如(data, label)
'''
# 过滤为None的数据
batch = list(filter(lambda x:x[0] is not None, batch))
if len(batch) == 0: return torch.Tensor()
return default_collate(batch) # 用默认方式拼接过滤后的batch数据
3)数据加载
data_loader = DataLoader(dataset=dataset,
batch_size=batch_size,
collate_fn=my_collate_fn,
sampler=sampler)
以上三步,即可顾虑掉维度不符合要求的样本。
3、参考文献
“pytorch 函数DataLoader“: https://blog.youkuaiyun.com/TH_NUM/article/details/80877687