yolov3后处理,包括网络输出、阈值过滤、多类NMS

  • pytorch版本:1.7.1+cu101
from net.yolov3 import Yolo
from PIL import Image, ImageDraw
import numpy as np
import torch, copy, time

def sigmoid(x):
	temporary = 1 + torch.exp(-x)
	return 1.0 / temporary

def get_boxes(output, anchors):
	h = output.size(2)
	w = output.size(3)
	output = output.view(3, 85, h, w).permute(0,2,3,1)
	tc = torch.sigmoid(output[..., 4])    # 3*h*w
	cl = torch.sigmoid(output[..., 5:])   # 3*h*w*80
	clv, cli = torch.max(cl,-1)
	mask = tc * clv > 0.9
	# print(torch.where(mask))
	cli = cli[mask].unsqueeze(-1)
	# 3*h*w
	tx = torch.sigmoid(output[..., 0][mask])
	ty = torch.sigmoid(output[..., 1][mask])
	tw = torch.exp(output[..., 2][mask])
	th = torch.exp(output[..., 3][mask])
	# grid
	FloatTensor = torch.cuda.FloatTensor if tx.is_cuda else torch.FloatTensor
	grid_x, grid_y = torch.meshgrid(torch.linspace(0,w-1,w), torch.linspace(0,h-1,h))
	grid_x = grid_x.repeat(3,1,1)[mask].type(FloatTensor)
	grid_y = grid_y.repeat(3,1,1)[mask].type(FloatTensor)
	tx = ((tx+grid_y) / w).unsqueeze(-1)
	ty = ((ty+grid_x) / h).unsqueeze(-1)
	# anchor
	aw = torch.Tensor(anchors[0::2]).view(3,1).repeat(1,h*w).view(3,h,w)[mask].type(FloatTensor)
	ah = torch.Tensor(anchors[1::2]).view(3,1
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

刀么克瑟拉莫

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值