Keras深度学习框架第二十三讲:在KerasCV中使用SAM进行任何图像分割

1 SAM概念

###1.1 SAM定义

Segment Anything Model(SAM)是一种基于深度学习的图像分割模型,其主要特点包括:

  • 高质量的图像分割:SAM可以从输入提示(如点、框、文字等)生成高质量的对象掩模,实现对图像中对象的精确分割。
  • 零次学习性能:SAM已经在一个包含1100万张图像和110亿个掩模的数据集上进行了训练,具有强大的零次学习性能。这意味着它可以迁移到新的图像分布和任务中,即使在训练阶段没有见过的物体类别也能够进行分割。
  • 网络结构:SAM采用了类似于U-Net的编码器-解码器结构。编码器部分由多个卷积层和池化层组成,用于提取图像特征;解码器部分则由多个反卷积层和上采样层组成,用于将特征图恢复到原始图像大小,并生成分割结果。
  • 通用性:SAM足够通用,可以涵盖广泛的用例,具有强大的零样本迁移能力。它已经学会了关于物体的一般概念,可以为任何图像或视频中的任何物体生成掩码,甚至包括在训练过程中没有遇到过的物体和图像类型。
  • 交互性和灵活性:SAM采用了可提示的方法,可以根据不同的数据进行训练,并且可以适应特定的任务。它支持灵活的提示,需要分摊实时计算掩码以允许交互使用,并且必须具有模糊性意识。
  • 强大的泛化性:SAM在边缘检测、物体检测、显著物体识别、工业异常检测等下游任务上表现出很强的泛化性。

总的来说,Segment Anything Model(SAM)是一种强大且灵活的图像分割模型,可以在广泛的应用场景中实现高质量的图像分割。它的出现将大大推动计算机视觉领域的发展,为自动驾驶、医学图像分析、卫星遥感图像分析等领域提供更准确、更高效的解决方案。

1.2 SAM的kerasCV实现

Segment Anything Model(SAM)通过输入如点或框等提示,产生高质量的对象掩码,并且它可以用于生成图像中所有对象的掩码。该模型在包含1100万张图像和110亿个掩码的数据集上进行了训练,并在各种分割任务上表现出强大的零次学习性能。

在本指南中,我们将展示如何使用KerasCV实现Segment Anything Model,并展示TensorFlow和JAX在性能提升方面的强大功能。

首先,让我们为我们的演示获取所有依赖项和图像。

!pip install -Uq keras-cv
!pip install -Uq keras
!wget -q https://raw.githubusercontent.com/facebookresearch/segment-anything/main/notebooks/images/truck.jpg

1.3 使用准备

import os

os.environ["KERAS_BACKEND"] = "jax"

import timeit
import numpy as np
import matplotlib.pyplot as plt
import keras
from keras import ops
import keras_cv

2 使用SAM

2.1定义辅助函数

让我们定义一些辅助函数,用于可视化图像、提示以及分割结果。

def show_mask(mask, ax, random_color=False):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    ax.imshow(mask_image)


def show_points(coords, labels, ax, marker_size=375):
    pos_points = coords[labels == 1]
    neg_points = coords[labels == 0]
    ax.scatter(
        pos_points[:, 0],
        pos_points[:, 1],
        color="green",
        marker="*",
        s=marker_size,
        edgecolor="white",
        linewidth=1.25,
    )
    ax.scatter(
        neg_points[:, 0],
        neg_points[:, 1],
        color="red",
        marker="*",
        s=marker_size,
        edgecolor="white",
        linewidth=1.25,
    )


def show_box(box, ax):
    box = box.reshape(-1)
    x0, y0 = box[0], box[1]
    w, h = box[2] - box[0], box[3] - box[1]
    ax.add_patch(
        plt.Rectangle((x0, y0), w, h, edgecolor="green", facecolor=(0, 0, 0, 0), lw=2)
    )


def inference_resizing(image, pad=True):
    # Compute Preprocess Shape
    image = ops.cast(image, dtype="float32")
    old_h, old_w = image.shape[0], image.shape[1]
    scale = 1024 * 1.0 
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

MUKAMO

你的鼓励是我们创作最大的动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值