看代码
import torch.nn as nn
import torchvision.models as models
import argparse
parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
if __name__ == '__main__':
model_names = sorted(name for name in models.__dict__
if name.islower() and not name.startswith("__")
and callable(models.__dict__[name]))
# print(model_names)
parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet50',
choices=model_names,
help='model architecture: ' +
' | '.join(model_names) +
' (default: resnet50)')
args = parser.parse_args()
model = models.__dict__[args.arch]()
for name, param in model.named_parameters():
print(name,param.size())
resnet50 的参数和形状输出: