目录
mask iou
def calculate_iou(mask1, mask2):
# Convert masks to float tensors for calculations
mask1 = mask1.to(torch.float32)
mask2 = mask2.to(torch.float32)
# Calculate intersection and union
intersection = (mask1 * mask2).sum()
union = mask1.sum() + mask2.sum() - intersection
# Calculate IoU
iou = intersection / union
iou1=intersection / min(mask1.sum(),mask2.sum())
return iou,iou1