Pytorch中使用transforms.ToTensor和transforms.ToPILImage的numpy.ndarray与Tensor和PILImage的转换举例

本文详细解析了Pytorch中numpy、tensor与PIL图像之间的数据转换过程,包括numpy到tensor的维度变化、数据类型调整及归一化,numpy到PIL的类型需求满足,以及tensor到PIL的维度调整,帮助读者理解并掌握不同数据类型间的转换技巧。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

作为刚接触Pytorch的新人,在刚开始时候处理三种数据tensor,numpy,PIL有点弄晕了,而且中间的转换过程需要注意类型,仅此记录下来,如果能有幸帮到大家那更好了。

进入正题,我们首先看看numpy to tensor的情况:

import numpy as np
from torchvision import transforms as T

x = np.ones((128, 228, 1))   # 这里numpy维度是HWC
print(x.shape)   # (128, 228, 1)  
x_tensor = T.ToTensor()(x) 
print(x_tensor.size())   # torch.Size([1, 128, 228]),这里tensor维度是CHW
print(torch.from_numpy(x).size())   # torch.Size([128, 228, 1])注意这里维度不变
x_numpy = x_tensor.numpy()   # 将tensor转换为numpy,但是维度也不改变(1, 128, 228)

所以这种情况比较好理解,就是将HWC变成了CHW,当然这个过程还进行了矩阵运算将其值进行了归一化。

接下来看看numpy to PIL的情况:

x = np.ones((128, 228, 1))
print(x.shape)
x_pil = T.ToPILImage()(np.uint8(x))   # 注意这里的np.uint8(),为了使其满足类型需求,在这里是必须的
print(x_pil.size)
x_numpy = np.array(x_pil)
print(x_numpy.shape)

# 输出为:
# (128, 228, 1)
# (228, 128)
# (128, 228)

接下来看看tensor to PIL的情况:

x = torch.ones((1, 128, 228))
print(x.shape)
x = x.float()   # 有时候从numpy到tensor后再转化为PIL会报错,这一步就是使其满足数据类型,在这里这一句可以去掉。
x_pil = T.ToPILImage()(x)
print(x_pil.size)
x_tensor = T.ToTensor()(x_pil)
print(x_tensor.size())

# 输出为:
# torch.Size([1, 128, 228])
# (228, 128)
# torch.Size([1, 128, 228])

可以看到ToPILImage实际上将CHW变成了WH,也去掉了通道这个维度
经过评论区的网友提醒,ToPILImage实际上将CHW变成了CWH,只是因为显示原因没有显示C通道,而原本的size并没有改变,因为恢复回去还是3D的。但是Numpy那里确实是发生了变化,大家可以多多注意下。

可能随着版本的更新,上面这些Case可能都会改变,建议大家自己动手尝试下看看是否与本文中结果一致

