【debug锦集】Expected input batch_size (1) to match target batch_size (0)

我是雪天鱼,一名FPGA技术爱好者,B站账号:https://space.bilibili.com/397002941?spm_id_from=333.1007.0.0

(1)由于输入到 loss函数的label不满足要求格式导致

loss = criterion(output, label)

这里的label是tensor类型,元素个数要与种类个数对应,如四分类,则 label=[1,0,0,0]、[0,1,0,0]、[0,0,1,0]、[0,0,0,1] 这样 ,即与预测的输出是一个样子,size和dtype都要一致:
预测输出:
在这里插入图片描述
报错details:
在这里插入图片描述
解决方法:
在这里插入图片描述

手动调整下label即可

### 解决 PyTorch 中输入批次大小与目标批次大小不匹配的问题 在处理序列到序列(Seq2Seq)模型时,遇到输入批次大小与目标批次大小不一致的情况是比较常见的问题。这种错误通常发生在训练过程中,特别是在定义损失函数计算时。 为了确保输入和目标具有相同的批次大小,在构建数据加载器时应特别注意如何打包和填充序列。一种常见方法是在创建 `DataLoader` 实例之前先通过自定义 collate 函数来调整每批中的样本数量,使得它们保持一致[^1]。 另外,如果使用的是动态解码方式,则需要注意编码器输出的最后一维代表时间步数而不是批量维度;因此,在某些情况下可能需要转置张量或将额外的时间步设为零填充以保证形状兼容性[^2]。 下面给出一段简单的代码片段用于展示如何设置 DataLoader 和 Collate_fn 来防止此类错误: ```python from torch.utils.data import Dataset, DataLoader class SeqDataset(Dataset): def __init__(self, src_data, tgt_data): self.src_data = src_data self.tgt_data = tgt_data def __len__(self): return len(self.src_data) def __getitem__(self, idx): source_seq = self.src_data[idx] target_seq = self.tgt_data[idx] # 假定已经进行了必要的预处理如tokenization等 sample = {"source": source_seq, "target": target_seq} return sample def my_collate(batch_list): """ 自定义collate函数用来处理不同长度的序列, 并确保每个batch内的inputtarget有相同的第一维度(batch size). """ max_len_input = max([len(item['source']) for item in batch_list]) padded_inputs = [] max_len_output = max([len(item['target']) for item in batch_list]) padded_outputs = [] for b in batch_list: pad_width_in = (0,max_len_input-len(b["source"])) pad_width_out = (0,max_len_output-len(b["target"])) p_input = F.pad(torch.tensor(b["source"]),pad_width_in,value=PAD_TOKEN_ID).unsqueeze(0) p_target = F.pad(torch.tensor(b["target"]),pad_width_out,value=PAD_TOKEN_ID).unsqueeze(0) padded_inputs.append(p_input) padded_outputs.append(p_target) inputs_tensor = torch.cat(padded_inputs,dim=0) outputs_tensor = torch.cat(padded_outputs,dim=0) return {'inputs': inputs_tensor,'targets':outputs_tensor} dataset = SeqDataset(src_sequences,tgt_sequences) dataloader = DataLoader(dataset,batch_size=BATCH_SIZE,collate_fn=my_collate) ```
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

雪天鱼

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值