28、Rust实战:网络请求与机器学习模型训练

Rust实战:网络请求与机器学习模型训练

1. 发送网络请求

在当今的应用程序开发中,网络请求已成为不可或缺的一部分。Rust虽然没有内置的网络请求模块,但有一些优秀的库可以帮助我们实现这一功能。下面主要介绍 surf reqwest 这两个库的使用。

1.1 准备工作

首先,我们需要创建一个新的Rust项目,并添加所需的依赖。具体步骤如下:
1. 打开终端,使用以下命令创建一个新的项目:

cargo new web-requests
  1. 使用VS Code打开项目目录。
  2. 编辑 Cargo.toml 文件,添加以下依赖:
[dependencies]
surf = "1.0"
reqwest = "0.9"
serde = "1"
serde_json = "1"
runtime = "0.3.0-alpha.6"

1.2 导入依赖和定义数据结构

src/main.rs 文件中,添加以下代码进行依赖导入和数据结构定义:

#[macro_use]
extern crate serde_json;
use surf::Exception;
use serde::Serialize;

#[derive(Serialize)]
struct MyGetParams {
    a: u64,
    b: String,
}

1.3 使用 surf 库进行网络请求

surf 是一个全异步的库,以下是使用它进行网络请求的示例代码:

async fn test_surf() -> Result<(), Exception> {
    println!("> surf ...");
    let client = surf::Client::new();
    let mut res = client
       .get("https://blog.x5ff.xyz/other/cookbook2018")
       .await?;
    assert_eq!(200, res.status());
    assert_eq!("Rust is awesome\n", res.body_string().await?);

    let form_values = vec![
        ("custname", "Rusty Crabbington"),
        ("comments", "Thank you"),
        ("custemail", "rusty@nope.com"),
        ("custtel", "+1 234 33456"),
        ("delivery", "25th floor below ground, no elevator. sorry"),
    ];
    let res_forms: serde_json::Value = client
       .post("https://httpbin.org/post")
       .body_form(&form_values)?
       .recv_json()
       .await?;
    for (name, value) in form_values.iter() {
        assert_eq!(res_forms["form"][name], *value);
    }

    let json_payload = json!({
        "book": "Rust 2018 Cookbook",
        "blog": "https://blog.x5ff.xyz",
    });
    let res_json: serde_json::Value = client
       .put("https://httpbin.org/anything")
       .body_json(&json_payload)?
       .recv_json()
       .await?;
    assert_eq!(res_json["json"], json_payload);

    let query_params = MyGetParams {
        a: 0x5ff,
        b: "https://blog.x5ff.xyz".into(),
    };
    let res_query: serde_json::Value = client
       .get("https://httpbin.org/get")
       .set_query(&query_params)?
       .recv_json()
       .await?;
    assert_eq!(res_query["args"]["a"], query_params.a.to_string());
    assert_eq!(res_query["args"]["b"], query_params.b);

    println!("> surf successful!");
    Ok(())
}

1.4 使用 reqwest 库进行网络请求

reqwest 是一个成熟的非异步库,以下是使用它进行网络请求的示例代码:

fn test_reqwest() -> Result<(), Exception> {
    println!("> reqwest ...");
    let client = reqwest::Client::new();
    let mut res = client
       .get("https://blog.x5ff.xyz/other/cookbook2018")
       .send()?;
    assert_eq!(200, res.status());
    assert_eq!("Rust is awesome\n", res.text()?);

    let form_values = vec![
        ("custname", "Rusty Crabbington"),
        ("comments", "Thank you"),
        ("custemail", "rusty@nope.com"),
        ("custtel", "+1 234 33456"),
        ("delivery", "25th floor below ground, no elevator. sorry"),
    ];
    let res_forms: serde_json::Value = client
       .post("https://httpbin.org/post")
       .form(&form_values)
       .send()?
       .json()?;
    for (name, value) in form_values.iter() {
        assert_eq!(res_forms["form"][name], *value);
    }

    let json_payload = json!({
        "book": "Rust 2018 Cookbook",
        "blog": "https://blog.x5ff.xyz",
    });
    let res_json: serde_json::Value = client
       .put("https://httpbin.org/anything")
       .json(&json_payload)
       .send()?
       .json()?;
    assert_eq!(res_json["json"], json_payload);

    let query_params = MyGetParams {
        a: 0x5ff,
        b: "https://blog.x5ff.xyz".into(),
    };
    let res_query: serde_json::Value = client
       .get("https://httpbin.org/get")
       .query(&query_params)
       .send()?
       .json()?;
    assert_eq!(res_query["args"]["a"], query_params.a.to_string());
    assert_eq!(res_query["args"]["b"], query_params.b);

    println!("> reqwest successful!");
    Ok(())
}

