# learn theta -> grid -> img_out:5x5
# loss 为 MSE
t_loss = np.inf
learning_rate = 5*1e-4
loss_fn = torch.nn.MSELoss(reduction='mean')
log = "./loss_log.txt"
# model = STN(0.5,6,img_size)
# model = STN(0.76,18,img_size,1,16*16*20,img_size*img_size)
model = Net_9x9_theta_learn_grid()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
pre_train = False
if pre_train:
model.load_state_dict(torch.load('./min_9x9_model_img_out.pkl'))
optimizer.load_state_dict(torch.load("./latest_optimizer.pth"))
with open(
保存网络模型和优化器使得下次训练从断开的epoch地方开始继续训练
最新推荐文章于 2023-05-29 10:20:36 发布