Dify 知识库操作源码解析

Dify 知识库操作源码解析

知识库操作

数据表
接口
接口
名字
功能描述
请求示例
源码位置
请求参数
POST
/datasets

创建空知识库

创建空知识库

curl --location --request POST 'http://127.0.0.1:5001/v1/datasets' \--header 'Authorization: Bearer {api_key}' \--header 'Content-Type: application/json' \--data-raw '{"name": "name"}'
api/controllers/console/datasets/datasets.py
详情

GET
/datasets
知识库列表

知识库列表
curl --location --request GET 'http://127.0.0.1:5001/v1/datasets?page=1&limit=20' \--header 'Authorization: Bearer {api_key}'
api/controllers/console/datasets/datasets.py

DELETE
/datasets/{dataset_id}
删除知识库

删除知识库

curl --location --request DELETE 'http://127.0.0.1:5001/console/api/datasets/{dataset_id}--header 'Authorization: Bearer {api_key}'
位置

POST /datasets/init
上传文件创建知识库
上传文件创建知识库

位置

创建
方式一:先创建空知识库,再上传文件

创建空知识库

代码流程
请添加图片描述

[!TIP]
主要流程就是收集参数 insert 数据表(dataset)中

DatasetListApi-post:处理参数,保存至数据库

请添加图片描述

create_empty_dataset:保存将创建的知识库方法
请添加图片描述

请求参数

# 知识库的详细信息
knowledge_base = {
    # 知识库的唯一ID
    "id": "cbd8a746-a9ab-4d79-8337-99d4ac989691",
    
    # 知识库的名称
    "name": "测试知识库",
    
    # 描述,这里没有提供
    "description": None,
    
    # 知识库的提供商
    "provider": "vendor",
    
    # 权限设置,只对我可见
    "permission": "only_me",
    
    # 数据源类型,未指定
    "data_source_type": None,
    
    # 索引技术,未指定
    "indexing_technique": None,
    
    # 关联的应用数量
    "app_count": 0,
    
    # 文档数量
    "document_count": 0,
    
    # 文档总词数
    "word_count": 0,
    
    # 创建者的唯一ID
    "created_by": "c17d706d-6418-4ca0-9ba5-34b43bb7e32c",
    
    # 创建时间的时间戳
    "created_at": 1719337063,
    
    # 最后更新者的唯一ID
    "updated_by": "c17d706d-6418-4ca0-9ba5-34b43bb7e32c",
    
    # 最后更新时间的时间戳
    "updated_at": 1719337063,
    
    # 嵌入模型信息,未指定
    "embedding_model": None,
    
    # 嵌入模型提供商,未指定
    "embedding_model_provider": None,
    
    # 是否有可用的嵌入模型,未指定
    "embedding_available": None,
    
    # 检索模型配置
    "retrieval_model_dict": {
        # 搜索方法,语义搜索
        "search_method": "semantic_search",
        
        # 是否启用重排序
        "reranking_enable": False,
        
        # 重排序模型信息
        "reranking_model": {
            # 重排序提供商名称
            "reranking_provider_name": "",
            
            # 重排序模型名称
            "reranking_model_name": ""
        },
        
        # 返回的顶部结果数量
        "top_k": 2,
        
        # 是否启用分数阈值
        "score_threshold_enabled": False,
        
        # 分数阈值,未指定
        "score_threshold": None
    },
    
    # 标签列表,目前为空
    "tags": []
}
方法二:上传文件直接创建默认知识库

逻辑流程图

请添加图片描述

上传文件接口

位置:api/controllers/console/datasets/file.py

请添加图片描述

upload_file 代码逻辑:校验文件类型最终保存在/api/storage/upload_files 文件夹中

