代码写的特别好,很有借鉴意义。链接如下:
[github]: https://github.com/yjxiong/tsn-pytorch
1. @property
class VideoRecord(object):
def __init__(self, row):
self._data = row
@property #@prproperty的用法 将一个类方法转变成一个类属性
#试图将该属性设为其他值,我们会引发一个AttributeError错误
def path(self):
return self._data[0]
@property
def num_frames(self):
return int(self._data[1])
@property
def label(self):
return int(self._data[2])
2. np.multiply
if average_duration > 0: # np.multiply 的使用 np.multiply([0, 1, 2], x) = [0, x, 2x]
offsets = np.multiply(list(range(self.num_segments)), average_duration) + randint(average_duration, size=self.num_segments)
3. topk
# 计算top1 top5正确率
# topk=(1,5)
def accuracy(output, target, topk=(1,)):
"""Computes the precision@k for the specified values of k"""
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(target.view(1, -1).expand_as(pred))
res = []
for k in topk:
correct_k = correct[:k].view(-1).float().sum(0)
res.append(correct_k.mul_(100.0 / batch_size))
return res
4. getattr & setattr
if 'resnet' in base_model or 'vgg' in base_model:
#getattr() 函数用于返回一个对象属性值。
#torchvision.model pytorch 自带的复现模型
#这句话相当于 models.resnet101(pretrained=True)
self.base_model = getattr(torchvision.models, base_model)(True)
5. Freeze BN layers
if self._enable_pbn:
print("Freezing BatchNorm2D except the first one.")
for m in self.base_model.modules():
if isinstance(m, nn.BatchNorm2d):
count += 1
if count >= (2 if self._enable_pbn else 1):
m.eval()
# shutdown update in frozen mode
m.weight.requires_grad = False
m.bias.requires_grad = False

本文深入探讨了使用PyTorch进行视频理解的技术细节,包括@property装饰器的应用、numpy的multiply函数、topk函数的使用场景,以及如何利用getattr和setattr进行模型加载。此外,还介绍了如何冻结BN层,对理解视频数据集和模型训练提供了实用指导。
4378

被折叠的 条评论
为什么被折叠?



