多目标跟踪论文 Deep SORT 特征训练PyTorch实现

博客介绍了多目标跟踪论文Deep SORT特征训练的PyTorch实现。参考特定网址,提到Market - 1501数据需用prepare.py处理,以符合PyTorch加载方式,使训练和测试文件夹分开,而MARS数据集无需处理。

多目标跟踪论文 Deep SORT 特征训练PyTorch实现

flyfish

参考网址
https://github.com/layumi/Person_reID_baseline_pytorch

Market-1501数据先要用prepare.py处理一下,符合PyTorch加载方式
也就是产生一些子文件夹,让训练文件夹和测试文件夹分开
MARS数据集不用动

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Wed May 29 10:32:21 2019

@author: santiago
"""

import os
import time
import numpy as np
import matplotlib.pyplot as plt
import torch
import torchvision
from CNNArchitecture import Net
from torch.utils.data import Dataset, DataLoader


MARS=None

#这里提供了两种数据集MARS和Market-1501
if(MARS):#MARS
    root = "/home/santiago/dataset/MARS-v160809/"
    train_dir = os.path.join(root,"bbox_train")
    test_dir = os.path.join(root,"bbox_test")# height256 width128
else:#Market-1501
    root = "/home/santiago/dataset/Market-1501-v15.09.15/pytorch/"
    train_dir = os.path.join(root,"train")
    test_dir = os.path.join(root,"query")# 

#torchvision.transforms.Compose(transforms)

#数据增强部分 DataArgumentation
#Compose看做是一种容器
#将多个transform组合起来使用,包括以下transforms
#"Compose", "ToTensor", "ToPILImage", "Normalize", "Resize",
#"Scale", "CenterCrop", "Pad", "Lambda", "RandomCrop", 
#"RandomHorizontalFlip", "RandomVerticalFlip", "RandomResizedCrop", 
#"RandomSizedCrop", "FiveCrop", "TenCrop","LinearTransformation", 
#"ColorJitter", "RandomRotation", "Grayscale", "RandomGrayscale"
transform_train = torchvision.transforms.Compose([
    torchvision.transforms.RandomCrop((128,64),padding=4),
    torchvision.transforms.RandomHorizontalFlip(),#对载入的图片按照随机概率进行水平翻转
    torchvision.transforms.ToTensor(),##转换成PyTorch能够计算和处理的Tensor数据类型的变量
    torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
transform_test = torchvision.transforms.Compose([
    torchvision.transforms.Resize((128,64)),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
#[0.485, 0.456, 0.406], [0.229, 0.224, 0.225]这组数据是官网提供的数据,该数据来源的网址
#https://pytorch.org/docs/stable/torchvision/models.html

#还可以采用如下表达方式   
#transformList = []
#transformList.append(transforms.RandomHorizontalFlip())
#transformList.append(transforms.ToTensor())
#transformList.append(transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
#transform_train = transforms.Compose(transformList)

#数据加载
trainloader = torch.utils.data.DataLoader(
        torchvision.datasets.ImageFolder(train_dir, transform=transform_train),
        batch_size=64,shuffle=True
    )
testloader = torch.utils.data.DataLoader(
        torchvision.datasets.ImageFolder(test_dir, transform=transform_test),
        batch_size=64,shuffle=True
    )

 
#torch.utils.data.DataSet

num_classes = len(trainloader.dataset.classes)
print(num_classes)#mars 625 ,market1501 751
device = "cpu"#根据使用GPU 还是CPU 操作进行变更

# net definition
start_epoch = 0 #训练过程保存模型,训练过程中断后加载先前保存的模型,这里相当于记录训练次数
net = Net(num_classes=num_classes)

net.to(device)
total_epoch = 10000 #总共训练次数


#加载先前训练的模型
if(os.path.isfile("./checkpoint/ckpt.pytorch")):
    print('Loading model')
    checkpoint = torch.load("./checkpoint/ckpt.pytorch")
    net_dict = checkpoint['net_dict']
    net.load_state_dict(net_dict)
    best_acc = checkpoint['acc']
    start_epoch = checkpoint['epoch']
else:
    print('new model')  




# loss and optimizer
#网络定义好之后,还要定义模型的损失函数和对参数进行优化的优化函数
#优化函数使用的SGD,损失函数使用的是交叉熵
criterion = torch.nn.CrossEntropyLoss()
learning_rate=0.1#lr
optimizer = torch.optim.SGD(net.parameters(), learning_rate, momentum=0.9, weight_decay=5e-4)
best_acc = 0.

# train function for each epoch
def train(epoch):
    print("\nEpoch : %d"%(epoch+1))
    net.train()
    training_loss = 0.
    train_loss = 0.
    correct = 0
    total = 0
    interval = 20
    start = time.time()
    for idx, (inputs, labels) in enumerate(trainloader):
        # forward
        inputs,labels = inputs.to(device),labels.to(device)
        outputs = net(inputs)
        loss = criterion(outputs, labels)

        # backward
        #引入优化算法,要调用optimizer.zero_grad()完成对模型参数梯度的归0
        #如果没有optimizer.zero_grad(),梯度会累加在一起,结果不收敛
        #loss.backward() 根据计算图自动计算每个节点的梯度值,并根据需要进行保留
        #
        #backward主要是模型的反向传播中的自动梯度计算,在网络定义中的forward是模型前向传播中的矩阵计算
        #optimizer.step()作用是使用计算得到的梯度值对各个节点参数进行梯度更新
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # accumurating
        training_loss += loss.item()##注意显存
        train_loss += loss.item()
        correct += outputs.max(dim=1)[1].eq(labels).sum().item()
        total += labels.size(0)

        #每个一段时间将损失的值打印出来看,确定模型误差越来越小 
        #Variable会放在计算图中,然后进行前向传播,反向传播,自动求导
        #,可以通过data取出Variable中的tensor数值
        #如果要打印,还可以用loss.data[0]
        if (idx+1)%interval == 0:
            end = time.time()
            print("[progress:{:.1f}%]time:{:.2f}s Loss:{:.5f} Correct:{}/{} Acc:{:.3f}%".format(
                100.*(idx+1)/len(trainloader), end-start, training_loss/interval, correct, total, 100.*correct/total
            ))
            training_loss = 0.
            start = time.time()
    
    return train_loss/len(trainloader), 1.- correct/total

def test(epoch):
    global best_acc
    net.eval()
    test_loss = 0.
    correct = 0
    total = 0
    start = time.time()
    with torch.no_grad():#不需要bp的forward, 注意model.eval() 不等于 torch.no_grad()
        for idx, (inputs, labels) in enumerate(testloader):
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = net(inputs)
            loss = criterion(outputs, labels)

            test_loss += loss.item()
            correct += outputs.max(dim=1)[1].eq(labels).sum().item()
            total += labels.size(0)
        
        print("Testing ...")
        end = time.time()
        print("[progress:{:.1f}%]time:{:.2f}s Loss:{:.5f} Correct:{}/{} Acc:{:.3f}%".format(
                100.*(idx+1)/len(testloader), end-start, test_loss/len(testloader), correct, total, 100.*correct/total
            ))

    # saving checkpoint
    acc = 100.*correct/total
    if acc >=best_acc:
        best_acc = acc
        print("Saving parameters to checkpoint/ckpt.pytorch")
        checkpoint = {
            'net_dict':net.state_dict(),
            'acc':acc,
            'epoch':epoch,
        }
        if not os.path.isdir('checkpoint'):
            os.mkdir('checkpoint')
        torch.save(checkpoint, './checkpoint/ckpt.pytorch')

    return test_loss/len(testloader), 1.- correct/total




#学习率衰减方法:
#线性衰减。例如:每过20个epochs学习率减半
#指数衰减。例如:每过20个epochs将学习率乘以0.1
def learning_rate_decay():# 学习率衰减(learning rate decay)
    global optimizer
    for params in optimizer.param_groups:
        params['lr'] *= 0.1
        lr = params['lr']
        print("Learning rate adjusted to {}".format(lr))

def main():
    for epoch in range(start_epoch, start_epoch+total_epoch):
        train_loss, train_err = train(epoch)
        test_loss, test_err = test(epoch)
        if (epoch+1)%20==0:
            learning_rate_decay()


if __name__ == '__main__':
    main()

类似的训练过程如下

Epoch : 4
[progress:10.5%]time:30.97s Loss:5.89092 Correct:71/1280 Acc:5.547%
[progress:20.9%]time:30.94s Loss:5.90198 Correct:129/2560 Acc:5.039%
[progress:31.4%]time:30.93s Loss:5.91590 Correct:178/3840 Acc:4.635%
[progress:41.9%]time:30.95s Loss:5.88206 Correct:235/5120 Acc:4.590%
[progress:52.4%]time:30.96s Loss:5.89019 Correct:300/6400 Acc:4.688%
[progress:62.8%]time:31.02s Loss:5.85304 Correct:370/7680 Acc:4.818%
[progress:73.3%]time:31.02s Loss:5.81869 Correct:442/8960 Acc:4.933%
[progress:83.8%]time:31.00s Loss:5.81698 Correct:514/10240 Acc:5.020%
[progress:94.2%]time:31.09s Loss:5.76503 Correct:596/11520 Acc:5.174%
Testing ...
[progress:100.0%]time:34.81s Loss:6.74831 Correct:6/3368 Acc:0.178%
Saving parameters to checkpoint/ckpt.pytorch
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

二分掌柜的

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值