参考博客:
pytorch系列10 — 如何自定义参数初始化方式 ,apply()
PyTorch中model.modules(), model.named_modules(), model.children(), model.named_children()的详解
Pytroch常见的模型参数初始化方法有apply和model.modules()。Pytroch会自动给模型进行初始化,当需要自己定义模型初始化时才需要这两个方法。
一、 使用apply函数进行初始化
举一个包含自己命名层名称和利用nn.Sequential实现的网络。
from torch import nn
import torch
class Net(nn.Module):
def __init__(self):
super().__init__()
self.layer1 = nn.Linear(2, 4)
self.relu1 = nn.ReLU()
self.layer2 = nn.Sequential(
nn.Linear(4, 2),
nn.ReLU(),
nn.Linear(2, 2),
)
def forward(x):
output = self.layer2(self.relu1(self.layer1(x)))
return output
a