PyTorch + Ray Tune 调参

参考了PyTorch官方文档Ray Tune官方文档

1、HYPERPARAMETER TUNING WITH RAY TUNE

2、How to use Tune with PyTorch

以PyTorch中的CIFAR 10图片分类为例,示范如何将Ray Tune融入PyTorch模型训练过程中。

其中,要求我们对原PyTorch程序做一些小的修改,包括:

  • 将数据加载和训练过程封装到函数中;
  • 使一些网络参数可配置;
  • 增加检查点(可选);
  • 定义用于模型调参的搜索空间。

下面以示例代码解析的形式介绍Ray Tune具体如何操作:

from functools import partial
import numpy as np
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import random_split
import torchvision
import torchvision.transforms as transforms

from ray import tune
from ray.tune import CLIReporter
from ray.tune.schedulers import ASHAScheduler


# 定义神经网络模型
class Net(nn.Module):
    def __init__(self, l1=120, l2=84):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, l1)        # 参数待指定
        self.fc2 = nn.Linear(l1, l2)        # 参数待指定
        self.fc3 = nn.Linear(l2, 10)        # 参数待指定

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


# 封装数据加载过程,传递全局数据路径,以保证不同实验间共享数据路径
def load_data(data_dir="/home/taoshouzheng/Local_Connection/Algorithms/ray/"):
    transform = transforms.Compose([
        transforms.ToTensor(),
 
评论 10
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值