前言:
1-下载SAM算法文件(因为是github,所以需要外网)
重点:segment_anything调用时,需要torch,和torchvision库,这种库读者可以直接通过pip或者conda提前安装好。
2-文件安装到虚拟环境,并创建项目工程。
3-文件目录中重点文件解读
①:setup.py文件是segment_anything的安装文件,具体操作看前言2。
②:我自己本地下载的sam的已经训练好的官方的两个模型参数文件,但github源文件没有,可能是因为模型参数太大,所以没放。也可能是可以在线需要本地参数文件
内容:
一、蒙版mask + 原图image = 结果图imageResult
mask蒙版颜色博主这里设置为了绿色【0,255,0】
def show_mask(image, mask):
# mask 的形状是(1 X 图像高 X 图像宽)
color_ = [ 0, 255, 0] # 把mask做成 绿色蒙版
color = np.array(color_) # 把mask做成 绿色蒙版
h, w = mask.shape[-2:] # 从mask中 取出蒙版的 高 与 宽
imageResult = mask.reshape(h, w, 1) * color.reshape(1, 1, -1) # 把mask蒙版形状改为 (图像高 X 图像宽 X 3)
# 便利原图像素点,将蒙版点更新到原图的
for i in range(0, h ):
for j in range(0, w ):
# 判断mask蒙版中的rbg是否为绿色
if color_[0] == imageResult[i][j][0] and color_[1] == imageResult[i][j][1] and color_[2] == imageResult[i][j][2]:
image[i][j] = color_ # 将image原图中对应mask蒙版的点,改为绿色
return image
二、给模型输入提示点input_point,给予提示信息
# input_point 要与 input_label 数量对应
input_point = np.array([[500,200], [840,210], [970,320], [904,178]]) # 给模型输入的提示点
input_label = np.array([1, 0, 0, 0]) # 1表示该坐标点为你想要分割出来的(增加);0表示该坐标点内的图像不是你想要的图像(抑制)
input_label表示的是,input_point给的提示信息就要找该点图像内容分割1,还是该点图像不是被分割的对象0。
举例:
如果你想分割出原图image的猫的话,你要给出图像中猫的一个或者几个坐标点,并且把该坐标点的label设置为1。如果分割出结果带有其他背景图像,你可以把对应的非猫的图像坐标点标记出来,label设置为0。
比如这里,图像原大小为 :(1044,598)
我提供的猫的像素点为:(814,280),(835,503),(854,162)
我提供的除猫外的不需要分割的背景图为:(186,318),(1005,384)
所以我这里的input_point 和input_label因为为:
# input_point 要与 input_label 数量对应
input_point = np.array([[814,280], [835,503], [854,162], [186,318], [1005,384]]) # 给模型输入的提示点
input_label = np.array([1, 1, 1, 0, 0]) # 1表示该坐标点为你想要分割出来的(增加);0表示该坐标点内的图像不是你想要的图像(抑制)
最后运行结果如下:
三、具体代码如下,每段代码我都写好了注释了。
import time
import cv2
import os
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
import numpy as np
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
# segment_anything 是SAM算法的依赖库
from segment_anything import sam_model_registry, SamPredictor
# 屏蔽后台警告
import warnings
warnings.filterwarnings("ignore")
# 输入: 原图 + 蒙版 -》 输出: 图像最终结果(数组形式)
def show_mask(image, mask):
# mask 的形状是(1 X 图像高 X 图像宽)
color_ = [ 0, 255, 0] # 把mask做成 绿色蒙版
color = np.array(color_) # 把mask做成 绿色蒙版
h, w = mask.shape[-2:] # 从mask中 取出蒙版的 高 与 宽
imageResult = mask.reshape(h, w, 1) * color.reshape(1, 1, -1) # 把mask蒙版形状改为 (图像高 X 图像宽 X 3)
# 便利原图像素点,将蒙版点更新到原图的
for i in range(0, h ):
for j in range(0, w ):
# 判断mask蒙版中的rbg是否为绿色
if color_[0] == imageResult[i][j][0] and color_[1] == imageResult[i][j][1] and color_[2] == imageResult[i][j][2]:
image[i][j] = color_ # 将image原图中对应mask蒙版的点,改为绿色
return image
sam_checkpoint = "sam_vit_h_4b8939.pth" # 定义模型路径
model_type = "vit_h" # 定义模型类型
device = "cpu" # "cpu" or "cuda"
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device = device) # 定义模型参数
predictor = SamPredictor(sam) # 调用预测模型
start_time = time.time() # 记录代码运行时间:开始
photoPath = "TestDemo.png" # 读片位置
image = cv2.imread(photoPath) # opencv 方式读取图片
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # 将图像从BGR颜色空间转换为RGB颜色空间
predictor.set_image(image) # 预测图像
# input_point 要与 input_label 数量对应
input_point = np.array([[814,280], [835,503], [854,162], [186,318], [1005,384]]) # 给模型输入的提示点
input_label = np.array([1, 1, 1, 0, 0]) # 1表示该坐标点为你想要分割出来的(增加);0表示该坐标点内的图像不是你想要的图像(抑制)
masks, scores, logits = predictor.predict(
point_coords=input_point,
point_labels=input_label,
multimask_output=False, # 为True时会有三个mask结果,以及对应的mask得分;False时只有一个mask和其对应的一个得分
)
imageResult = show_mask(image, masks) # imageResult是最后的结果(图片高 X 图片款 X 3),
plt.imshow(imageResult) # 用matplotlib中plt的方法读取最后的结果,准备导出为图片
plt.imsave("xx.png", imageResult) # 用matplotlib中plt的方法,将结果存为 xx.png图片
end_time = time.time()# 记录代码运行时间:结束
execution_time = end_time - start_time # 计算运行时间
print("Socre: ", scores, " ",execution_time, "s") #后台输出结果
总结:
本次提供的是 点 方式的输入,从而得到分割结果。
Pycharm程序后台
①:main.py (新建项目带的,本次程序中没用到)。
②:photoTest.py(主要程序文件)。
③:sam_vit_h_4b8939.pth(SAM模型参数文件,但无此文件,程序运行时也可以通过联网的方式获取)。
④:TestDemo.png:(image原图)。
⑤:xx.png:(imageResult最终结果图)。