在看FCN代码时,因为需要同时处理图像和模板标签,所以需要重写一系列的transforms的功能。
参照的是霹导的代码:
import torchvision.transforms.functional as F
class RandomResize(object):
def __init__(self, min_size, max_size=None):
self.min_size = min_size
if max_size == None:
max_size = min_size
self.max_size = max_size
def __call__(self, image, target):
size = random.randint(self.min_size, self.max_size)
image = F.resize(image, size)
target = F.resize(target, size, interpolation=T.InterpolationMode.NEAREST)
return image, target
我在浏览了一遍原代码后,试着自己写,写的过程发现,我用了transforms.Resize(),它的输入只有一个size变量,看起来不太对,我就返回看源码,才知道源码用的是transforms.functional。点进transforms.Resize()的源码发现,它的forward函数也调用的functional中的resize。
# transforms.Resize()源码
class Resize(torch.nn.Module):
def __init__(self, size, interpolation=InterpolationMode.BILINEAR, max_size=None, antialias=None):
super().__init__()
_log_api_usage_once(self)
if not isinstance(size, (int, Sequence)):
raise TypeError(f"Size should be int or sequence. Got {type(size)}")
if isinstance(size, Sequence) and len(size) not in (1, 2):
raise ValueError("If size is a sequence, it should have 1 or 2 values")
self.size = size
self.max_size = max_size
# Backward compatibility with integer value
if isinstance(interpolation, int):
warnings.warn(
"Argument 'interpolation' of type int is deprecated since 0.13 and will be removed in 0.15. "
"Please use InterpolationMode enum."
)
interpolation = _interpolation_modes_from_int(interpolation)
self.interpolation = interpolation
self.antialias = antialias
def forward(self, img):
return F.resize(img, self.size, self.interpolation, self.max_size, self.antialias)
结论,定义的时候用transforms中的函数,比如Compose时,如果要自己重写,需要调用transforms.functional中的函数,也就是transforms中对该函数的实现。