caffe 训练网络模型总结

本文详细记录了使用Caffe进行模型训练的过程,包括模型结构配置、数据集准备、训练日志分析与绘图,以及参数调整技巧。作者分享了在训练过程中遇到的问题与解决方案,为读者提供了宝贵的实践经验。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

这三四个月一直在用caffe训练,出现各种问题,所以写这篇文来记录一下这段时间的辛苦泪水。写的不好,请多多包涵。

 

目录

一、caffe的模型结构

1.网络结构配置文件:

2.网络权重参数文件

3.训练超参数文件

二、数据集

1.下载数据集

2.制作标签文件train.txt 和 val.txt

3.用脚本生成lmdb数据集

三、训练

1.创建train.sh 脚本,并保存训练日志

2.调参

3.matlab绘制caffe训练后的log 中的loss曲线和accucy曲线


一、caffe的模型结构

1.网络结构配置文件:

  训练网络结构 train.prototxt,

  测试网络结构 test.prototxt,

  实施网络结构 deploy.prototxt

2.网络权重参数文件

  该文件用于保存网络各层的权重值,是以*.caffemodel格式存储的文件。而在运行中,其将以Blob数据形式存入内存中

3.训练超参数文件

  用来控制网络训练及测试阶段的超参数,比如测试网络结构配置文件,梯度下降法中的批量、学习率、遗忘因子等参数,测试的间隔迭代次数等等,其同样是用*.prototxt的文件格式(比如solver.prototxt)

二、数据集

  这里以lfw-aligned为例制作deepid2  lmdb数据集

1.下载数据集

  去官网下载。

  Download  -> All images aligned with deep funneling

2.制作标签文件train.txt 和 val.txt

  注意:train. txt  和 val.txt  中数据不能出现重叠,否则会出现过拟合现象

  这里采用python来制作标签文件

  最后将train. txt  和 val.txt 放到 lfw-aligned文件夹下

3.用脚本生成lmdb数据集

  在caffe下创建run.sh,并运行 

注意:数据集必须打乱 即: --shuffle=true

EXAMPLE=examples/deepid2
DATA=data/lfw-aligned/
TOOLS=build/tools

RESIZE_HEIGHT=55
RESIZE_WIDTH=47

echo "creating lmdb..."

rm -rf $EXAMPLE/DeepID2_train_lmdb

rm -rf $EXAMPLE/DeepID2_test_lmdb


$TOOLS/convert_imageset --shuffle=true \
    --resize_height=$RESIZE_HEIGHT \
    --resize_width=$RESIZE_WIDTH \
    $DATA \
    $DATA/train.txt \
    $EXAMPLE/DeepID2_train_lmdb

$TOOLS/convert_imageset --shuffle=true \
    --resize_height=$RESIZE_HEIGHT \
    --resize_width=$RESIZE_WIDTH \
    $DATA \
    $DATA/val.txt \
    $EXAMPLE/DeepID2_test_lmdb

echo "compute image mean..."

$TOOLS/compute_image_mean -backend=lmdb $EXAMPLE/DeepID2_train_lmdb \
  $EXAMPLE/DeepID2_mean.proto

echo "done..."
sh run.sh

三、训练

1.创建train.sh 脚本,并保存训练日志

#!/usr/bin/env sh
set -e

TOOLS=./build/tools
GLOG_logtostderr=0
GLOG_log_dir='./trainLog/deepID2/train.log' 

$TOOLS/caffe train --solver=examples/deepid2/DeepID2_solver.prototxt $@ 2>&1 | tee $GLOG_log_dir

2.调参

  在solver.prototxt 文件中

test interval* train batch size 应该>=train image 总数

test iter * test batch size应该>=test image 总数

  否则容易出现训练曲线震荡问题

  如果loss不收敛,甚至出现 loss = nan 时,base_lr 调小一点,比如将  base_lr = 0.001 改为 base_lr = 0.0001

3.matlab绘制caffe训练后的log 中的loss曲线和accucy曲线

clc;
clear;
% load the log file of caffe model 
fid = fopen('train.log', 'r'); 
tline = fgetl(fid); %get arrays to draw figures 
accuracyIter = [0];%accuracy横坐标 
accuracyArray = [];%accuracy纵坐标 
lossIter = [];%loss横坐标 
lossArray = [];%loss纵坐标 
%record the last line 
lastLine = ''; %read line 
LLastLine = '';
while ischar(tline) 
    %%%%%%%%%%%%%% the accuracy line %%%%%%%%%%%%%% 
    k = strfind(tline, 'Test net output #0'); 
    if (k) 
        k = strfind(tline, 'accuracy1'); 
        if (k) % If the string contain test and accuracy at the same time % The bias from 'accuracy' to the float number 
            indexStart = k + 12; 
            indexEnd = size(tline); 
            str = tline(indexStart : indexEnd(2)); 
            accuracyArray = [accuracyArray, str2num(str)]; 
        end
        % Get the number of index 
        k =strfind(lastLine, 'Restarting');
        if (k)
            lastLine = LLastLine;
        end
        k = strfind(lastLine, 'Iteration'); 
        if (k) 
            indexStart = k + 10; 
            indexEnd = strfind(lastLine, ','); 
            str2 = lastLine(indexStart : indexEnd - 1); 
            accuracyIter = [accuracyIter, str2num(str2)]; 
        end
        % Concatenation of two string 
        res_str = strcat(str2, '/', str); 
    end
    %%%%%%%%%%%%%% the loss line %%%%%%%%%%%%%% 
    k1 = strfind(tline, 'Iteration'); 
    if (k1) 
        k2 = strfind(tline, 'loss'); 
        if (k2) 
            indexStart = k2 + 7; 
            indexEnd = size(tline); 
            str1 = tline(indexStart:indexEnd(2)); 
            indexStart = k1 + 10; 
            indexEnd = strfind(tline, '(') - 1; 
            str2 = tline(indexStart:indexEnd); 
            res_str1 = strcat(str2, '/', str1); 
            lossIter = [lossIter, str2num(str2)]; 
            lossArray = [lossArray, str2num(str1)]; 
        end
    end
    LLastLine = lastLine;
    lastLine = tline; 
    tline = fgetl(fid); 
end
%draw figure 
figure;
plot(accuracyIter, accuracyArray);
%title('iteration vs accurancy');	%绘制accuracy曲线 
hold on
plot(lossIter, lossArray);
%title('iteration vs loss');	%绘制loss曲线

结果如下:(橘红色为loss曲线,蓝色为accuracy曲线,横坐标为iteration。训练的效果不好)

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值