断点续训

本文详细介绍了使用TensorFlow实现神经网络断点续训的过程,包括前向传播、反向传播及模型验证,确保训练的连续性和高效性。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

断点续训:在进行神经网络训练过程中由于一些因素导致训练无法进行,需要保存当前的训练结果下次接着训练
全连接反向传播神经网络中,训练过程的代码如下:

#coding:utf-8
#1前向传播过程
import tensorflow as tf

#网络输入节点为784个(代表每张输入图片的像素个数)
INPUT_NODE = 784
#输出节点为10个(表示输出为数字0-9的十分类)
OUTPUT_NODE = 10
#隐藏层节点500个
LAYER1_NODE = 500


def get_weight(shape, regularizer):
    #参数满足截断正态分布,并使用正则化,
    w = tf.Variable(tf.truncated_normal(shape,stddev=0.1))
    #w = tf.Variable(tf.random_normal(shape,stddev=0.1))
    #将每个参数的正则化损失加到总损失中
    if regularizer != None: tf.add_to_collection('losses', tf.contrib.layers.l2_regularizer(regularizer)(w))
    return w


def get_bias(shape):  
    #初始化的一维数组,初始化值为全 0
    b = tf.Variable(tf.zeros(shape))  
    return b
    
def forward(x, regularizer):
    #由输入层到隐藏层的参数w1形状为[784,500]
    w1 = get_weight([INPUT_NODE, LAYER1_NODE], regularizer)
    #由输入层到隐藏的偏置b1形状为长度500的一维数组,
    b1 = get_bias([LAYER1_NODE])
    #前向传播结构第一层为输入 x与参数 w1矩阵相乘加上偏置 b1 ,再经过relu函数 ,得到隐藏层输出 y1。
    y1 = tf.nn.relu(tf.matmul(x, w1) + b1)
    #由隐藏层到输出层的参数w2形状为[500,10]
    w2 = get_weight([LAYER1_NODE, OUTPUT_NODE], regularizer)
    #由隐藏层到输出的偏置b2形状为长度10的一维数组
    b2 = get_bias([OUTPUT_NODE])
    #前向传播结构第二层为隐藏输出 y1与参 数 w2 矩阵相乘加上偏置 矩阵相乘加上偏置 b2,得到输出 y。
    #由于输出 。由于输出 y要经过softmax oftmax 函数,使其符合概率分布,故输出y不经过 relu函数
    y = tf.matmul(y1, w2) + b2
    return y
#coding:utf-8
#2反向传播过程
#引入tensorflow、input_data、前向传播mnist_forward和os模块
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import mnist_forward
import os

#每轮喂入神经网络的图片数
BATCH_SIZE = 200
#初始学习率
LEARNING_RATE_BASE = 0.1
#学习率衰减率
LEARNING_RATE_DECAY = 0.99
#正则化系数
REGULARIZER = 0.0001
#训练轮数
STEPS = 50000
#滑动平均衰减率
MOVING_AVERAGE_DECAY = 0.99
#模型保存路径
MODEL_SAVE_PATH="./model/"
#模型保存名称
MODEL_NAME="mnist_model"