def upload_file(file: FileStorage, user: Union[Account, EndUser], only_image: bool = False) -> UploadFile:
    # 获取文件名和扩展名
    filename = file.filename
    extension = file.filename.split('.')[-1]
    # 如果文件名过长,截断并保留扩展名
    if len(filename) > 200:
        filename = filename.split('.')[0][:200] + '.' + extension
    # 根据配置获取允许的文件类型
    etl_type = dify_config.ETL_TYPE
    allowed_extensions = UNSTRUCTURED_ALLOWED_EXTENSIONS + IMAGE_EXTENSIONS if etl_type == 'Unstructured' \
        else ALLOWED_EXTENSIONS + IMAGE_EXTENSIONS
    # 检查文件类型是否在允许的范围内
    if extension.lower() not in allowed_extensions:
        raise UnsupportedFileTypeError()
    elif only_image and extension.lower() not in IMAGE_EXTENSIONS:
        raise UnsupportedFileTypeError()

    #  # 读取文件内容
    file_content = file.read()

    # 获取文件大小
    file_size = len(file_content)
    # 设置文件大小限制,图片文件和非图片文件有不同的限制
    if extension.lower() in IMAGE_EXTENSIONS:
        file_size_limit = dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT * 1024 * 1024
    else:
        file_size_limit = dify_config.UPLOAD_FILE_SIZE_LIMIT * 1024 * 1024
    # 检查文件大小是否超过限制
    if file_size > file_size_limit:
        message = f'File size exceeded. {file_size} > {file_size_limit}'
        raise FileTooLargeError(message)

    # 生成文件的唯一UUID
    file_uuid = str(uuid.uuid4())
    # 确定当前租户ID
    if isinstance(user, Account):
        current_tenant_id = user.current_tenant_id
    else:
        # end_user
        current_tenant_id = user.tenant_id
    # 构建文件在存储系统中的路径
    file_key = 'upload_files/' + current_tenant_id + '/' + file_uuid + '.' + extension

    #  # 保存文件到存储系统
    storage.save(file_key, file_content)

    #  # 创建UploadFile模型实例,并填充必要的字段
    upload_file = UploadFile(
        tenant_id=current_tenant_id,
        storage_type=dify_config.STORAGE_TYPE,
        key=file_key,
        name=filename,
        size=file_size,
        extension=extension,
        mime_type=file.mimetype,
        created_by_role=('account' if isinstance(user, Account) else 'end_user'),
        created_by=user.id,
        created_at=datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None),
        used=False,
        hash=hashlib.sha3_256(file_content).hexdigest()
    )

    db.session.add(upload_file)
    db.session.commit()

    return upload_file
知识库创建

DatasetInitApi-post:主要逻辑校验整理参数交给 DocumentService.save_document_without_dataset_id

代码地址:api/controllers/console/datasets/datasets_document.py

请添加图片描述

这些参数对应了前端创建知识库时的参数配置信息:

indexing_technique:索引方式,提供了两种方式:[ ‘high_quality’, ‘economy’ ],其中高质量会使用 Embedding 模型做向量索引

data_source:数据源,前端上传文档之后的文件 id,用于原始文档获取

process_rule:数据的前处理规则配置

doc_form:文档格式

doc_language:文档语言

retrieval_model:检索模型,用于设置检索时搜索方式、各种参数

请添加图片描述

DatasetInitApi 类是一个资源类,它继承自 Resource 类。在这个类中,定义了一个 post 方法,这个方法对应 HTTP 的 POST 请求。

post 方法的主要功能是初始化一个数据集。首先检查用户是否已经设置、登录并完成了初始化。然后,它会检查用户是否有足够的权限来创建一个新的向量空间。

post 方法中,首先通过 reqparse.RequestParser() 解析请求中的参数,包括索引技术(indexing_technique)、数据源(data_source)、处理规则(process_rule)、文档形式(doc_form)、文档语言(doc_language)和检索模型(retrieval_model)。

如果索引技术是’high_quality’,则会尝试获取默认的嵌入模型实例。如果获取失败,会抛出相应的错误。然后,它会验证请求参数是否有效。如果参数有效,它会调用 DocumentService.save_document_without_dataset_id 方法来创建一个新的数据集并在其中保存文档。

最后,它会返回一个包含新创建的数据集、文档和批次信息的响应。

  • save_document_without_dataset_id

