深度学习八股文
学习记录:一些深度学习相关代码的积累。
main.py
设置各种参数,执行数据加载、数据增广、优化器、训练流程
-> 初始化参数(gpu等)
-> 初始化dataset, 定义data_loader:torch.utils.data.DataLoader()
-> 定义modal(考虑是否并行:数据太多时使用、checkpoint)
-> 定义优化算法:optim.Adam(modal.parameters(), lr=args.d_lr, betas=args.beta)
-> 定义损失函数/目标函数criterion
(nn.CrossEntropyLoss()、nn.BCELoss()、nn.L1Loss()、nn.MSELoss())
-> 开始epoch循环训练。
start = time.time()
argparse
argparse是python自带的命令行参数解析包,可以用来方便地读取命令行参数。官方文档
argparse.ArgumentParser()
创建新对象
add_argument()
添加参数
parse_args()
读取参数
parse_known_args()
读取参数并将将未知参数以列表形式返回
import argparse
parse = argparse.ArgumentParser()
parse.add_argument('--flag_int', type=int, default=2, help='flag_int')
# opt1 = parse.parse_args()
# print(opt1)
# 将所有可用的trainer选项添加到argparse
parser = Trainer.add_argparse_args(parser)
opt2 = parse.parse_known_args()
print(opt2)
>>> python test.py --flag_int 3
>>> (Namespace(flag_int=3), [])
Trainer.add_argparse_args(parser)
(lightning)将所有可用的trainer选项添加到argparse
加载数据
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
dataset1 = datasets.MNIST('../data', train=True, download=True, transform=transform)
dataset2 = datasets.MNIST('../data', train=False, transform=transform)
train_loader = torch.utils.data.DataLoader(dataset1,**train_kwargs)
test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs)
训练流程
model = Net().to(device)
optimizer = optim.Adadelta(model.parameters(), lr=args.lr)
scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)
for epoch in range