Rust实战:网络请求与机器学习模型训练
1. 发送网络请求
在当今的应用程序开发中,网络请求已成为不可或缺的一部分。Rust虽然没有内置的网络请求模块,但有一些优秀的库可以帮助我们实现这一功能。下面主要介绍
surf
和
reqwest
这两个库的使用。
1.1 准备工作
首先,我们需要创建一个新的Rust项目,并添加所需的依赖。具体步骤如下:
1. 打开终端,使用以下命令创建一个新的项目:
cargo new web-requests
- 使用VS Code打开项目目录。
-
编辑
Cargo.toml文件,添加以下依赖:
[dependencies]
surf = "1.0"
reqwest = "0.9"
serde = "1"
serde_json = "1"
runtime = "0.3.0-alpha.6"
1.2 导入依赖和定义数据结构
在
src/main.rs
文件中,添加以下代码进行依赖导入和数据结构定义:
#[macro_use]
extern crate serde_json;
use surf::Exception;
use serde::Serialize;
#[derive(Serialize)]
struct MyGetParams {
a: u64,
b: String,
}
1.3 使用
surf
库进行网络请求
surf
是一个全异步的库,以下是使用它进行网络请求的示例代码:
async fn test_surf() -> Result<(), Exception> {
println!("> surf ...");
let client = surf::Client::new();
let mut res = client
.get("https://blog.x5ff.xyz/other/cookbook2018")
.await?;
assert_eq!(200, res.status());
assert_eq!("Rust is awesome\n", res.body_string().await?);
let form_values = vec![
("custname", "Rusty Crabbington"),
("comments", "Thank you"),
("custemail", "rusty@nope.com"),
("custtel", "+1 234 33456"),
("delivery", "25th floor below ground, no elevator. sorry"),
];
let res_forms: serde_json::Value = client
.post("https://httpbin.org/post")
.body_form(&form_values)?
.recv_json()
.await?;
for (name, value) in form_values.iter() {
assert_eq!(res_forms["form"][name], *value);
}
let json_payload = json!({
"book": "Rust 2018 Cookbook",
"blog": "https://blog.x5ff.xyz",
});
let res_json: serde_json::Value = client
.put("https://httpbin.org/anything")
.body_json(&json_payload)?
.recv_json()
.await?;
assert_eq!(res_json["json"], json_payload);
let query_params = MyGetParams {
a: 0x5ff,
b: "https://blog.x5ff.xyz".into(),
};
let res_query: serde_json::Value = client
.get("https://httpbin.org/get")
.set_query(&query_params)?
.recv_json()
.await?;
assert_eq!(res_query["args"]["a"], query_params.a.to_string());
assert_eq!(res_query["args"]["b"], query_params.b);
println!("> surf successful!");
Ok(())
}
1.4 使用
reqwest
库进行网络请求
reqwest
是一个成熟的非异步库,以下是使用它进行网络请求的示例代码:
fn test_reqwest() -> Result<(), Exception> {
println!("> reqwest ...");
let client = reqwest::Client::new();
let mut res = client
.get("https://blog.x5ff.xyz/other/cookbook2018")
.send()?;
assert_eq!(200, res.status());
assert_eq!("Rust is awesome\n", res.text()?);
let form_values = vec![
("custname", "Rusty Crabbington"),
("comments", "Thank you"),
("custemail", "rusty@nope.com"),
("custtel", "+1 234 33456"),
("delivery", "25th floor below ground, no elevator. sorry"),
];
let res_forms: serde_json::Value = client
.post("https://httpbin.org/post")
.form(&form_values)
.send()?
.json()?;
for (name, value) in form_values.iter() {
assert_eq!(res_forms["form"][name], *value);
}
let json_payload = json!({
"book": "Rust 2018 Cookbook",
"blog": "https://blog.x5ff.xyz",
});
let res_json: serde_json::Value = client
.put("https://httpbin.org/anything")
.json(&json_payload)
.send()?
.json()?;
assert_eq!(res_json["json"], json_payload);
let query_params = MyGetParams {
a: 0x5ff,
b: "https://blog.x5ff.xyz".into(),
};
let res_query: serde_json::Value = client
.get("https://httpbin.org/get")
.query(&query_params)
.send()?
.json()?;
assert_eq!(res_query["args"]["a"], query_params.a.to_string());
assert_eq!(res_query["args"]["b"], query_params.b);
println!("> reqwest successful!");
Ok(())
}
1.5 主函数和运行命令
在
src/main.rs
文件中添加主函数:
#[runtime::main]
async fn main() -> Result<(), Exception> {
println!("Running some tests");
test_reqwest()?;
test_surf().await?;
Ok(())
}
运行以下命令来测试网络请求:
cargo +nightly run
1.6 工作原理
surf
和
reqwest
这两个库遵循了构建器模式和装饰器模式。
surf
是全异步的,使用
await
等待请求完成;
reqwest
则使用
send()
方法执行请求。它们都可以方便地处理不同类型的请求,如GET、POST、PUT等,并能处理不同的数据格式,如表单数据和JSON数据。
以下是一个简单的流程图,展示了使用这两个库进行网络请求的基本流程:
graph TD;
A[创建项目并添加依赖] --> B[导入依赖和定义数据结构];
B --> C[使用surf进行请求];
B --> D[使用reqwest进行请求];
C --> E[主函数调用测试函数];
D --> E;
E --> F[运行测试命令];
2. 运行机器学习模型
机器学习,特别是深度学习,近年来发展迅速。虽然Rust在机器学习领域的库支持相对较少,但社区提供了与流行深度学习框架的绑定,让我们可以使用Rust进行一些实验。这里将使用Rust的PyTorch绑定
tch-rs
来训练一个模型,识别时尚MNIST数据集中的服装物品。
2.1 准备工作
在开始之前,需要对神经网络有一定的了解。如果不太熟悉,可以先参加一些在线课程。准备工作步骤如下:
1. 打开终端,创建一个新的Rust项目:
cargo new rusty-ml
cd rusty-ml
mkdir models
- 克隆或下载Zalando Research的时尚MNIST数据集:
git clone https://github.com/zalandoresearch/fashion-mnist
-
解压数据文件:
-
在Linux/macOS上,进入
fashion-mnist/data/fashion目录,运行gunzip *.gz。 - 在Windows上,使用解压工具解压。
-
在Linux/macOS上,进入
时尚MNIST数据集是原始MNIST数据集的替代品,包含10个类别的服装物品,任务是正确分类这些物品。
2.2 添加依赖和导入代码
打开
Cargo.toml
文件,添加以下依赖:
[dependencies]
tch = "0.1"
failure = "0.1"
在
src/main.rs
文件中添加导入代码:
use std::io::{Error, ErrorKind};
use std::path::Path;
use std::time::Instant;
use tch::{nn, nn::ModuleT, nn::OptimizerConfig, Device, Tensor};
2.3 定义卷积神经网络
定义一个卷积神经网络
ConvNet
:
#[derive(Debug)]
struct ConvNet {
conv1: nn::Conv2D,
conv2: nn::Conv2D,
fc1: nn::Linear,
fc2: nn::Linear,
}
impl ConvNet {
fn new(vs: &nn::Path, labels: i64) -> ConvNet {
ConvNet {
conv1: nn::conv2d(vs, 1, 32, 5, Default::default()),
conv2: nn::conv2d(vs, 32, 64, 5, Default::default()),
fc1: nn::linear(vs, 1024, 512, Default::default()),
fc2: nn::linear(vs, 512, labels, Default::default()),
}
}
}
impl nn::ModuleT for ConvNet {
fn forward_t(&self, xs: &Tensor, train: bool) -> Tensor {
xs.view([-1, 1, 28, 28])
.apply(&self.conv1)
.relu()
.max_pool2d_default(2)
.apply(&self.conv2)
.relu()
.max_pool2d_default(2)
.view([-1, 1024])
.apply(&self.fc1)
.relu()
.dropout_(0.5, train)
.apply(&self.fc2)
}
}
2.4 训练神经网络
添加训练函数
train_from_scratch
:
fn train_from_scratch(learning_rate: f64, batch_size: i64, epochs: usize) -> failure::Fallible<()> {
let data_path = Path::new("fashion-mnist/data/fashion");
let model_path = Path::new("models/best.ot");
if !data_path.exists() {
println!(
"Data not found at '{}'. Did you run 'git submodule update --init'?",
data_path.to_string_lossy()
);
return Err(Error::from(ErrorKind::NotFound).into());
}
println!("Loading data from '{}'", data_path.to_string_lossy());
let m = tch::vision::mnist::load_dir(data_path)?;
let vs = nn::VarStore::new(Device::cuda_if_available());
let net = ConvNet::new(&vs.root(), 10);
let opt = nn::Adam::default().build(&vs, learning_rate)?;
println!(
"Starting training, saving model to '{}'",
model_path.to_string_lossy()
);
let mut min_loss = ::std::f32::INFINITY;
for epoch in 1..=epochs {
let start = Instant::now();
let mut losses = vec![];
for (image_batch, label_batch) in m.train_iter(batch_size).shuffle().to_device(vs.device()) {
let loss = net
.forward_t(&image_batch, true)
.cross_entropy_for_logits(&label_batch);
opt.backward_step(&loss);
losses.push(f32::from(loss));
}
let total_loss = losses.iter().sum::<f32>() / (losses.len() as f32);
let test_accuracy = net
.forward_t(&m.test_images, false)
.accuracy_for_logits(&m.test_labels);
if total_loss <= min_loss {
vs.save(model_path)?;
min_loss = total_loss;
}
println!(
"{:4} | train loss: {:7.4} | test acc: {:5.2}% | duration: {}s",
epoch,
&total_loss,
100. * f64::from(&test_accuracy),
start.elapsed().as_secs()
);
}
println!(
"Done! The best model was saved to '{}'",
model_path.to_string_lossy()
);
Ok(())
}
2.5 进行推理
训练好模型后,可以进行推理。添加推理函数
predict_from_best
:
fn predict_from_best() -> failure::Fallible<()> {
let data_path = Path::new("fashion-mnist/data/fashion");
let model_weights_path = Path::new("models/best.ot");
let m = tch::vision::mnist::load_dir(data_path)?;
let mut vs = nn::VarStore::new(Device::cuda_if_available());
let net = ConvNet::new(&vs.root(), 10);
println!(
"Loading model weights from '{}'",
model_weights_path.to_string_lossy()
);
vs.load(model_weights_path)?;
println!("Probabilities and predictions for 10 random images in the test set");
for (image_batch, label_batch) in m.test_iter(1).shuffle().to_device(vs.device()).take(10) {
let raw_tensor = net
.forward_t(&image_batch, false)
.softmax(-1)
.view(m.labels);
let predicted_index: Vec<i64> = raw_tensor.argmax(0, false).into();
let probabilities: Vec<f64> = raw_tensor.into();
print!("[ ");
for p in probabilities {
print!("{:.4} ", p);
}
let label: Vec<i64> = label_batch.into();
println!("] predicted {}, was {}", predicted_index[0], label[0]);
}
Ok(())
}
2.6 主函数和运行命令
在
src/main.rs
文件中添加主函数:
fn main() -> failure::Fallible<()> {
train_from_scratch(1e-2, 1024, 5)?;
predict_from_best()?;
Ok(())
}
运行以下命令来训练模型和进行推理:
cargo run
2.7 工作原理
训练过程包括加载数据、初始化模型和优化器、迭代训练数据计算损失并进行反向传播。推理过程则是加载训练好的模型权重,对测试数据进行预测。
以下是训练和推理的流程图:
graph TD;
A[准备工作] --> B[添加依赖和导入代码];
B --> C[定义卷积神经网络];
C --> D[训练神经网络];
D --> E[保存最佳模型];
E --> F[进行推理];
F --> G[输出预测结果];
通过以上步骤,我们可以使用Rust进行网络请求和机器学习模型的训练与推理。Rust的高性能和安全性使其在这些领域具有很大的潜力。
超级会员免费看
913

被折叠的 条评论
为什么被折叠?



