用训练得到的模型预测-单张图片
载入配置文件和模型
import numpy as np
import matplotlib.pyplot as plt
from mmseg.apis import init_model, inference_model, show_result_pyplot
import mmcv
import cv2
from mmengine import Config
cfg = Config.fromfile('../Zihao-Configs/ZihaoDataset_Segformer_20230712.py')
checkpoint_path = '../work_dirs/ZihaoDataset-Segformer/iter_40000.pth'
model = init_model(cfg, checkpoint_path, 'cuda:0')
载入图片进行语义分割
img_path = '../Watermelon87_Semantic_Seg_Mask/img_dir/val/01bd15599c606aa801201794e1fa30.jpg'
img_bgr = cv2.imread(img_path)
result = inference_model(model, img_bgr)
pred_mask = result.pred_sem_seg.data[0].cpu().numpy()
plt.figure(figsize=(8, 8))
plt.imshow(pred_mask)
plt.savefig('../outputs/K1-0.jpg')
分割结果

设置高亮区域透明度
plt.figure(figsize=(10, 8))
plt.imshow(img_bgr[:,:,::-1])
plt.imshow(pred_mask, alpha=0.55)
plt.axis('off')
plt.savefig('../outputs/K1-1.jpg')
plt.show()

将分割结果与原图并排显示
plt.figure(figsize=(14, 8))
plt.subplot(1,2,1)
plt.imshow(img_bgr[:,:,::-1])
plt.axis('off')
plt.subplot(1,2,2)
plt.imshow(img_bgr[:,:,::-1])
plt.imshow(pred_mask, alpha=0.6)
plt.axis('off')
plt.savefig('../outputs/K1-2.jpg')

按配色方案叠加在原图上显示
palette = [
['background', [127,127,127]],
['red', [0,0,200]],
['green', [0,200,0]],
['white', [144,238,144]],
['seed-black', [30,30,30]],
['seed-white', [8,189,251]]
]
palette_dict = {}
for idx, each in enumerate(palette):
palette_dict[idx] = each[1]
opacity = 0.3
pred_mask_bgr = np.zeros((pred_mask.shape[0], pred_mask.shape[1], 3))
for idx in palette_dict.keys():
pred_mask_bgr[np.where(pred_mask==idx)] = palette_dict[idx]
pred_mask_bgr = pred_mask_bgr.astype('uint8')
pred_viz = cv2.addWeighted(img_bgr, opacity, pred_mask_bgr, 1-opacity, 0)
cv2.imwrite('../outputs/K1-3.jpg', pred_viz)

加上图例
from mmseg.datasets import ZihaoDataset
import numpy as np
import mmcv
from PIL import Image
classes = ZihaoDataset.METAINFO['classes']
palette = ZihaoDataset.METAINFO['palette']
opacity = 0.15
seg_map = pred_mask.astype('uint8')
seg_img = Image.fromarray(seg_map).convert('P')
seg_img.putpalette(np.array(palette, dtype=np.uint8))
from matplotlib import pyplot as plt
import matplotlib.patches as mpatches
plt.figure(figsize=(14, 8))
img_plot = ((np.array(seg_img.convert('RGB')))*(1-opacity) + mmcv.imread(img_path)*opacity) / 255
im = plt.imshow(img_plot)
patches = [mpatches.Patch(color=np.array(palette[i])/255., label=classes[i]) for i in range(len(classes))]
plt.legend(handles=patches, bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0., fontsize='large')
plt.savefig('../outputs/K1-6.jpg')

绘制混淆矩阵
label_path = '../Watermelon87_Semantic_Seg_Mask/ann_dir/val/01bd15599c606aa801201794e1fa30.png'
label = cv2.imread(label_path)
label_mask = label[:,:,0]
from sklearn.metrics import confusion_matrix
confusion_matrix_model = confusion_matrix(label_mask.flatten(), pred_mask.flatten())
import itertools
def cnf_matrix_plotter(cm, classes, cmap=plt.cm.Blues):
"""
传入混淆矩阵和标签名称列表,绘制混淆矩阵
"""
plt.figure(figsize=(10, 10))
plt.imshow(cm, interpolation='nearest', cmap=cmap)
tick_marks = np.arange(len(classes))
plt.title('Confusion Matrix', fontsize=30)
plt.xlabel('Pred', fontsize=25, c='r')
plt.ylabel('True', fontsize=25, c='r')
plt.tick_params(labelsize=16)
plt.xticks(tick_marks, classes, rotation=90)
plt.yticks(tick_marks, classes)
threshold = cm.max() / 2.
for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
plt.text(j, i, cm[i, j],
horizontalalignment="center",
color="white" if cm[i, j] > threshold else "black",
fontsize=12)
plt.tight_layout()
plt.savefig('../outputs/K1-混淆矩阵.jpg', dpi=300)
plt.show()
cnf_matrix_plotter(confusion_matrix_model, classes, cmap='Blues')

