dify 知识库构建和文本嵌入源码解读

缘起

我在本地构建知识库的时候发现嵌入一个1mb的文件居然要3个小时,嵌入期间cpu,gpu占用率都很少,本人并且在更换了ollama参数更换了模型之后都没有任何改善,在手动调用ollama的api接口发现,接口性能是有很大提升的,故现在怀疑的是dify本身运行较慢,但是网络上关于dify实现的文章很少,所以只能自己读代码了.

这个文章描述了一些增删改查的过程
https://blog.youkuaiyun.com/Python_cocola/article/details/140558589
本文会在这篇文章的之后描述dify具体怎么调度嵌入任务,和怎么执行嵌入过程的.

一 dify的任务调度方式

dify有一个api的服务,用来提供页面的后端服务,另外还有一个worker进程,使用celery框架实现.
文本的嵌入都是在worker进程中做的.worker获取任务的信息是通过redis来实现.

文本嵌入的任务是在 IndexingRunner这个类里面

在这里插入图片描述

二 runner的处理过程

runner是接收到嵌入任务的处理循环.runner的处理过程整体分为四个阶段
extract:把文件的内容提取到内存中
transform:
load_segments:
load:

三 Extract处理过程

在这里插入图片描述
dify支持三种不同的数据导入形式,分别是文本文件,Notion,和web站点.
在这里插入图片描述
针对本地文件的情况,extract函数直接抛给index_processor进行处理,

        if dataset_document.data_source_type == "upload_file":
            if not data_source_info or "upload_file_id" not in data_source_info:
                raise ValueError("no upload file found")

            file_detail = (
                db.session.query(UploadFile).filter(UploadFile.id == data_source_info["upload_file_id"]).one_or_none()
            )

            if file_detail:
                extract_setting = ExtractSetting(
                    datasource_type="upload_file", upload_file=file_detail, document_model=dataset_document.doc_form
                )
                text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule["mode"])

index_processor是文件的配置类型,界面对应如下
在这里插入图片描述

可以看到根据不同的文件类型,选择不同的处理器
在这里插入图片描述
pdf的处理器可以看到就是加载文件内容到内存,这里面使用了yield,可以减少内存压力.
解析过程使用了pypdfium2框架

class PdfExtractor(BaseExtractor):
    """Load pdf files.


    Args:
        file_path: Path to the file to load.
    """

    def __init__(self, file_path: str, file_cache_key: Optional[str] = None):
        """Initialize with file path."""
        self._file_path = file_path
        self._file_cache_key = file_cache_key

    def extract(self) -> list[Document]:
        plaintext_file_exists = False
        if self._file_cache_key:
            try:
                text = cast(bytes, storage.load(self._file_cache_key)).decode("utf-8")
                plaintext_file_exists = True
                return [Document(page_content=text)]
            except FileNotFoundError:
                pass
        documents = list(self.load())
        text_list = []
        for document in documents:
            text_list.append(document.page_content)
        text = "\n\n".join(text_list)

        # save plaintext file for caching
        if not plaintext_file_exists and self._file_cache_key:
            storage.save(self._file_cache_key, text.encode("utf-8"))

        return documents

    def load(
        self,
    ) -> Iterator[Document]:
        """Lazy load given path as pages."""
        blob = Blob.from_path(self._file_path)
        yield from self.parse(blob)

    def parse(self, blob: Blob) -> Iterator[Document]:
        """Lazily parse the blob."""
        import pypdfium2  # type: ignore

        with blob.as_bytes_io() as file_path:
            pdf_reader = pypdfium2.PdfDocument(file_path, autoclose=True)
            try:
                for page_number, page in enumerate(pdf_reader):
                    text_page = page.get_textpage()
                    content = text_page.get_text_range()
                    text_page.close()
                    page.close()
                    metadata = {"source": blob.source, "page": page_number}
                    yield Document(page_content=content, metadata=metadata)
            finally:
                pdf_reader.close()

四Transform

