参考了PyTorch官方文档和Ray Tune官方文档
1、HYPERPARAMETER TUNING WITH RAY TUNE
2、How to use Tune with PyTorch
以PyTorch中的CIFAR 10图片分类为例,示范如何将Ray Tune融入PyTorch模型训练过程中。
其中,要求我们对原PyTorch程序做一些小的修改,包括:
- 将数据加载和训练过程封装到函数中;
- 使一些网络参数可配置;
- 增加检查点(可选);
- 定义用于模型调参的搜索空间。
下面以示例代码解析的形式介绍Ray Tune具体如何操作:
from functools import partial
import numpy as np
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import random_split
import torchvision
import torchvision.transforms as transforms
from ray import tune
from ray.tune import CLIReporter
from ray.tune.schedulers import ASHAScheduler
# 定义神经网络模型
class Net(nn.Module):
def __init__(self, l1=120, l2=84):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, l1) # 参数待指定
self.fc2 = nn.Linear(l1, l2) # 参数待指定
self.fc3 = nn.Linear(l2, 10) # 参数待指定
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 5 * 5)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
# 封装数据加载过程,传递全局数据路径,以保证不同实验间共享数据路径
def load_data(data_dir="/home/taoshouzheng/Local_Connection/Algorithms/ray/"):
transform = transforms.Compose([
transforms.ToTensor(),

最低0.47元/天 解锁文章
1136