1.验证租户信息 然后各种校验:是否开启账单、是否有上传文章限制、是否超过批量上传限制。接着判断索引技术模型 接着构造 Dataset 入库,这里的 Dataset 就是知识库的信息总览,包含了该知识库的 embedding 模型信息、检索模型等信息

@staticmethod
def save_document_without_dataset_id(tenant_id: str, document_data: dict, account: Account):
    # 获取租户的特性配置
    features = FeatureService.get_features(current_user.current_tenant_id)

    # 如果账单功能已启用,则检查文档上传限制
    if features.billing.enabled:
        count = 0
        if document_data["data_source"]["type"] == "upload_file":
            upload_file_list = document_data["data_source"]["info_list"]['file_info_list']['file_ids']
            count = len(upload_file_list)
        elif document_data["data_source"]["type"] == "notion_import":
            notion_info_list = document_data["data_source"]['info_list']['notion_info_list']
            for notion_info in notion_info_list:
                count = count + len(notion_info['pages'])
        elif document_data["data_source"]["type"] == "website_crawl":
            website_info = document_data["data_source"]['info_list']['website_info_list']
            count = len(website_info['urls'])
        # 检查是否超过批量上传限制
        batch_upload_limit = int(dify_config.BATCH_UPLOAD_LIMIT)
        if count > batch_upload_limit:
            raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.")

        # 检查文档上传配额
        DocumentService.check_documents_upload_quota(count, features)
    # 初始化嵌入模型和数据集绑定信息
    embedding_model = None
    dataset_collection_binding_id = None
    retrieval_model = None
    # 如果索引技术要求高质量,则获取默认的嵌入模型实例
    if document_data['indexing_technique'] == 'high_quality':
        model_manager = ModelManager()
        embedding_model = model_manager.get_default_model_instance(
            tenant_id=current_user.current_tenant_id,
            model_type=ModelType.TEXT_EMBEDDING
        )
        # 获取与嵌入模型关联的数据集集合绑定
        dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
            embedding_model.provider,
            embedding_model.model
        )
        dataset_collection_binding_id = dataset_collection_binding.id
        # 设置检索模型,优先使用文档数据中的模型,否则使用默认模型
        if document_data.get('retrieval_model'):
            retrieval_model = document_data['retrieval_model']
        else:
            default_retrieval_model = {
                'search_method': RetrievalMethod.SEMANTIC_SEARCH.value,
                'reranking_enable': False,
                'reranking_model': {
                    'reranking_provider_name': '',
                    'reranking_model_name': ''
                },
                'top_k': 2,
                'score_threshold_enabled': False
            }
            retrieval_model = default_retrieval_model
    #  # 创建数据集实例
    dataset = Dataset(
        tenant_id=tenant_id,
        name='',
        data_source_type=document_data["data_source"]["type"],
        indexing_technique=document_data["indexing_technique"],
        created_by=account.id,
        embedding_model=embedding_model.model if embedding_model else None,
        embedding_model_provider=embedding_model.provider if embedding_model else None,
        collection_binding_id=dataset_collection_binding_id,
        retrieval_model=retrieval_model
    )
    # 将数据集添加到数据库会话并刷新以获取ID
    db.session.add(dataset)
    db.session.flush()
    #.....

2.调用 save_document_with_dataset_id 构造 document 以及启动异步处理任务

@staticmethod
def save_document_without_dataset_id(tenant_id: str, document_data: dict, account: Account):
    #.......
    # 保存文档并获取文档列表、批次信息
    documents, batch = DocumentService.save_document_with_dataset_id(dataset, document_data, account)
    # 截断文档名称以适应显示
    cut_length = 18
    cut_name = documents[0].name[:cut_length]
    dataset.name = cut_name + '...'

    # 更新数据集描述
    dataset.description = 'useful for when you want to answer queries about the ' + documents[0].name
    db.session.commit()

    # 返回数据集、文档列表和批次信息
    return dataset, documents, batch

通过调试得到 document_data 一个示例数据如下所示:

