参考文章 https://www.jianshu.com/p/42939bf83b8a
import torch
import numpy as np
import matplotlib.pyplot as plt
def fast_hist(a, b, n):
"""
生成混淆矩阵
a 是形状为(HxW,)的预测值
b 是形状为(HxW,)的真实值
n 是类别数
"""
# 确保a和b在0~n-1的范围内,k是(HxW,)的True和False数列
a = a.numpy()
b = b.numpy()
k = (a >= 0) & (a < n)
# 横坐标是预测的类别,纵坐标是真实的类别
return np.bincount(a[k].astype(int) + n*b[k], minlength=n ** 2).reshape(n, n)
def per_class_iou(hist):
"""
hist传入混淆矩阵(n, n)
"""
# 因为下面有除法,防止分母为0的情况报错
np.seterr(divide="ignore", invalid="ignore")
# 交集:np.diag取hist的对角线元素
# 并集:hist.sum(1)和hist.sum(0)分别按两个维度相加,而对角线元素加了两次,因此减一次
iou = np.diag(hist) / (hist.sum(1) + hist.sum(0) - np.diag(hist))