torchvision transforms自动clip图像数值范围坑

本文探讨了如何使用PyTorch和OpenCV读取.exr格式的深度图,并解决将其转换为PIL.Image和torch.Tensor时的值范围问题。作者分享了针对深度图特性的最佳实践,确保在后续处理和模型预测过程中数值范围不变。
部署运行你感兴趣的模型镜像

用 [1] 在自己数据集跑实验,需要读 .exr 后缀的 depth map(参照 [2] 配置好 opencv),期望是保持 depth map 原始的值不变。由于 torchvison.transforms 接受 PIL.Image 或 torch.Tensor 做输入,而 [1] 原本的代码是用 PIL.Image 读的,于是一开头用 cv2 读完之后转成 PIL.Image 套用原代码。但发现这样操作,depth map 的范围会自动被其 transforms clip 成 [0, 255],导致训练不了。

对比:

import cv2
import numpy as np
from PIL import Image
import torch
from torchvision import transforms
import traceback

os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1"

depth = cv2.imread("027.jpg.geometric.exr", cv2.IMREAD_UNCHANGED)
# 原始范围
print("original:", depth.shape, depth.min(), depth.max(), depth.dtype)
# 转 PIL.Image -> 值没变
depth_pil = Image.fromarray(depth)
print("PIL:", depth_pil.size, np.min(depth_pil), np.max(depth_pil))#, depth_pil.dtype)
# 转 torch.Tensor -> 值没变
depth_t = torch.from_numpy(depth)
print("torch:", depth_t.size(), depth_t.min(), depth_t.max(), depth_t.dtype)

# 情况 1:对 PIL.Image 用,会(先 clip 到 [0, 255] 再)归一化到 [0, 1]
trfm1 = transforms.Compose([
    transforms.Grayscale(num_output_channels=1),  # 会 clip
    transforms.ToTensor()  # 会归一化到 [0, 1]
])
dmap1 = trfm1(depth_pil)
print("dmap1:", dmap1.min(), dmap1.max())

# 情况 2:对 torch.Tensor 用,报错
try:
    dmap2 = trfm1(depth_t.unsqueeze(0).tile((3, 1, 1)))
    print("dmap2:", dmap2.min(), dmap2.max())
except:
    print('\n' + '<' * 7)
    print(traceback.format_exc(), end="")
    print('>' * 7 + '\n')

# 情况 3:对 PIL.Image 用,去掉 `ToTensor`,只 clip 没归一化
trfm2 = transforms.Grayscale(num_output_channels=1)
dmap3 = trfm2(depth_pil)
print("dmap3:", np.min(dmap3), np.max(dmap3))

# 情况 4:对 torch.Tensor 用(要变换成 [3, H, W] 先),保持原值
dmap4 = trfm2(depth_t.unsqueeze(0).tile((3, 1, 1)))
print("dmap4:", dmap4.min(), dmap4.max())
  • 输出
original: (1276, 717) 0.0 5640.0 float32
PIL: (717, 1276) 0.0 5640.0
torch: torch.Size([1276, 717]) tensor(0.) tensor(5640.) torch.float32
dmap1: tensor(0.) tensor(1.)

<<<<<<<
Traceback (most recent call last):
  File "test.py", line 69, in <module>
    dmap2 = trfm1(depth_t.unsqueeze(0).tile((3, 1, 1)))
  File "/home/tyliang/miniconda3/envs/pt110/lib/python3.8/site-packages/torchvision/transforms/transforms.py", line 61, in __call__
    img = t(img)
  File "/home/tyliang/miniconda3/envs/pt110/lib/python3.8/site-packages/torchvision/transforms/transforms.py", line 98, in __call__
    return F.to_tensor(pic)
  File "/home/tyliang/miniconda3/envs/pt110/lib/python3.8/site-packages/torchvision/transforms/functional.py", line 114, in to_tensor
    raise TypeError('pic should be PIL Image or ndarray. Got {}'.format(type(pic)))