{
        'indexing_technique': 'high_quality',
        'data_source': {
                'type': 'upload_file',
                'info_list': {
                        'data_source_type': 'upload_file',
                        'file_info_list': {
                                'file_ids': ['6f393937-d0ec-41b3-a6cb-56f38081eb94']
                        }
                }
        },
        'process_rule': {
                'rules': {},
                'mode': 'automatic'
        },
        'duplicate': True,
        'original_document_id': None,
        'doc_form': 'text_model',
        'doc_language': 'Chinese',
        'retrieval_model': {
                'search_method': 'semantic_search',
                'reranking_enable': False,
                'reranking_model': {
                        'reranking_provider_name': '',
                        'reranking_model_name': ''
                },
                'top_k': 2,
                'score_threshold_enabled': False,
                'score_threshold': None
        }
}

最后将 Dataset,document_data,account 传入 save_document_with_dataset_id 中创建数据集文件

save_document_with_dataset_id

_总结来说就是 保存文档到指定数据集 Document ,并处理各种数据源类型,包括上传文件、Notion 导入和网站抓取 _

1.构造 Document 入库,这里的 Document 是知识库中的单个文档(我们创建知识库的时候可以上传多个文档),也就是一个 Dataset 实际上包含了多个 Document,Document 的信息就包含了文档信息以及数据处理相关的参数

2.启动异步索引任务,针对知识库中的每个文档创建索引