<think>好的,我现在需要帮用户解决他们提供的Python代码中的ValueError问题。错误信息是“pic should be 2/3 dimensional. Got 4 dimensions.”。首先,我得仔细看看用户提供的代码,理解其功能,然后找出哪里出了问题。 用户给出的函数是color_hsv,接收一个tensor1参数。函数的第一步是调用tensor1.squeeze(0),这应该是为了去除第0维的维度,比如将形状从(1, C, H, W)变为(C, H, W)。接下来,用transforms.ToPILImage()将张量转换为PIL图像,然后转换成HSV模式,再转回张量返回。 错误发生在transform_to_pil(tensor1)这一步,因为ToPILImage期望的输入是2D(灰度)或3D(RGB)的张量,但这里可能传入了一个4维的张量。用户提到错误信息显示得到了4维的数据,这说明即使调用了squeeze(0),张量可能仍然保留了多余的维度。 首先,我需要检查squeeze(0)是否真的去除了第0维。例如,如果原始张量的形状是(1, 3, 224, 224),squeeze(0)后变成(3, 224, 224),这是3D的,应该没问题。但如果原始张量的形状是(1, 1, 3, 224, 224),那么squeeze(0)只能去掉第一个1,变成(1, 3, 224, 224),仍然是4维的,这时候就会报错。因此,用户可能需要更彻底地压缩所有单维度,比如使用squeeze()不带参数,或者检查输入张量的形状。 接下来,可能的情况是输入的tensor1在调用squeeze(0)之后仍然有4维,说明原来的维度可能不止一个单维度。比如,输入张量可能是(1, 1, C, H, W),squeeze(0)后变成(1, C, H, W),仍然4维。这时候需要进一步压缩,或者用户需要调整squeeze的参数。 另外,可能用户在使用squeeze(0)时,原张量的第0维并不是1,导致squeeze无效。例如,如果原张量是(2, 3, 224, 224),squeeze(0)不会有变化,这时候需要其他处理方式。 可能的解决方案包括: 1. 使用squeeze()而不指定维度,自动去除所有大小为1的维度。 2. 检查输入张量的形状,确保在调用squeeze之后变成3维。 3.转换为PIL图像之前,确保张量的维度正确,比如通过打印形状来调试。 4. 如果输入是批量数据,需要单独处理每个图像,而不是传入一个批次。 用户可能需要确认输入张量tensor1的形状,特别是在调用squeeze(0)后的形状。如果仍有4维,可以尝试多次squeeze或使用view/reshape调整维度。 例如,假设输入张量形状是(1, 3, 224, 224),squeeze(0)后变成(3, 224, 224),正确。但如果输入是(1, 1, 3, 224, 224),squeeze(0)后是(1, 3, 224, 224),仍然4维,此时需要再squeeze(0)或者直接使用squeeze()去掉所有单维度。 因此,修改代码中的squeeze(0)为squeeze()可能解决问题。或者,如果输入张量存在批次维度,可能需要遍历每个图像进行处理,而不是一次性处理整个批次。 总结步骤: 1. 检查输入张量的形状,确定哪些维度需要压缩。 2. 使用squeeze()去除所有单维度,而不仅仅是第0维。 3. 确保转换到PIL前的张量是2D或3D。 4. 添加调试打印语句,查看处理后的张量形状。 另外,用户可能需要考虑是否错误地将批量数据传入该函数。例如,如果tensor1是批量数据,如(B, C, H, W),那么需要逐个处理每个图像,或者确保传入的是单个图像张量。 例如,如果函数color_hsv的输入是一个批量张量,那么应该先遍历每个样本,或者调整函数以处理批量输入。但根据当前代码,函数似乎设计处理单个图像,因此需要确保输入张量的维度正确。 可能的修改是将tensor1 = tensor1.squeeze(0)改为tensor1 = tensor1.squeeze(),或者多次squeeze,或者在转换前确保维度正确。例如: tensor1 = tensor1.squeeze() # 去除所有大小为1的维度 或者在调用squeeze之后,检查其维度,如果仍然是4维,可能需要进一步处理,例如: if tensor1.dim() == 4: tensor1 = tensor1.squeeze(0) 但具体处理方式需要根据输入形状决定。 总之,核心问题在于输入ToPILImage的张量维度不正确,需要调整为2或3维。解决方案是确保在squeeze之后张量的维度正确,可能需要更彻底的squeeze或者调整输入结构。</think>让我们逐步分析并解决这个报错: --- ### **错误原因分析** 报错`ValueError: pic should be 2/3 dimensional. Got 4 dimensions.`表明`ToPILImage()`接收到了4维张量,但该转换器只接受2D(灰度图)或3D(RGB图)张量。 --- ### **关键代码检查** 观察代码中的维度压缩操作: ```python tensor1 = tensor1.squeeze(0) # 试图删除第0维 ``` **可能问题场景**: 1. 原始输入张量形如 `[1, 1, C, H, W]`(例如从视频帧或多批次数据中提取) 2. 单次`squeeze(0)`只能删除第0维,得到 `[1, C, H, W]`(仍为4维) 3. 此时传递给`ToPILImage()`的仍是4维张量 --- ### **解决方案** #### 方法一:彻底压缩所有单维度 将`squeeze(0)`改为**无参数压缩**,删除所有大小为1的维度: ```python tensor1 = tensor1.squeeze() # 删除所有size=1的维度 ``` #### 方法二:显式指定目标维度 若已知需要3D张量(通道优先格式`[C, H, W]`),可强制降维: ```python tensor1 = tensor1[0] # 当第0维是冗余批次维度时 # 或 tensor1 = tensor1.view(-1, *tensor1.shape[-2:]) # 保持最后两维不变 ``` #### 方法三:添加维度验证 ```python tensor1 = tensor1.squeeze(0) assert tensor1.dim() in [2, 3], f"Invalid dimension {tensor1.dim()}, expect 2 or 3" ``` --- ### **改进后代码** ```python def color_hsv(tensor1): # 彻底压缩所有单维度 tensor1 = tensor1.squeeze() # 维度合法性检查 if tensor1.dim() not in [2, 3]: raise ValueError(f"输入张量需为2D或3D,当前维度:{tensor1.dim()}") transform_to_pil = transforms.ToPILImage() pil_image = transform_to_pil(tensor1) hsv_image = pil_image.convert('HSV') transform_to_tensor = transforms.ToTensor() tensor_hsv = transform_to_tensor(hsv_image) return tensor_hsv ``` --- ### **调试建议** 1. **打印输入张量形状**: ```python print("原始输入形状:", tensor1.shape) tensor1 = tensor1.squeeze() print("压缩后形状:", tensor1.shape) ``` 2. **典型输入验证**: - 合法输入:`[C, H, W]`(3D)或 `[H, W]`(2D) - 非法输入:`[B, C, H, W]`(需先拆分批次) --- ### **数学表达验证** 假设原始张量维度为: $$ \text{shape}(tensor1) = (1, 1, 3, 224, 224) $$ 经过`squeeze(0)`后: $$ \text{shape} \rightarrow (1, 3, 224, 224) \quad (\text{仍为4D}) $$ 改用`squeeze()`后: $$ \text{shape} \rightarrow (3, 224, 224) \quad (\text{合法3D}) $$ --- ### **总结** 通过彻底压缩冗余维度并添加维度校验,可确保输入`ToPILImage()`的张量符合2D/3D要求。建议在处理张量时始终伴随形状打印语句以快速定位维度问题。
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值