在深度学习中,可以通过学习曲线评估当前训练状态:
- train loss 不断下降,test loss 不断下降,说明网络仍然在认真学习中。
- train loss 不断下降,test loss 趋于不变,说明网络过拟合。
- train loss 趋于不变,test loss 趋于不变,说明学习遇到瓶颈,需减小学习速率或者批量数据尺寸。
- train loss 趋于不变,test loss 不断下降,说明数据集 100% 有问题。
- train loss 不断上升,test loss不断上升(最终为NaN),可能网络结构设计不当、训练超参数设置不当、程序bug等某个问题引起,需要进一步定位。
Linux下的MATLAB代码:
// 提取log文件中的loss值shell命令:cat train_log_file | grep ”Train net output ” | awk ‘{print $11}’
clear;
clc;
close all;
train_log_file = 'train.log';
train_interval = 100;
test_interval = 200;
[~, train_string_output] = dos(['cat ', train_log_file, ' | grep ''Train net output #0'' | awk ''{print $11}''']);
train_loss = str2num(train_string_output);
n = 1 : length(train_loss);
idx_train = (n - 1) * train_interval;
[~, test_string_output] = dos(['cat ', train_log_file, ' | grep ''Test net output #1'' | awk ''{print $11}''']);
test_loss = str2num(test_string_output);
m = 1 : length(test_loss);
idx_test = (m - 1) * test_interval;
figure;
plot(idx_train, train_loss);
hold on;
plot(idx_test, test_loss);
grid on;
legend('Train Loss', 'Test Loss');
xlabel('iterations');
ylabel('loss');
title(' Train & Test Loss Curve');
Window下的Python3(Anaconda3+Pycharm)代码:
"./bin/caffe.exe" train --solver=./examples/mnist/lenet_solver.prototxt >./examples/mnist/log/mnist_Lenet_train_test.log 2>&1
pause
命令>./examples/mnist/log/mnist_Lenet_train_test.log 2>&1
表示训练日志的输出。
parse_log.py和extract_seconds.py文件用于解析训练日志:
parse_log.py源码:
import re
from examples.mnist.log.extract_seconds import *
import csv
from collections import OrderedDict
def parse_log(log_file_name):
"""
Parse log file
:param log_file_name: the name of log file
:return: (train_dict_list, test_dict_list)
"""
regex_iteration = re.compile('Iteration (\d+)')
regex_train_output = re.compile('Train net output #(\d+): (\S+) = ([.\deE+-]+)')
regex_test_output = re.compile('Test net output #(\d+): (\S+) = ([.\deE+-]+)')
regex_learning_rate = re.compile('lr = ([-+]?[0-9]*\.?[0-9]+([eE]?[-+]?[0-9]+)?)')
# Pick out lines of interest
iteration = -1
learning_rate = float('NaN')
train_dict_list = []
test_dict_list = []
train_row = None
test_row = None
logfile_year = get_log_created_year(log_file_name)
with open(log_file_name) as f:
start_time = get_start_time(f, logfile_year)
last_time = start_time
for line in f:
iteration_match = regex_iteration.search(line)
if iteration_m