@staticmethod
def save_document_with_dataset_id(
    dataset: Dataset, document_data: dict,
    account: Account, dataset_process_rule: Optional[DatasetProcessRule] = None,
    created_from: str = 'web'
):
    _"""_
_       保存文档到指定数据集,并处理各种数据源类型,包括上传文件、Notion导入和网站抓取。_

_       :param dataset: 数据集对象。_
_       :param document_data: 包含文档数据和元数据的字典。_
_       :param account: 用户账户对象。_
_       :param dataset_process_rule: 可选的数据集处理规则对象。_
_       :param created_from: 创建文档的来源,例如 'web'。_
_       :return: 一个包含已保存文档的列表和批次标识符的元组。_
_       """_
_    _# 验证并处理文档数量限制
    features = FeatureService.get_features(current_user.current_tenant_id)

    if features.billing.enabled:
        if 'original_document_id' not in document_data or not document_data['original_document_id']:
            count = 0
            if document_data["data_source"]["type"] == "upload_file":
                upload_file_list = document_data["data_source"]["info_list"]['file_info_list']['file_ids']
                count = len(upload_file_list)
            elif document_data["data_source"]["type"] == "notion_import":
                notion_info_list = document_data["data_source"]['info_list']['notion_info_list']
                for notion_info in notion_info_list:
                    count = count + len(notion_info['pages'])
            elif document_data["data_source"]["type"] == "website_crawl":
                website_info = document_data["data_source"]['info_list']['website_info_list']
                count = len(website_info['urls'])
            batch_upload_limit = int(dify_config.BATCH_UPLOAD_LIMIT)
            if count > batch_upload_limit:
                raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.")

            DocumentService.check_documents_upload_quota(count, features)

    #  # 更新数据集的数据源类型
    if not dataset.data_source_type:
        dataset.data_source_type = document_data["data_source"]["type"]
    # 设置索引技术
    if not dataset.indexing_technique:
        if 'indexing_technique' not in document_data \
            or document_data['indexing_technique'] not in Dataset.INDEXING_TECHNIQUE_LIST:
            raise ValueError("Indexing technique is required")

        dataset.indexing_technique = document_data["indexing_technique"]
        if document_data["indexing_technique"] == 'high_quality':
            # 获取模型管理器实例
            model_manager = ModelManager()
            # 获取默认的文本嵌入模型实例
            embedding_model = model_manager.get_default_model_instance(
                tenant_id=current_user.current_tenant_id,
                model_type=ModelType.TEXT_EMBEDDING
            )
            # 设置数据集的嵌入模型和提供商信息
            dataset.embedding_model = embedding_model.model
            dataset.embedding_model_provider = embedding_model.provider
            # 获取与嵌入模型绑定的数据集集合绑定
            dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
                embedding_model.provider,
                embedding_model.model
            )
            # 设置数据集的集合绑定ID
            dataset.collection_binding_id = dataset_collection_binding.id
            # 如果数据集的检索模型未设置,则设置默认的检索模型配置
            if not dataset.retrieval_model:
                default_retrieval_model = {
                    'search_method': RetrievalMethod.SEMANTIC_SEARCH.value,
                    'reranking_enable': False,
                    'reranking_model': {
                        'reranking_provider_name': '',
                        'reranking_model_name': ''
                    },
                    'top_k': 2,
                    'score_threshold_enabled': False
                }
                # 使用文档数据中的检索模型配置覆盖默认配置,如果存在的话
                dataset.retrieval_model = document_data.get('retrieval_model') if document_data.get(
                    'retrieval_model'
                ) else default_retrieval_model
    # 初始化文档列表和批次标识符
    documents = []
    batch = time.strftime('%Y%m%d%H%M%S') + str(random.randint(100000, 999999))

    # 根据document_data是否存在"original_document_id"字段判断是否更新已有文档
    if document_data.get("original_document_id"):
        document = DocumentService.update_document_with_dataset_id(dataset, document_data, account)
        documents.append(document)
    else:
        #如果没有,则需要创建或者更新数据集规则
        if not dataset_process_rule:
            process_rule = document_data["process_rule"]

            # 根据process_rule的模式创建DatasetProcessRule实例
            if process_rule["mode"] == "custom":
                dataset_process_rule = DatasetProcessRule(
                    dataset_id=dataset.id,
                    mode=process_rule["mode"],
                    rules=json.dumps(process_rule["rules"]),
                    created_by=account.id
                )
            elif process_rule["mode"] == "automatic":
                dataset_process_rule = DatasetProcessRule(
                    dataset_id=dataset.id,
                    mode=process_rule["mode"],
                    rules=json.dumps(DatasetProcessRule.AUTOMATIC_RULES),
                    created_by=account.id
                )
            # 将新创建的DatasetProcessRule添加到数据库会话中
            db.session.add(dataset_process_rule)
            db.session.commit()
        # 获取数据集中文档的位置信息
        position = DocumentService.get_documents_position(dataset.id)
        document_ids = []
        duplicate_document_ids = []
        if document_data["data_source"]["type"] == "upload_file":
            # 获取文件ID列表
            upload_file_list = document_data["data_source"]["info_list"]['file_info_list']['file_ids']
            for file_id in upload_file_list:
                # 查询文件信息
                file = db.session.query(UploadFile).filter(
                    UploadFile.tenant_id == dataset.tenant_id,
                    UploadFile.id == file_id
                ).first()

                # raise error if file not found
                if not file:
                    raise FileNotExistsError()
                # 文件名和数据源信息
                file_name = file.name
                data_source_info = {
                    "upload_file_id": file_id,
                }
                # # 检查是否允许导入重复文档
                if document_data.get('duplicate', False):
                    document = Document.query.filter_by(
                        dataset_id=dataset.id,
                        tenant_id=current_user.current_tenant_id,
                        data_source_type='upload_file',
                        enabled=True,
                        name=file_name
                    ).first()
                    if document:
                        # 更新现有文档
                        document.dataset_process_rule_id = dataset_process_rule.id
                        document.updated_at = datetime.datetime.utcnow()
                        document.created_from = created_from
                        document.doc_form = document_data['doc_form']
                        document.doc_language = document_data['doc_language']
                        document.data_source_info = json.dumps(data_source_info)
                        document.batch = batch
                        document.indexing_status = 'waiting'
                        db.session.add(document)
                        documents.append(document)
                        duplicate_document_ids.append(document.id)
                        continue
                # 创建新文档
                document = DocumentService.build_document(
                    dataset, dataset_process_rule.id,
                    document_data["data_source"]["type"],
                    document_data["doc_form"],
                    document_data["doc_language"],
                    data_source_info, created_from, position,
                    account, file_name, batch
                )
                db.session.add(document)
                db.session.flush()
                
                
                
                document_ids.append(document.id)
                documents.append(document)
                position += 1
        # 处理Notion导入数据源
        elif document_data["data_source"]["type"] == "notion_import":
            # 获取Notion信息列表
            notion_info_list = document_data["data_source"]['info_list']['notion_info_list']
            # 初始化已存在的Notion页面ID列表
            exist_page_ids = []
            # 初始化已存在的文档字典,键为Notion页面ID,值为文档ID
            exist_document = {}
            # 查询已存在的Notion导入类型的文档
            documents = Document.query.filter_by(
                dataset_id=dataset.id,
                tenant_id=current_user.current_tenant_id,
                data_source_type='notion_import',
                enabled=True
            ).all()
            if documents:
                for document in documents:
                    # 解析数据源信息
                    data_source_info = json.loads(document.data_source_info)
                    exist_page_ids.append(data_source_info['notion_page_id'])
                    exist_document[data_source_info['notion_page_id']] = document.id
            # 遍历Notion信息列表
            for notion_info in notion_info_list:
                # 获取工作空间ID
                workspace_id = notion_info['workspace_id']
                # 查询数据源绑定信息
                data_source_binding = DataSourceOauthBinding.query.filter(
                    db.and_(
                        DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
                        DataSourceOauthBinding.provider == 'notion',
                        DataSourceOauthBinding.disabled == False,
                        DataSourceOauthBinding.source_info['workspace_id'] == f'"{workspace_id}"'
                    )
                ).first()
                if not data_source_binding:
                    # 如果数据源绑定不存在,抛出错误
                    raise ValueError('Data source binding not found.')
                # 遍历Notion页面
                for page in notion_info['pages']:
                    # 如果页面ID不在已存在的页面ID列表中
                    if page['page_id'] not in exist_page_ids:
                        # 创建数据源信息字典
                        data_source_info = {
                            "notion_workspace_id": workspace_id,
                            "notion_page_id": page['page_id'],
                            "notion_page_icon": page['page_icon'],
                            "type": page['type']
                        }
                        # 创建新文档
                        document = DocumentService.build_document(
                            dataset, dataset_process_rule.id,
                            document_data["data_source"]["type"],
                            document_data["doc_form"],
                            document_data["doc_language"],
                            data_source_info, created_from, position,
                            account, page['page_name'], batch
                        )
                        db.session.add(document)
                        db.session.flush()
                        document_ids.append(document.id)
                        documents.append(document)
                        position += 1
                    else:
                        exist_document.pop(page['page_id'])
            # # 删除未被选择的文档
            if len(exist_document) > 0:
                clean_notion_document_task.delay(list(exist_document.values()), dataset.id)
        # 处理网站抓取数据源
        elif document_data["data_source"]["type"] == "website_crawl":
            # 获取网站信息
            website_info = document_data["data_source"]['info_list']['website_info_list']
            urls = website_info['urls']
            # 遍历URL列表
            for url in urls:
                data_source_info = {
                    'url': url,
                    'provider': website_info['provider'],
                    'job_id': website_info['job_id'],
                    'only_main_content': website_info.get('only_main_content', False),
                    'mode': 'crawl',
                }
                if url.length > 255:
                    document_name = url[:200] + '...'
                else:
                    document_name = url
                # 创建新文档
                document = DocumentService.build_document(
                    dataset, dataset_process_rule.id,
                    document_data["data_source"]["type"],
                    document_data["doc_form"],
                    document_data["doc_language"],
                    data_source_info, created_from, position,
                    account, document_name, batch
                )
                db.session.add(document)
                db.session.flush()
                document_ids.append(document.id)
                documents.append(document)

                position += 1
        # 提交数据库会话,确保所有更改被持久化
        db.session.commit()

        #  # 触发异步任务,对新创建或更新的文档进行索引
        if document_ids:
            document_indexing_task.delay(dataset.id, document_ids)
        if duplicate_document_ids:
            duplicate_document_indexing_task.delay(dataset.id, duplicate_document_ids)

    return documents, batch
RAG

duplicate_document_indexing_task.delay(dataset.id, duplicate_document_ids)

启动异步索引任务,针对知识库中的每个文档创建索引

删除

DatasetApi-delete

位置:api/controllers/console/datasets/datasets.py

请添加图片描述

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值