1.define the class :losshistory 定义loss类
import os
import numpy as np
import scipy.signal
from matplotlib import pyplot as plt
class LossHistory():
def __init__(self, log_dir):
import datetime
curr_time = datetime.datetime.now()
time_str = datetime.datetime.strftime(curr_time,'%Y_%m_%d_%H_%M_%S')
self.log_dir = log_dir
self.time_str = time_str
self.save_path = os.path.join(self.log_dir, "loss_" + str(self.time_str))
self.acc = []
self.losses = []
self.val_loss = []
os.makedirs(self.save_path)
def append_loss(self, acc, loss, val_loss):
self.acc.append(acc)
self.losses.append(loss)
self.val_loss.append(val_loss)
with open(os.path.join(self.save_path, "epoch_acc_" + str(self.time_str) + ".txt"), 'a') as f: