import chromadb
from chromadb.api.types import Documents,EmbeddingFunction, Embeddings
from transformers import AutoModel, AutoTokenizer
from sentence_transformers import SentenceTransformer
model = SentenceTransformer(model_name_or_path=embedding_model_dir)
print('============>',model.encode('你好'))
class MyEmbeddingFunction(EmbeddingFunction):
def __call__(self, texts: Documents) -> Embeddings:
embeddings = [model.encode(x) for x in texts]
return embeddings
def validate_ids(ids):
if not all(isinstance(id, str) for id in ids):
raise ValueError("All IDs must be strings.")
return ids
valid_ids = validate_ids(collectionDbs['ids'])
valid_ids
# 初始化客户端
client = chromadb.PersistentClient(path="./vector_db")
collection = client.create_collection(
name="papers4",
embedding_function=MyEmbeddingFunction,
metadata={"hnsw:space": "cosine"} # 相似度计算方式
)
# 添加数据
collection.add(
documents=["3.1 实验设计...", "3.2 数据采集..." ], # 文本块
metadatas=[
{
'title': 'USR-DR801 使用手册',
'run_item': '场馆',
'keywords':'["量子计算", "密码学", "Shor算法"]',
'chapter': '1. 产品简介',
'page_range': '',
'source_file': ''
},
{
'title': 'USR-DR801 使用手册',
'run_item': '场馆',
'keywords': '["量子计算", "密码学", "Shor算法"]',
'chapter': '2. 产品操作入门',
'page_range': '',
'source_file': ''
}],
ids=['paper01_chunk018','paper01_chunk020'])
以上代码执行报错:Expected EmbeddingFunction.__call__ to have the following signature: odict_keys(['self', 'input']), got odict_keys(['self', 'args', 'kwargs'])