Pytorch学习(十三)python中*args和**kwargs的用法

1. python中*args和**kwargs的用法

*args和 **kwargs经常在函数定义中被使用,用来传递可变长参数。*args表示任意多个无名参数,是一个tuple,**kwargs表示关键字参数,是一个dict。
当*args和 **kwargs同时是使用时,*args参数列表要放在 **kwargs前边。
举例如下:

# 测试*args和**kwargs
def func(*args, **kwargs):
    print("args = ", args)
    print("kwargs = ", kwargs)
    print("*" * 20)
    
if __name__ == "__main__":
    func(1,2,3,4,5,6)
    func(a=1, b=2, c=3, d=4)
    func(1,2,3, a=4, b=5, c=6)

# 输出结果如下:
args =  (1, 2, 3, 4, 5, 6)
kwargs =  {}
********************
args =  ()
kwargs =  {'a': 1, 'b': 2, 'c': 3, 'd': 4}
********************
args =  (1, 2, 3)
kwargs =  {'a': 4, 'b': 5, 'c': 6}
********************

结果分析:func(1,2,3,4,5,6)中的参数1,2,3,4,5,6都是无名参数,所以是*args。func(a=1, b=2, c=3, d=4)中的参数a=1, b=2, c=3, d=4是关键字参数,所以是**kwargs. func(1,2,3, a=4, b=5, c=6)中的参数1,2,3是无名参数,a=4, b=5, c=6是关键字参数。

2. *args和**kwargs的使用场景

我是在使用torchvision.models中自带的alexnet网络模型的时候,注意到**kwargs这个参数的,def alexnet(pretrained=False, progress=True, **kwargs):中**kwargs参数其实就是传递给类AlexNet的参数,即我们这里需要把 kwargs的位置写成num_classes = X . 可结合下边alexnet的源码来理解。代码如下:

import torch
import torch.nn as nn
from .utils import load_state_dict_from_url


__all__ = ['AlexNet', 'alexnet']


model_urls = {
    'alexnet': 'https://download.pytorch.org/models/alexnet-owt-4df8aa71.pth',
}


class AlexNet(nn.Module):

    def __init__(self, num_classes=1000):
        super(AlexNet, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(64, 192, kernel_size=5, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(192, 384, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(384, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
        )
        self.avgpool = nn.AdaptiveAvgPool2d((6, 6))
        self.classifier = nn.Sequential(
            nn.Dropout(),
            nn.Linear(256 * 6 * 6, 4096),
            nn.ReLU(inplace=True),
            nn.Dropout(),
            nn.Linear(4096, 4096),
            nn.ReLU(inplace=True),
            nn.Linear(4096, num_classes),
        )

    def forward(self, x):
        x = self.features(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x


def alexnet(pretrained=False, progress=True, **kwargs):
    r"""AlexNet model architecture from the
    `"One weird trick..." <https://arxiv.org/abs/1404.5997>`_ paper.

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
        progress (bool): If True, displays a progress bar of the download to stderr
    """
    model = AlexNet(**kwargs)
    if pretrained:
        state_dict = load_state_dict_from_url(model_urls['alexnet'],
                                              progress=progress)
        model.load_state_dict(state_dict)
    return model
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值