【自用】PyTorch实现神经网络模型的代码(小技巧

以自己写的vgg模型为例:
import torch
import torch.nn as nn
import numpy as np
import collections

def conv(in_channels, out_channels, kernel_size, stride=1, padding=0, bn=0, relu=1, pooling=0):
    modules = [nn.Conv2d(in_channels,out_channels,kernel_size,stride=stride,padding=padding)]
    if bn:
        modules.append(nn.BatchNorm2d(out_channels))
    if relu:
        modules.append(nn.ReLU(inplace=True))
    if pooling:
        modules.append(nn.MaxPool2d(2,2))
    return nn.Sequential(*modules)

def fc(in_channels, out_channels,relu=0):
    modules = [nn.Linear(in_channels,out_channels)]
    if relu:
        modules.append(nn.ReLU(inplace=True))
    return nn.Sequential(*modules)

def conv_dw(in_channels,out_channels):
    """
    深度可分解卷积
    """
    return nn.Sequential(
        # point-wise
        nn.Conv2d(in_channels,in_channels,3,padding=1,groups=in_channels),
        nn.BatchNorm2d(in_channels),
        nn.ReLU(inplace=True),
        
        # depth-wise
        nn.Conv2d(in_channels,out_channels,1),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(inplace=True)
    )

def load_weight(model, path):
    checkpoint = torch.load(path)
    load_state = checkpoint['state_dict'] # 这一步还是需要先去看一下预训练模型中key长什么样
    model_state = model.state_dict()
    new_state = collections.OrderedDict() # 生成一个有顺序的字典
    for k in model_state.keys():
        if k in load_state and load_state[k].size() == model_state[k].size():
            new_state[k] = load_state[k]
    model.load_state_dict(new_state)
        
class vgg19(nn.Module):
    def __init__(self,classes_num):
        super(vgg19,self).__init__()
        self.classes_num = classes_num
        self.seq = nn.Sequential(
            conv(3,64,3),
            conv(64,64,3,pooling=1),
            conv(64,128,3),
            conv(128,128,3,pooling=1),
            conv(128,256,3),
            conv(256,256,3),
            conv(256,256,3),
            conv(256,256,3,pooling=1),
            conv(256,512,3),
            conv(512,512,3),
            conv(512,512,3),
            conv(512,512,3,pooling=1),
        )
        self.fc = nn.Sequential(
            fc(512*7*7,4096),
            fc(4096,4096),
            fc(4096,self.classes_num,relu=0)
        )
    def forward(self, x):
        x = self.seq(x)
        x = x.reshape(-1,512*7*7)
        x = self.fc(x)    
        return x

if __name__ == '__main__':
    image = torch.randn(1,3,224,224)
    model = vgg19(1000)
    out = model(image)
    print(out.shape)
一些必要的解释:
  1. nn.Sequential() 可以把若干个操作包括在一个容器内,在forward()函数里就不用再挨个把对输入的操作复述一遍。
  2. 通常来说,每个3x3的conv2d后面都会加上bn和relu,因此可以先自定义一个conv函数,把bn和relu都加进去,能够有效减少代码重复量。
  3. python中,带×号(星号)的意思表示将参数的所有内容依次作为形参传入,带××(两个星号)的意思是将字典型的所有内容的数值作为形参传入。
  4. nn.ReLU(inplace=True)的意思是直接改变输入,而不另创一个新的输出。可以节省空间和时间。
  5. 调用collections.OrderedDict()函数可以生成一个有顺序的字典。一般的字典是无序的,输出是顺序不一定。
### 太原理工大学信息安全技术 #### 防火墙的优点 防火墙具有多种优点,能够有效提升网络安全水平。具体而言: - **防止非授权访问**:防火墙可以阻止未经授权的外部用户进入内部网络环境,从而保护内网资源的安全[^1]。 - **缓解IP地址短缺**:通过采用NAT(Network Address Translation)技术,防火墙有助于解决IPv4地址不足的问题,使得多个设备可以通过少量公共IP地址访问互联网。 - **监控与警报功能**:防火墙具备方便地监视网络活动的能力,并能在检测到异常行为时发出警告通知管理员采取相应措施。 #### Linux系统中的关键文件路径及其意义 对于Linux操作系统来说,理解各个重要目录的作用至关重要: - `/var`:此目录包含了系统运行过程中经常变化的数据文件,如日志记录、邮件队列等[^2]。 - `/tmp`:这是一个临时存储区域,默认情况下任何用户都可以在此创建文件;然而,由于权限设置不当等原因,该位置也可能成为恶意软件植入的目标之一。 - `/etc`:这里存放着众多用于管理和配置系统的文件以及子目录,例如网络接口设定、服务启动脚本等等。 - `/bin`:主要用来保存那些对系统至关重要的命令工具二进制文件,在单用户模式下也能正常使用的程序通常会放在这个位置。 #### 用户ID(UID) 在Linux系统中,每个账户都会被分配一个唯一的整数编号即用户ID(UID),它用于区分不同身份的操作者。其中: - `0`表示超级用户root; - `1~99`区间内的数值一般是留给特定用途的服务进程所用; - 而普通用户的起始范围是从`100`开始计数,在某些发行版里可能是从`500`起步。 ```bash # 查询当前登录用户的UID id -u ```
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值