diff --git a/chat2db/common/.env.example b/chat2db/common/.env.example index 905a126a13a71cddc370417e88fd38159780d30a..999e50afd1abbb79ed28eaf009874e397e337d94 100644 --- a/chat2db/common/.env.example +++ b/chat2db/common/.env.example @@ -1,9 +1,9 @@ # FastAPI -UVICORN_IP = -UVICORN_PORT = -SSL_CERTFILE = -SSL_KEYFILE = -SSL_ENABLE = +UVICORN_IP = 0.0.0.0 +UVICORN_PORT = 9015 +# SSL_CERTFILE = +# SSL_KEYFILE = +# SSL_ENABLE = # Postgres DATABASE_TYPE = @@ -26,6 +26,6 @@ EMBEDDING_ENDPOINT = EMBEDDING_MODEL_NAME = # security -HALF_KEY1 = -HALF_KEY2 = -HALF_KEY3 = \ No newline at end of file +HALF_KEY1 = R4UsZgLB +HALF_KEY2 = zRTvYV8N +HALF_KEY3 = 4eQ1wAGA \ No newline at end of file diff --git a/data_chain/apps/base/convertor.py b/data_chain/apps/base/convertor.py index 09cbca0066a803de035f57254e11c015af6ee641..0b9692628000843fb6736286cb49fdee29d979e3 100644 --- a/data_chain/apps/base/convertor.py +++ b/data_chain/apps/base/convertor.py @@ -361,7 +361,8 @@ class Convertor: finished_time = None if task_report is not None: task_completed = task_report.current_stage/task_report.stage_cnt*100 - finished_time = task_report.created_time.strftime('%Y-%m-%d %H:%M') + if task_entity.status == TaskStatus.SUCCESS.value: + finished_time = task_report.created_time.strftime('%Y-%m-%d %H:%M') task = Task( opId=task_entity.op_id, opName=task_entity.op_name, diff --git a/data_chain/apps/base/task/process_handler.py b/data_chain/apps/base/task/process_handler.py index 97a9143df692ce4a48bed3af79f76f7ad6e41224..3492bd41282b0facf5c1ee753ae98ece68e0af54 100644 --- a/data_chain/apps/base/task/process_handler.py +++ b/data_chain/apps/base/task/process_handler.py @@ -16,7 +16,7 @@ class ProcessHandler: lock = multiprocessing.Lock() # 创建一个锁对象 max_processes = min( max((os.cpu_count() or 1) // 2, 1), - config['DOCUMENT_PARSE_USE_CPU_LIMIT']) # 获取CPU核心数作为最大进程数,默认为1 + config['USE_CPU_LIMIT']) # 获取CPU核心数作为最大进程数,默认为1 time_out = 10 @staticmethod diff --git a/data_chain/apps/base/task/worker/acc_testing_worker.py b/data_chain/apps/base/task/worker/acc_testing_worker.py index 5f3e219416a18ebd3444ead80f9b4551f8d1d35f..45ae92740338379506641f2db485b2834f1aa196 100644 --- a/data_chain/apps/base/task/worker/acc_testing_worker.py +++ b/data_chain/apps/base/task/worker/acc_testing_worker.py @@ -142,7 +142,7 @@ class TestingWorker(BaseWorker): question = qa_entity.question answer = qa_entity.answer chunk = qa_entity.chunk - chunk_entities = await BaseSearcher.search(testing_entity.search_method, testing_entity.kb_id, question, top_k=2*testing_entity.top_k, doc_ids=None, banned_ids=[]) + chunk_entities = await BaseSearcher.search(testing_entity.search_method, testing_entity.kb_id, question, top_k=testing_entity.top_k, doc_ids=None, banned_ids=[]) related_chunk_entities = [] banned_ids = [chunk_entity.id for chunk_entity in chunk_entities] divide_tokens = llm.max_tokens // len(chunk_entities) if chunk_entities else llm.max_tokens diff --git a/data_chain/apps/base/task/worker/parse_document_worker.py b/data_chain/apps/base/task/worker/parse_document_worker.py index b6373c0f7a8d8589d6f07217f3e2f4e72c161c46..989de81176fe39f352a95e0bd78b803de6297da4 100644 --- a/data_chain/apps/base/task/worker/parse_document_worker.py +++ b/data_chain/apps/base/task/worker/parse_document_worker.py @@ -402,11 +402,8 @@ class ParseDocumentWorker(BaseWorker): if llm is not None: abstract = await TokenTool.get_abstract_by_llm(abstract, llm) else: - sentences = TokenTool.get_top_k_keysentence(abstract, 1) - if sentences: - abstract = sentences[0] - else: - abstract = '' + keywords = TokenTool.get_top_k_keywords(abstract, 20) + abstract = ' '.join(keywords) abstract_vector = await Embedding.vectorize_embedding(abstract) await DocumentManager.update_document_by_doc_id( doc_id, diff --git a/data_chain/common/.env.example b/data_chain/common/.env.example index 359c85ca2277b1c6fc8a2a0b30db77c79aba9bba..063eddddebb8439b62ee5770471dadab15bd0564 100644 --- a/data_chain/common/.env.example +++ b/data_chain/common/.env.example @@ -2,13 +2,13 @@ # debug DEBUG = # FastAPI -UVICORN_IP = -UVICORN_PORT = -SSL_CERTFILE = -SSL_KEYFILE = -SSL_ENABLE = +UVICORN_IP = 0.0.0.0 +UVICORN_PORT = 9988 +# SSL_CERTFILE = +# SSL_KEYFILE = +# SSL_ENABLE = # LOG METHOD -LOG_METHOD = +LOG_METHOD = stout # Database DATABASE_TYPE = DATABASE_HOST = @@ -28,7 +28,7 @@ MONGODB_HOST = MONGODB_PORT = MONGODB_DATABASE = # Task -TASK_RETRY_TIME = +TASK_RETRY_TIME = 3 # LLM MODEL_NAME = OPENAI_API_BASE = @@ -45,14 +45,14 @@ EMBEDDING_MODEL_NAME = SESSION_TTL = CSRF_KEY = # Security -HALF_KEY1 = -HALF_KEY2 = -HALF_KEY3 = +HALF_KEY1 = 4QLg8bxe +HALF_KEY2 = Bm571Gcq +HALF_KEY3 = VpVF6Tuj # Prompt file -PROMPT_PATH = +PROMPT_PATH = ./data_chain/common/prompt.yaml # Stop Words PATH -STOP_WORDS_PATH = -# DOCUMENT PARSER -DOCUMENT_PARSE_USE_CPU_LIMIT = +STOP_WORDS_PATH = ./data_chain/common/stopwords.txt +# CPU Limit +USE_CPU_LIMIT = 64 # Task Retry Time limit -TASK_RETRY_TIME_LIMIT = +TASK_RETRY_TIME_LIMIT = 3 diff --git a/data_chain/common/prompt.yaml b/data_chain/common/prompt.yaml index 611949edaafa4c7f0043e7a11d7d4fe0491dfe8f..43e5ea321a563cfed7cf7c274d019d35abee8203 100644 --- a/data_chain/common/prompt.yaml +++ b/data_chain/common/prompt.yaml @@ -65,7 +65,10 @@ QA_TO_STATEMENTS_PROMPT: '你是一个文本分解专家,你的任务是根据 ANSWER_TO_ANSWER_PROMPT: '你是一个文本分析专家,你的任务对比两个文本之间的相似度,并输出一个0-100之间的分数且保留两位小数: 注意: #01 请根据文本在语义、语序和关键字上的相似度进行打分 -#02 请仅输出分数,不要输出其他内容 +#02 如果两个文本在核心表达上一致,那么分数也相对高 +#03 一个文本包含另一个文本的核心内容,那么分数也相对高 +#04 两个文本间内容有重合,那么按照重合内容的比例打分 +#05 请仅输出分数,不要输出其他内容 例子: 输入1: 文本1:openEuler是一个开源的操作系统,旨在为云计算和边缘计算提供支持。它具有高性能、高安全性和高可靠性等特点。 @@ -108,7 +111,8 @@ STATEMENTS_TO_FRAGMENT_PROMPT: '你是一个文本专家,你的任务是根据 注意: #01 如果陈诉与片段强相关或者来自于片段,请输出YES #02 如果陈诉中的内容与片段无关,请输出NO - #03 请仅输出YES或NO,不要输出其他内容 + #03 如果陈诉是片段中某部分的提炼,请输出YES + #05 请仅输出YES或NO,不要输出其他内容 例子: 输入1: @@ -130,7 +134,7 @@ STATEMENTS_TO_QUESTION_PROMPT: '你是一个文本分析专家,你的任务是 #01 如果陈诉是否与问题相关,请输出YES #02 如果陈诉与问题不相关,请输出NO #03 请仅输出YES或NO,不要输出其他内容 - #04 陈诉与问题相关是指,陈诉中的内容可以回答问题或者与问题在内容上有交际 + #04 陈诉与问题相关是指,陈诉中的内容可以回答问题或者与问题在内容上有交集 例子: 输入1: 陈诉:openEuler是一个开源的操作系统。 diff --git a/data_chain/common/stopwords.txt b/data_chain/common/stopwords.txt index 5784b4462a67442a7301abb939b8ca17fa791598..bfb5f302afa87935686501368c011a0a99de855e 100644 --- a/data_chain/common/stopwords.txt +++ b/data_chain/common/stopwords.txt @@ -1276,7 +1276,6 @@ indeed 第三句 更 看上去 -安全 零 也好 上去 @@ -3702,7 +3701,6 @@ sup 它们的 它是 它的 -安全 完全 完成 定 diff --git a/data_chain/config/config.py b/data_chain/config/config.py index 1eb3c9de8d1fded94185268feef462c57d5d8626..742e87d236d4652ca82c844954accd0e17e19c71 100644 --- a/data_chain/config/config.py +++ b/data_chain/config/config.py @@ -67,8 +67,8 @@ class ConfigModel(DictBaseModel): PROMPT_PATH: str = Field(None, description="prompt路径") # Stop Words PATH STOP_WORDS_PATH: str = Field(None, description="停用词表存放位置") - # DOCUMENT PARSER - DOCUMENT_PARSE_USE_CPU_LIMIT: int = Field(default=4, description="文档解析器使用CPU核数") + # CPU Limit + USE_CPU_LIMIT: int = Field(default=64, description="文档解析器使用CPU核数") # Task Retry Time limit TASK_RETRY_TIME_LIMIT: int = Field(default=3, description="任务重试次数限制") diff --git a/data_chain/entities/request_data.py b/data_chain/entities/request_data.py index 92195dd20f2cae37eb471bd1d1d8d7060959fd00..0c0b5825251e4c95f78879e3f9ee4d6c197f5fe1 100644 --- a/data_chain/entities/request_data.py +++ b/data_chain/entities/request_data.py @@ -221,7 +221,7 @@ class ListTestingRequest(BaseModel): testing_id: Optional[uuid.UUID] = Field(default=None, description="测试id", alias="testingId") testing_name: Optional[str] = Field(default=None, description="测试名称", alias="testingName") llm_ids: Optional[list[str]] = Field(default=None, description="测试使用的大模型id", alias="llmIds") - search_methods: Optional[List[SearchMethod]] = Field(default=None, description="测试使用的检索方法", alias="searchMethods") + search_method: Optional[List[SearchMethod]] = Field(default=None, description="测试使用的检索方法", alias="searchMethod") run_status: Optional[List[TaskStatus]] = Field(default=None, description="测试运行状态", alias="runStatus") scores_order: Optional[OrderType] = Field(default=OrderType.DESC, description="测试评分", alias="scoresOrder") author_name: Optional[str] = Field(default=None, description="测试创建者", alias="authorName") diff --git a/data_chain/manager/document_manager.py b/data_chain/manager/document_manager.py index 29b6cab863600646ad5bb83b8f6c3ef7c6184fe5..d10e3260f140774e18b8b287a74396ef8dd32ab7 100644 --- a/data_chain/manager/document_manager.py +++ b/data_chain/manager/document_manager.py @@ -1,5 +1,5 @@ # Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. -from sqlalchemy import select, delete, update, func, between, asc, desc, and_ +from sqlalchemy import select, delete, update, func, between, asc, desc, and_, Float, literal_column, text from datetime import datetime, timezone import uuid from typing import Dict, List, Tuple @@ -50,7 +50,7 @@ class DocumentManager(): """根据知识库ID和向量获取前K个文档""" try: async with await DataBase.get_session() as session: - similarity_score = DocumentEntity.abstract.cosine_distance(vector).label("similarity_score") + similarity_score = DocumentEntity.abstract_vector.cosine_distance(vector).label("similarity_score") stmt = ( select(DocumentEntity, similarity_score) .where(similarity_score > 0) @@ -58,7 +58,6 @@ class DocumentManager(): .where(DocumentEntity.id.notin_(banned_ids)) .where(DocumentEntity.status != DocumentStatus.DELETED.value) .where(DocumentEntity.enabled == True) - .where(DocumentEntity.abstract_vector.cosine_distance(vector).desc()) ) if doc_ids: stmt = stmt.where(DocumentEntity.id.in_(doc_ids)) @@ -116,6 +115,87 @@ class DocumentManager(): logging.exception("[DocumentManager] %s", err) raise e + async def get_top_k_document_by_kb_id_dynamic_weighted_keyword( + kb_id: uuid.UUID, keywords: List[str], weights: List[float], + top_k: int, doc_ids: list[uuid.UUID] = None, banned_ids: list[uuid.UUID] = []) -> List[DocumentEntity]: + """根据知识库ID和关键词和关键词权重查询文档解析结果""" + try: + async with await DataBase.get_session() as session: + kb_entity = await KnowledgeBaseManager.get_knowledge_base_by_kb_id(kb_id) + if kb_entity.tokenizer == Tokenizer.ZH.value: + if config['DATABASE_TYPE'].lower() == 'opengauss': + tokenizer = 'chparser' + else: + tokenizer = 'zhparser' + elif kb_entity.tokenizer == Tokenizer.EN.value: + tokenizer = 'english' + else: + if config['DATABASE_TYPE'].lower() == 'opengauss': + tokenizer = 'chparser' + else: + tokenizer = 'zhparser' + + # 构建VALUES子句的参数 + params = {} + values_clause = [] + + for i, (term, weight) in enumerate(zip(keywords, weights)): + # 使用单独的参数名,避免与类型转换冲突 + params[f"term_{i}"] = term + params[f"weight_{i}"] = weight + # 在VALUES子句中使用类型转换函数 + values_clause.append(f"(CAST(:term_{i} AS TEXT), CAST(:weight_{i} AS FLOAT8))") + + # 构建VALUES子句 + values_text = f"(VALUES {', '.join(values_clause)}) AS t(term, weight)" + + # 创建weighted_terms CTE + weighted_terms = ( + select( + literal_column("t.term").label("term"), + literal_column("t.weight").cast(Float).label("weight") + ) + .select_from(text(values_text)) + .cte("weighted_terms") + ) + + # 计算相似度得分 + similarity_score = func.sum( + func.ts_rank_cd( + func.to_tsvector(tokenizer, DocumentEntity.abstract), + func.to_tsquery(tokenizer, weighted_terms.c.term) + ) * weighted_terms.c.weight + ).label("similarity_score") + + stmt = ( + select(DocumentEntity, similarity_score) + .where(DocumentEntity.enabled == True) + .where(DocumentEntity.status != DocumentStatus.DELETED.value) + .where(DocumentEntity.kb_id == kb_id) + .where(DocumentEntity.id.notin_(banned_ids)) + ) + # 添加 GROUP BY 子句,按 ChunkEntity.id 分组 + stmt = stmt.group_by(DocumentEntity.id) + + stmt.having(similarity_score > 0) + + if doc_ids is not None: + stmt = stmt.where(DocumentEntity.id.in_(doc_ids)) + + # 按相似度分数排序 + stmt = stmt.order_by(similarity_score.desc()) + stmt = stmt.limit(top_k) + + # 执行最终查询 + result = await session.execute(stmt, params=params) + doc_entites = result.scalars().all() + + return doc_entites + except Exception as e: + err = f"根据知识库ID和关键字动态查询文档失败: {str(e)}" + logging.exception("[ChunkManager] %s", err) + return [] + @staticmethod async def get_doc_cnt_by_kb_id(kb_id: uuid.UUID) -> int: """根据知识库ID获取文档数量""" diff --git a/data_chain/manager/qa_manager.py b/data_chain/manager/qa_manager.py index 29dd224ca9dcb7e95c6e60defaeca52afa948ba8..215379f0834cf4400ef05e3d233d4af4516bf9bb 100644 --- a/data_chain/manager/qa_manager.py +++ b/data_chain/manager/qa_manager.py @@ -81,6 +81,7 @@ class QAManager: stmt = ( select(QAEntity) .where(QAEntity.dataset_id == dataset_id) + .where(QAEntity.status != QAStatus.DELETED.value) ) result = await session.execute(stmt) return result.scalars().all() diff --git a/data_chain/manager/testing_manager.py b/data_chain/manager/testing_manager.py index e2b4cb3f2464b1889eb441f3433dfad9f6126325..1fb5c171ea668e8ecdcb345697d76444d6f71179 100644 --- a/data_chain/manager/testing_manager.py +++ b/data_chain/manager/testing_manager.py @@ -109,9 +109,9 @@ class TestingManager(): inner_stmt = inner_stmt.where(TestingEntity.name.ilike(f"%{req.testing_name}%")) if req.llm_ids is not None: inner_stmt = inner_stmt.where(TestingEntity.llm_id.in_(req.llm_ids)) - if req.search_methods is not None: + if req.search_method is not None: inner_stmt = inner_stmt.where(TestingEntity.search_method.in_( - [search_method.value for search_method in req.search_methods])) + [search_method.value for search_method in req.search_method])) if req.run_status is not None: inner_stmt = inner_stmt.where(subq.c.status.in_([status.value for status in req.run_status])) if req.author_name is not None: @@ -165,9 +165,9 @@ class TestingManager(): stmt = stmt.where(TestingEntity.name.ilike(f"%{req.testing_name}%")) if req.llm_ids is not None: stmt = stmt.where(TestingEntity.llm_id.in_(req.llm_ids)) - if req.search_methods is not None: + if req.search_method is not None: stmt = stmt.where(TestingEntity.search_method.in_( - [search_method.value for search_method in req.search_methods])) + [search_method.value for search_method in req.search_method])) if req.run_status is not None: stmt = stmt.where(subq.c.status.in_([status.value for status in req.run_status])) if req.author_name is not None: diff --git a/data_chain/rag/doc2chunk_searcher.py b/data_chain/rag/doc2chunk_searcher.py index fc438a0b11b3632bd14a6d3b48b86d2031f60985..40510d513c9f2355a87a3ac3efe64f1fb008794c 100644 --- a/data_chain/rag/doc2chunk_searcher.py +++ b/data_chain/rag/doc2chunk_searcher.py @@ -31,30 +31,35 @@ class Doc2ChunkSearcher(BaseSearcher): """ vector = await Embedding.vectorize_embedding(query) try: - doc_entities_keyword = await DocumentManager.get_top_k_document_by_kb_id_keyword(kb_id, query, top_k, doc_ids, banned_ids) + keywords, weights = TokenTool.get_top_k_keywords_and_weights(query) + doc_entities_keyword = await DocumentManager.get_top_k_document_by_kb_id_dynamic_weighted_keyword(kb_id, keywords, weights, top_k//2, doc_ids, []) use_doc_ids = [doc_entity.id for doc_entity in doc_entities_keyword] doc_entities_vector = [] for _ in range(3): try: - doc_entities_vector = await asyncio.wait_for(DocumentManager.get_top_k_document_by_kb_id_vector(kb_id, vector, top_k-len(doc_entities_keyword), doc_ids, banned_ids), timeout=3) + doc_entities_vector = await asyncio.wait_for(DocumentManager.get_top_k_document_by_kb_id_vector(kb_id, vector, top_k-len(doc_entities_keyword), use_doc_ids, banned_ids), timeout=3) break except Exception as e: err = f"[KeywordVectorSearcher] 向量检索失败,error: {e}" logging.error(err) continue use_doc_ids += [doc_entity.id for doc_entity in doc_entities_vector] - chunk_entities_keyword = await ChunkManager.get_top_k_chunk_by_kb_id_keyword(kb_id, query, top_k//2, use_doc_ids, banned_ids) - chunk_ids = [chunk_entity.id for chunk_entity in chunk_entities_keyword] + chunk_entities_keyword = await ChunkManager.get_top_k_chunk_by_kb_id_keyword(kb_id, query, top_k//3, use_doc_ids, banned_ids) + banned_ids += [chunk_entity.id for chunk_entity in chunk_entities_keyword] + keywords, weights = TokenTool.get_top_k_keywords_and_weights(query) + logging.error(f"[KeywordVectorSearcher] keywords: {keywords}, weights: {weights}") + chunk_entities_get_by_dynamic_weighted_keyword = await ChunkManager.get_top_k_chunk_by_kb_id_dynamic_weighted_keyword(kb_id, keywords, weights, top_k//2, use_doc_ids, banned_ids) + banned_ids += [chunk_entity.id for chunk_entity in chunk_entities_get_by_dynamic_weighted_keyword] chunk_entities_vector = [] for _ in range(3): try: - chunk_entities_vector = await asyncio.wait_for(ChunkManager.get_top_k_chunk_by_kb_id_vector(kb_id, vector, top_k-len(chunk_entities_keyword), use_doc_ids, banned_ids+chunk_ids), timeout=3) + chunk_entities_vector = await asyncio.wait_for(ChunkManager.get_top_k_chunk_by_kb_id_vector(kb_id, vector, top_k-len(chunk_entities_keyword), use_doc_ids, banned_ids), timeout=3) break except Exception as e: err = f"[KeywordVectorSearcher] 向量检索失败,error: {e}" logging.error(err) continue - chunk_entities = chunk_entities_keyword + chunk_entities_vector + chunk_entities = chunk_entities_keyword + chunk_entities_get_by_dynamic_weighted_keyword + chunk_entities_vector except Exception as e: err = f"[KeywordVectorSearcher] 关键词向量检索失败,error: {e}" logging.exception(err) diff --git a/data_chain/rag/dynamic_weighted_keyword_searcher.py b/data_chain/rag/dynamic_weighted_keyword_searcher.py index 6a8ac64780eddcde3c0c0f81d9a6242a76532a91..53cbb9625d2c657b3abb284deaf8dc339f8e3235 100644 --- a/data_chain/rag/dynamic_weighted_keyword_searcher.py +++ b/data_chain/rag/dynamic_weighted_keyword_searcher.py @@ -28,8 +28,7 @@ class DynamicWeightKeyWordSearcher(BaseSearcher): :return: 检索结果 """ try: - query_filtered = TokenTool.filter_stopwords(query) - keywords, weights = TokenTool.get_top_k_keywords_and_weights(query_filtered) + keywords, weights = TokenTool.get_top_k_keywords_and_weights(query) chunk_entities = await ChunkManager.get_top_k_chunk_by_kb_id_dynamic_weighted_keyword(kb_id, keywords, weights, top_k, doc_ids, banned_ids) except Exception as e: err = f"[KeywordVectorSearcher] 关键词向量检索失败,error: {e}" diff --git a/data_chain/rag/enhanced_by_llm_searcher.py b/data_chain/rag/enhanced_by_llm_searcher.py index 9ac2dd4ad844d20267e29d770e29f094a8ff2e23..c2381eee49d92c639263f77b06660cb2b6124cf1 100644 --- a/data_chain/rag/enhanced_by_llm_searcher.py +++ b/data_chain/rag/enhanced_by_llm_searcher.py @@ -46,9 +46,11 @@ class EnhancedByLLMSearcher(BaseSearcher): model_name=config['MODEL_NAME'], max_tokens=config['MAX_TOKENS'], ) + keywords, weights = TokenTool.get_top_k_keywords_and_weights(query) while len(chunk_entities) < top_k and rd < max_retry: rd += 1 - sub_chunk_entities_keyword = await ChunkManager.get_top_k_chunk_by_kb_id_keyword(kb_id, query, top_k, doc_ids, banned_ids) + sub_chunk_entities_keyword = await ChunkManager.get_top_k_chunk_by_kb_id_dynamic_weighted_keyword( + kb_id, keywords, weights, top_k, doc_ids, banned_ids) chunk_ids = [chunk_entity.id for chunk_entity in sub_chunk_entities_keyword] banned_ids += chunk_ids sub_chunk_entities_vector = [] diff --git a/data_chain/rag/keyword_and_vector_searcher.py b/data_chain/rag/keyword_and_vector_searcher.py index 0237eaa61e30fb2168a19df12874e4ee090d0d77..7d535e252ecc64c0da712c7cda796e8da219f777 100644 --- a/data_chain/rag/keyword_and_vector_searcher.py +++ b/data_chain/rag/keyword_and_vector_searcher.py @@ -30,21 +30,23 @@ class KeywordVectorSearcher(BaseSearcher): """ vector = await Embedding.vectorize_embedding(query) try: - query_filtered = TokenTool.filter_stopwords(query) - keywords, weights = TokenTool.get_top_k_keywords_and_weights(query_filtered) + chunk_entities_get_by_keyword = await ChunkManager.get_top_k_chunk_by_kb_id_keyword( + kb_id, query, max(top_k//3, 1), doc_ids, banned_ids) + banned_ids += [chunk_entity.id for chunk_entity in chunk_entities_get_by_keyword] + keywords, weights = TokenTool.get_top_k_keywords_and_weights(query) logging.error(f"[KeywordVectorSearcher] keywords: {keywords}, weights: {weights}") - chunk_entities_get_by_keyword = await ChunkManager.get_top_k_chunk_by_kb_id_dynamic_weighted_keyword(kb_id, keywords, weights, top_k//2, doc_ids, banned_ids) - chunk_ids = [chunk_entity.id for chunk_entity in chunk_entities_get_by_keyword] + chunk_entities_get_by_dynamic_weighted_keyword = await ChunkManager.get_top_k_chunk_by_kb_id_dynamic_weighted_keyword(kb_id, keywords, weights, top_k//2, doc_ids, banned_ids) + banned_ids += [chunk_entity.id for chunk_entity in chunk_entities_get_by_dynamic_weighted_keyword] chunk_entities_get_by_vector = [] for _ in range(3): try: - chunk_entities_get_by_vector = await asyncio.wait_for(ChunkManager.get_top_k_chunk_by_kb_id_vector(kb_id, vector, top_k-len(chunk_entities_get_by_keyword), doc_ids, banned_ids+chunk_ids), timeout=3) + chunk_entities_get_by_vector = await asyncio.wait_for(ChunkManager.get_top_k_chunk_by_kb_id_vector(kb_id, vector, top_k-len(chunk_entities_get_by_keyword), doc_ids, banned_ids), timeout=3) break except Exception as e: err = f"[KeywordVectorSearcher] 向量检索失败,error: {e}" logging.error(err) continue - chunk_entities = chunk_entities_get_by_keyword + chunk_entities_get_by_vector + chunk_entities = chunk_entities_get_by_keyword + chunk_entities_get_by_dynamic_weighted_keyword + chunk_entities_get_by_vector for chunk_entity in chunk_entities: logging.error( f"[KeywordVectorSearcher] chunk_entity: {chunk_entity.id}, text: {chunk_entity.text[:100]}...")