将lang-segment-anything(Language Segment-Anything)模型部署在ubuntu上
一.模型及权重文件下载
lang-segment-anything用到了SAM, GroundingDINO, 而GroundingDINO用到了bert-base-uncased,在线下载很慢容易失败,因此最好将这几个模型的相关文件下载到本地。
(1)lang-segment-anything:https://github.com/luca-medeiros/lang-segment-anything
以下下载的模型及权重可都放到lang-segment-anything文件目录下
(2)下载SAM模型权重(vit-h SAM): https://github.com/facebookresearch/segment-anything?tab=readme-ov-file
(3)下载GroundingDINO模型及其权重:https://github.com/IDEA-Research/GroundingDINO
模型:
及其权重:
(4)下载 bert-base-uncased 模型:https://huggingface.co/google-bert/bert-base-uncased/tree/main
二、环境配置
1.环境配置
conda env create -f environment.yml
conda activate lsa
2.Docker安装(安装如果遇到报错可先尝试下一步,在代码调试过程中根据报错来配置)
docker build --tag lang-segment-anything:latest .
docker run --gpus all -it lang-segment-anything:latest
三、修改相关文件并运行
1.将pyproject.toml 中以下划线的这行代码注释掉
2.修改groundingdino文件夹下的get_tokenlizer.py文件
bert-base-uncased模型改为离线加载方式(前提:已经按照上面的方式下载好了该模型),确保代码总是加载同一个预训练模型 bert-base-uncased,使得在网络不可用的情况下也能加载和使用该模型。
修改后的get_tokenlizer.py文件如下:
from transformers import AutoTokenizer, BertModel, BertTokenizer, RobertaModel, RobertaTokenizerFast
import os
def get_tokenlizer(text_encoder_type):
if not isinstance(text_encoder_type, str):
# print("text_encoder_type is not a str")
if hasattr(text_encoder_type, "text_encoder_type"):
text_encoder_type = text_encoder_type.text_encoder_type
elif text_encoder_type.get("text_encoder_type", False):
text_encoder_type = text_encoder_type.get("text_encoder_type")
elif os.path.isdir(text_encoder_type) and os.path.exists(text_encoder_type):
pass
else:
raise ValueError(
"Unknown type of text_encoder_type: {}".format(type(text_encoder_type))
)
print("final text_encoder_type: {}".format(text_encoder_type))
# tokenizer = AutoTokenizer.from_pretrained(text_encoder_type)
tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
return tokenizer
def get_pretrained_language_model(text_encoder_type):
if text_encoder_type == "bert-base-uncased" or (os.path.isdir(text_encoder_type) and os.path.exists(text_encoder_type)):
# return BertModel.from_pretrained(text_encoder_type)