29、Rust实践:深度学习与日志配置

Rust实践:深度学习与日志配置

1. Rust中的深度学习

在Rust中进行深度学习是一个有趣的探索。tch-rs(https://github.com/LaurentMazare/tch-rs )是一个很好的框架,如果你熟悉PyTorch,那么可以立即上手。不过,对于机器学习新手来说,建议先从Python和PyTorch开始,以适应所需的思维方式。

tch-rs基于Python版本的C++基础,并在其创建的绑定上提供了一个薄包装。这意味着Python版本的大多数概念也适用于tch-rs,但大量使用C++可能会带来安全问题。由于增加了抽象层和宿主语言编程范式的改变,使用绑定的包装代码更有可能导致内存泄漏,这在机器学习应用中影响更大,因为机器学习通常需要数十甚至数百GB的内存。

以下是深度学习模型训练的具体步骤:
1. 设置依赖 :设置tch依赖以解决后续步骤中的导入问题。
2. 导入相关内容 :为后续使用做准备。
3. 定义模型架构 :深度学习本质上是一系列矩阵乘法,输入和输出维度必须匹配。由于PyTorch是底层框架,需要手动设置各个层并匹配其维度。在这个例子中,使用两层二维卷积层和两层全连接层。在 new() 函数中初始化网络时,为实例化函数( nn::conv2d nn::linear )分配输入大小、神经元/滤波器数量和输出/层数。各层之间的数字要匹配,以便能够连接它们,最后一层输出的类别数量为10。

| 层类型       | 输入维度 | 输出维度 |
| ------------ | -------- | -------- |
| 卷积层1      | -        | -        |
| 卷积层2      | -        | -        |
| 全连接层1    | -        | -        |
| 全连接层2    | -        | 10       |
  1. 实现前向传播过程 :实现 nn::ModuleT 特征提供的前向传播过程。与 nn::Module 的区别在于 forward_t() 函数中的 train 参数,它指示这次运行是否用于训练。该函数的另一个参数是表示为 nn::Tensor 引用的实际数据。由于处理的是(灰度)图像,将其表示为4维张量,各维度含义如下:
    • 第一维是批次,包含0到批次大小数量的图像。
    • 第二维表示图像的通道数,灰度图像为1,RGB图像为3。
    • 最后两维存储实际图像,即图像的宽度和高度。

调用张量实例的 .view() 函数将其解释为这些维度,-1表示适合的任意值(通常用于批次大小)。然后将一批28 x 28 x 1的图像输入到第一个卷积层,并对结果应用修正线性单元(ReLU)函数。接着是一个二维最大池化层,第二个卷积层重复此模式。第二次最大池化后,将输出向量展平,并依次应用全连接层,中间使用ReLU函数。最后一层的原始输出作为张量返回。

graph TD;
    A[输入图像] --> B[卷积层1];
    B --> C[ReLU];
    C --> D[最大池化层1];
    D --> E[卷积层2];
    E --> F[ReLU];
    F --> G[最大池化层2];
    G --> H[展平];
    H --> I[全连接层1];
    I --> J[ReLU];
    J --> K[全连接层2];
    K --> L[输出];
  1. 训练循环

    • 读取数据 :使用预定义的数据集函数从磁盘读取数据,这里使用MNIST数据。数据已分为训练集和测试集,是一个带有一些实用函数的迭代器。
    • 创建 nn::VarStore :用于存储模型权重,并将其传递给模型架构结构体 ConvNet 和优化器(这里使用Adam优化器)。由于PyTorch允许在设备(CPU或GPU)之间移动数据,因此需要为权重和数据分配设备。
    • 设置学习率 :学习率表示优化器向最佳解决方案跳转的步长,通常非常小(例如1e - 2),因为过大的值可能会超出目标并使解决方案变差,过小的值可能导致无法收敛。
    • 实现训练循环 :循环运行多个周期,周期数越高通常意味着收敛性越好,但这里选择的5个周期是为了快速完成训练并获得可感知的结果。在每个周期内,遍历打乱的批次,进行前向传播并计算每个批次的损失。损失函数使用交叉熵,它返回一个数字,让我们知道预测的偏差程度,这对反向传播很重要。这里选择一次处理1024张图像的大批次大小,每个周期需要运行59次循环,这样可以在不影响训练质量的情况下加快训练过程。
    • 记录损失 :创建一个简单的向量来存储每个批次的平均损失。绘制每个周期的损失曲线,通常会看到损失逐渐趋于零。
    • 测试模型 :使用测试集在不进行反向传播的情况下直接计算准确率,以检查模型是否真正得到了改进,而不仅仅是记住了训练数据。一般建议将数据分为训练集、测试集和验证集,验证集用于确保模型在真实数据上的表现,且在训练过程中不能用于更改任何参数。
    • 保存最佳模型 :采用检查点策略,当模型产生的损失低于之前的损失时,将最佳模型保存到磁盘。
  2. 加载模型进行预测 :重复部分数据加载的设置过程,但不进行模型训练,而是从磁盘加载网络的权重。为了说明预测过程,使用测试集(实际应用中应避免),随机选取10张图像(10个大小为1的批次),进行前向传播,然后使用 softmax 函数从网络的原始输出中得出概率。应用 .view() 函数使数据与标签对齐后,将概率打印到命令行。概率最高的索引即为网络的预测结果。

  3. 调用函数 :按顺序调用函数,我们可以看到训练过程和预测结果。训练完成后,使用最佳模型的权重进行推理并打印概率矩阵。
2. 配置和使用日志

在Rust中,将调试和其他信息输出到控制台虽然简单方便,但当复杂度增加时,可能会变得混乱。缺乏标准化的日期/时间、来源类或不一致的格式,会使跟踪系统的执行路径变得困难。现代系统将日志作为额外的信息来源,例如统计每小时服务的用户数量、用户来源以及95%分位数响应时间等。

以下是创建和使用自定义日志记录器的步骤:
1. 创建新项目 :打开终端,使用 cargo new logging 创建一个新项目,然后使用VS Code打开项目目录。
2. 添加依赖 :在 Cargo.toml 中添加以下依赖:

[dependencies]
log = "0.4"
log4rs = "0.8.3"
time = "0.1"
  1. 导入宏 :在 src/main.rs 中导入所需的宏:
use log::{debug, error, info, trace, warn};
  1. 使用宏记录日志 :添加一个函数来展示如何使用刚刚导入的宏:
fn log_some_stuff() {
    let a = 100;
    trace!("TRACE: Called log_some_stuff()");
    debug!("DEBUG: a = {} ", a);
    info!("INFO: The weather is fine");
    warn!("WARNING, stuff is breaking down");
    warn!(target: "special-target", "WARNING, stuff is breaking down");
    error!("ERROR: stopping ...");
}
  1. 设置日志框架 :在 main 函数中设置日志框架。可以选择使用自定义日志记录器或 log4rs
const USE_CUSTOM: bool = false;
fn main() {
    if USE_CUSTOM {
        log::set_logger(&LOGGER)
           .map(|()| log::set_max_level(log::LevelFilter::Trace))
           .unwrap();
    } else {
        log4rs::init_file("log4rs.yml", Default::default()).unwrap();
    }
    log_some_stuff();
}
  1. 创建配置文件 :创建 log4rs.yml 配置文件:
refresh_rate: 30 seconds
appenders:
  stdout:
    kind: console
  outfile:
    kind: file
    path: "outfile.log"
    encoder:
      pattern: "{d} - {m}{n}"
root:
  level: trace
  appenders:
    - stdout
loggers:
  special-target:
    level: info
    appenders:
      - outfile
  1. 创建自定义日志记录器 :在 src/main.rs 中创建一个嵌套模块并实现自定义日志记录器:
mod custom {
    pub use log::Level;
    use log::{Metadata, Record};
    pub struct EmojiLogger {
        pub level: Level,
    }
    impl log::Log for EmojiLogger {
        fn flush(&self) {}
        fn enabled(&self, metadata: &Metadata) -> bool {
            metadata.level() <= self.level
        }
        fn log(&self, record: &Record) {
            if self.enabled(record.metadata()) {
                let level = match record.level() {
                    Level::Warn => "WARNING-SIGN",
                    Level::Info => "INFO-SIGN",
                    Level::Debug => "CATERPILLAR",
                    Level::Trace => "LIGHTBULB",
                    Level::Error => "NUCLEAR",
                };
                let utc = time::now_utc();
                println!("{} | [{}] | {:<5}",
                         utc.rfc3339(), record.target(), level);
                println!("{:21} {}", "", record.args());
            }
        }
    }
}
  1. 实例化日志记录器 :使用 static 关键字实例化并声明日志记录器变量:
static LOGGER: custom::EmojiLogger = custom::EmojiLogger {
    level: log::Level::Trace,
};
  1. 运行程序
    • USE_CUSTOM false 时,程序读取并使用 log4rs.yaml 配置:
$ cargo run
Finished dev [unoptimized + debuginfo] target(s) in 0.04s
    Running `target/debug/logging`
2019-09-01T12:42:18.056681073+02:00 TRACE logging - TRACE: Called log_some_stuff()
2019-09-01T12:42:18.056764247+02:00 DEBUG logging - DEBUG: a = 100
2019-09-01T12:42:18.056791639+02:00 INFO logging - INFO: The weather is fine
2019-09-01T12:42:18.056816420+02:00 WARN logging - WARNING, stuff is breaking down
2019-09-01T12:42:18.056881011+02:00 ERROR logging - ERROR: stopping ...
- 当`USE_CUSTOM`为`true`时,使用自定义日志记录器:
$ cargo run
   Compiling logging v0.1.0 (Rust-Cookbook/Chapter10/logging)
    Finished dev [unoptimized + debuginfo] target(s) in 0.94s
     Running `target/debug/logging`
2019-09-01T10:46:43Z | [logging] | LIGHTBULB
                      TRACE: Called log_some_stuff()
2019-09-01T10:46:43Z | [logging] | CATERPILLAR
                      DEBUG: a = 100
2019-09-01T10:46:43Z | [logging] | INFO-SIGN
                      INFO: The weather is fine
2019-09-01T10:46:43Z | [logging] | WARNING-SIGN
                      WARNING, stuff is breaking down
2019-09-01T10:46:43Z | [special-target] | WARNING-SIGN
                      WARNING, stuff is breaking down
2019-09-01T10:46:43Z | [logging] | NUCLEAR
                      ERROR: stopping ...

Rust的日志基础设施主要由两部分组成:
- log crate(https://github.com/rust-lang-nursery/log ):提供日志宏的接口。
- 日志实现器,如 log4rs (https://github.com/sfackler/log4rs )、 env_logger (https://github.com/sebasmagri/env_logger/ )等。

在完成初始设置后,只需导入 log crate提供的宏即可。在步骤4中创建的函数涵盖了大多数日志记录的用例。步骤5设置了 log4rs 日志框架,它模仿了Java世界的事实标准 log4j ,提供了出色的灵活性,可以在运行时更改日志级别和输出格式。步骤6的配置文件定义了刷新速率、输出目标(appenders)和日志级别等信息。

Rust实践:深度学习与日志配置

3. 深度学习与日志的综合应用思考

在实际开发中,深度学习和日志配置往往是相辅相成的。深度学习模型的训练和预测过程中会产生大量的信息,如损失值、准确率、模型参数等,合理的日志配置可以帮助我们更好地监控和分析这些信息,从而优化模型。

3.1 深度学习中的日志应用

在深度学习的训练过程中,我们可以利用日志记录以下关键信息:
- 损失值和准确率 :在每个周期或批次的训练后,记录损失值和准确率,通过绘制曲线可以直观地观察模型的收敛情况。
- 模型参数 :记录模型的参数更新情况,有助于分析模型的学习过程,判断是否存在过拟合或欠拟合的问题。
- 异常情况 :当训练过程中出现异常,如梯度爆炸、内存溢出等,及时记录相关信息,方便后续排查问题。

以下是一个简单的示例,在深度学习训练循环中添加日志记录:

// 在训练循环中添加日志记录
for epoch in 0..epochs {
    let mut total_loss = 0.0;
    let mut correct = 0;
    let mut total = 0;
    for (data, target) in train_data.iter() {
        // 前向传播
        let output = model.forward_t(&data, true);
        // 计算损失
        let loss = criterion.forward(&output, &target);
        total_loss += loss.item::<f32>();
        // 计算准确率
        let predicted = output.argmax(1, true);
        correct += (predicted.eq(&target)).sum::<i64>().item::<i64>();
        total += target.size()[0] as i64;
        // 反向传播
        opt.backward_step(&loss);
    }
    let accuracy = correct as f32 / total as f32;
    info!("Epoch {}: Loss = {:.4}, Accuracy = {:.4}", epoch, total_loss / train_data.len() as f32, accuracy);
}
3.2 日志对模型优化的帮助

通过对日志信息的分析,我们可以采取以下措施来优化深度学习模型:
- 调整学习率 :如果损失曲线在训练初期下降缓慢,可能是学习率过小;如果损失曲线出现震荡或上升,可能是学习率过大。根据日志中的损失值变化,调整学习率可以提高模型的收敛速度。
- 增加或减少训练周期 :如果准确率在训练后期不再提高,可能是模型已经过拟合,此时可以减少训练周期;如果准确率仍然有上升趋势,可以增加训练周期。
- 调整模型架构 :通过分析模型参数的更新情况和准确率的变化,判断模型架构是否合理。如果某些层的参数更新缓慢或没有更新,可能需要调整该层的神经元数量或激活函数。

graph LR;
    A[日志记录] --> B[分析损失值和准确率];
    B --> C[调整学习率];
    B --> D[调整训练周期];
    B --> E[调整模型架构];
    C --> F[优化模型];
    D --> F;
    E --> F;
4. 总结与实践建议

通过前面的介绍,我们了解了在Rust中进行深度学习和日志配置的方法。以下是一些总结和实践建议:

4.1 深度学习方面
  • 选择合适的框架 :tch-rs是一个不错的选择,如果你熟悉PyTorch,可以快速上手。但对于机器学习新手,建议先从Python和PyTorch开始,掌握基本的机器学习思维后再使用Rust。
  • 注意内存管理 :深度学习模型通常需要大量的内存,使用tch-rs时要注意内存泄漏的问题。合理设置批次大小和优化算法,可以减少内存的使用。
  • 多进行实验 :通过调整学习率、训练周期、模型架构等参数,进行多次实验,找到最优的模型配置。
4.2 日志配置方面
  • 使用标准化的日志框架 :如log4rs,它提供了一致且可配置的输出,方便后续的分析和处理。
  • 定义合理的日志级别 :根据不同的需求,定义不同的日志级别,如TRACE、DEBUG、INFO、WARN、ERROR,避免输出过多无用的信息。
  • 结合日志分析工具 :使用日志分析工具,如ELK Stack(Elasticsearch、Logstash、Kibana),可以更方便地对日志进行存储、搜索和可视化分析。
4.3 综合实践建议
  • 在深度学习中合理使用日志 :在深度学习的训练和预测过程中,合理记录关键信息,通过日志分析优化模型。
  • 定期清理日志文件 :随着时间的推移,日志文件会越来越大,定期清理无用的日志文件可以节省磁盘空间。

总之,在Rust中进行深度学习和日志配置需要我们不断地实践和探索,通过合理的配置和优化,提高模型的性能和开发效率。希望本文的内容对你有所帮助,让你在Rust的开发中更加得心应手。

评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符  | 博主筛选后可见
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值