Neural Network Intelligence (NNI) | PyTorch实现GPU的MNIST卷积/LSTM两个版本

这篇博客介绍了如何使用NNI库在PyTorch中实现GPU加速的MNIST数据集卷积神经网络(CNN)和长短期记忆网络(LSTM)模型。作者强调了实际操作中的难点和学习过程。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

环境 :

conda install pytorch==1.7.0 torchvision==0.8.0 torchaudio==0.7.0 cudatoolkit=10.1 -c pytorch

每次看起来都很简单,自己写起来才知道哪里不会。


import os
import argparse
import sys
# import logging
# import nni
import torch
import time
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
# from nni.utils import merge_parameter
from torchvision import datasets, transforms

class LSTMs(nn.Module):
    def __init__(self, hidden_dim, n_layer):
        super(LSTMs, self).__init__()
        self.n_layer = n_layer
        self.hidden_dim = hidden_dim
        self.lstm = nn.LSTM(28, hidden_dim, n_layer, batch_first=True)
        self.linear = nn.Linear(hidden_dim, 10)

    def forward(self, x):
        # x = x.squeeze() 
        out, _ = self.lstm(x)
        out = out[:, -1, :]
        out = self.linear(out)
        return out

def train(args, net, device, train_loader, optimizer, epoch):
    net.train()
    # train_loss = 0
    correct = 0
    total = 0
    for batch_idx,(data,label) in enumerate(train_loader):
        data, label = data.to(device),label.to(device)
        optimizer.zero_grad()
        data = data.squeeze()
        out = net(data)
        # 一定要这种写法
        loss = torch.nn.functional.cross_entropy(out, label)
        loss.backward()
        optimizer.step()

        # train_loss += loss.item()
        predicted = torch.max(out, 1)[1].cpu().numpy()
        total += label.size(0)
        correct += float((predicted == label.cpu().numpy()).astype(int).sum()) #/ float(label.size(0))

        # acc = 100.*correct/total
        progress_bar(batch_idx, len(train_loader), 'Loss: %.4f | Acc: %.4f%% (%d/%d)'
            % (loss/(batch_idx+1), 100.*correct/total, correct, total))


def test(args, net, device, test_loader):
    net.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for index,(data,label) in enumerate(test_loader):
            data, label = data.to(device),label.to(device)
            data = data.squeeze()
            out = net(data)
            # test_loss += F.nll_loss(out, label, reduction='sum').item()
            test_loss += torch.nn.functional.cross_entropy(out, label, reduction='sum').item()
            pred = out.argmax(dim=1,keepdim=True)
            correct += pred.eq(label.view_as(pred)).sum().item()
    test_loss /= len(test_loader.dataset)
    accuracy = 100.*correct/len(test_loader.dataset)
    print("Test: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n".format(
        test_loss, correct, len(test_loader.dataset), accuracy))

    return accuracy
term_width = 0
try:
    term_width = os.get_terminal_size().columns
except Exception as exception:
    term_width = 200
term_width = int(term_width)
def format_time(seconds):
    days = int(seconds / 3600/24)
    seconds = seconds - days*3600*24
    hours = int(seconds / 3600)
    seconds = seconds - hours*3600
    minutes = int(seconds / 60)
    seconds = seconds - minutes*60
    secondsf = int(seconds)
    seconds = seconds - secondsf
    millis = int(seconds*1000)
    f = ''
    i = 1
    if days > 0:
        f += str(days) + 'D'
        i += 1
    if hours > 0 and i <= 2:
        f += str(hours) + 'h'
        i += 1
    if minutes > 0 and i <= 2:
        f += str(minutes) + 'm'
        i += 1
    if secondsf > 0 and i <= 2:
        f += str(secondsf) + 's'
        i += 1
    if millis > 0 and i <= 2:
        f += str(millis) + 'ms'
        i += 1
    if f == '':
        f = '0ms'
    return f
TOTAL_BAR_LENGTH = 95.
last_time = time.time()
begin_time = last_time
def progress_bar(current, total, msg=None):
    global last_time, begin_time
    if current == 0:
        begin_time = time.time()  # Reset for new bar.
    cur_len = int(TOTAL_BAR_LENGTH*current/total)
    rest_len = int(TOTAL_BAR_LENGTH - cur_len) - 1
 
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值