参考rcan的程序
在utils.py添加
import torch
import math
import os
from functools import reduce
import numpy as np
import imageio as misc
import time
import datetime
import torch.optim as optim
#以上import哪个有用不知道,懒得试 全复制过来加载了
class checkpoint():
def __init__(self):
# self.args = args
self.ok = True
self.log = torch.Tensor()
# now = datetime.datetime.now().strftime('%Y-%m-%d-%H:%M:%S')
self.dir = r'D:\LY\DLmodify\3Deblur_cnn\Deblur_cvpr\document\3test\modelDict'
def _make_dir(path):
if not os.path.exists(path): os.makedirs(path)
_make_dir(self.dir)
#_make_dir(self.dir + '/model')
open_type = 'a' if os.path.exists(self.dir + '/log_614.txt') else 'w'
self.log_file = open(self.dir + '/log_614.txt', open_type)
def add_log(self, log):
self.log = torch.cat([self.log, log])
def write_log(self, log, refresh=False):
print(log)
self.log_file.write(log + '\n')
if refresh:
self.log_file.close()
self.log_file = open(self.dir + '/log.txt', 'a')
在train.py和main里改动
main.py
from __future__ import print_function
import argparse
from math import log10
import os
from typing import Any
import torch
import torch.nn as nn
import torch.optim as optim
import torch.backends.cudnn as cudnn
from torch.autograd import Variable
from torch.utils.data import DataLoader
from tqdm import tqdm
#from dataset import MyDataset,dataset_split
from utils import *
from train import *
#from data import get_training_set
import pdb
import socket
import time
# Training settings
parser = argparse.ArgumentParser(description='PyTorch Super Res Example')
parser.add_argument('--batchSize', type=int, default=640, help='training batch size')
parser.add_argument('--nEpochs', type=int, default=1000, help='number of epochs to train for')
parser.add_argument('--snapshots', type=int, default=50, help='Snapshots')
parser.add_argument('--lr', type=float, default=1e-2, help='Learning Rate. Default=0.0001')
parser.add_argument('--threads', type=int, default=5, help='number of threads for data loader to use') # 由1 改为0
parser.add_argument('--seed', type=int, default=123, help='random seed to use. Default=123')
parser.add_argument('--pretrained_sr', default='MIX2K_LR_aug_x4dl10DBPNITERtpami_epoch_399.pth', help='sr pretrained base model')
parser.add_argument('--pretrained', type=bool, default=False)
parser.add_argument('--model_type', type=str, default='MyCNN')
#parser.add_argument('--data_root', default=r'D:\LY\DLmodify\3Deblur_cnn\Deblur_cvpr\dataset', help='all dataset Location ')
#parser.add_argument('--train_path', default=r'D:\LY\DLmodify\3Deblur_cnn\Deblur_cvpr\dataset\train_small', help='dataset Location')
parser.add_argument('--train_path', default=r'C:\train_set', help='dataset Location')
#parser.add_argument('--model_save_path', default=r'D:\LY\DLmodify\3Deblur_cnn\Deblur_cvpr\document\3test\modelDict\test_73.pth', help='model_save_path')
parser.add_argument('--model_save', default=r'D:\LY\DLmodify\3Deblur_cnn\Deblur_cvpr\document\3test\modelDict', help='model_save')
parser.add_argument('--start_iter', type=int, default=1, help='Starting Epoch')
checkpoint = checkpoint() #增加的
opt = parser.parse_args()
#gpus_list = range(opt.gpus)
hostname = str(socket.gethostname())
cudnn.benchmark = True
print(opt)
def print_network(net):
num_params = 0
for param in net.parameters():
num_params += param.numel()
print(net)
#改动1
#print('Total number of parameters: %d' % num_params)
checkpoint.write_log('Total number of parameters: %d' % num_params)
def main():
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
#改动2
checkpoint.write_log('===> Loading datasets')
#print('===> Loading datasets')
train_path = opt.train_path
train_path = train_path
train_ds = MyDataset(train_path)
new_train_ds, validate_ds = dataset_split(train_ds, 0.8)
train_loader = torch.utils.data.DataLoader(train_ds, batch_size=opt.batchSize,shuffle=True, pin_memory=True, num_workers=3)
new_train_loader = torch.utils.data.DataLoader(new_train_ds, batch_size=opt.batchSize,shuffle=True, pin_memory=True, num_workers=3)
validate_loader = torch.utils.data.DataLoader(validate_ds, batch_size=opt.batchSize,shuffle=True, pin_memory=True, num_workers=3)
print('===> Building model ', opt.model_type)
if opt.model_type == 'MyCNN':
net = MyCNN()
criterion = torch.nn.CrossEntropyLoss()
# criterion = torch.nn.NLLLoss()
print('---------- Networks architecture -------------')
print_network(net)
print('----------------------------------------------')
optimizer = optim.SGD(net.parameters(), lr=opt.lr, momentum=0.8)
#optimizer = optim.Adam(net.parameters(),lr=opt.lr) # 使用Adam
for epoch in range(opt.start_iter, opt.nEpochs + 1):
#改动3 增加一个checkpoint位置
train(epoch,new_train_loader,device,net,criterion,optimizer,checkpoint)
validate(validate_loader, device, net, criterion)
#改动4
checkpoint.write_log("validate acc:{}".format(validate(validate_loader,device,net,criterion)))
#print("validate acc:",validate(validate_loader,device,net,criterion))
if (epoch+1) % 70 == 0:
for param_group in optimizer.param_groups:
param_group['lr'] /= 5.0
#改动5
#print('Learning rate decay: lr={}'.format(optimizer.param_groups[0]['lr']))
checkpoint.write_log('Learning rate decay: lr={}'.format(optimizer.param_groups[0]['lr']))
if (epoch + 1) % (opt.snapshots) == 0:
model_save_path = opt.model_save + r'\classify73_{}.pth'.format(epoch)
torch.save(net, model_save_path)
if __name__ == '__main__':
main()
在train.py里
from dataset import MyDataset,dataset_split
#from config import config as C
from model import MyCNN
import torch.optim as optim
from utils import *
import time
# 改动1 增加一个ckp
def train( epoch, train_loader, device, model, criterion, optimizer,ckp):
#改动2
ckp = ckp
model = model.to(device)
criterion = criterion.to(device)
#for epoch in range(epochs):
model.train()
top1 = AvgrageMeter()
train_loss = 0.0
t2 = time.time()
#改动3
ckp.write_log('Learning rate : lr={}'.format(optimizer.param_groups[0]['lr']))
#print('Learning rate : lr={}'.format(optimizer.param_groups[0]['lr']))
for i, data in enumerate(train_loader, 0): # 0是下标起始位置默认为0
t3 = time.time()
inputs, labels = data[0].to(device), data[1].to(device)
# 初始为0,清除上个batch的梯度信息
#print(inputs.shape)
optimizer.zero_grad()
t0 = time.time()
outputs = model(inputs)
outputs = torch.squeeze(outputs)
#print('output shape',outputs.shape)
#print('output data',outputs.data)
#print(torch.max(outputs.data,1))
#print(labels.shape)
#print(outputs.shape)
# print(torch.squeeze(outputs).shape)
#print('train predicted',train_predicted)
#print('label',labels.data)
#loss = criterion(outputs,labels)
loss = criterion(outputs,labels)
#print('loss.data',loss.data)
#print('loss.item',loss.item())
t1 = time.time()
loss.backward()
optimizer.step()
prec1, prec2 = accuracy(outputs, labels, topk=(1, 2))
n = inputs.size(0)
top1.update(prec1.item(), n)
train_loss += loss.data
#print("===> Epoch[{}]({}/{}): Loss: {:.4f} ||train_acc:{:.4f}%||Timer: {:.4f} sec || Timer: {:.4f} sec.".format(epoch, i, len(train_loader), loss.data,train_correct/train_total*100, (t2 - t0),(t1 - t0)))
#改改4
ckp.write_log(
"===> Epoch[{}]({}/{}): Loss: {:.4f} ||train_acc:{:.4f}%||train_loader Timer: {:.4f} sec || Timer: {:.4f} sec.".format(
epoch, i, len(train_loader), train_loss / (i + 1), top1.avg, (t3 - t2), (t1 - t0)))
#print("===> Epoch[{}]({}/{}): Loss: {:.4f} ||train_acc:{:.4f}%||train_loader Timer: {:.4f} sec || Timer: {:.4f} sec.".format(epoch, i, len(train_loader), train_loss / (i + 1),top1.avg, (t3 - t2),(t1 - t0)))
t2 = time.time()
#print("===> Epoch {} Complete: Avg. Loss: {:.4f} || train_acc:{:.4f}% ".format(epoch, train_loss / len(train_loader),top1.avg))
#改动5
ckp.write_log('Finished Training')
#print('Finished Training')
本文介绍了如何在PyTorch项目中应用RCAN算法,并对训练流程进行优化,包括使用checkpoint类管理模型状态、学习率调整、日志记录等。通过修改`train.py`和`main.py`文件,展示了如何跟踪参数数量、监控训练过程和保存模型进度。
843

被折叠的 条评论
为什么被折叠?



