1. python中*args和**kwargs的用法
*args和 **kwargs经常在函数定义中被使用,用来传递可变长参数。*args表示任意多个无名参数,是一个tuple,**kwargs表示关键字参数,是一个dict。
当*args和 **kwargs同时是使用时,*args参数列表要放在 **kwargs前边。
举例如下:
# 测试*args和**kwargs
def func(*args, **kwargs):
print("args = ", args)
print("kwargs = ", kwargs)
print("*" * 20)
if __name__ == "__main__":
func(1,2,3,4,5,6)
func(a=1, b=2, c=3, d=4)
func(1,2,3, a=4, b=5, c=6)
# 输出结果如下:
args = (1, 2, 3, 4, 5, 6)
kwargs = {}
********************
args = ()
kwargs = {'a': 1, 'b': 2, 'c': 3, 'd': 4}
********************
args = (1, 2, 3)
kwargs = {'a': 4, 'b': 5, 'c': 6}
********************
结果分析:func(1,2,3,4,5,6)中的参数1,2,3,4,5,6都是无名参数,所以是*args。func(a=1, b=2, c=3, d=4)中的参数a=1, b=2, c=3, d=4是关键字参数,所以是**kwargs. func(1,2,3, a=4, b=5, c=6)中的参数1,2,3是无名参数,a=4, b=5, c=6是关键字参数。
2. *args和**kwargs的使用场景
我是在使用torchvision.models中自带的alexnet网络模型的时候,注意到**kwargs这个参数的,def alexnet(pretrained=False, progress=True, **kwargs):
中**kwargs参数其实就是传递给类AlexNet的参数,即我们这里需要把 kwargs的位置写成num_classes = X . 可结合下边alexnet的源码来理解。代码如下:
import torch
import torch.nn as nn
from .utils import load_state_dict_from_url
__all__ = ['AlexNet', 'alexnet']
model_urls = {
'alexnet': 'https://download.pytorch.org/models/alexnet-owt-4df8aa71.pth',
}
class AlexNet(nn.Module):
def __init__(self, num_classes=1000):
super(AlexNet, self).__init__()
self.features = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2),
nn.Conv2d(64, 192, kernel_size=5, padding=2),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2),
nn.Conv2d(192, 384, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(384, 256, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(256, 256, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2),
)
self.avgpool = nn.AdaptiveAvgPool2d((6, 6))
self.classifier = nn.Sequential(
nn.Dropout(),
nn.Linear(256 * 6 * 6, 4096),
nn.ReLU(inplace=True),
nn.Dropout(),
nn.Linear(4096, 4096),
nn.ReLU(inplace=True),
nn.Linear(4096, num_classes),
)
def forward(self, x):
x = self.features(x)
x = self.avgpool(x)
x = torch.flatten(x, 1)
x = self.classifier(x)
return x
def alexnet(pretrained=False, progress=True, **kwargs):
r"""AlexNet model architecture from the
`"One weird trick..." <https://arxiv.org/abs/1404.5997>`_ paper.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
model = AlexNet(**kwargs)
if pretrained:
state_dict = load_state_dict_from_url(model_urls['alexnet'],
progress=progress)
model.load_state_dict(state_dict)
return model