def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=None, agnostic=False, multi_label=False,
labels=()):
"""Runs Non-Maximum Suppression (NMS) on inference results
Returns:
list of detections, on (n,6) tensor per image [xyxy, conf, cls]
"""
nc = prediction.shape[2] - 5 # 分类数
# 第四个值框置信度大于conf_thres的为True,否则为False
xc = prediction[..., 4] > conf_thres # candidates
# Settings
min_wh, max_wh = 2, 4096 # (pixels) minimum and maximum box width and height
max_det = 300 # maximum number of detections per image
max_nms = 30000 # maximum number of boxes into torchvision.ops.nms()
time_limit = 10.0 # seconds to quit after
redundant = True # require redundant detections
# 是否属于多分类问题
multi_label &= nc > 1 # multiple labels per box (adds 0.5ms/img)
# 默认是关闭的,使用的话需要修改为True
merge = False # use merge-NMS
t = time.time()
# 创建一个存储容器
output = [torch.zeros((0, 6), device=prediction.device)] * prediction.shape[0]
# xi表示某张图片的索引
# x表示某张图片的tensor张量数据,x中包含数条预测框数据
for xi, x in enumerate(prediction): # image index, image inference
# xc[xi]:表示筛选某张图片框置信度大于阈值的所有数据,这里的表示形式是True或False,而不是数据
# x表示所有的candidates 为True的数据,这里表示筛选出所有置信度大于conf_thres的框的数据
x = x[xc[xi]] # confidence
# Cat apriori labels if autolabelling
# 好像没用到,暂时不管了
if labels and len(labels[xi]):
l = labels[xi]
v = torch.zeros((len(l), nc + 5), device=x.device)
v[:, :4] = l[:, 1:5] # box
v[:, 4] = 1.0 # conf
v[range(len(l)), l[:, 0].long() + 5] = 1.0 # cls
x = torch.cat((x, v), 0)
# If none remain process next image
if not x.shape[0]:
continue
# Compute conf
# 见2讲解
# 计算最后六个分类置信度信息,将他们分别乘上第四个数据(框置信度)
x<
YOLOv5——NMS
最新推荐文章于 2025-06-12 14:11:13 发布