Windows+Python3下绘制Caffe训练日志中的Loss和Accuracy曲线图

在深度学习中,可以通过学习曲线评估当前训练状态:

  1. train loss 不断下降,test loss 不断下降,说明网络仍然在认真学习中。
  2. train loss 不断下降,test loss 趋于不变,说明网络过拟合。
  3. train loss 趋于不变,test loss 趋于不变,说明学习遇到瓶颈,需减小学习速率或者批量数据尺寸。
  4. train loss 趋于不变,test loss 不断下降,说明数据集 100% 有问题。
  5. train loss 不断上升,test loss不断上升(最终为NaN),可能网络结构设计不当、训练超参数设置不当、程序bug等某个问题引起,需要进一步定位。

Linux下的MATLAB代码:
// 提取log文件中的loss值shell命令:cat train_log_file | grep ”Train net output ” | awk ‘{print $11}’

clear;
clc;
close all;
train_log_file = 'train.log';
train_interval = 100;
test_interval = 200;
[~, train_string_output] = dos(['cat ', train_log_file, ' | grep ''Train net output #0'' | awk ''{print $11}''']);
train_loss = str2num(train_string_output);
n = 1 : length(train_loss);
idx_train = (n - 1) * train_interval;
[~, test_string_output] = dos(['cat ', train_log_file, ' | grep ''Test net output #1'' | awk ''{print $11}''']);
test_loss = str2num(test_string_output);
m = 1 : length(test_loss);
idx_test = (m - 1) * test_interval;
figure;
plot(idx_train, train_loss);
hold on;
plot(idx_test, test_loss);

grid on;
legend('Train Loss', 'Test Loss');
xlabel('iterations');
ylabel('loss');
title(' Train & Test Loss Curve');

Window下的Python3(Anaconda3+Pycharm)代码:

"./bin/caffe.exe" train --solver=./examples/mnist/lenet_solver.prototxt >./examples/mnist/log/mnist_Lenet_train_test.log 2>&1
pause

命令>./examples/mnist/log/mnist_Lenet_train_test.log 2>&1表示训练日志的输出。
parse_log.py和extract_seconds.py文件用于解析训练日志:
parse_log.py源码:

import re
from examples.mnist.log.extract_seconds import *
import csv
from collections import OrderedDict

def parse_log(log_file_name):
    """
    Parse log file
    :param log_file_name: the name of log file
    :return: (train_dict_list, test_dict_list)
    """
    regex_iteration = re.compile('Iteration (\d+)')
    regex_train_output = re.compile('Train net output #(\d+): (\S+) = ([.\deE+-]+)')
    regex_test_output = re.compile('Test net output #(\d+): (\S+) = ([.\deE+-]+)')
    regex_learning_rate = re.compile('lr = ([-+]?[0-9]*\.?[0-9]+([eE]?[-+]?[0-9]+)?)')

    # Pick out lines of interest
    iteration = -1
    learning_rate = float('NaN')
    train_dict_list = []
    test_dict_list = []
    train_row = None
    test_row = None

    logfile_year = get_log_created_year(log_file_name)
    with open(log_file_name) as f:
        start_time = get_start_time(f, logfile_year)
        last_time = start_time

        for line in f:
            iteration_match = regex_iteration.search(line)
            if iteration_m
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值