调用 _transform 方法对提取的文档进行转换处理,并添加了性能监控:
indexing_running把具体实现交给index_processor中
可以看到首先创建一个分割器,参数包含最大token,每个分隔的重贴,分隔符等等,这些是在界面中配置的参数.
然后把extract的文件进行分隔.

 def transform(self, documents: list[Document], **kwargs) -> list[Document]:
        splitter = self._get_splitter(
            processing_rule_mode=process_rule.get("mode"),
            max_tokens=rules.segmentation.max_tokens,
            chunk_overlap=rules.segmentation.chunk_overlap,
            separator=rules.segmentation.separator,
            embedding_model_instance=kwargs.get("embedding_model_instance"),
        )
        all_documents = []
        for document in documents:
            # document clean
            document_text = CleanProcessor.clean(document.page_content, kwargs.get("process_rule", {}))
            document.page_content = document_text
            # parse document to nodes
            document_nodes = splitter.split_documents([document])
            split_documents = []
            for document_node in document_nodes:
                if document_node.page_content.strip():
                    doc_id = str(uuid.uuid4())
                    hash = helper.generate_text_hash(document_node.page_content)
                    if document_node.metadata is not None:
                        document_node.metadata["doc_id"] = doc_id
                        document_node.metadata["doc_hash"] = hash
                    # delete Splitter character
                    page_content = remove_leading_symbols(document_node.page_content).strip()
                    if len(page_content) > 0:
                        document_node.page_content = page_content
                        split_documents.append(document_node)
            all_documents.extend(split_documents)
        
        end_time = time.time()
        print(f"[ParagraphIndexProcessor] transform method executed in {end_time - start_time:.2f} seconds")
        return all_documents

在这里插入图片描述

五 Load Segments

代码如下:
在这里插入图片描述
这段代码使用doc_store 去了add_documents
剩下两个是更新数据库记录的状态到indexing

重点看下doc_store 的实现

doc_store主要是创建DocumentSegment数据库记录,然后存储到数据库中

六Load 过程

这是整个嵌入过程的重头戏

6.1 获取嵌入模型

获取llm的配置

        embedding_model_instance = None
        if dataset.indexing_technique == "high_quality":
            embedding_model_instance = self.model_manager.get_model_instance(
                tenant_id=dataset.tenant_id,
                provider=dataset.embedding_model_provider,
                model_type=ModelType.TEXT_EMBEDDING,
                model=dataset.embedding_model,
            )

6.2 创建关键词

开启一个线程,使用_process_keyword_index处理文档关键词

        if dataset_document.doc_form != IndexType.PARENT_CHILD_INDEX:
            # create keyword index
            create_keyword_thread = threading.Thread(
                target=self._process_keyword_index,
                args=(current_app._get_current_object(), dataset.id, dataset_document.id, documents),  # type: ignore
            )
            create_keyword_thread.start()

6.2.1 关键词处理:

看代码最主要是使用Keyword进行数据处理

    @staticmethod
    def _process_keyword_index(flask_app, dataset_id, document_id, documents):
        with flask_app.app_context():
            dataset = db.session.query(Dataset).filter_by(id=dataset_id).first()
            if not dataset:
                raise ValueError("no dataset found")
            keyword = Keyword(dataset)
            keyword.create(documents)
            if dataset.indexing_technique != "high_quality":
                document_ids = [document.metadata["doc_id"] for document in documents]
                db.session.query(DocumentSegment).filter(
                    DocumentSegment.document_id == document_id,
                    DocumentSegment.dataset_id == dataset_id,
                    DocumentSegment.index_node_id.in_(document_ids),
                    DocumentSegment.status == "indexing",
                ).update(
                    {
                        DocumentSegment.status: "completed",
                        DocumentSegment.enabled: True,
                        DocumentSegment.completed_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None),
                    }
                )
                db.session.commit()

6.2.2 keyword处理

在这里插入图片描述
通过代码可以看到是借给jieba处理
在这里插入图片描述
extract_tags = tfidf = default_tfidf.extract_tags

看来使用了tfidf对文档进行词频分析

TF-IDF是一种用于衡量词语在文档集中重要性的统计方法,它结合了**词频(Term Frequency,TF)和逆文档频率(Inverse Document Frequency,IDF)**两个概念。 一个词的TF-IDF值越高,说明它在文档中出现的次数越多,同时在整个语料库中出现的次数越少,因此具有更强的区分文档的能力。
TF-IDF的组成部分
词频(Term Frequency, TF): 指某个词语在一篇特定文档中出现的次数。
逆文档频率(Inverse Document Frequency, IDF): 用来衡量一个词语的普遍性。 如果一个词语在越少的文档中出现,那么它的IDF值就越大,这说明它区分文档的能力越强。
IDF值计算的是词语在整个语料库中出现的频率的倒数。