def backward(mnist):
    #用placeholder给训练数据x和标签y_占位
    x = tf.placeholder(tf.float32, [None, mnist_forward.INPUT_NODE])
    y_ = tf.placeholder(tf.float32, [None, mnist_forward.OUTPUT_NODE])
    #调用mnist_forward文件中的前向传播过程forword()函数,并设置正则化,计算训练数据集上的预测结果y
    y = mnist_forward.forward(x, REGULARIZER)
    #当前计算轮数计数器赋值,设定为不可训练类型
    global_step = tf.Variable(0, trainable=False)

    #调用包含所有参数正则化损失的损失函数loss
    ce = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=y, labels=tf.argmax(y_, 1))
    cem = tf.reduce_mean(ce)
    loss = cem + tf.add_n(tf.get_collection('losses'))
    #设定指数衰减学习率learning_rate
    learning_rate = tf.train.exponential_decay(
        LEARNING_RATE_BASE,
        global_step,
        mnist.train.num_examples / BATCH_SIZE, 
        LEARNING_RATE_DECAY,
        staircase=True)

    #使用梯度衰减算法对模型优化,降低损失函数
    #train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss, global_step=global_step)
    train_step = tf.train.MomentumOptimizer(learning_rate,0.9).minimize(loss, global_step=global_step)
    #train_step = tf.train.AdamOptimizer(learning_rate).minimize(loss, global_step=global_step)
    #定义参数的滑动平均
    ema = tf.train.ExponentialMovingAverage(MOVING_AVERAGE_DECAY, global_step)
    ema_op = ema.apply(tf.trainable_variables())
    #实例化可还原滑动平均的saver 
    #在模型训练时引入滑动平均可以使模型在测试数据上表现的更加健壮
    with tf.control_dependencies([train_step,ema_op]):
        train_op = tf.no_op(name='train')

    saver = tf.train.Saver()

    with tf.Session() as sess:
        #所有参数初始化
        init_op = tf.global_variables_initializer()
        sess.run(init_op)
       
        #断点续训,加入ckpt操作
    ckpt = tf.train.get_checkpoint_state(MODEL_SAVE_PATH)
        if ckpt and ckpt.model_checkpoint_path:
            saver.restore(sess, ckpt.model_checkpoint_path)

        #每次喂入batch_size组(即200组)训练数据和对应标签,循环迭代steps轮
        for i in range(STEPS):
            xs, ys = mnist.train.next_batch(BATCH_SIZE)
            _, loss_value, step = sess.run([train_op, loss, global_step], feed_dict={x: xs, y_: ys})
            if i % 1000 == 0:
                print("After %d training step(s), loss on training batch is %g." % (step, loss_value))
                #将当前会话加载到指定路径
                saver.save(sess, os.path.join(MODEL_SAVE_PATH, MODEL_NAME), global_step=global_step)


def main():
    #读入mnist
    mnist = input_data.read_data_sets("./data/", one_hot=True)
    #反向传播
    backward(mnist)

if __name__ == '__main__':
    main()
#coding:utf-8
#验证网络的准确性和泛化性
import time
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import mnist_forward
import mnist_backward
#程序5秒的循环间隔时间
TEST_INTERVAL_SECS = 5

def test(mnist):
    #利用tf.Graph()复现之前定义的计算图
    with tf.Graph().as_default() as g:
        #利用placeholder给训练数据x和标签y_占位
        x = tf.placeholder(tf.float32, [None, mnist_forward.INPUT_NODE])
        y_ = tf.placeholder(tf.float32, [None, mnist_forward.OUTPUT_NODE])
        #调用mnist_forward文件中的前向传播过程forword()函数
        y = mnist_forward.forward(x, None)
        #实例化具有滑动平均的saver对象,从而在会话被加载时模型中的所有参数被赋值为各自的滑动平均值,增强模型的稳定性
        ema = tf.train.ExponentialMovingAverage(mnist_backward.MOVING_AVERAGE_DECAY)
        ema_restore = ema.variables_to_restore()
        saver = tf.train.Saver(ema_restore)
        #计算模型在测试集上的准确率
        correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
        accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

        while True:
            with tf.Session() as sess:
                #加载指定路径下的ckpt
                ckpt = tf.train.get_checkpoint_state(mnist_backward.MODEL_SAVE_PATH)
                #若模型存在,则加载出模型到当前对话,在测试数据集上进行准确率验证,并打印出当前轮数下的准确率
                if ckpt and ckpt.model_checkpoint_path:
                    saver.restore(sess, ckpt.model_checkpoint_path)
                    global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]
                    accuracy_score = sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels})
                    print("After %s training step(s), test accuracy = %g" % (global_step, accuracy_score))
                #若模型不存在,则打印出模型不存在的提示,从而test()函数完成
                else:
                    print('No checkpoint file found')
                    return
            time.sleep(TEST_INTERVAL_SECS)

def main():
    #加载指定路径下的测试数据集
    mnist = input_data.read_data_sets("./data/", one_hot=True)
    test(mnist)

if __name__ == '__main__':
    main()

关键处理:加入 ckpt 操作:
ckpt = tf.train.get_checkpoint_state(MODEL_SAVE_PATH)
if ckpt and ckpt.model_checkpoint_path:
saver.restore(sess, ckpt.model_checkpoint_path)
1、注解:
1)tf.train.get_checkpoint_state(checkpoint_dir,latest_filename=None)
该函数表示如果断点文件夹中包含有效断点状态文件,则返回该文件。
参数说明:checkpoint_dir:表示存储断点文件的目录
latest_filename=None:断点文件的可选名称,默认为“checkpoint”
2)saver.restore(sess, ckpt.model_checkpoint_path)
该函数表示恢复当前会话,将 ckpt 中的值赋给 w 和 b。
参数说明:sess:表示当前会话,之前保存的结果将被加载入这个会话
ckpt.model_checkpoint_path:表示模型存储的位置,不需要提供模
型的名字,它会去查看 checkpoint 文件,看看最新的是谁,叫做什么。

 

