Jetbot中基于PyTorch的避障训练及并部署代码分析
import torch
import torch.optim as optim
import torch.nn.functional as F
import torchvision
import torchvision.datasets as datasets
import torchvision.models as models
import torchvision.transforms as transforms
import cv2
import numpy as np
'''数据包转化'''
dataset = datasets.ImageFolder(
'dataset',
transforms.Compose([
transforms.ColorJitter(0.1, 0.1, 0.1, 0.1),
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
)
'''将数据包分成训练组和测试组'''
train_dataset, test_dataset = torch.utils.data.random_split(dataset, [len(dataset) - 50, 50])
'''创建两个实例,用于混排并生成图片'''
train_loader = torch.utils.data.DataLoader(
train_dataset,
batch_size=8,
shuffle=True,
num_workers=0
)
test_loader = torch.utils.data