模型相关的好用函数

使用模型库来定义模型

torchvision.models.get_model()

官方文档解释:给定 被注册过的模型名称 和 指定的模型参数 ,返回一个实例化的模型

        torchvision.models.get_model(name: str**config: Any)

eg:

model_teacher = torchvision.models.get_model("resnet18", num_classes=100)

返回一个分类的类别数为100的resnet18模型

torchvision.models.__dict__[]()

在[]内填入模型名称,在()内填入指定模型的参数,如下所示:

model_verifier = models.__dict__['mobilenet_v2'](num_classes=100)

 如何知晓支持的模型呢?

——我们可以通过以下代码来输出模型名称,从中选择即可

import torchvision.models as models
print(models.__dict__.keys())

#运行结果
#dict_keys(['__name__', '__doc__', '__package__', '__loader__', '__spec__', '__path__', '__file__', '__cached__', '__builtins__', '_api', '_meta', '_utils', 'alexnet', 'AlexNet', 'AlexNet_Weights', 'convnext', 'ConvNeXt', 'ConvNeXt_Tiny_Weights', 'ConvNeXt_Small_Weights', 'ConvNeXt_Base_Weights', 'ConvNeXt_Large_Weights', 'convnext_tiny', 'convnext_small', 'convnext_base', 'convnext_large', 'densenet', 'DenseNet', 'DenseNet121_Weights', 'DenseNet161_Weights', 'DenseNet169_Weights', 'DenseNet201_Weights', 'densenet121', 'densenet161', 'densenet169', 'densenet201', 'efficientnet', 'EfficientNet', 'EfficientNet_B0_Weights', 'EfficientNet_B1_Weights', 'EfficientNet_B2_Weights', 'EfficientNet_B3_Weights', 'EfficientNet_B4_Weights', 'EfficientNet_B5_Weights', 'EfficientNet_B6_Weights', 'EfficientNet_B7_Weights', 'EfficientNet_V2_S_Weights', 'EfficientNet_V2_M_Weights', 'EfficientNet_V2_L_Weights', 'efficientnet_b0', 'efficientnet_b1', 'efficientnet_b2', 'efficientnet_b3', 'efficientnet_b4', 'efficientnet_b5', 'efficientnet_b6', 'efficientnet_b7', 'efficientnet_v2_s', 'efficientnet_v2_m', 'efficientnet_v2_l', 'googlenet', 'GoogLeNet', 'GoogLeNetOutputs', '_GoogLeNetOutputs', 'GoogLeNet_Weights', 'inception', 'Inception3', 'InceptionOutputs', '_InceptionOutputs', 'Inception_V3_Weights', 'inception_v3', 'mnasnet', 'MNASNet', 'MNASNet0_5_Weights', 'MNASNet0_75_Weights', 'MNASNet1_0_Weights', 'MNASNet1_3_Weights', 'mnasnet0_5', 'mnasnet0_75', 'mnasnet1_0', 'mnasnet1_3', 'mobilenetv2', 'mobilenetv3', 'mobilenet', 'MobileNetV2', 'MobileNet_V2_Weights', 'mobilenet_v2', 'MobileNetV3', 'MobileNet_V3_Large_Weights', 'MobileNet_V3_Small_Weights', 'mobilenet_v3_large', 'mobilenet_v3_small', 'regnet', 'RegNet', 'RegNet_Y_400MF_Weights', 'RegNet_Y_800MF_Weights', 'RegNet_Y_1_6GF_Weights', 'RegNet_Y_3_2GF_Weights', 'RegNet_Y_8GF_Weights', 'RegNet_Y_16GF_Weights', 'RegNet_Y_32GF_Weights', 'RegNet_Y_128GF_Weights', 'RegNet_X_400MF_Weights', 'RegNet_X_800MF_Weights', 'RegNet_X_1_6GF_Weights', 'RegNet_X_3_2GF_Weights', 'RegNet_X_8GF_Weights', 'RegNet_X_16GF_Weights', 'RegNet_X_32GF_Weights', 'regnet_y_400mf', 'regnet_y_800mf', 'regnet_y_1_6gf', 'regnet_y_3_2gf', 'regnet_y_8gf', 'regnet_y_16gf', 'regnet_y_32gf', 'regnet_y_128gf', 'regnet_x_400mf', 'regnet_x_800mf', 'regnet_x_1_6gf', 'regnet_x_3_2gf', 'regnet_x_8gf', 'regnet_x_16gf', 'regnet_x_32gf', 'resnet', 'ResNet', 'ResNet18_Weights', 'ResNet34_Weights', 'ResNet50_Weights', 'ResNet101_Weights', 'ResNet152_Weights', 'ResNeXt50_32X4D_Weights', 'ResNeXt101_32X8D_Weights', 'ResNeXt101_64X4D_Weights', 'Wide_ResNet50_2_Weights', 'Wide_ResNet101_2_Weights', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', 'resnext101_64x4d', 'wide_resnet50_2', 'wide_resnet101_2', 'shufflenetv2', 'ShuffleNetV2', 'ShuffleNet_V2_X0_5_Weights', 'ShuffleNet_V2_X1_0_Weights', 'ShuffleNet_V2_X1_5_Weights', 'ShuffleNet_V2_X2_0_Weights', 'shufflenet_v2_x0_5', 'shufflenet_v2_x1_0', 'shufflenet_v2_x1_5', 'shufflenet_v2_x2_0', 'squeezenet', 'SqueezeNet', 'SqueezeNet1_0_Weights', 'SqueezeNet1_1_Weights', 'squeezenet1_0', 'squeezenet1_1', 'vgg', 'VGG', 'VGG11_Weights', 'VGG11_BN_Weights', 'VGG13_Weights', 'VGG13_BN_Weights', 'VGG16_Weights', 'VGG16_BN_Weights', 'VGG19_Weights', 'VGG19_BN_Weights', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn', 'vgg19', 'vgg19_bn', 'vision_transformer', 'VisionTransformer', 'ViT_B_16_Weights', 'ViT_B_32_Weights', 'ViT_L_16_Weights', 'ViT_L_32_Weights', 'ViT_H_14_Weights', 'vit_b_16', 'vit_b_32', 'vit_l_16', 'vit_l_32', 'vit_h_14', 'swin_transformer', 'SwinTransformer', 'Swin_T_Weights', 'Swin_S_Weights', 'Swin_B_Weights', 'Swin_V2_T_Weights', 'Swin_V2_S_Weights', 'Swin_V2_B_Weights', 'swin_t', 'swin_s', 'swin_b', 'swin_v2_t', 'swin_v2_s', 'swin_v2_b', 'maxvit', 'MaxVit', 'MaxVit_T_Weights', 'maxvit_t', 'detection', 'optical_flow', 'quantization', 'segmentation', 'video', 'get_model', 'get_model_builder', 'get_model_weights', 'get_weight', 'list_models'])