用训练得到的模型预测-视频
import numpy as np
import matplotlib.pyplot as plt
from mmseg.apis import init_model, inference_model, show_result_pyplot
import mmcv
import cv2
import tqdm
import time
from mmengine import Config
cfg = Config.fromfile('../Zihao-Configs/ZihaoDataset_Segformer_20230712.py')
checkpoint_path = '../work_dirs/ZihaoDataset-Segformer/iter_40000.pth'
model = init_model(cfg, checkpoint_path, 'cuda:0')
palette = [
['background', [127,127,127]],
['red', [0,0,200]],
['green', [0,200,0]],
['white', [144,238,144]],
['seed-black', [30,30,30]],
['seed-white', [8,189,251]]
]
palette_dict = {}
for idx, each in enumerate(palette):
palette_dict[idx] = each[1]
opacity = 0.3
def process_frame(img_bgr):
start_time = time.time()
result = inference_model(model, img_bgr)
pred_mask = result.pred_sem_seg.data[0].cpu().numpy()
pred_mask_bgr = np.zeros((pred_mask.shape[0], pred_mask.shape[1], 3))
for idx in palette_dict.keys():
pred_mask_bgr[np.where(pred_mask==idx)] = palette_dict[idx]
pred_mask_bgr = pred_mask_bgr.astype('uint8')
pred_viz = cv2.addWeighted(img_bgr, opacity, pred_mask_bgr, 1-opacity, 0)
return pred_viz
def generate_video(input_path='videos/robot.mp4'):
filehead = input_path.split('/')[-1]
output_path = "out-" + filehead
print('视频开始处理',input_path)
cap = cv2.VideoCapture(input_path)
frame_count = 0
while(cap.isOpened()):
success, frame = cap.read()
frame_count += 1
if not success:
break
cap.release()
print('视频总帧数为',frame_count)
cap = cv2.VideoCapture(input_path)
frame_size = (cap.get(cv2.CAP_PROP_FRAME_WIDTH), cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
fps = cap.get(cv2.CAP_PROP_FPS)
out = cv2.VideoWriter(output_path, fourcc, fps, (int(frame_size[0]), int(frame_size[1])))
with tqdm(total=frame_count-1) as pbar:
try:
while(cap.isOpened()):
success, frame = cap.read()
if not success:
break
try:
frame = process_frame(frame)
except:
pass
if success == True:
out.write(frame)
pbar.update(1)
except:
print('中途中断')
pass
cv2.destroyAllWindows()
out.release()
cap.release()
print('视频已保存', output_path)
generate_video(input_path='../data/video_watermelon_3.mov')
本地摄像头实时预测
载入模型并且设置帧处理函数
import numpy as np
import matplotlib.pyplot as plt
from mmseg.apis import init_model, inference_model, show_result_pyplot
import mmcv
import cv2
import tqdm
import time
from mmengine import Config
cfg = Config.fromfile('../Zihao-Configs/ZihaoDataset_Segformer_20230712.py')
checkpoint_path = '../work_dirs/ZihaoDataset-Segformer/iter_40000.pth'
model = init_model(cfg, checkpoint_path, 'cuda:0')
palette = [
['background', [127,127,127]],
['red', [0,0,200]],
['green', [0,200,0]],
['white', [144,238,144]],
['seed-black', [30,30,30]],
['seed-white', [8,189,251]]
]
palette_dict = {}
for idx, each in enumerate(palette):
palette_dict[idx] = each[1]
opacity = 0.3
def process_frame(img_bgr):
start_time = time.time()
result = inference_model(model, img_bgr)
pred_mask = result.pred_sem_seg.data[0].cpu().numpy()
pred_mask_bgr = np.zeros((pred_mask.shape[0], pred_mask.shape[1], 3))
for idx in palette_dict.keys():
pred_mask_bgr[np.where(pred_mask==idx)] = palette_dict[idx]
pred_mask_bgr = pred_mask_bgr.astype('uint8')
pred_viz = cv2.addWeighted(img_bgr, opacity, pred_mask_bgr, 1-opacity, 0)
单张摄像拍照预测
cap = cv2.VideoCapture(0)
time.sleep(1)
success, frame = cap.read()
cap.release()
cv2.destroyAllWindows()
实时预测
import cv2
import time
cap = cv2.VideoCapture(1)
cap.open(0)
while cap.isOpened():
success, frame = cap.read()
if not success:
print('获取画面不成功,退出')
break
frame = process_frame(frame)
cv2.imshow('my_window',frame)
key_pressed = cv2.waitKey(60)
if key_pressed in [ord('q'),27]:
break
cap.release()
cv2.destroyAllWindows()