TypeError: pic should be PIL Image or ndarray. Got <class 'torch.Tensor'>
>>>>>>>

dmap3: 0 255
dmap4: tensor(0.) tensor(5639.4360)

所以如果是 depth map 这种非 RGB 图像(数值范围不是 [0, 255]),但当成图像读、transform 时,考虑转成 torch.Tensor 再传进 torchvision transforms,而且要验证在 data loader 一套操作之后、传给模型 inference 之前,数值范围对不对。

References

  1. antocad/FocusOnDepth
  2. Open EXR files, how to enable? #21928

您可能感兴趣的与本文相关的镜像

PyTorch 2.5

PyTorch 2.5

PyTorch
Cuda

PyTorch 是一个开源的 Python 机器学习库,基于 Torch 库,底层由 C++ 实现,应用于人工智能领域,如计算机视觉和自然语言处理

#作业2-2 import torch import torchvision import numpy as np import matplotlib.pyplot as plt from sklearn.cluster import KMeans import torch import torchvision import numpy as np import matplotlib.pyplot as plt from sklearn.cluster import KMeans # === 添加中文支持 === plt.rcParams['font.sans-serif'] = ['SimHei', 'DejaVu Sans', 'Arial Unicode MS'] plt.rcParams['axes.unicode_minus'] = False # ================== transform = torchvision.transforms.Compose([ torchvision.transforms.ToTensor(), ]) cifar10_train = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform) dataloader = torch.utils.data.DataLoader(cifar10_train, batch_size=50, shuffle=False) images, labels = next(iter(dataloader)) flattened_images = images.view(50, -1).numpy() flattened_images = flattened_images / 255.0 def plot_cluster_centers(k): kmeans = KMeans(n_clusters=k, random_state=42) kmeans.fit(flattened_images) centers = kmeans.cluster_centers_.reshape(k, 3, 32, 32) centers_img = np.transpose(centers, (0, 2, 3, 1)) plt.figure(figsize=(10, 4)) for i in range(k): plt.subplot(2, (k+1)//2, i+1) plt.imshow(np.clip(centers_img[i], 0, 1)) plt.title(f'聚类中心 {i}') plt.axis('off') plt.suptitle(f'CIFAR10 聚类中心可视化(K={k})') # 现在可以正常显示中文 plt.tight_layout() plt.show() plot_cluster_centers(5) plot_cluster_centers(10) # 1. 加载 CIFAR10 前 50 张图像(仅训练集前 50) transform = torchvision.transforms.Compose([ torchvision.transforms.ToTensor(), ]) cifar10_train = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform) dataloader = torch.utils.data.DataLoader(cifar10_train, batch_size=50, shuffle=False) images, labels = next(iter(dataloader)) # 取出第一批(50张) # 2. 展平 + 归一化 flattened_images = images.view(50, -1).numpy() # (50, 3072) flattened_images = flattened_images / 255.0 # 归一化到 [0,1] # 3. 分别使用 K=5 和 K=10 进行聚类 def plot_cluster_centers(k): kmeans = KMeans(n_clusters=k, random_state=42) kmeans.fit(flattened_images) centers = kmeans.cluster_centers_ # (k, 3072) # 恢复为 (k, 3, 32, 32),再转为 (k, 32, 32, 3) 可视化 centers_img = centers.reshape(k, 3, 32, 32) centers_img = np.transpose(centers_img, (0, 2, 3, 1)) # CHW -> HWC # 显示聚类中心图像 plt.figure(figsize=(10, 4)) for i in range(k): plt.subplot(2, (k+1)//2, i+1) plt.imshow(np.clip(centers_img[i], 0, 1)) # clip 处理可能超出范围的值 plt.title(f'聚类中心 {i}') plt.axis('off') plt.suptitle(f'CIFAR10 聚类中心可视化(K={k})') plt.tight_layout() plt.show() # 显示 K=5 和 K=10 的结果 plot_cluster_centers(5) plot_cluster_centers(10) ,根据上面给出的图修改此代码为什么是这样的
最新发布
12-24
评论 1
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值