torchvision.models.resnet18()

这种方式适合加载模型库的给定的权重,易于训练
model_teacher = torchvision.models.resnet18(pretrained=True, num_classes=100) #res18

查看模型网络结构

print()

会输出每个层的模块及参数。()内的是定义的模块名字,冒号后边是模块函数及参数

给定输入,确定输出的形状

torch.randn()

通过给出输入的形状,创建符合正态分布的符合要求形状的张量。但常用4D即传入4个数字

a=torch.randn(3,24,24)

print(a.shape)
# torch.Size([3, 24, 24])
print(a)
'''tensor([[[-0.9060,  0.1075, -1.1254,  ...,  0.1310,  0.2196,  0.4092],
         [ 0.4919,  0.6309, -0.0298,  ..., -0.0880, -0.2851, -0.0318],
         [-0.0073,  0.8881, -0.1518,  ...,  0.8028, -0.5416, -2.5675],
         ...,
         [-0.4601, -0.3924, -1.9049,  ...,  1.4624,  0.8603, -0.7537],
         [ 1.4616, -0.0962,  0.7303,  ...,  0.0832,  0.7557,  0.4849],
         [-1.5253, -0.0582, -1.6721,  ..., -0.5345, -0.9942,  0.0203]],

        [[ 0.8230,  0.9749, -1.1581,  ...,  1.2500, -1.2332,  0.3555],
         [-0.1002, -0.9513, -0.2972,  ..., -0.1737,  1.8249, -0.9832],
         [-1.8356,  0.4404, -0.7004,  ...,  0.2948,  0.6764,  0.7496],
         ...,
         [-1.0121, -1.6816, -1.0084,  ..., -1.0904, -0.4518, -0.0906],
         [-0.8290,  0.3990, -1.9928,  ..., -0.0240, -0.7755,  1.6265],
         [ 0.1753,  1.6268,  0.7902,  ..., -0.0964, -0.1302,  1.7886]],

        [[-1.3082,  0.2769, -0.6148,  ..., -1.0288,  0.2551, -0.6365],
         [ 0.1632,  1.1878,  1.0580,  ..., -2.2719, -0.3561,  0.5158],
         [-0.4988, -1.0627,  1.3055,  ...,  1.4051,  1.1880, -0.0343],
         ...,
         [ 0.1920,  0.1599, -0.8272,  ...,  0.5240,  1.2998,  0.0988],
         [-1.3782, -1.2923,  0.3729,  ..., -0.3355, -1.1872,  0.7259],
         [-3.3038,  1.1082, -1.2282,  ..., -1.5526, -0.0775,  1.7922]]])'''

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值