这一步应该是对应了低质量的索引方式,只使用关键词进行文本的索引
在这里插入图片描述

6.3 创建嵌入向量

 if dataset.indexing_technique == "high_quality":
            with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
                futures = []

                # Distribute documents into multiple groups based on the hash values of page_content
                # This is done to prevent multiple threads from processing the same document,
                # Thereby avoiding potential database insertion deadlocks
                document_groups: list[list[Document]] = [[] for _ in range(max_workers)]
                for document in documents:
                    hash = helper.generate_text_hash(document.page_content)
                    group_index = int(hash, 16) % max_workers
                    document_groups[group_index].append(document)
                for chunk_documents in document_groups:
                    if len(chunk_documents) == 0:
                        continue
                    futures.append(
                        executor.submit(
                            self._process_chunk,
                            current_app._get_current_object(),  # type: ignore
                            index_processor,
                            chunk_documents,
                            dataset,
                            dataset_document,
                            embedding_model_instance,
                        )
                    )

                for future in futures:
                    tokens += future.result()

一大段代码其实就是开启一个线程,然后进行_process_chunk的处理

6.3.1_process_chunk处理

indexing_runner还是交给具体的index_processor处理:

先处理vector,再处理keywords

    def load(self, dataset: Dataset, documents: list[Document], with_keywords: bool = True, **kwargs):
        if dataset.indexing_technique == "high_quality":
            vector = Vector(dataset)
            vector.create(documents)
        if with_keywords:
            keywords_list = kwargs.get("keywords_list")
            keyword = Keyword(dataset)
            if keywords_list and len(keywords_list) > 0:
                keyword.add_texts(documents, keywords_list=keywords_list)
            else:
                keyword.add_texts(documents)

6.3.2 Vector初始化

vector的初始化函数,这里面主要初始化embedding的大语言模型调用方式,和 具体的存储向量数据库

    def __init__(self, dataset: Dataset, attributes: Optional[list] = None):
        if attributes is None:
            attributes = ["doc_id", "dataset_id", "document_id", "doc_hash"]
        self._dataset = dataset
        self._embeddings = self._get_embeddings()
        self._attributes = attributes
        self._vector_processor = self._init_vector()
    def _get_embeddings(self) -> Embeddings:
        model_manager = ModelManager()

        embedding_model = model_manager.get_model_instance(
            tenant_id=self._dataset.tenant_id,
            provider=self._dataset.embedding_model_provider,
            model_type=ModelType.TEXT_EMBEDDING,
            model=self._dataset.embedding_model,
        )
        return CacheEmbedding(embedding_model)
    def _init_vector(self) -> BaseVector:
        vector_type = dify_config.VECTOR_STORE

        if self._dataset.index_struct_dict:
            vector_type = self._dataset.index_struct_dict["type"]
        else:
            if dify_config.VECTOR_STORE_WHITELIST_ENABLE:
                whitelist = (
                    db.session.query(Whitelist)
                    .filter(Whitelist.tenant_id == self._dataset.tenant_id, Whitelist.category == "vector_db")
                    .one_or_none()
                )
                if whitelist:
                    vector_type = VectorType.TIDB_ON_QDRANT

        if not vector_type:
            raise ValueError("Vector store must be specified.")

        vector_factory_cls = self.get_vector_factory(vector_type)
        return vector_factory_cls().init_vector(self._dataset, self._attributes, self._embeddings)
   @staticmethod
    def get_vector_factory(vector_type: str) -> type[AbstractVectorFactory]:
        match vector_type:
            case VectorType.CHROMA:
                from core.rag.datasource.vdb.chroma.chroma_vector import ChromaVectorFactory

                return ChromaVectorFactory
            case VectorType.MILVUS:
                from core.rag.datasource.vdb.milvus.milvus_vector import MilvusVectorFactory

                return MilvusVectorFactory
            case VectorType.MYSCALE:
                from core.rag.datasource.vdb.myscale.myscale_vector import MyScaleVectorFactory

                return MyScaleVectorFactory