1.5 主函数和运行命令

src/main.rs 文件中添加主函数:

#[runtime::main]
async fn main() -> Result<(), Exception> {
    println!("Running some tests");
    test_reqwest()?;
    test_surf().await?;
    Ok(())
}

运行以下命令来测试网络请求:

cargo +nightly run

1.6 工作原理

surf reqwest 这两个库遵循了构建器模式和装饰器模式。 surf 是全异步的,使用 await 等待请求完成; reqwest 则使用 send() 方法执行请求。它们都可以方便地处理不同类型的请求,如GET、POST、PUT等,并能处理不同的数据格式,如表单数据和JSON数据。

以下是一个简单的流程图,展示了使用这两个库进行网络请求的基本流程:

graph TD;
    A[创建项目并添加依赖] --> B[导入依赖和定义数据结构];
    B --> C[使用surf进行请求];
    B --> D[使用reqwest进行请求];
    C --> E[主函数调用测试函数];
    D --> E;
    E --> F[运行测试命令];

2. 运行机器学习模型

机器学习,特别是深度学习,近年来发展迅速。虽然Rust在机器学习领域的库支持相对较少,但社区提供了与流行深度学习框架的绑定,让我们可以使用Rust进行一些实验。这里将使用Rust的PyTorch绑定 tch-rs 来训练一个模型,识别时尚MNIST数据集中的服装物品。

2.1 准备工作

在开始之前,需要对神经网络有一定的了解。如果不太熟悉,可以先参加一些在线课程。准备工作步骤如下:
1. 打开终端,创建一个新的Rust项目:

cargo new rusty-ml
cd rusty-ml
mkdir models
  1. 克隆或下载Zalando Research的时尚MNIST数据集:
git clone https://github.com/zalandoresearch/fashion-mnist
  1. 解压数据文件:
    • 在Linux/macOS上,进入 fashion-mnist/data/fashion 目录,运行 gunzip *.gz
    • 在Windows上,使用解压工具解压。

时尚MNIST数据集是原始MNIST数据集的替代品,包含10个类别的服装物品,任务是正确分类这些物品。

2.2 添加依赖和导入代码

打开 Cargo.toml 文件,添加以下依赖:

[dependencies]
tch = "0.1"
failure = "0.1"

src/main.rs 文件中添加导入代码:

use std::io::{Error, ErrorKind};
use std::path::Path;
use std::time::Instant;
use tch::{nn, nn::ModuleT, nn::OptimizerConfig, Device, Tensor};

2.3 定义卷积神经网络

定义一个卷积神经网络 ConvNet

#[derive(Debug)]
struct ConvNet {
    conv1: nn::Conv2D,
    conv2: nn::Conv2D,
    fc1: nn::Linear,
    fc2: nn::Linear,
}

impl ConvNet {
    fn new(vs: &nn::Path, labels: i64) -> ConvNet {
        ConvNet {
            conv1: nn::conv2d(vs, 1, 32, 5, Default::default()),
            conv2: nn::conv2d(vs, 32, 64, 5, Default::default()),
            fc1: nn::linear(vs, 1024, 512, Default::default()),
            fc2: nn::linear(vs, 512, labels, Default::default()),
        }
    }
}

impl nn::ModuleT for ConvNet {
    fn forward_t(&self, xs: &Tensor, train: bool) -> Tensor {
        xs.view([-1, 1, 28, 28])
           .apply(&self.conv1)
           .relu()
           .max_pool2d_default(2)
           .apply(&self.conv2)
           .relu()
           .max_pool2d_default(2)
           .view([-1, 1024])
           .apply(&self.fc1)
           .relu()
           .dropout_(0.5, train)
           .apply(&self.fc2)
    }
}

2.4 训练神经网络

添加训练函数 train_from_scratch

