1 SAM概念
###1.1 SAM定义
Segment Anything Model(SAM)是一种基于深度学习的图像分割模型,其主要特点包括:
- 高质量的图像分割:SAM可以从输入提示(如点、框、文字等)生成高质量的对象掩模,实现对图像中对象的精确分割。
- 零次学习性能:SAM已经在一个包含1100万张图像和110亿个掩模的数据集上进行了训练,具有强大的零次学习性能。这意味着它可以迁移到新的图像分布和任务中,即使在训练阶段没有见过的物体类别也能够进行分割。
- 网络结构:SAM采用了类似于U-Net的编码器-解码器结构。编码器部分由多个卷积层和池化层组成,用于提取图像特征;解码器部分则由多个反卷积层和上采样层组成,用于将特征图恢复到原始图像大小,并生成分割结果。
- 通用性:SAM足够通用,可以涵盖广泛的用例,具有强大的零样本迁移能力。它已经学会了关于物体的一般概念,可以为任何图像或视频中的任何物体生成掩码,甚至包括在训练过程中没有遇到过的物体和图像类型。
- 交互性和灵活性:SAM采用了可提示的方法,可以根据不同的数据进行训练,并且可以适应特定的任务。它支持灵活的提示,需要分摊实时计算掩码以允许交互使用,并且必须具有模糊性意识。
- 强大的泛化性:SAM在边缘检测、物体检测、显著物体识别、工业异常检测等下游任务上表现出很强的泛化性。
总的来说,Segment Anything Model(SAM)是一种强大且灵活的图像分割模型,可以在广泛的应用场景中实现高质量的图像分割。它的出现将大大推动计算机视觉领域的发展,为自动驾驶、医学图像分析、卫星遥感图像分析等领域提供更准确、更高效的解决方案。
1.2 SAM的kerasCV实现
Segment Anything Model(SAM)通过输入如点或框等提示,产生高质量的对象掩码,并且它可以用于生成图像中所有对象的掩码。该模型在包含1100万张图像和110亿个掩码的数据集上进行了训练,并在各种分割任务上表现出强大的零次学习性能。
在本指南中,我们将展示如何使用KerasCV实现Segment Anything Model,并展示TensorFlow和JAX在性能提升方面的强大功能。
首先,让我们为我们的演示获取所有依赖项和图像。
!pip install -Uq keras-cv
!pip install -Uq keras
!wget -q https://raw.githubusercontent.com/facebookresearch/segment-anything/main/notebooks/images/truck.jpg
1.3 使用准备
import os
os.environ["KERAS_BACKEND"] = "jax"
import timeit
import numpy as np
import matplotlib.pyplot as plt
import keras
from keras import ops
import keras_cv
2 使用SAM
2.1定义辅助函数
让我们定义一些辅助函数,用于可视化图像、提示以及分割结果。
def show_mask(mask, ax, random_color=False):
if random_color:
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
else:
color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6])
h, w = mask.shape[-2:]
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
ax.imshow(mask_image)
def show_points(coords, labels, ax, marker_size=375):
pos_points = coords[labels == 1]
neg_points = coords[labels == 0]
ax.scatter(
pos_points[:, 0],
pos_points[:, 1],
color="green",
marker="*",
s=marker_size,
edgecolor="white",
linewidth=1.25,
)
ax.scatter(
neg_points[:, 0],
neg_points[:, 1],
color="red",
marker="*",
s=marker_size,
edgecolor="white",
linewidth=1.25,
)
def show_box(box, ax):
box = box.reshape(-1)
x0, y0 = box[0], box[1]
w, h = box[2] - box[0], box[3] - box[1]
ax.add_patch(
plt.Rectangle((x0, y0), w, h, edgecolor="green", facecolor=(0, 0, 0, 0), lw=2)
)
def inference_resizing(image, pad=True):
# Compute Preprocess Shape
image = ops.cast(image, dtype="float32")
old_h, old_w = image.shape[0], image.shape[1]
scale = 1024 * 1.0