Rust实践:深度学习与日志配置
1. Rust中的深度学习
在Rust中进行深度学习是一个有趣的探索。tch-rs(https://github.com/LaurentMazare/tch-rs )是一个很好的框架,如果你熟悉PyTorch,那么可以立即上手。不过,对于机器学习新手来说,建议先从Python和PyTorch开始,以适应所需的思维方式。
tch-rs基于Python版本的C++基础,并在其创建的绑定上提供了一个薄包装。这意味着Python版本的大多数概念也适用于tch-rs,但大量使用C++可能会带来安全问题。由于增加了抽象层和宿主语言编程范式的改变,使用绑定的包装代码更有可能导致内存泄漏,这在机器学习应用中影响更大,因为机器学习通常需要数十甚至数百GB的内存。
以下是深度学习模型训练的具体步骤:
1.
设置依赖
:设置tch依赖以解决后续步骤中的导入问题。
2.
导入相关内容
:为后续使用做准备。
3.
定义模型架构
:深度学习本质上是一系列矩阵乘法,输入和输出维度必须匹配。由于PyTorch是底层框架,需要手动设置各个层并匹配其维度。在这个例子中,使用两层二维卷积层和两层全连接层。在
new()
函数中初始化网络时,为实例化函数(
nn::conv2d
和
nn::linear
)分配输入大小、神经元/滤波器数量和输出/层数。各层之间的数字要匹配,以便能够连接它们,最后一层输出的类别数量为10。
| 层类型 | 输入维度 | 输出维度 |
| ------------ | -------- | -------- |
| 卷积层1 | - | - |
| 卷积层2 | - | - |
| 全连接层1 | - | - |
| 全连接层2 | - | 10 |
-
实现前向传播过程
:实现
nn::ModuleT特征提供的前向传播过程。与nn::Module的区别在于forward_t()函数中的train参数,它指示这次运行是否用于训练。该函数的另一个参数是表示为nn::Tensor引用的实际数据。由于处理的是(灰度)图像,将其表示为4维张量,各维度含义如下:- 第一维是批次,包含0到批次大小数量的图像。
- 第二维表示图像的通道数,灰度图像为1,RGB图像为3。
- 最后两维存储实际图像,即图像的宽度和高度。
调用张量实例的
.view()
函数将其解释为这些维度,-1表示适合的任意值(通常用于批次大小)。然后将一批28 x 28 x 1的图像输入到第一个卷积层,并对结果应用修正线性单元(ReLU)函数。接着是一个二维最大池化层,第二个卷积层重复此模式。第二次最大池化后,将输出向量展平,并依次应用全连接层,中间使用ReLU函数。最后一层的原始输出作为张量返回。
graph TD;
A[输入图像] --> B[卷积层1];
B --> C[ReLU];
C --> D[最大池化层1];
D --> E[卷积层2];
E --> F[ReLU];
F --> G[最大池化层2];
G --> H[展平];
H --> I[全连接层1];
I --> J[ReLU];
J --> K[全连接层2];
K --> L[输出];
-
训练循环 :
- 读取数据 :使用预定义的数据集函数从磁盘读取数据,这里使用MNIST数据。数据已分为训练集和测试集,是一个带有一些实用函数的迭代器。
-
创建
nn::VarStore:用于存储模型权重,并将其传递给模型架构结构体ConvNet和优化器(这里使用Adam优化器)。由于PyTorch允许在设备(CPU或GPU)之间移动数据,因此需要为权重和数据分配设备。 - 设置学习率 :学习率表示优化器向最佳解决方案跳转的步长,通常非常小(例如1e - 2),因为过大的值可能会超出目标并使解决方案变差,过小的值可能导致无法收敛。
- 实现训练循环 :循环运行多个周期,周期数越高通常意味着收敛性越好,但这里选择的5个周期是为了快速完成训练并获得可感知的结果。在每个周期内,遍历打乱的批次,进行前向传播并计算每个批次的损失。损失函数使用交叉熵,它返回一个数字,让我们知道预测的偏差程度,这对反向传播很重要。这里选择一次处理1024张图像的大批次大小,每个周期需要运行59次循环,这样可以在不影响训练质量的情况下加快训练过程。
- 记录损失 :创建一个简单的向量来存储每个批次的平均损失。绘制每个周期的损失曲线,通常会看到损失逐渐趋于零。
- 测试模型 :使用测试集在不进行反向传播的情况下直接计算准确率,以检查模型是否真正得到了改进,而不仅仅是记住了训练数据。一般建议将数据分为训练集、测试集和验证集,验证集用于确保模型在真实数据上的表现,且在训练过程中不能用于更改任何参数。
- 保存最佳模型 :采用检查点策略,当模型产生的损失低于之前的损失时,将最佳模型保存到磁盘。
-
加载模型进行预测 :重复部分数据加载的设置过程,但不进行模型训练,而是从磁盘加载网络的权重。为了说明预测过程,使用测试集(实际应用中应避免),随机选取10张图像(10个大小为1的批次),进行前向传播,然后使用
softmax函数从网络的原始输出中得出概率。应用.view()函数使数据与标签对齐后,将概率打印到命令行。概率最高的索引即为网络的预测结果。 - 调用函数 :按顺序调用函数,我们可以看到训练过程和预测结果。训练完成后,使用最佳模型的权重进行推理并打印概率矩阵。
2. 配置和使用日志
在Rust中,将调试和其他信息输出到控制台虽然简单方便,但当复杂度增加时,可能会变得混乱。缺乏标准化的日期/时间、来源类或不一致的格式,会使跟踪系统的执行路径变得困难。现代系统将日志作为额外的信息来源,例如统计每小时服务的用户数量、用户来源以及95%分位数响应时间等。
以下是创建和使用自定义日志记录器的步骤:
1.
创建新项目
:打开终端,使用
cargo new logging
创建一个新项目,然后使用VS Code打开项目目录。
2.
添加依赖
:在
Cargo.toml
中添加以下依赖:
[dependencies]
log = "0.4"
log4rs = "0.8.3"
time = "0.1"
-
导入宏
:在
src/main.rs中导入所需的宏:
use log::{debug, error, info, trace, warn};
- 使用宏记录日志 :添加一个函数来展示如何使用刚刚导入的宏:
fn log_some_stuff() {
let a = 100;
trace!("TRACE: Called log_some_stuff()");
debug!("DEBUG: a = {} ", a);
info!("INFO: The weather is fine");
warn!("WARNING, stuff is breaking down");
warn!(target: "special-target", "WARNING, stuff is breaking down");
error!("ERROR: stopping ...");
}
-
设置日志框架
:在
main函数中设置日志框架。可以选择使用自定义日志记录器或log4rs:
const USE_CUSTOM: bool = false;
fn main() {
if USE_CUSTOM {
log::set_logger(&LOGGER)
.map(|()| log::set_max_level(log::LevelFilter::Trace))
.unwrap();
} else {
log4rs::init_file("log4rs.yml", Default::default()).unwrap();
}
log_some_stuff();
}
-
创建配置文件
:创建
log4rs.yml配置文件:
refresh_rate: 30 seconds
appenders:
stdout:
kind: console
outfile:
kind: file
path: "outfile.log"
encoder:
pattern: "{d} - {m}{n}"
root:
level: trace
appenders:
- stdout
loggers:
special-target:
level: info
appenders:
- outfile
-
创建自定义日志记录器
:在
src/main.rs中创建一个嵌套模块并实现自定义日志记录器:
mod custom {
pub use log::Level;
use log::{Metadata, Record};
pub struct EmojiLogger {
pub level: Level,
}
impl log::Log for EmojiLogger {
fn flush(&self) {}
fn enabled(&self, metadata: &Metadata) -> bool {
metadata.level() <= self.level
}
fn log(&self, record: &Record) {
if self.enabled(record.metadata()) {
let level = match record.level() {
Level::Warn => "WARNING-SIGN",
Level::Info => "INFO-SIGN",
Level::Debug => "CATERPILLAR",
Level::Trace => "LIGHTBULB",
Level::Error => "NUCLEAR",
};
let utc = time::now_utc();
println!("{} | [{}] | {:<5}",
utc.rfc3339(), record.target(), level);
println!("{:21} {}", "", record.args());
}
}
}
}
-
实例化日志记录器
:使用
static关键字实例化并声明日志记录器变量:
static LOGGER: custom::EmojiLogger = custom::EmojiLogger {
level: log::Level::Trace,
};
-
运行程序
:
-
当
USE_CUSTOM为false时,程序读取并使用log4rs.yaml配置:
-
当
$ cargo run
Finished dev [unoptimized + debuginfo] target(s) in 0.04s
Running `target/debug/logging`
2019-09-01T12:42:18.056681073+02:00 TRACE logging - TRACE: Called log_some_stuff()
2019-09-01T12:42:18.056764247+02:00 DEBUG logging - DEBUG: a = 100
2019-09-01T12:42:18.056791639+02:00 INFO logging - INFO: The weather is fine
2019-09-01T12:42:18.056816420+02:00 WARN logging - WARNING, stuff is breaking down
2019-09-01T12:42:18.056881011+02:00 ERROR logging - ERROR: stopping ...
- 当`USE_CUSTOM`为`true`时,使用自定义日志记录器:
$ cargo run
Compiling logging v0.1.0 (Rust-Cookbook/Chapter10/logging)
Finished dev [unoptimized + debuginfo] target(s) in 0.94s
Running `target/debug/logging`
2019-09-01T10:46:43Z | [logging] | LIGHTBULB
TRACE: Called log_some_stuff()
2019-09-01T10:46:43Z | [logging] | CATERPILLAR
DEBUG: a = 100
2019-09-01T10:46:43Z | [logging] | INFO-SIGN
INFO: The weather is fine
2019-09-01T10:46:43Z | [logging] | WARNING-SIGN
WARNING, stuff is breaking down
2019-09-01T10:46:43Z | [special-target] | WARNING-SIGN
WARNING, stuff is breaking down
2019-09-01T10:46:43Z | [logging] | NUCLEAR
ERROR: stopping ...
Rust的日志基础设施主要由两部分组成:
-
log
crate(https://github.com/rust-lang-nursery/log ):提供日志宏的接口。
- 日志实现器,如
log4rs
(https://github.com/sfackler/log4rs )、
env_logger
(https://github.com/sebasmagri/env_logger/ )等。
在完成初始设置后,只需导入
log
crate提供的宏即可。在步骤4中创建的函数涵盖了大多数日志记录的用例。步骤5设置了
log4rs
日志框架,它模仿了Java世界的事实标准
log4j
,提供了出色的灵活性,可以在运行时更改日志级别和输出格式。步骤6的配置文件定义了刷新速率、输出目标(appenders)和日志级别等信息。
Rust实践:深度学习与日志配置
3. 深度学习与日志的综合应用思考
在实际开发中,深度学习和日志配置往往是相辅相成的。深度学习模型的训练和预测过程中会产生大量的信息,如损失值、准确率、模型参数等,合理的日志配置可以帮助我们更好地监控和分析这些信息,从而优化模型。
3.1 深度学习中的日志应用
在深度学习的训练过程中,我们可以利用日志记录以下关键信息:
-
损失值和准确率
:在每个周期或批次的训练后,记录损失值和准确率,通过绘制曲线可以直观地观察模型的收敛情况。
-
模型参数
:记录模型的参数更新情况,有助于分析模型的学习过程,判断是否存在过拟合或欠拟合的问题。
-
异常情况
:当训练过程中出现异常,如梯度爆炸、内存溢出等,及时记录相关信息,方便后续排查问题。
以下是一个简单的示例,在深度学习训练循环中添加日志记录:
// 在训练循环中添加日志记录
for epoch in 0..epochs {
let mut total_loss = 0.0;
let mut correct = 0;
let mut total = 0;
for (data, target) in train_data.iter() {
// 前向传播
let output = model.forward_t(&data, true);
// 计算损失
let loss = criterion.forward(&output, &target);
total_loss += loss.item::<f32>();
// 计算准确率
let predicted = output.argmax(1, true);
correct += (predicted.eq(&target)).sum::<i64>().item::<i64>();
total += target.size()[0] as i64;
// 反向传播
opt.backward_step(&loss);
}
let accuracy = correct as f32 / total as f32;
info!("Epoch {}: Loss = {:.4}, Accuracy = {:.4}", epoch, total_loss / train_data.len() as f32, accuracy);
}
3.2 日志对模型优化的帮助
通过对日志信息的分析,我们可以采取以下措施来优化深度学习模型:
-
调整学习率
:如果损失曲线在训练初期下降缓慢,可能是学习率过小;如果损失曲线出现震荡或上升,可能是学习率过大。根据日志中的损失值变化,调整学习率可以提高模型的收敛速度。
-
增加或减少训练周期
:如果准确率在训练后期不再提高,可能是模型已经过拟合,此时可以减少训练周期;如果准确率仍然有上升趋势,可以增加训练周期。
-
调整模型架构
:通过分析模型参数的更新情况和准确率的变化,判断模型架构是否合理。如果某些层的参数更新缓慢或没有更新,可能需要调整该层的神经元数量或激活函数。
graph LR;
A[日志记录] --> B[分析损失值和准确率];
B --> C[调整学习率];
B --> D[调整训练周期];
B --> E[调整模型架构];
C --> F[优化模型];
D --> F;
E --> F;
4. 总结与实践建议
通过前面的介绍,我们了解了在Rust中进行深度学习和日志配置的方法。以下是一些总结和实践建议:
4.1 深度学习方面
- 选择合适的框架 :tch-rs是一个不错的选择,如果你熟悉PyTorch,可以快速上手。但对于机器学习新手,建议先从Python和PyTorch开始,掌握基本的机器学习思维后再使用Rust。
- 注意内存管理 :深度学习模型通常需要大量的内存,使用tch-rs时要注意内存泄漏的问题。合理设置批次大小和优化算法,可以减少内存的使用。
- 多进行实验 :通过调整学习率、训练周期、模型架构等参数,进行多次实验,找到最优的模型配置。
4.2 日志配置方面
- 使用标准化的日志框架 :如log4rs,它提供了一致且可配置的输出,方便后续的分析和处理。
- 定义合理的日志级别 :根据不同的需求,定义不同的日志级别,如TRACE、DEBUG、INFO、WARN、ERROR,避免输出过多无用的信息。
- 结合日志分析工具 :使用日志分析工具,如ELK Stack(Elasticsearch、Logstash、Kibana),可以更方便地对日志进行存储、搜索和可视化分析。
4.3 综合实践建议
- 在深度学习中合理使用日志 :在深度学习的训练和预测过程中,合理记录关键信息,通过日志分析优化模型。
- 定期清理日志文件 :随着时间的推移,日志文件会越来越大,定期清理无用的日志文件可以节省磁盘空间。
总之,在Rust中进行深度学习和日志配置需要我们不断地实践和探索,通过合理的配置和优化,提高模型的性能和开发效率。希望本文的内容对你有所帮助,让你在Rust的开发中更加得心应手。
超级会员免费看
899

被折叠的 条评论
为什么被折叠?