fn train_from_scratch(learning_rate: f64, batch_size: i64, epochs: usize) -> failure::Fallible<()> {
    let data_path = Path::new("fashion-mnist/data/fashion");
    let model_path = Path::new("models/best.ot");
    if !data_path.exists() {
        println!(
            "Data not found at '{}'. Did you run 'git submodule update --init'?",
            data_path.to_string_lossy()
        );
        return Err(Error::from(ErrorKind::NotFound).into());
    }
    println!("Loading data from '{}'", data_path.to_string_lossy());
    let m = tch::vision::mnist::load_dir(data_path)?;

    let vs = nn::VarStore::new(Device::cuda_if_available());
    let net = ConvNet::new(&vs.root(), 10);
    let opt = nn::Adam::default().build(&vs, learning_rate)?;
    println!(
        "Starting training, saving model to '{}'",
        model_path.to_string_lossy()
    );

    let mut min_loss = ::std::f32::INFINITY;
    for epoch in 1..=epochs {
        let start = Instant::now();
        let mut losses = vec![];
        for (image_batch, label_batch) in m.train_iter(batch_size).shuffle().to_device(vs.device()) {
            let loss = net
               .forward_t(&image_batch, true)
               .cross_entropy_for_logits(&label_batch);
            opt.backward_step(&loss);
            losses.push(f32::from(loss));
        }
        let total_loss = losses.iter().sum::<f32>() / (losses.len() as f32);

        let test_accuracy = net
           .forward_t(&m.test_images, false)
           .accuracy_for_logits(&m.test_labels);

        if total_loss <= min_loss {
            vs.save(model_path)?;
            min_loss = total_loss;
        }

        println!(
            "{:4} | train loss: {:7.4} | test acc: {:5.2}% | duration: {}s",
            epoch,
            &total_loss,
            100. * f64::from(&test_accuracy),
            start.elapsed().as_secs()
        );
    }
    println!(
        "Done! The best model was saved to '{}'",
        model_path.to_string_lossy()
    );
    Ok(())
}

2.5 进行推理

训练好模型后,可以进行推理。添加推理函数 predict_from_best

fn predict_from_best() -> failure::Fallible<()> {
    let data_path = Path::new("fashion-mnist/data/fashion");
    let model_weights_path = Path::new("models/best.ot");
    let m = tch::vision::mnist::load_dir(data_path)?;
    let mut vs = nn::VarStore::new(Device::cuda_if_available());
    let net = ConvNet::new(&vs.root(), 10);
    println!(
        "Loading model weights from '{}'",
        model_weights_path.to_string_lossy()
    );
    vs.load(model_weights_path)?;

    println!("Probabilities and predictions for 10 random images in the test set");
    for (image_batch, label_batch) in m.test_iter(1).shuffle().to_device(vs.device()).take(10) {
        let raw_tensor = net
           .forward_t(&image_batch, false)
           .softmax(-1)
           .view(m.labels);
        let predicted_index: Vec<i64> = raw_tensor.argmax(0, false).into();
        let probabilities: Vec<f64> = raw_tensor.into();
        print!("[ ");
        for p in probabilities {
            print!("{:.4} ", p);
        }
        let label: Vec<i64> = label_batch.into();
        println!("] predicted {}, was {}", predicted_index[0], label[0]);
    }
    Ok(())
}

2.6 主函数和运行命令

src/main.rs 文件中添加主函数:

fn main() -> failure::Fallible<()> {
    train_from_scratch(1e-2, 1024, 5)?;
    predict_from_best()?;
    Ok(())
}

运行以下命令来训练模型和进行推理:

cargo run

2.7 工作原理

训练过程包括加载数据、初始化模型和优化器、迭代训练数据计算损失并进行反向传播。推理过程则是加载训练好的模型权重,对测试数据进行预测。

以下是训练和推理的流程图:

graph TD;
    A[准备工作] --> B[添加依赖和导入代码];
    B --> C[定义卷积神经网络];
    C --> D[训练神经网络];
    D --> E[保存最佳模型];
    E --> F[进行推理];
    F --> G[输出预测结果];

通过以上步骤,我们可以使用Rust进行网络请求和机器学习模型的训练与推理。Rust的高性能和安全性使其在这些领域具有很大的潜力。

评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符  | 博主筛选后可见
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值