PyTorch训练脚本,测试GPU是否正确加载

        代码示例是一个简单的PyTorch训练脚本,它包含了命令行参数解析、设备选择(GPU或CPU)、模型定义、数据加载、训练循环等关键部分。

import argparse
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset


# 假设的模型和数据(在实际应用中,您会有自己的模型和数据)
class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.fc = nn.Linear(10, 1)

    def forward(self, x):
        return self.fc(x)

    # 生成一些随机数据作为示例


data = torch.randn(100, 10)
targets = torch.randn(100, 1)
dataset = TensorDataset(data, targets)
dataloader = DataLoader(dataset, batch_size=10)

# 解析命令行参数  
parser = argparse.ArgumentParser(description='Model training script')
parser.add_argument('--device', default='cuda:0' if torch.cuda.is_available() else 'cpu',
                    help='cuda device, i.e. cuda:0, cuda:1, ... or cpu. Default is cuda:0 if available.')
opt = parser.parse_args()


# 列出所有可用的 GPU 及其名称
def list_gpus():
    if torch.cuda.is_available():
        num_gpus = torch.cuda.device_count()
        print(f"Number of available GPUs: {num_gpus}")
        for i in range(num_gpus):
            gpu_name = torch.cuda.get_device_name(i)
            print(f"GPU {i}: {gpu_name}")
    else:
        print("CUDA is not available. Only CPU will be used.")

    # 检查设备字符串并设置设备


# 注意:这里我们假设用户输入的设备字符串是有效的,如果不有效,PyTorch 会在后续操作中报错
device = torch.device(opt.device)

# 打印设备信息  
print(f'Using device: {device}')

# 列出所有 GPU(在训练之前)  
list_gpus()

# 初始化模型、损失函数和优化器  
model = SimpleModel().to(device)
criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)

# 训练循环  
for epoch in range(10):  # 假设训练 10 个 epoch  
    model.train()  # 设置模型为训练模式  
    for inputs, targets in dataloader:
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()
    print(f'Epoch {epoch + 1}, Loss: {loss.item()}')

# 训练完成后可以再次列出 GPU(可选)  
# list_gpus()  # 这通常不是必需的,因为 GPU 的状态在训练过程中不会改变(除非有其他进程在使用它们)

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值