注:本文章通过观看北京大学曹健老师的Tensorflow视频,笔记总结而来的。

 

转载于:https://www.cnblogs.com/fcfc940503/p/11001924.html

<think>好的,我现在需要帮助用户解决如何在YOLO模型训练中实现断点续训的问题。首先,我得回想一下之前学过的相关知识,以及用户提供的参考引用内容。 用户提到三个引用,都是关于YOLO的train文件、断点续训和自动训练工具的。引用1和3直接提到了断点续训的方法,尤其是引用3给出了两种方法。用户的问题是如何实现断点续训,所以需要结合这些信息来回答。 首先,我应该明确断点续训的基本概念,就是在训练中断后能够从上次停止的地方继续训练,而不是从头开始。这在长时间的训练任务中非常重要,可以节省时间和计算资源。 接下来,根据引用3的内容,方法一是通过修改代码,加载最新的权重和优化器状态,并调整起始的epoch数。而方法二则是使用resume参数,可能在训练脚本中已经内置了这个功能,用户只需要设置参数即可。比如在YOLOv5或YOLOv8中,常用的命令行参数是--resume,这样脚本会自动处理断点恢复。 然后,我需要将这些方法具体化,给出步骤。比如对于方法一,可能需要用户找到保存的last.pt文件,包含模型权重、优化器状态等信息,然后在代码中加载这些参数,并设置start_epoch为上次的epoch+1。同时,调整学习率调度器,确保学习率正确恢复。 对于方法二,用户只需在启动训练时添加--resume参数,脚本会自动检测最新的检查点并恢复训练。这里可能需要给出具体的命令示例,例如在YOLOv5中的使用方式。 另外,注意事项也很重要,比如需要确保每次训练保存检查点,检查文件命名和路径是否正确,环境是否一致,以及验证恢复后的训练是否正常。这些能帮助用户避免常见的问题。 用户提供的引用1和2可能还涉及train文件的结构和自动训练工具的使用,虽然问题主要聚焦在断点续训,但或许可以提到自动工具如何简化这个过程,比如引用2中的自动训练工具可能内置了断点恢复功能,用户不需要手动处理。 最后,生成的相关问题需要考虑用户可能的后续疑问,比如不同YOLO版本的差异、检查点保存频率、恢复训练后的性能影响等。这些问题有助于用户进一步理解和应用断点续训功能。 总结下来,回答的结构应该是先解释断点续训的意义,然后分方法详细说明,接着注意事项,最后相关问题。需要确保使用正确的中文术语,并引用提供的参考文献,比如引用3提到的方法一和方法二,引用1和2可能作为补充信息。</think>### YOLO模型断点续训实现方法 断点续训功能可通过以下两种方式实现(以YOLOv5/YOLOv8为例): #### 方法一:手动加载检查点 1. **定位检查点文件** 在训练输出目录中找到最新保存的`last.pt`文件,该文件包含: - 模型权重 - 优化器状态 - 当前epoch数 - 训练参数配置 2. **修改训练代码** ```python # 加载检查点时指定resume参数 model = YOLO('yolov8n.yaml').load('last.pt') # 加载模型架构和权重 results = model.train(resume=True) # 自动继承上次训练参数[^3] ``` #### 方法二:命令行参数启动 直接通过`--resume`参数启动训练: ```bash python train.py --resume # YOLOv5/v8通用格式[^1] ``` #### 关键参数说明 | 参数 | 作用 | 示例值 | |---------------|-----------------------------|-------------| | `--weights` | 指定初始权重文件 | last.pt | | `--resume` | 自动检测最新检查点 | True/False | | `--epochs` | 总训练轮次(需包含已完成轮次) | 300 | | `--cfg` | 保持与上次相同的配置文件 | yolov8n.yaml| ### 注意事项 1. **检查点保存机制** - 默认每1-10个epoch自动保存(不同版本有差异) - 确保磁盘有足够空间(单个.pt文件约100-300MB) 2. **版本一致性** - 必须使用相同框架版本恢复训练 - 配置文件不可修改网络结构 3. **异常处理** ```bash # 检查GPU内存分配(常见中断原因) nvidia-smi -l 1 # 实时监控显存使用[^2] ``` ### 训练恢复验证技巧 1. 观察初始loss值是否与中断前最后记录相符 2. 检查学习率变化曲线是否连续 3. 验证验证集指标是否正常延续
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值