6.3.3 Vector的处理

    def create(self, texts: Optional[list] = None, **kwargs):
        if texts:
            embeddings = self._embeddings.embed_documents([document.page_content for document in texts])
            self._vector_processor.create(texts=texts, embeddings=embeddings, **kwargs)

第一步使用大语言模型创建embedding,这里最终执行的类在CacheEmbedding中,主要获取text的嵌入值
第二部使用向量数据库处理,dify默认配置的是Weaviate数据库,所以使用的processor是WeaviateVector

6.3.4 WeaviateVector的处理

WeaviateVector

    def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
        # create collection
        self._create_collection()
        # create vector
        self.add_texts(texts, embeddings)

第一步创建一个collection,可以理解为mysql中的表,然后增加一条记录

    def _create_collection(self):
        lock_name = "vector_indexing_lock_{}".format(self._collection_name)
        with redis_client.lock(lock_name, timeout=20):
            collection_exist_cache_key = "vector_indexing_{}".format(self._collection_name)
            if redis_client.get(collection_exist_cache_key):
                return
            schema = self._default_schema(self._collection_name)
            if not self._client.schema.contains(schema):
                # create collection
                self._client.schema.create_class(schema)
            redis_client.set(collection_exist_cache_key, 1, ex=3600)

    def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
        uuids = self._get_uuids(documents)
        texts = [d.page_content for d in documents]
        metadatas = [d.metadata for d in documents]

        ids = []

        with self._client.batch as batch:
            for i, text in enumerate(texts):
                data_properties = {Field.TEXT_KEY.value: text}
                if metadatas is not None:
                    # metadata maybe None
                    for key, val in (metadatas[i] or {}).items():
                        data_properties[key] = self._json_serializable(val)
                start_time = time.time()
                batch.add_data_object(
                    data_object=data_properties,
                    class_name=self._collection_name,
                    uuid=uuids[i],
                    vector=embeddings[i] if embeddings else None,
                )
                end_time = time.time()
                execution_time = (end_time - start_time) * 1000  # 转换为毫秒
                #print(f"[Weaviate] Index###: {uuids[i]}: Time={execution_time:.2f}ms")
                # 打印到控制台
                ids.append(uuids[i])
        return ids

最后调用 _load 方法执行实际的索引创建:

6.3.5 关键字处理

这部分同样最终落入jieba.py处理, jieba使用tfidf方法处理关键字

def add_texts(self, texts: list[Document], **kwargs):
        lock_name = "keyword_indexing_lock_{}".format(self.dataset.id)
        with redis_client.lock(lock_name, timeout=600):
            keyword_table_handler = JiebaKeywordTableHandler()

            keyword_table = self._get_dataset_keyword_table()
            keywords_list = kwargs.get("keywords_list")
            for i in range(len(texts)):
                text = texts[i]
                if keywords_list:
                    keywords = keywords_list[i]
                    if not keywords:
                        keywords = keyword_table_handler.extract_keywords(
                            text.page_content, self._config.max_keywords_per_chunk
                        )
                else:
                    keywords = keyword_table_handler.extract_keywords(
                        text.page_content, self._config.max_keywords_per_chunk
                    )
                if text.metadata is not None:
                    self._update_segment_keywords(self.dataset.id, text.metadata["doc_id"], list(keywords))
                    keyword_table = self._add_text_to_keyword_table(
                        keyword_table or {}, text.metadata["doc_id"], list(keywords)
                    )

            self._save_dataset_keyword_table(keyword_table)

index举止执行在BaseIndexProcessor,不过这是一个基类,实现类有ParagraphIndexProcessor,QAIndexProcessor,ParentChildIndexProcessor
选取一个举例说明

def create(self, texts: Optional[list] = None, **kwargs):
    if not texts:
        return
    
    try:
        # 向量计算
        embeddings = self._embeddings.embed_documents([document.page_content for document in texts])
        # 向量存储
        self._vector_processor.create(texts=texts, embeddings=embeddings, **kwargs)
    except Exception as e:
        logging.error(f"Failed to create vector index: {str(e)}")
        # 可以考虑添加重试机制或更具体的异常处理
        raise
评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值