diff --git a/data_chain/apps/app.py b/data_chain/apps/app.py index d072df7d46df6cc1e7d0974c62b555d5f4d220a0..f8cd6fa6ec140263874bbb5f6cddb7e2ca35ddb9 100644 --- a/data_chain/apps/app.py +++ b/data_chain/apps/app.py @@ -103,15 +103,33 @@ app.add_exception_handler(HTTPException, http_exception_handler) app.add_exception_handler(StarletteHTTPException, starlette_http_exception_handler) app.add_exception_handler(Exception, general_exception_handler) - +import time @app.on_event("startup") async def startup_event(): + st= time.time() await configure() + en = time.time() + logging.info(f"[App] configure completed, time used: {en - st:.2f} seconds") + st= time.time() await add_acitons() + en = time.time() + logging.info(f"[App] add_actions completed, time used: {en - st:.2f} seconds") + st= time.time() await TaskQueueService.init_task_queue() + en = time.time() + logging.info(f"[App] init_task_queue completed, time used: {en - st:.2f} seconds") + st= time.time() await add_knowledge_base() + en = time.time() + logging.info(f"[App] add_knowledge_base completed, time used: {en - st:.2f} seconds") + st= time.time() await add_document_type() + en = time.time() + logging.info(f"[App] add_document_type completed, time used: {en - st:.2f} seconds") + st= time.time() await init_path() + en = time.time() + logging.info(f"[App] init_path completed, time used: {en - st:.2f} seconds") scheduler.add_job(TaskQueueService.handle_tasks, 'interval', seconds=5) scheduler.start() diff --git a/data_chain/apps/service/chunk_service.py b/data_chain/apps/service/chunk_service.py index f6c51d46281056ae642cce670d5a62298c3cacac..300843b4261c31819e2872dc2e8a3dc8fbf1696a 100644 --- a/data_chain/apps/service/chunk_service.py +++ b/data_chain/apps/service/chunk_service.py @@ -1,5 +1,6 @@ # Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. -import aiofiles +import time +import asyncio from fastapi import APIRouter, Depends, Query, Body, File, UploadFile import uuid import traceback @@ -28,7 +29,7 @@ from data_chain.manager.chunk_manager import ChunkManager from data_chain.manager.role_manager import RoleManager from data_chain.manager.task_manager import TaskManager from data_chain.manager.task_report_manager import TaskReportManager -from data_chain.stores.database.database import DocumentEntity +from data_chain.stores.database.database import ChunkEntity from data_chain.stores.minio.minio import MinIO from data_chain.entities.enum import ParseMethod, DataSetStatus, DocumentStatus, TaskType from data_chain.entities.common import DOC_PATH_IN_OS, DOC_PATH_IN_MINIO, DEFAULT_KNOWLEDGE_BASE_ID, DEFAULT_DOC_TYPE_ID @@ -39,7 +40,9 @@ from data_chain.embedding.embedding import Embedding class ChunkService: + """Chunk Service""" + @staticmethod async def validate_user_action_to_chunk(user_sub: str, chunk_id: uuid.UUID, action: str) -> bool: """验证用户对分片的操作权限""" try: @@ -58,6 +61,7 @@ class ChunkService: logging.exception("[ChunkService] %s", err) raise e + @staticmethod async def list_chunks_by_document_id(req: ListChunkRequest) -> ListChunkMsg: """根据文档ID列出分片""" try: @@ -75,99 +79,159 @@ class ChunkService: logging.exception("[ChunkService] %s", err) raise e + @staticmethod + async def search_chunks_from_kb(user_sub: str, action: str, search_method: str, kb_id: uuid.UUID, query: str, top_k: int, doc_ids: list[uuid.UUID] = None, banned_ids: list[uuid.UUID] = [], is_rerank: bool = False) -> list[ChunkEntity]: + """从知识库搜索分片""" + top_k_search = top_k + if is_rerank: + top_k_search = top_k * 3 + try: + chunk_entities = await BaseSearcher.search(search_method, kb_id, query, top_k_search, doc_ids, banned_ids) + except Exception as e: + err = f"搜索分片失败,error: {e}" + logging.exception("[ChunkService] %s", err) + return [] + return chunk_entities + + @staticmethod + async def rerank_chunks(chunk_entities: list[ChunkEntity], rerank_method: str, query: str, top_k: int) -> list[ChunkEntity]: + """对分片进行重排序""" + chunk_indexs = await BaseSearcher.rerank(chunk_entities, rerank_method, query) + chunk_entities = [chunk_entities[i] for i in chunk_indexs] + chunk_entities = chunk_entities[:top_k] + return chunk_entities + # 关联上下文 + + @staticmethod + async def relate_surrounding_chunks(chunk_entities: list[ChunkEntity], tokens_limit: int) -> list[ChunkEntity]: + """关联上下文到搜索分片结果中""" + chunk_ids = [chunk_entity.id for chunk_entity in chunk_entities] + tokens_limit_every_chunk = tokens_limit // len( + chunk_entities) if len(chunk_entities) > 0 else tokens_limit + leave_tokens = 0 + related_chunk_entities = [] + token_sum = 0 + for chunk_entity in chunk_entities: + token_sum += chunk_entity.tokens + for chunk_entity in chunk_entities: + leave_tokens = tokens_limit_every_chunk+leave_tokens + try: + sub_related_chunk_entities = await BaseSearcher.related_surround_chunk(chunk_entity, leave_tokens-chunk_entity.tokens, chunk_ids) + except Exception as e: + leave_tokens += tokens_limit_every_chunk + err = f"[ChunkService] 关联上下文失败,error: {e}" + logging.exception(err) + continue + for related_chunk_entity in sub_related_chunk_entities: + token_sum += related_chunk_entity.tokens + leave_tokens -= related_chunk_entity.tokens + if leave_tokens < 0: + leave_tokens = 0 + chunk_ids += [chunk_entity.id for chunk_entity in sub_related_chunk_entities] + related_chunk_entities += sub_related_chunk_entities + if token_sum >= tokens_limit: + break + return related_chunk_entities + # 补全文档信息 + + @staticmethod + async def enrich_doc_info_to_search_chunks(search_chunk_msg: SearchChunkMsg) -> None: + """补全文档信息到搜索分片结果中""" + doc_entities = await DocumentManager.list_document_by_doc_ids( + [doc_chunk.doc_id for doc_chunk in search_chunk_msg.doc_chunks]) + doc_map = {doc_entity.id: doc_entity for doc_entity in doc_entities} + for doc_chunk in search_chunk_msg.doc_chunks: + doc_entity = doc_map.get(doc_chunk.doc_id) + doc_chunk.doc_author = doc_entity.author_name if doc_entity else "" + doc_chunk.doc_created_at = doc_entity.created_time.strftime( + '%Y-%m-%d %H:%M') if doc_entity else "" + doc_chunk.doc_abstract = doc_entity.abstract if doc_entity else "" + doc_chunk.doc_extension = doc_entity.extension if doc_entity else "" + doc_chunk.doc_size = doc_entity.size if doc_entity else 0 + + @staticmethod async def search_chunks(user_sub: str, action: str, req: SearchChunkRequest) -> SearchChunkMsg: """根据查询条件搜索分片""" + search_chunk_msg = SearchChunkMsg(docChunks=[]) + kb_ids_after_validate = [] + for kb_id in req.kb_ids: + if kb_id == DEFAULT_KNOWLEDGE_BASE_ID or await KnowledgeBaseService.validate_user_action_to_knowledge_base(user_sub, kb_id, action): + kb_ids_after_validate.append(kb_id) + else: + logging.error( + "[ChunkService] 用户没有权限访问该知识库,知识库ID: %s", str(kb_id)) + req.kb_ids = kb_ids_after_validate logging.error("[ChunkService] 搜索分片,查询条件: %s", req) chunk_entities = [] + search_tasks = [] + st = time.time() for kb_id in req.kb_ids: - try: - kb_entity = await KnowledgeBaseManager.get_knowledge_base_by_kb_id(kb_id) - if kb_entity is None: - err = f"知识库不存在,知识库ID: {kb_id}" - logging.warning("[ChunkService] %s", err) - continue - if kb_id != DEFAULT_KNOWLEDGE_BASE_ID and not await KnowledgeBaseService.validate_user_action_to_knowledge_base(user_sub, kb_id, action): - err = f"用户没有权限访问该知识库,知识库ID: {kb_id}" - logging.warning("[ChunkService] %s", err) - continue - top_k = req.top_k - if req.is_rerank: - top_k = req.top_k*3 - sub_chunk_entities = await BaseSearcher.search(req.search_method.value, kb_id, req.query, top_k, req.doc_ids, req.banned_ids) - if req.is_rerank: - sub_chunk_indexs = await BaseSearcher.rerank(sub_chunk_entities, kb_entity.rerank_method, req.query) - sub_chunk_entities = [sub_chunk_entities[i] - for i in sub_chunk_indexs] - sub_chunk_entities = sub_chunk_entities[:req.top_k] - chunk_entities += sub_chunk_entities - except Exception as e: - err = f"[ChunkService] 搜索分片失败,error: {e}" - logging.exception(err) - return SearchChunkMsg(docChunks=[]) + search_task = ChunkService.search_chunks_from_kb( + user_sub, action, req.search_method, kb_id, req.query, req.top_k, req.doc_ids, req.banned_ids, req.is_rerank) + search_tasks.append(search_task) + search_results = await asyncio.gather(*search_tasks) + en = time.time() + if req.is_testing_scene: + search_chunk_msg.t_used_in_search = round(en - st, 3) + if req.is_rerank: + st = time.time() + for i in range(len(req.kb_ids)): + kb_entity = await KnowledgeBaseManager.get_knowledge_base_by_kb_id(req.kb_ids[i]) + search_results[i] = await ChunkService.rerank_chunks( + search_results[i], kb_entity.rerank_method, req.query, req.top_k) + chunk_entities += search_results[i] + en = time.time() + if req.is_testing_scene: + search_chunk_msg.t_used_in_rerank = round(en - st, 3) + else: + for result in search_results: + chunk_entities += result if len(chunk_entities) == 0: return SearchChunkMsg(docChunks=[]) if req.is_rerank: + st = time.time() chunk_indexs = await BaseSearcher.rerank(chunk_entities, None, req.query) chunk_entities = [chunk_entities[i] for i in chunk_indexs] + en = time.time() + if req.is_testing_scene: + search_chunk_msg.t_used_in_rerank = round( + (search_chunk_msg.t_used_in_rerank or 0) + (en - st), 3) chunk_entities = chunk_entities[:req.top_k] - chunk_ids = [chunk_entity.id for chunk_entity in chunk_entities] logging.error("[ChunkService] 搜索分片,查询结果数量: %s", len(chunk_entities)) if req.is_related_surrounding: # 关联上下文 - tokens_limit = req.tokens_limit - tokens_limit_every_chunk = tokens_limit // len( - chunk_entities) if len(chunk_entities) > 0 else tokens_limit - leave_tokens = 0 - related_chunk_entities = [] - token_sum = 0 - for chunk_entity in chunk_entities: - token_sum += chunk_entity.tokens - for chunk_entity in chunk_entities: - leave_tokens = tokens_limit_every_chunk+leave_tokens - try: - sub_related_chunk_entities = await BaseSearcher.related_surround_chunk(chunk_entity, leave_tokens-chunk_entity.tokens, chunk_ids) - except Exception as e: - leave_tokens += tokens_limit_every_chunk - err = f"[ChunkService] 关联上下文失败,error: {e}" - logging.exception(err) - continue - for related_chunk_entity in sub_related_chunk_entities: - token_sum += related_chunk_entity.tokens - leave_tokens -= related_chunk_entity.tokens - if leave_tokens < 0: - leave_tokens = 0 - chunk_ids += [chunk_entity.id for chunk_entity in sub_related_chunk_entities] - related_chunk_entities += sub_related_chunk_entities - if token_sum >= tokens_limit: - break - chunk_entities += related_chunk_entities - search_chunk_msg = SearchChunkMsg(docChunks=[]) + st = time.time() + chunk_entities_related = await ChunkService.relate_surrounding_chunks( + chunk_entities, req.tokens_limit) + chunk_entities += chunk_entities_related + en = time.time() + if req.is_testing_scene: + search_chunk_msg.t_used_in_surrounding_text_relation = round( + en - st, 3) if req.is_classify_by_doc: doc_chunks = await BaseSearcher.classify_by_doc_id(chunk_entities) - for doc_chunk in doc_chunks: - for chunk in doc_chunk.chunks: - if req.is_compress: - chunk.text = TokenTool.compress_tokens(chunk.text) - search_chunk_msg.doc_chunks.append(doc_chunk) + search_chunk_msg.doc_chunks = doc_chunks else: for chunk_entity in chunk_entities: chunk = await Convertor.convert_chunk_entity_to_chunk(chunk_entity) - if req.is_compress: - chunk.text = TokenTool.compress_tokens(chunk.text) dc = DocChunk(docId=chunk_entity.doc_id, docName=chunk_entity.doc_name, chunks=[chunk]) search_chunk_msg.doc_chunks.append(dc) - doc_entities = await DocumentManager.list_document_by_doc_ids( - [doc_chunk.doc_id for doc_chunk in search_chunk_msg.doc_chunks]) - doc_map = {doc_entity.id: doc_entity for doc_entity in doc_entities} - for doc_chunk in search_chunk_msg.doc_chunks: - doc_entity = doc_map.get(doc_chunk.doc_id) - doc_chunk.doc_author = doc_entity.author_name if doc_entity else "" - doc_chunk.doc_created_at = doc_entity.created_time.strftime( - '%Y-%m-%d %H:%M') if doc_entity else "" - doc_chunk.doc_abstract = doc_entity.abstract if doc_entity else "" - doc_chunk.doc_extension = doc_entity.extension if doc_entity else "" - doc_chunk.doc_size = doc_entity.size if doc_entity else 0 + if req.is_compress: + st = time.time() + for doc_chunk in search_chunk_msg.doc_chunks: + for chunk in doc_chunk.chunks: + chunk.text = await TokenTool.compress_tokens(chunk.text) + en = time.time() + if req.is_testing_scene: + search_chunk_msg.t_used_in_text_compression = round( + en - st, 3) + if req.is_testing_scene: + for doc_chunk in search_chunk_msg.doc_chunks: + for chunk in doc_chunk.chunks: + chunk.score = await TokenTool.cal_jac(req.query, chunk.text) + await ChunkService.enrich_doc_info_to_search_chunks(search_chunk_msg) + logging.error("f{search_chunk_msg}") return search_chunk_msg async def update_chunk_by_id(chunk_id: uuid.UUID, req: UpdateChunkRequest) -> uuid.UUID: diff --git a/data_chain/apps/service/document_service.py b/data_chain/apps/service/document_service.py index 7ec05556bb5381615813ff6bc6829f8dc35ea9ec..8f95f1905fe433507b2c601b1a660b1333798089 100644 --- a/data_chain/apps/service/document_service.py +++ b/data_chain/apps/service/document_service.py @@ -1,4 +1,5 @@ # Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. +import time import aiofiles from fastapi import APIRouter, Depends, Query, Body, File, UploadFile import uuid @@ -67,7 +68,8 @@ class DocumentService: doc_ids = [doc_entity.id for doc_entity in doc_entities] task_entities = await TaskManager.list_current_tasks_by_op_ids(doc_ids) task_ids = [task_entity.id for task_entity in task_entities] - task_dict = {task_entity.op_id: task_entity for task_entity in task_entities} + task_dict = { + task_entity.op_id: task_entity for task_entity in task_entities} task_report_entities = await TaskReportManager.list_current_task_report_by_task_ids(task_ids) task_report_dict = {task_report_entity.task_id: task_report_entity for task_report_entity in task_report_entities} @@ -82,7 +84,8 @@ class DocumentService: task = await Convertor.convert_task_entity_to_task(task_entity, task_report) document.parse_task = task documents.append(document) - list_document_msg = ListDocumentMsg(total=total, documents=documents) + list_document_msg = ListDocumentMsg( + total=total, documents=documents) return list_document_msg except Exception as e: err = "列出文档失败" @@ -166,7 +169,8 @@ class DocumentService: continue doc_ids.append(doc_entity.id) task_entities = await TaskManager.list_current_tasks_by_op_ids(doc_ids) - task_dict = {task_entity.op_id: task_entity for task_entity in task_entities} + task_dict = { + task_entity.op_id: task_entity for task_entity in task_entities} doc_status_list = [] for doc_id in doc_ids: task_entity = task_dict.get(doc_id, None) @@ -313,6 +317,7 @@ class DocumentService: err = f"上传文档数量或大小超过限制, 知识库ID: {kb_id}, 上传文档数量: {doc_cnt}, 上传文档大小: {doc_sz}" logging.error("[DocumentService] %s", err) raise ValueError(err) + st = time.time() doc_entities = [] for doc in docs: try: @@ -356,7 +361,9 @@ class DocumentService: err = f"上传文档失败, 文档名: {doc.filename}, 错误信息: {e}" logging.error("[DocumentService] %s", err) continue + logging.error("[DocumentService] 上传文档总耗时: %.2f 秒", time.time() - st) index = 0 + st = time.time() while index < len(doc_entities): try: await DocumentManager.add_documents(doc_entities[index:index+1024]) @@ -365,10 +372,16 @@ class DocumentService: err = f"上传文档失败, 文档名: {doc_entity.name}, 错误信息: {e}" logging.error("[DocumentService] %s", err) continue + logging.error("[DocumentService] 入库文档总耗时: %.2f 秒", time.time() - st) + st = time.time() for doc_entity in doc_entities: await TaskQueueService.init_task(TaskType.DOC_PARSE.value, doc_entity.id) + logging.error("[DocumentService] 初始化任务总耗时: %.2f 秒", time.time() - st) doc_ids = [doc_entity.id for doc_entity in doc_entities] + st = time.time() await KnowledgeBaseManager.update_doc_cnt_and_doc_size(kb_id=kb_entity.id) + logging.error( + "[DocumentService] 更新知识库文档数量和大小总耗时: %.2f 秒", time.time() - st) return doc_ids @staticmethod @@ -455,7 +468,8 @@ class DocumentService: doc_entities = await DocumentManager.update_document_by_doc_ids( doc_ids, {"status": DocumentStatus.DELETED.value}) doc_ids = [doc_entity.id for doc_entity in doc_entities] - kb_ids = [doc_entity.kb_id for doc_entity in doc_entities if doc_entity.kb_id is not None] + kb_ids = [ + doc_entity.kb_id for doc_entity in doc_entities if doc_entity.kb_id is not None] kb_ids = list(set(kb_ids)) for kb_id in kb_ids: await KnowledgeBaseManager.update_doc_cnt_and_doc_size(kb_id=kb_id) diff --git a/data_chain/apps/service/task_queue_service.py b/data_chain/apps/service/task_queue_service.py index 042d0c3da6a77360958616a9faafb5bebc9b224e..259cbac6a7826776403f697ea682a44bfeaae248 100644 --- a/data_chain/apps/service/task_queue_service.py +++ b/data_chain/apps/service/task_queue_service.py @@ -16,23 +16,45 @@ class TaskQueueService: @staticmethod async def init_task_queue(): + import time + st = time.time() task_entities = await TaskManager.list_task_by_task_status(TaskStatus.PENDING.value) + en = time.time() + logging.info(f"[TaskQueueService] 获取待处理任务耗时 {en-st} 秒") + st = time.time() task_entities += await TaskManager.list_task_by_task_status(TaskStatus.RUNNING.value) + en = time.time() + logging.info(f"[TaskQueueService] 获取运行中任务耗时 {en-st} 秒") for task_entity in task_entities: + # 将所有任务取消 + try: if task_entity.status == TaskStatus.RUNNING.value: + st = time.time() flag = await BaseWorker.reinit(task_entity.id) + en = time.time() + logging.info(f"[TaskQueueService] 重新初始化任务耗时 {en-st} 秒") if flag: - task = TaskQueueEntity(id=task_entity.id, status=TaskStatus.PENDING.value) - await TaskQueueManager.update_task_by_id(task_entity.id, task) + st = time.time() + await TaskQueueManager.update_task_by_id(task_entity.id, TaskStatus.PENDING) + en = time.time() else: + st = time.time() await BaseWorker.stop(task_entity.id) await TaskQueueManager.delete_task_by_id(task_entity.id) + en = time.time() else: + st = time.time() task = await TaskQueueManager.get_task_by_id(task_entity.id) + en = time.time() + logging.info(f"[TaskQueueService] 获取任务耗时 {en-st} 秒") if task is None: - task = TaskQueueEntity(id=task_entity.id, status=TaskStatus.PENDING.value) + st = time.time() + task = TaskQueueEntity( + id=task_entity.id, status=TaskStatus.PENDING.value) await TaskQueueManager.add_task(task) + en = time.time() + logging.info(f"[TaskQueueService] 添加任务耗时 {en-st} 秒") except Exception as e: warning = f"[TaskQueueService] 初始化任务失败 {e}" logging.warning(warning) @@ -103,8 +125,7 @@ class TaskQueueService: await TaskQueueManager.delete_task_by_id(task.id) continue if flag: - task.status = TaskStatus.PENDING.value - await TaskQueueManager.update_task_by_id(task.id, task) + await TaskQueueManager.update_task_by_id(task.id, TaskStatus.PENDING) else: await TaskQueueManager.delete_task_by_id(task.id) diff --git a/data_chain/entities/request_data.py b/data_chain/entities/request_data.py index 3f1c12991e047d572f709f43c67f1b86b54de5e2..67a40ee8ba8440664a01dac582821f7088be840a 100644 --- a/data_chain/entities/request_data.py +++ b/data_chain/entities/request_data.py @@ -263,6 +263,8 @@ class SearchChunkRequest(BaseModel): default=False, description="是否压缩", alias="isCompress") tokens_limit: int = Field( default=8192, description="token限制", alias="tokensLimit") + is_testing_scene: Optional[bool] = Field( + default=False, description="是否是评测场景", alias="isTestingScene") class ListDatasetRequest(BaseModel): diff --git a/data_chain/entities/response_data.py b/data_chain/entities/response_data.py index cb636eedb43522fe2dee3981184cd4ca44a38c33..f8cdffe55ab175fae7240da0d6021ba47d8cbe15 100644 --- a/data_chain/entities/response_data.py +++ b/data_chain/entities/response_data.py @@ -349,6 +349,8 @@ class Chunk(BaseModel): chunk_type: ChunkType = Field(description="分片类型", alias="chunkType") text: str = Field(description="分片文本") enabled: bool = Field(description="分片是否启用", alias="enabled") + score: Optional[float] = Field( + default=None, description="分片得分", alias="score") class ListChunkMsg(BaseModel): @@ -392,6 +394,14 @@ class SearchChunkMsg(BaseModel): """Post /chunk/search 数据结构""" doc_chunks: list[DocChunk] = Field( default=[], description="文档分片列表", alias="docChunks") + t_used_in_search: Optional[float] = Field( + default=None, description="搜索中使用的时间", alias="tUsedInSearch") + t_used_in_rerank: Optional[float] = Field( + default=None, description="重排序中使用的时间", alias="tUsedInRerank") + t_used_in_surrounding_text_relation: Optional[float] = Field( + default=None, description="上下文关系中使用的时间", alias="tUsedInSurroundingTextRelation") + t_used_in_text_compression: Optional[float] = Field( + default=None, description="文本压缩中使用的时间", alias="tUsedInTextCompression") class SearchChunkResponse(ResponseData): diff --git a/data_chain/manager/chunk_manager.py b/data_chain/manager/chunk_manager.py index d0c9ac679ead62753d0e51477f4bc9efe06ad3c3..7f83ee4dc54ed55d058b87b9f61a2875b3be22d4 100644 --- a/data_chain/manager/chunk_manager.py +++ b/data_chain/manager/chunk_manager.py @@ -1,5 +1,5 @@ # Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. -from sqlalchemy import select, update, func, text, or_, and_, Float, literal_column +from sqlalchemy import select, update, func, text, or_, and_, bindparam from typing import List, Tuple, Dict, Optional import uuid from datetime import datetime @@ -197,7 +197,7 @@ class ChunkManager(): else: # PostgreSQL需要将向量转换为字符串格式 vector_param = str(vector) - + params = { "vector": vector_param, "kb_id": kb_id, @@ -246,7 +246,8 @@ class ChunkManager(): # 方案1:临时关闭全表扫描(优先推荐,简单有效,会话级生效,不影响其他查询) # 执行SET命令:关闭全表扫描后,数据库会优先选择可用索引(text_vector_index) await session.execute(text("SET enable_seqscan = off;")) - + # 增加searcher数量,提升向量检索性能(可选) + await session.execute(text("SET hnsw_ef_search = 1000;")) # 方案2(备选):使用OpenGauss查询计划hints(需确保数据库开启hints支持) # 在SELECT后添加 /*+ IndexScan(chunk text_vector_index) */ 强制索引扫描 # 若用方案2,需将下面SELECT行改为:SELECT /*+ IndexScan(chunk text_vector_index) */ @@ -259,12 +260,12 @@ class ChunkManager(): chunk.pre_id_in_parse_topology, chunk.parse_topology_type, chunk.global_offset, chunk.local_offset, chunk.enabled, chunk.status, chunk.created_time, chunk.updated_time, - chunk.text_vector <#> :vector AS similarity_score + chunk.text_vector <=> :vector AS similarity_score FROM chunk JOIN document ON document.id = chunk.doc_id WHERE {where_clause} - AND (chunk.text_vector <#> :vector) IS NOT NULL - AND (chunk.text_vector <#> :vector) = (chunk.text_vector <#> :vector) + AND (chunk.text_vector <=> :vector) IS NOT NULL + AND (chunk.text_vector <=> :vector) = (chunk.text_vector <=> :vector) ORDER BY similarity_score ASC NULLS LAST LIMIT :limit """ @@ -392,73 +393,42 @@ class ChunkManager(): @staticmethod async def get_top_k_chunk_by_kb_id_dynamic_weighted_keyword( - kb_id: uuid.UUID, keywords: List[str], weights: List[float], + kb_id: uuid.UUID, query: str, # 关键词列表改为单查询文本 top_k: int, doc_ids: list[uuid.UUID] = None, banned_ids: list[uuid.UUID] = [], chunk_to_type: str = None, pre_ids: list[uuid.UUID] = None) -> List[ChunkEntity]: - """根据知识库ID和关键词权重查询文档解析结果(修复NoneType报错+强制索引)""" + """根据知识库ID和查询文本查询文档解析结果(使用BM25直接打分)""" try: st = datetime.now() async with await DataBase.get_session() as session: - # 1. 分词器选择(保留原逻辑) - kb_entity = await KnowledgeBaseManager.get_knowledge_base_by_kb_id(kb_id) - if kb_entity.tokenizer == Tokenizer.ZH.value: - tokenizer = 'chparser' if config['DATABASE_TYPE'].lower( - ) == 'opengauss' else 'zhparser' - elif kb_entity.tokenizer == Tokenizer.EN.value: - tokenizer = 'english' - else: - tokenizer = 'chparser' if config['DATABASE_TYPE'].lower( - ) == 'opengauss' else 'zhparser' - - # 2. 构建加权关键词CTE(保留原逻辑) - params = {} - values_clause = [] - for idx, (term, weight) in enumerate(zip(keywords, weights)): - params[f"term_{idx}"] = term - params[f"weight_{idx}"] = weight - values_clause.append( - f"(CAST(:term_{idx} AS TEXT), CAST(:weight_{idx} AS FLOAT8))") - values_text = f"(VALUES {', '.join(values_clause)}) AS t(term, weight)" - weighted_terms = ( - select( - literal_column("t.term").label("term"), - literal_column("t.weight").cast(Float).label("weight") - ) - .select_from(text(values_text)) - .cte("weighted_terms") - ) + # 1. 构建查询文本参数(单文本,无需CTE列表) + params = {"query": query} - # 3. 初始化查询(确保stmt始终是Select对象,不直接赋值None) + # 2. 初始化查询(直接使用查询文本计算BM25分数) + # 使用bindparam定义参数,避免混合使用占位符和美元符号引用 + query_param = bindparam("query") stmt = ( select( ChunkEntity, - func.sum( - func.ts_rank_cd(ChunkEntity.text_ts_vector, func.to_tsquery( - tokenizer, weighted_terms.c.term)) - * weighted_terms.c.weight - ).label("similarity_score") + # 计算查询文本与chunk的BM25分数 + ChunkEntity.text.op('<&>')(query_param).label("similarity_score") ) # 关联文档表 .join(DocumentEntity, DocumentEntity.id == ChunkEntity.doc_id) - .join( # 关联CTE+强制触发GIN索引(核心优化) - weighted_terms, - ChunkEntity.text_ts_vector.op( - '@@')(func.to_tsquery(tokenizer, weighted_terms.c.term)), - isouter=False - ) # 基础过滤条件 .where(DocumentEntity.enabled == True) .where(DocumentEntity.status != DocumentStatus.DELETED.value) .where(ChunkEntity.kb_id == kb_id) .where(ChunkEntity.enabled == True) .where(ChunkEntity.status != ChunkStatus.DELETED.value) + # 过滤BM25分数大于0的结果(确保有相关性) + .where(ChunkEntity.text.op('<&>')(query_param) > 0) ) - # 4. 动态条件:禁用ID(修复关键:用if-else确保stmt不被赋值为None) + # 3. 动态条件:禁用ID if banned_ids: stmt = stmt.where(ChunkEntity.id.notin_(banned_ids)) - # 5. 其他动态条件(同样用if-else确保链式调用不中断) + # 4. 其他动态条件 if doc_ids is not None: stmt = stmt.where(DocumentEntity.id.in_(doc_ids)) if chunk_to_type is not None: @@ -468,44 +438,31 @@ class ChunkManager(): stmt = stmt.where( ChunkEntity.pre_id_in_parse_topology.in_(pre_ids)) - # 6. 分组、过滤分数、排序、限制行数(链式调用安全) + # 5. 排序、限制(直接使用BM25分数排序) stmt = (stmt - .group_by(ChunkEntity.id) # 按chunk分组计算总权重 - .having( # 过滤总分数>0的结果 - func.sum( - func.ts_rank_cd(ChunkEntity.text_ts_vector, func.to_tsquery( - tokenizer, weighted_terms.c.term)) - * weighted_terms.c.weight - ) > 0 - ) - .order_by( # 按总分数降序 - func.sum( - func.ts_rank_cd(ChunkEntity.text_ts_vector, func.to_tsquery( - tokenizer, weighted_terms.c.term)) - * weighted_terms.c.weight - ).desc() + .order_by( + ChunkEntity.text.op('<&>')(query_param).desc() ) - .limit(top_k) # 限制返回数量 + .limit(top_k) ) - # 7. 执行查询与结果处理(保留原逻辑) + # 6. 执行查询与结果处理 result = await session.execute(stmt, params=params) chunk_entities = result.scalars().all() - # 8. 日志输出 + # 7. 日志输出 cost = (datetime.now() - st).total_seconds() logging.warning( - f"[ChunkManager] get_top_k_chunk_by_kb_id_dynamic_weighted_keyword cost: {cost}s " - f"| kb_id: {kb_id} | keywords: {keywords[:2]}... | match_count: {len(chunk_entities)}" + f"[ChunkManager] BM25查询耗时: {cost}s " + f"| kb_id: {kb_id} | query: {query[:50]}... | 匹配数量: {len(chunk_entities)}" ) return chunk_entities except Exception as e: - # 异常日志补充关键上下文 - err = f"根据知识库ID和关键词权重查询失败: kb_id={kb_id}, keywords={keywords[:2]}..., error={str(e)[:150]}" + err = f"BM25查询失败: kb_id={kb_id}, query={query[:50]}..., error={str(e)[:150]}" logging.exception("[ChunkManager] %s", err) return [] - + @staticmethod async def fetch_surrounding_chunk_by_doc_id_and_global_offset( doc_id: uuid.UUID, global_offset: int, diff --git a/data_chain/manager/document_manager.py b/data_chain/manager/document_manager.py index f03433b65797bc225faa9cbfb5e758d14d794694..205ca937243c6d727568c9106886655990ca1913 100644 --- a/data_chain/manager/document_manager.py +++ b/data_chain/manager/document_manager.py @@ -1,5 +1,5 @@ # Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. -from sqlalchemy import select, delete, update, func, between, asc, desc, and_, Float, literal_column, text +from sqlalchemy import select, delete, update, func, between, asc, desc, and_, Float, literal_column, text, bindparam from datetime import datetime, timezone import uuid from typing import Dict, List, Tuple @@ -105,15 +105,17 @@ class DocumentManager(): document.chunk_size, document.type_id, document.enabled, document.status, document.full_text, document.abstract, document.abstract_vector, document.created_time, document.updated_time, - document.abstract_vector <#> :vector AS similarity_score + document.abstract_vector <=> :vector AS similarity_score FROM document WHERE {where_clause} - AND (document.abstract_vector <#> :vector) IS NOT NULL - AND (document.abstract_vector <#> :vector) = (document.abstract_vector <#> :vector) + AND (document.abstract_vector <=> :vector) IS NOT NULL + AND (document.abstract_vector <=> :vector) = (document.abstract_vector <=> :vector) ORDER BY similarity_score ASC NULLS LAST LIMIT :limit """ + # 增加searcher数量,提升向量检索性能(可选) async with await DataBase.get_session() as session: + await session.execute(text("SET hnsw_ef_search = 1000;")) params["limit"] = top_k result = await session.execute(text(base_sql), params) rows = result.fetchall() @@ -209,108 +211,65 @@ class DocumentManager(): logging.exception("[DocumentManager] %s", err) raise e + @staticmethod async def get_top_k_document_by_kb_id_dynamic_weighted_keyword( - kb_id: uuid.UUID, keywords: List[str], weights: List[float], + kb_id: uuid.UUID, query: str, # 关键词列表改为单查询文本,移除weights参数 top_k: int, doc_ids: list[uuid.UUID] = None, banned_ids: list[uuid.UUID] = []) -> List[DocumentEntity]: - """根据知识库ID和关键词权重查询文档(修复NoneType报错+强制索引)""" + """根据知识库ID和查询文本查询文档(BM25检索版,匹配abstract_bm25_index索引)""" try: - st = datetime.now() # 新增计时日志 + st = datetime.now() async with await DataBase.get_session() as session: - # 1. 分词器选择(与第一个方法保持一致) - kb_entity = await KnowledgeBaseManager.get_knowledge_base_by_kb_id(kb_id) - if kb_entity.tokenizer == Tokenizer.ZH.value: - tokenizer = 'chparser' if config['DATABASE_TYPE'].lower( - ) == 'opengauss' else 'zhparser' - elif kb_entity.tokenizer == Tokenizer.EN.value: - tokenizer = 'english' - else: - tokenizer = 'chparser' if config['DATABASE_TYPE'].lower( - ) == 'opengauss' else 'zhparser' - - # 2. 构建加权关键词CTE(保留原逻辑) - params = {} - values_clause = [] - for idx, (term, weight) in enumerate(zip(keywords, weights)): - params[f"term_{idx}"] = term - params[f"weight_{idx}"] = weight - values_clause.append( - f"(CAST(:term_{idx} AS TEXT), CAST(:weight_{idx} AS FLOAT8))") - values_text = f"(VALUES {', '.join(values_clause)}) AS t(term, weight)" - weighted_terms = ( - select( - literal_column("t.term").label("term"), - literal_column("t.weight").cast(Float).label("weight") - ) - .select_from(text(values_text)) - .cte("weighted_terms") - ) + # 1. 构建查询文本参数(单文本,无需CTE列表) + params = {"query": query} - # 3. 初始化查询(确保stmt始终是Select对象) + # 2. 初始化查询(直接使用查询文本计算BM25分数) + query_param = bindparam("query") stmt = ( select( DocumentEntity, - func.sum( - func.ts_rank_cd(DocumentEntity.abstract_ts_vector, func.to_tsquery( - tokenizer, weighted_terms.c.term)) - * weighted_terms.c.weight - ).label("similarity_score") - ) - # 关联CTE+强制触发GIN索引(核心优化) - .join( - weighted_terms, - DocumentEntity.abstract_ts_vector.op( - '@@')(func.to_tsquery(tokenizer, weighted_terms.c.term)), - isouter=False + # 计算查询文本与文档abstract的BM25分数 + DocumentEntity.abstract.op('<&>')( + query_param).label("similarity_score") ) # 基础过滤条件 .where(DocumentEntity.enabled == True) .where(DocumentEntity.status != DocumentStatus.DELETED.value) .where(DocumentEntity.kb_id == kb_id) + # 过滤BM25分数大于0的结果(确保有相关性) + .where(DocumentEntity.abstract.op('<&>')(query_param) > 0) ) - # 4. 动态条件:禁用ID(确保stmt链式调用不中断) + # 3. 动态条件:禁用ID if banned_ids: stmt = stmt.where(DocumentEntity.id.notin_(banned_ids)) - # 5. 其他动态条件 + # 4. 其他动态条件:指定文档ID过滤 if doc_ids is not None: stmt = stmt.where(DocumentEntity.id.in_(doc_ids)) - # 6. 分组、过滤分数、排序、限制行数(链式调用安全) + # 5. 排序、限制(直接使用BM25分数排序) stmt = (stmt - .group_by(DocumentEntity.id) # 按文档ID分组计算总权重 - .having( # 过滤总分数>0的结果 - func.sum( - func.ts_rank_cd(DocumentEntity.abstract_ts_vector, func.to_tsquery( - tokenizer, weighted_terms.c.term)) - * weighted_terms.c.weight - ) > 0 - ) - .order_by( # 按总分数降序 - func.sum( - func.ts_rank_cd(DocumentEntity.abstract_ts_vector, func.to_tsquery( - tokenizer, weighted_terms.c.term)) - * weighted_terms.c.weight - ).desc() + .order_by( + DocumentEntity.abstract.op( + '<&>')(query_param).desc() ) - .limit(top_k) # 限制返回数量 + .limit(top_k) ) - # 7. 执行查询与结果处理 + # 6. 执行查询与结果处理 result = await session.execute(stmt, params=params) doc_entities = result.scalars().all() - # 8. 新增执行时间日志 + # 7. 日志输出 cost = (datetime.now() - st).total_seconds() logging.warning( - f"[DocumentManager] get_top_k_document_by_kb_id_dynamic_weighted_keyword cost: {cost}s " - f"| kb_id: {kb_id} | keywords: {keywords[:2]}... | match_count: {len(doc_entities)}" + f"[DocumentManager] BM25检索文档耗时: {cost}s " + f"| kb_id: {kb_id} | query: {query[:50]}... | 匹配数量: {len(doc_entities)}" ) return doc_entities except Exception as e: - # 异常日志补充关键上下文 - err = f"根据知识库ID和关键词权重查询文档失败: kb_id={kb_id}, keywords={keywords[:2]}..., error={str(e)[:150]}" + err = f"BM25检索文档失败: kb_id={kb_id}, query={query[:50]}..., error={str(e)[:150]}" logging.exception("[DocumentManager] %s", err) return [] diff --git a/data_chain/manager/task_queue_mamanger.py b/data_chain/manager/task_queue_mamanger.py index b0df886ba3e7a15691f3b533d6d3993903c9a638..2d1c33f01ec6dec80fe63f78192a0deea3f5cbe6 100644 --- a/data_chain/manager/task_queue_mamanger.py +++ b/data_chain/manager/task_queue_mamanger.py @@ -27,7 +27,8 @@ class TaskQueueManager(): """根据任务ID删除任务""" try: async with await DataBase.get_session() as session: - stmt = delete(TaskQueueEntity).where(TaskQueueEntity.id == task_id) + stmt = delete(TaskQueueEntity).where( + TaskQueueEntity.id == task_id) await session.execute(stmt) await session.commit() except Exception as e: @@ -58,7 +59,8 @@ class TaskQueueManager(): """根据任务ID获取任务""" try: async with await DataBase.get_session() as session: - stmt = select(TaskQueueEntity).where(TaskQueueEntity.id == task_id) + stmt = select(TaskQueueEntity).where( + TaskQueueEntity.id == task_id) result = await session.execute(stmt) return result.scalars().first() except Exception as e: @@ -67,14 +69,14 @@ class TaskQueueManager(): raise e @staticmethod - async def update_task_by_id(task_id: uuid.UUID, task: TaskQueueEntity): + async def update_task_by_id(task_id: uuid.UUID, status: TaskStatus): """根据任务ID更新任务""" try: async with await DataBase.get_session() as session: stmt = ( update(TaskQueueEntity) .where(TaskQueueEntity.id == task_id) - .values(status=task.status) + .values(status=status.value) ) await session.execute(stmt) await session.commit() @@ -82,3 +84,20 @@ class TaskQueueManager(): err = "更新任务失败" logging.exception("[TaskQueueManager] %s", err) raise e + + @staticmethod + async def update_task_by_ids(task_ids: List[uuid.UUID], status: TaskStatus): + """根据任务ID列表批量更新任务状态""" + try: + async with await DataBase.get_session() as session: + stmt = ( + update(TaskQueueEntity) + .where(TaskQueueEntity.id.in_(task_ids)) + .values(status=status.value) + ) + await session.execute(stmt) + await session.commit() + except Exception as e: + err = "批量更新任务状态失败" + logging.exception("[TaskQueueManager] %s", err) + raise e diff --git a/data_chain/parser/tools/token_tool.py b/data_chain/parser/tools/token_tool.py index 1d4765c52e2be9f5e5b41ae573bdcf0eb1829167..b71ef981a38216e4fd1b0de1553cc8db3e28034e 100644 --- a/data_chain/parser/tools/token_tool.py +++ b/data_chain/parser/tools/token_tool.py @@ -447,6 +447,25 @@ class TokenTool: # 计算余弦距离 cosine_dist = 1 - cosine_similarity return cosine_dist + # 通过向量计算语义相似度 + + @staticmethod + async def cal_semantic_similarity(content1: str, content2: str) -> float: + """ + 计算语义相似度 + 参数: + content1:内容1 + content2:内容2 + """ + try: + vector1 = await Embedding.vectorize_embedding(content1) + vector2 = await Embedding.vectorize_embedding(content2) + similarity = TokenTool.cosine_distance_numpy(vector1, vector2) + return (1 - similarity) * 100 + except Exception as e: + err = f"[TokenTool] 计算语义相似度失败 {e}" + logging.exception("[TokenTool] %s", err) + return -1 @staticmethod async def cal_relevance(question: str, answer: str, llm: LLM, language: str) -> float: diff --git a/data_chain/rag/doc2chunk_searcher.py b/data_chain/rag/doc2chunk_searcher.py index 6b3aca75cb1cc0d915acabe029a8d100364cd9ff..774048d5bc62aa6b54f1422b4daff2f9cbb91bf7 100644 --- a/data_chain/rag/doc2chunk_searcher.py +++ b/data_chain/rag/doc2chunk_searcher.py @@ -46,9 +46,7 @@ class Doc2ChunkSearcher(BaseSearcher): use_doc_ids += [doc_entity.id for doc_entity in doc_entities_vector] chunk_entities_keyword = await ChunkManager.get_top_k_chunk_by_kb_id_keyword(kb_id, query, top_k//3, use_doc_ids, banned_ids) banned_ids += [chunk_entity.id for chunk_entity in chunk_entities_keyword] - keywords, weights = TokenTool.get_top_k_keywords_and_weights(query) - logging.error(f"[KeywordVectorSearcher] keywords: {keywords}, weights: {weights}") - chunk_entities_get_by_dynamic_weighted_keyword = await ChunkManager.get_top_k_chunk_by_kb_id_dynamic_weighted_keyword(kb_id, keywords, weights, top_k//2, use_doc_ids, banned_ids) + chunk_entities_get_by_dynamic_weighted_keyword = await ChunkManager.get_top_k_chunk_by_kb_id_dynamic_weighted_keyword(kb_id, query, top_k//2, use_doc_ids, banned_ids) banned_ids += [chunk_entity.id for chunk_entity in chunk_entities_get_by_dynamic_weighted_keyword] chunk_entities_vector = [] for _ in range(3): diff --git a/data_chain/rag/dynamic_weighted_keyword_and_vector_searcher.py b/data_chain/rag/dynamic_weighted_keyword_and_vector_searcher.py index a5467d41d52546c0b5ae7fb879afcc148a8f28a9..aed0ffcc25010f2cc0cede28c4f7cbe30ae390eb 100644 --- a/data_chain/rag/dynamic_weighted_keyword_and_vector_searcher.py +++ b/data_chain/rag/dynamic_weighted_keyword_and_vector_searcher.py @@ -29,35 +29,39 @@ class DynamicKeywordVectorSearcher(BaseSearcher): :return: 检索结果 """ vector = await Embedding.vectorize_embedding(query) + logging.error(f"[DynamicKeywordVectorSearcher] vector {vector}") try: chunk_entities_get_by_keyword = await ChunkManager.get_top_k_chunk_by_kb_id_keyword( kb_id, query, max(top_k//3, 1), doc_ids, banned_ids) banned_ids += [chunk_entity.id for chunk_entity in chunk_entities_get_by_keyword] - keywords, weights = TokenTool.get_top_k_keywords_and_weights(query) - logging.error(f"[DynamicKeywordVectorSearcher] keywords: {keywords}, weights: {weights}") import time start_time = time.time() - chunk_entities_get_by_dynamic_weighted_keyword = await ChunkManager.get_top_k_chunk_by_kb_id_dynamic_weighted_keyword(kb_id, keywords, weights, top_k//2, doc_ids, banned_ids) + chunk_entities_get_by_dynamic_weighted_keyword = await ChunkManager.get_top_k_chunk_by_kb_id_dynamic_weighted_keyword(kb_id, query, top_k//2, doc_ids, banned_ids) end_time = time.time() - logging.info(f"[DynamicKeywordVectorSearcher] 动态关键字检索成功完成,耗时: {end_time - start_time:.2f}秒") + logging.info( + f"[DynamicKeywordVectorSearcher] 动态关键字检索成功完成,耗时: {end_time - start_time:.2f}秒") banned_ids += [chunk_entity.id for chunk_entity in chunk_entities_get_by_dynamic_weighted_keyword] chunk_entities_get_by_vector = [] for _ in range(3): try: import time start_time = time.time() - logging.error(f"[DynamicKeywordVectorSearcher] 开始进行向量检索,top_k: {top_k-len(chunk_entities_get_by_keyword)-len(chunk_entities_get_by_dynamic_weighted_keyword)}") + logging.error( + f"[DynamicKeywordVectorSearcher] 开始进行向量检索,top_k: {top_k-len(chunk_entities_get_by_keyword)-len(chunk_entities_get_by_dynamic_weighted_keyword)}") chunk_entities_get_by_vector = await asyncio.wait_for(ChunkManager.get_top_k_chunk_by_kb_id_vector(kb_id, vector, top_k-len(chunk_entities_get_by_keyword)-len(chunk_entities_get_by_dynamic_weighted_keyword), doc_ids, banned_ids), timeout=20) end_time = time.time() - logging.info(f"[DynamicKeywordVectorSearcher] 向量检索成功完成,耗时: {end_time - start_time:.2f}秒") + logging.info( + f"[DynamicKeywordVectorSearcher] 向量检索成功完成,耗时: {end_time - start_time:.2f}秒") break except Exception as e: import traceback err = f"[DynamicKeywordVectorSearcher] 向量检索失败,error: {e}, traceback: {traceback.format_exc()}" logging.error(err) continue - logging.error(f"[DynamicKeywordVectorSearcher] chunk_entities_get_by_keyword: {len(chunk_entities_get_by_keyword)}, chunk_entities_get_by_dynamic_weighted_keyword: {len(chunk_entities_get_by_dynamic_weighted_keyword)}, chunk_entities_get_by_vector: {len(chunk_entities_get_by_vector)}") - chunk_entities = chunk_entities_get_by_keyword + chunk_entities_get_by_dynamic_weighted_keyword + chunk_entities_get_by_vector + logging.error( + f"[DynamicKeywordVectorSearcher] chunk_entities_get_by_keyword: {len(chunk_entities_get_by_keyword)}, chunk_entities_get_by_dynamic_weighted_keyword: {len(chunk_entities_get_by_dynamic_weighted_keyword)}, chunk_entities_get_by_vector: {len(chunk_entities_get_by_vector)}") + chunk_entities = chunk_entities_get_by_keyword + \ + chunk_entities_get_by_dynamic_weighted_keyword + chunk_entities_get_by_vector except Exception as e: err = f"[DynamicKeywordVectorSearcher] 关键词向量检索失败,error: {e}" logging.exception(err) diff --git a/data_chain/rag/dynamic_weighted_keyword_searcher.py b/data_chain/rag/dynamic_weighted_keyword_searcher.py index 53cbb9625d2c657b3abb284deaf8dc339f8e3235..7860ae36005a5a1759bb4a23952a96fd995cc2f9 100644 --- a/data_chain/rag/dynamic_weighted_keyword_searcher.py +++ b/data_chain/rag/dynamic_weighted_keyword_searcher.py @@ -28,8 +28,7 @@ class DynamicWeightKeyWordSearcher(BaseSearcher): :return: 检索结果 """ try: - keywords, weights = TokenTool.get_top_k_keywords_and_weights(query) - chunk_entities = await ChunkManager.get_top_k_chunk_by_kb_id_dynamic_weighted_keyword(kb_id, keywords, weights, top_k, doc_ids, banned_ids) + chunk_entities = await ChunkManager.get_top_k_chunk_by_kb_id_dynamic_weighted_keyword(kb_id, query, top_k, doc_ids, banned_ids) except Exception as e: err = f"[KeywordVectorSearcher] 关键词向量检索失败,error: {e}" logging.exception(err) diff --git a/data_chain/rag/enhanced_by_llm_searcher.py b/data_chain/rag/enhanced_by_llm_searcher.py index 738eaac974ef3121aff613532edec4bff75afe8b..b27cf7294ec2928b593e8bcc45f124551f60c65c 100644 --- a/data_chain/rag/enhanced_by_llm_searcher.py +++ b/data_chain/rag/enhanced_by_llm_searcher.py @@ -49,11 +49,10 @@ class EnhancedByLLMSearcher(BaseSearcher): model_name=config['MODEL_NAME'], max_tokens=config['MAX_TOKENS'], ) - keywords, weights = TokenTool.get_top_k_keywords_and_weights(query) while len(chunk_entities) < top_k and rd < max_retry: rd += 1 sub_chunk_entities_keyword = await ChunkManager.get_top_k_chunk_by_kb_id_dynamic_weighted_keyword( - kb_id, keywords, weights, top_k, doc_ids, banned_ids) + kb_id, query, top_k, doc_ids, banned_ids) chunk_ids = [chunk_entity.id for chunk_entity in sub_chunk_entities_keyword] banned_ids += chunk_ids sub_chunk_entities_vector = [] diff --git a/data_chain/stores/database/database.py b/data_chain/stores/database/database.py index 7d159913ee4c3a6ffcbea2845aca8d31a1ab8905..429a862e80079c4b224142c40953476a1e0212ab 100644 --- a/data_chain/stores/database/database.py +++ b/data_chain/stores/database/database.py @@ -7,7 +7,7 @@ from uuid import uuid4 import urllib.parse from data_chain.logger.logger import logger as logging from pgvector.sqlalchemy import Vector -from sqlalchemy import Boolean, Column, ForeignKey, BigInteger, Float, String, func +from sqlalchemy import Boolean, Column, ForeignKey, BigInteger, Float, Text, func from sqlalchemy.types import TIMESTAMP, UUID from sqlalchemy.dialects.postgresql import TSVECTOR from sqlalchemy.orm import declarative_base, DeclarativeBase, MappedAsDataclass, Mapped, mapped_column @@ -45,13 +45,13 @@ class TeamEntity(Base): __tablename__ = 'team' id = Column(UUID, default=uuid4, primary_key=True) - author_id = Column(String) - author_name = Column(String) - name = Column(String) - description = Column(String) + author_id = Column(Text) + author_name = Column(Text) + name = Column(Text) + description = Column(Text) member_cnt = Column(BigInteger, default=0) is_public = Column(Boolean, default=True) - status = Column(String, default=TeamStatus.EXISTED.value) + status = Column(Text, default=TeamStatus.EXISTED.value) created_time = Column( TIMESTAMP(timezone=True), nullable=True, @@ -62,19 +62,25 @@ class TeamEntity(Base): server_default=func.current_timestamp(), onupdate=func.current_timestamp() ) + # 添加索引 + __table_args__ = ( + Index('team_id_index', id), + Index('team_name_index', name), + Index('team_author_id_index', author_id) + ) class TeamMessageEntity(Base): __tablename__ = 'team_message' id = Column(UUID, default=uuid4, primary_key=True) - team_id = Column(UUID, ForeignKey('team.id')) - author_id = Column(String) - author_name = Column(String) - message_level = Column(String) - zh_message = Column(String, default='') - en_message = Column(String, default='') - status = Column(String, default=TeamMessageStatus.EXISTED.value) + team_id = Column(UUID) + author_id = Column(Text) + author_name = Column(Text) + message_level = Column(Text) + zh_message = Column(Text, default='') + en_message = Column(Text, default='') + status = Column(Text, default=TeamMessageStatus.EXISTED.value) created_time = Column( TIMESTAMP(timezone=True), nullable=True, @@ -86,16 +92,23 @@ class TeamMessageEntity(Base): onupdate=func.current_timestamp() ) + # 添加索引 + __table_args__ = ( + Index('team_message_id_index', id), + Index('team_message_team_id_index', team_id), + Index('team_message_author_id_index', author_id) + ) + class RoleEntity(Base): __tablename__ = 'role' id = Column(UUID, default=uuid4, primary_key=True) - team_id = Column(UUID, ForeignKey('team.id')) - name = Column(String) + team_id = Column(UUID) + name = Column(Text) is_unique = Column(Boolean, default=False) editable = Column(Boolean, default=True) - status = Column(String, default=RoleStatus.EXISTED.value) # 角色状态 + status = Column(Text, default=RoleStatus.EXISTED.value) # 角色状态 created_time = Column( TIMESTAMP(timezone=True), nullable=True, @@ -106,14 +119,20 @@ class RoleEntity(Base): server_default=func.current_timestamp(), onupdate=func.current_timestamp() ) + # 添加索引 + __table_args__ = ( + Index('role_id_index', id), + Index('role_team_id_index', team_id), + Index('role_name_index', name) + ) class ActionEntity(Base): __tablename__ = 'action' - action = Column(String, primary_key=True) - name = Column(String) - type = Column(String) + action = Column(Text, primary_key=True) + name = Column(Text) + type = Column(Text) created_time = Column( TIMESTAMP(timezone=True), nullable=True, @@ -130,9 +149,9 @@ class RoleActionEntity(Base): __tablename__ = 'role_action' id = Column(UUID, default=uuid4, primary_key=True) - role_id = Column(UUID, ForeignKey('role.id', ondelete="CASCADE")) - action = Column(String) - status = Column(String, default=RoleActionStatus.EXISTED.value) + role_id = Column(UUID) + action = Column(Text) + status = Column(Text, default=RoleActionStatus.EXISTED.value) created_time = Column( TIMESTAMP(timezone=True), nullable=True, @@ -144,13 +163,20 @@ class RoleActionEntity(Base): onupdate=func.current_timestamp() ) + # 添加索引 + __table_args__ = ( + Index('role_action_id_index', id), + Index('role_action_role_id_index', role_id), + Index('role_action_action_index', action) + ) + class UserEntity(Base): __tablename__ = 'users' - id = Column(String, primary_key=True) - name = Column(String) - status = Column(String, default=UserStatus.ACTIVE.value) + id = Column(Text, primary_key=True) + name = Column(Text) + status = Column(Text, default=UserStatus.ACTIVE.value) created_time = Column( TIMESTAMP(timezone=True), nullable=True, @@ -162,37 +188,50 @@ class UserEntity(Base): onupdate=func.current_timestamp() ) + # 添加索引 + __table_args__ = ( + Index('user_id_index', id), + Index('user_name_index', name) + ) + class UserMessageEntity(Base): __tablename__ = 'user_message' id = Column(UUID, default=uuid4, primary_key=True) team_id = Column(UUID) - team_name = Column(String) + team_name = Column(Text) role_id = Column(UUID) - sender_id = Column(String) - sender_name = Column(String) - status_to_sender = Column(String, default=UserMessageStatus.UNREAD.value) - receiver_id = Column(String) - receiver_name = Column(String) + sender_id = Column(Text) + sender_name = Column(Text) + status_to_sender = Column(Text, default=UserMessageStatus.UNREAD.value) + receiver_id = Column(Text) + receiver_name = Column(Text) is_to_all = Column(Boolean, default=False) - status_to_receiver = Column(String, default=UserMessageStatus.UNREAD.value) - message = Column(String) - type = Column(String) + status_to_receiver = Column(Text, default=UserMessageStatus.UNREAD.value) + message = Column(Text) + type = Column(Text) created_time = Column( TIMESTAMP(timezone=True), nullable=True, server_default=func.current_timestamp() ) + # 添加索引 + __table_args__ = ( + Index('user_message_id_index', id), + Index('user_message_sender_id_index', sender_id), + Index('user_message_receiver_id_index', receiver_id) + ) + class TeamUserEntity(Base): __tablename__ = 'team_user' id = Column(UUID, default=uuid4, primary_key=True) - team_id = Column(UUID, ForeignKey('team.id', ondelete="CASCADE")) # 团队id - user_id = Column(String) # 用户id - status = Column(String, default=TeamUserStaus.EXISTED.value) # 用户在团队中的状态 + team_id = Column(UUID) # 团队id + user_id = Column(Text) # 用户id + status = Column(Text, default=TeamUserStaus.EXISTED.value) # 用户在团队中的状态 created_time = Column( TIMESTAMP(timezone=True), nullable=True, @@ -204,14 +243,21 @@ class TeamUserEntity(Base): onupdate=func.current_timestamp() ) + # 添加索引 + __table_args__ = ( + Index('team_user_id_index', id), + Index('team_user_team_id_index', team_id), + Index('team_user_user_id_index', user_id) + ) + class UserRoleEntity(Base): __tablename__ = 'user_role' id = Column(UUID, default=uuid4, primary_key=True) - team_id = Column(UUID, ForeignKey('team.id', ondelete="CASCADE")) # 团队id - user_id = Column(String) # 用户id + team_id = Column(UUID) # 团队id + user_id = Column(Text) # 用户id role_id = Column(UUID) # 角色id - status = Column(String, default=UserRoleStatus.EXISTED.value) # 用户角色状态 + status = Column(Text, default=UserRoleStatus.EXISTED.value) # 用户角色状态 created_time = Column( TIMESTAMP(timezone=True), nullable=True, @@ -223,30 +269,36 @@ class UserRoleEntity(Base): onupdate=func.current_timestamp() ) + # 添加索引 + __table_args__ = ( + Index('user_role_id_index', id), + Index('user_role_team_id_index', team_id), + Index('user_role_user_id_index', user_id) + ) + class KnowledgeBaseEntity(Base): __tablename__ = 'knowledge_base' id = Column(UUID, default=uuid4, primary_key=True) - team_id = Column(UUID, ForeignKey( - 'team.id', ondelete="CASCADE"), nullable=True) # 团队id - author_id = Column(String) # 作者id - author_name = Column(String) # 作者名称 - name = Column(String, default='') # 知识库名资产名 - tokenizer = Column(String, default=Tokenizer.ZH.value) # 分词器 - description = Column(String, default='') # 资产描述 - embedding_model = Column(String) # 资产向量化模型 - rerank_method = Column(String) - rerank_name = Column(String) - spearating_characters = Column(String) # 资产分块的分隔符 + team_id = Column(UUID, nullable=True) # 团队id + author_id = Column(Text) # 作者id + author_name = Column(Text) # 作者名称 + name = Column(Text, default='') # 知识库名资产名 + tokenizer = Column(Text, default=Tokenizer.ZH.value) # 分词器 + description = Column(Text, default='') # 资产描述 + embedding_model = Column(Text) # 资产向量化模型 + rerank_method = Column(Text) + rerank_name = Column(Text) + spearating_characters = Column(Text) # 资产分块的分隔符 doc_cnt = Column(BigInteger, default=0) # 资产文档个数 doc_size = Column(BigInteger, default=0) # 资产下所有文档大小(TODO: 单位kb或者字节) upload_count_limit = Column(BigInteger, default=128) # 更新次数限制 upload_size_limit = Column(BigInteger, default=512) # 更新大小限制 default_parse_method = Column( - String, default=ParseMethod.GENERAL.value) # 默认解析方法 + Text, default=ParseMethod.GENERAL.value) # 默认解析方法 default_chunk_size = Column(BigInteger, default=1024) # 默认分块大小 - status = Column(String, default=KnowledgeBaseStatus.IDLE.value) + status = Column(Text, default=KnowledgeBaseStatus.IDLE.value) created_time = Column( TIMESTAMP(timezone=True), nullable=True, @@ -258,14 +310,20 @@ class KnowledgeBaseEntity(Base): onupdate=func.current_timestamp() ) + # 添加索引 + __table_args__ = ( + Index('knowledge_base_id_index', id), + Index('knowledge_base_team_id_index', team_id), + Index('knowledge_base_name_index', name) + ) + class DocumentTypeEntity(Base): __tablename__ = 'document_type' id = Column(UUID, default=uuid4, primary_key=True) - kb_id = Column(UUID, ForeignKey('knowledge_base.id', - ondelete="CASCADE"), nullable=True) - name = Column(String) + kb_id = Column(UUID, nullable=True) + name = Column(Text) created_time = Column( TIMESTAMP(timezone=True), nullable=True, @@ -277,28 +335,34 @@ class DocumentTypeEntity(Base): onupdate=func.current_timestamp() ) + # 添加索引 + __table_args__ = ( + Index('document_type_id_index', id), + Index('document_type_kb_id_index', kb_id), + Index('document_type_name_index', name) + ) + class DocumentEntity(Base): __tablename__ = 'document' id = Column(UUID, default=uuid4, primary_key=True) team_id = Column(UUID) # 文档所属团队id - kb_id = Column(UUID, ForeignKey( - 'knowledge_base.id', ondelete="CASCADE")) # 文档所属资产id - author_id = Column(String) # 文档作者id - author_name = Column(String) # 文档作者名称 - name = Column(String) # 文档名 - extension = Column(String) # 文件后缀 + kb_id = Column(UUID) # 文档所属资产id + author_id = Column(Text) # 文档作者id + author_name = Column(Text) # 文档作者名称 + name = Column(Text) # 文档名 + extension = Column(Text) # 文件后缀 size = Column(BigInteger) # 文档大小 - parse_method = Column(String, default=ParseMethod.GENERAL.value) # 文档解析方法 + parse_method = Column(Text, default=ParseMethod.GENERAL.value) # 文档解析方法 parse_relut_topology = Column( - String, default=DocParseRelutTopology.LIST.value) # 文档解析结果拓扑结构 + Text, default=DocParseRelutTopology.LIST.value) # 文档解析结果拓扑结构 chunk_size = Column(BigInteger) # 文档分块大小 type_id = Column(UUID) # 文档类别 enabled = Column(Boolean) # 文档是否启用 - status = Column(String, default=DocumentStatus.IDLE.value) # 文档状态 - full_text = Column(String) # 文档全文 - abstract = Column(String) # 文档摘要 + status = Column(Text, default=DocumentStatus.IDLE.value) # 文档状态 + full_text = Column(Text) # 文档全文 + abstract = Column(Text) # 文档摘要 abstract_ts_vector = Column(TSVECTOR) # 文档摘要词向量 abstract_vector = Column(Vector(1024)) # 文档摘要向量 created_time = Column( @@ -312,15 +376,22 @@ class DocumentEntity(Base): onupdate=func.current_timestamp() ) __table_args__ = ( + Index("document_id_index", id), + Index("document_team_id_index", team_id), + Index("document_kb_id_index", kb_id), + Index("document_author_id_index", author_id), + Index("document_author_name_index", author_name), + Index("document_name_index", name), Index('abstract_ts_vector_index', abstract_ts_vector, postgresql_using='gin'), Index( 'abstract_vector_index', abstract_vector, postgresql_using='hnsw', - postgresql_with={'m': 16, 'ef_construction': 200}, + postgresql_with={'m': 32, 'ef_construction': 200}, postgresql_ops={'abstract_vector': 'vector_cosine_ops'} ), + Index('abstract_bm25_index', abstract, postgresql_using='bm25') ) @@ -330,23 +401,22 @@ class ChunkEntity(Base): id = Column(UUID, default=uuid4, primary_key=True) # chunk id team_id = Column(UUID) # 团队id kb_id = Column(UUID) # 知识库id - doc_id = Column(UUID, ForeignKey( - 'document.id', ondelete="CASCADE")) # 片段所属文档id - doc_name = Column(String) # 片段所属文档名称 - text = Column(String) # 片段文本内容 + doc_id = Column(UUID) # 片段所属文档id + doc_name = Column(Text) # 片段所属文档名称 + text = Column(Text) # 片段文本内容 text_ts_vector = Column(TSVECTOR) # 片段文本词向量 text_vector = Column(Vector(1024)) # 文本向量 tokens = Column(BigInteger) # 片段文本token数 - type = Column(String, default=ChunkType.TEXT.value) # 片段类型 + type = Column(Text, default=ChunkType.TEXT.value) # 片段类型 # 前一个chunk的id(假如解析结果为链表,那么这里是前一个节点的id,如果文档解析结果为树,那么这里是父节点的id) pre_id_in_parse_topology = Column(UUID) # chunk的在解析结果中的拓扑类型(假如解析结果为链表,那么这里为链表头、中间和尾;假如解析结果为树,那么这里为树根、树的中间节点和叶子节点) parse_topology_type = Column( - String, default=ChunkParseTopology.LISTHEAD.value) + Text, default=ChunkParseTopology.LISTHEAD.value) global_offset = Column(BigInteger) # chunk在文档中的相对偏移 local_offset = Column(BigInteger) # chunk在块中的相对偏移 enabled = Column(Boolean) # chunk是否启用 - status = Column(String, default=ChunkStatus.EXISTED.value) # chunk状态 + status = Column(Text, default=ChunkStatus.EXISTED.value) # chunk状态 created_time = Column( TIMESTAMP(timezone=True), nullable=True, @@ -357,14 +427,19 @@ class ChunkEntity(Base): server_default=func.current_timestamp(), onupdate=func.current_timestamp()) __table_args__ = ( + Index("chunk_id_index", id), + Index("chunk_team_id_index", team_id), + Index("chunk_kb_id_index", kb_id), + Index("chunk_doc_id_index", doc_id), Index('text_ts_vector_index', text_ts_vector, postgresql_using='gin'), Index( 'text_vector_index', text_vector, postgresql_using='hnsw', - postgresql_with={'m': 16, 'ef_construction': 200}, + postgresql_with={'m': 32, 'ef_construction': 200}, postgresql_ops={'text_vector': 'vector_cosine_ops'} ), + Index('text_bm25_index', text, postgresql_using='bm25') ) @@ -374,8 +449,8 @@ class ImageEntity(Base): team_id = Column(UUID) # 团队id doc_id = Column(UUID) # 图片所属文档id chunk_id = Column(UUID) # 图片所属chunk的id - extension = Column(String) # 图片后缀 - status = Column(String, default=ImageStatus.EXISTED.value) # 图片状态 + extension = Column(Text) # 图片后缀 + status = Column(Text, default=ImageStatus.EXISTED.value) # 图片状态 created_time = Column( TIMESTAMP(timezone=True), nullable=True, @@ -387,24 +462,31 @@ class ImageEntity(Base): onupdate=func.current_timestamp() ) + # 添加索引 + __table_args__ = ( + Index('image_id_index', id), + Index('image_team_id_index', team_id), + Index('image_doc_id_index', doc_id), + Index('image_chunk_id_index', chunk_id) + ) + class DataSetEntity(Base): __tablename__ = 'dataset' id = Column(UUID, default=uuid4, primary_key=True) # 数据集id team_id = Column(UUID) # 数据集所属团队id - kb_id = Column(UUID, ForeignKey('knowledge_base.id', - ondelete="CASCADE")) # 数据集所属资产id - author_id = Column(String) # 数据的创建者id - author_name = Column(String) # 数据的创建者名称 - llm_id = Column(String) # 数据的生成使用的大模型的id - name = Column(String, nullable=False) # 数据集名称 - description = Column(String) # 数据集描述 + kb_id = Column(UUID) # 数据集所属资产id + author_id = Column(Text) # 数据的创建者id + author_name = Column(Text) # 数据的创建者名称 + llm_id = Column(Text) # 数据的生成使用的大模型的id + name = Column(Text, nullable=False) # 数据集名称 + description = Column(Text) # 数据集描述 data_cnt = Column(BigInteger) # 数据集数据量 is_data_cleared = Column(Boolean, default=False) # 数据集是否清洗 is_chunk_related = Column(Boolean, default=False) # 数据集是否关联上下文 is_imported = Column(Boolean, default=False) # 数据集是否导入 - status = Column(String, default=DataSetStatus.IDLE) # 数据集状态 + status = Column(Text, default=DataSetStatus.IDLE) # 数据集状态 score = Column(Float, default=-1) # 数据集得分 created_at = Column( TIMESTAMP(timezone=True), @@ -417,13 +499,20 @@ class DataSetEntity(Base): onupdate=func.current_timestamp() ) + # 添加索引 + __table_args__ = ( + Index('dataset_id_index', id), + Index('dataset_team_id_index', team_id), + Index('dataset_kb_id_index', kb_id), + Index('dataset_name_index', name) + ) + class DataSetDocEntity(Base): __tablename__ = 'dataset_doc' id = Column(UUID, default=uuid4, primary_key=True) # 数据集文档id - dataset_id = Column(UUID, ForeignKey( - 'dataset.id', ondelete="CASCADE")) # 数据集id + dataset_id = Column(UUID) # 数据集id doc_id = Column(UUID) # 文档id created_at = Column( TIMESTAMP(timezone=True), @@ -436,20 +525,26 @@ class DataSetDocEntity(Base): onupdate=func.current_timestamp() ) + # 添加索引 + __table_args__ = ( + Index('dataset_doc_id_index', id), + Index('dataset_doc_dataset_id_index', dataset_id), + Index('dataset_doc_doc_id_index', doc_id) + ) + class QAEntity(Base): __tablename__ = 'qa' id = Column(UUID, default=uuid4, primary_key=True) # 数据id - dataset_id = Column(UUID, ForeignKey( - 'dataset.id', ondelete="CASCADE")) # 数据所属数据集id + dataset_id = Column(UUID) # 数据所属数据集id doc_id = Column(UUID) # 数据关联的文档id - doc_name = Column(String, default="未知文档") # 数据关联的文档名称 - question = Column(String) # 数据的问题 - answer = Column(String) # 数据的答案 - chunk = Column(String) # 数据的片段 - chunk_type = Column(String, default="未知片段类型") # 数据的片段类型 - status = Column(String, default=QAStatus.EXISTED.value) # 数据的状态 + doc_name = Column(Text, default="未知文档") # 数据关联的文档名称 + question = Column(Text) # 数据的问题 + answer = Column(Text) # 数据的答案 + chunk = Column(Text) # 数据的片段 + chunk_type = Column(Text, default="未知片段类型") # 数据的片段类型 + status = Column(Text, default=QAStatus.EXISTED.value) # 数据的状态 created_at = Column( TIMESTAMP(timezone=True), nullable=True, @@ -460,6 +555,12 @@ class QAEntity(Base): server_default=func.current_timestamp(), onupdate=func.current_timestamp() ) + # 添加索引 + __table_args__ = ( + Index('qa_id_index', id), + Index('qa_dataset_id_index', dataset_id), + Index('qa_doc_id_index', doc_id) + ) class TestingEntity(Base): @@ -468,17 +569,16 @@ class TestingEntity(Base): id = Column(UUID, default=uuid4, primary_key=True) # 测试任务的id team_id = Column(UUID) # 测试任务所属团队id kb_id = Column(UUID) # 测试任务所属资产id - dataset_id = Column(UUID, ForeignKey( - 'dataset.id', ondelete="CASCADE")) # 测试任务使用数据集的id - author_id = Column(String) # 测试任务的创建者id - author_name = Column(String) # 测试任务的创建者名称 - name = Column(String) # 测试任务的名称 - description = Column(String) # 测试任务的描述 - llm_id = Column(String) # 测试任务的使用的大模型 + dataset_id = Column(UUID) # 测试任务使用数据集的id + author_id = Column(Text) # 测试任务的创建者id + author_name = Column(Text) # 测试任务的创建者名称 + name = Column(Text) # 测试任务的名称 + description = Column(Text) # 测试任务的描述 + llm_id = Column(Text) # 测试任务的使用的大模型 search_method = Column( - String, default=SearchMethod.KEYWORD_AND_VECTOR.value) # 测试任务的使用的检索增强模式类型 + Text, default=SearchMethod.KEYWORD_AND_VECTOR.value) # 测试任务的使用的检索增强模式类型 top_k = Column(BigInteger, default=5) # 测试任务的检索增强模式的top_k - status = Column(String, default=TestingStatus.IDLE.value) # 测试任务的状态 + status = Column(Text, default=TestingStatus.IDLE.value) # 测试任务的状态 ave_score = Column(Float, default=-1) # 测试任务的综合得分 ave_pre = Column(Float, default=-1) # 测试任务的平均召回率 ave_rec = Column(Float, default=-1) # 测试任务的平均精确率 @@ -498,19 +598,26 @@ class TestingEntity(Base): onupdate=func.current_timestamp() ) + # 添加索引 + __table_args__ = ( + Index('testing_id_index', id), + Index('testing_team_id_index', team_id), + Index('testing_kb_id_index', kb_id), + Index('testing_dataset_id_index', dataset_id) + ) + class TestCaseEntity(Base): __tablename__ = 'testcase' id = Column(UUID, default=uuid4, primary_key=True) # 测试case的id - testing_id = Column(UUID, ForeignKey( - 'testing.id', ondelete="CASCADE")) # 测试 - question = Column(String) # 数据的问题 - answer = Column(String) # 数据的答案 - chunk = Column(String) # 数据的片段 - llm_answer = Column(String) # 测试答案 - related_chunk = Column(String) # 测试关联到的chunk - doc_name = Column(String) # 测试关联的文档名称 + testing_id = Column(UUID) # 测试 + question = Column(Text) # 数据的问题 + answer = Column(Text) # 数据的答案 + chunk = Column(Text) # 数据的片段 + llm_answer = Column(Text) # 测试答案 + related_chunk = Column(Text) # 测试关联到的chunk + doc_name = Column(Text) # 测试关联的文档名称 score = Column(Float) # 测试得分 pre = Column(Float) # 召回率 rec = Column(Float) # 精确率 @@ -519,7 +626,7 @@ class TestCaseEntity(Base): lcs = Column(Float) # 最长公共子序列得分 leve = Column(Float) # 编辑距离得分 jac = Column(Float) # 杰卡德相似系数 - status = Column(String, default=TestCaseStatus.EXISTED.value) # 测试状态 + status = Column(Text, default=TestCaseStatus.EXISTED.value) # 测试状态 created_at = Column( TIMESTAMP(timezone=True), nullable=True, @@ -531,19 +638,24 @@ class TestCaseEntity(Base): onupdate=func.current_timestamp() ) + # 添加索引 + __table_args__ = ( + Index('testcase_id_index', id), + Index('testcase_testing_id_index', testing_id) + ) + class TaskEntity(Base): __tablename__ = 'task' id = Column(UUID, default=uuid4, primary_key=True) team_id = Column(UUID) # 团队id - user_id = Column(String, ForeignKey( - 'users.id', ondelete="CASCADE")) # 创建者id + user_id = Column(Text) # 创建者id op_id = Column(UUID) # 任务关联的实体id, 资产或者文档id - op_name = Column(String) # 任务关联的实体名称 - type = Column(String) # 任务类型 + op_name = Column(Text) # 任务关联的实体名称 + type = Column(Text) # 任务类型 retry = Column(BigInteger) # 重试次数 - status = Column(String) # 任务状态 + status = Column(Text) # 任务状态 created_time = Column( TIMESTAMP(timezone=True), nullable=True, @@ -555,13 +667,23 @@ class TaskEntity(Base): onupdate=func.current_timestamp() ) + # 添加索引 + __table_args__ = ( + Index('task_id_index', id), + Index('task_team_id_index', team_id), + Index('task_user_id_index', user_id), + Index('task_op_id_index', op_id), + Index('task_type_index', type), + Index('task_status_index', status) + ) + class TaskReportEntity(Base): __tablename__ = 'task_report' id = Column(UUID, default=uuid4, primary_key=True) # 任务报告的id - task_id = Column(UUID, ForeignKey('task.id', ondelete="CASCADE")) # 任务id - message = Column(String) # 任务报告信息 + task_id = Column(UUID) # 任务id + message = Column(Text) # 任务报告信息 current_stage = Column(BigInteger) # 任务当前阶段 stage_cnt = Column(BigInteger) # 任务总的阶段 created_time = Column( @@ -575,12 +697,18 @@ class TaskReportEntity(Base): onupdate=func.current_timestamp() ) + # 添加索引 + __table_args__ = ( + Index('task_report_id_index', id), + Index('task_report_task_id_index', task_id) + ) + class TaskQueueEntity(Base): __tablename__ = 'task_queue' id = Column(UUID, default=uuid4, primary_key=True) # 任务ID - status = Column(String) # 任务状态 + status = Column(Text) # 任务状态 created_time = Column( TIMESTAMP(timezone=True), nullable=True, @@ -607,6 +735,7 @@ class DataBase: pool_size = os.cpu_count() if pool_size is None: pool_size = 5 + logging.error(f"Database pool size set to: {pool_size}") engine = create_async_engine( database_url, echo=False,