diff --git a/data_chain/apps/app.py b/data_chain/apps/app.py index cb4bf7e5439c023c0f2c87d953d0cadc8c0b2689..30d7f7d7b97b0710e27ddba522efc3d1e8141235 100644 --- a/data_chain/apps/app.py +++ b/data_chain/apps/app.py @@ -53,6 +53,7 @@ from data_chain.rag import ( base_searcher, keyword_searcher, vector_searcher, + dynamic_weighted_keyword_and_vector_searcher, keyword_and_vector_searcher, doc2chunk_searcher, doc2chunk_bfs_searcher, diff --git a/data_chain/apps/base/task/worker/parse_document_worker.py b/data_chain/apps/base/task/worker/parse_document_worker.py index a1b568a504089ddeb0569215119c6a90de8590d5..50a5996477379dce6c034bae893cd6e3f04c088e 100644 --- a/data_chain/apps/base/task/worker/parse_document_worker.py +++ b/data_chain/apps/base/task/worker/parse_document_worker.py @@ -417,8 +417,7 @@ class ParseDocumentWorker(BaseWorker): if llm is not None: abstract = await TokenTool.get_abstract_by_llm(abstract, llm) else: - keywords = TokenTool.get_top_k_keywords(abstract, 20) - abstract = ' '.join(keywords) + abstract = abstract[:128] abstract_vector = await Embedding.vectorize_embedding(abstract) await DocumentManager.update_document_by_doc_id( doc_id, diff --git a/data_chain/apps/service/chunk_service.py b/data_chain/apps/service/chunk_service.py index e49634649844ebcd1681e0088b42e5e0056b521b..d1b706461a3e191cdfdb9b9403ef0fdc7ba5cecd 100644 --- a/data_chain/apps/service/chunk_service.py +++ b/data_chain/apps/service/chunk_service.py @@ -137,8 +137,14 @@ class ChunkService: chunk.text = TokenTool.compress_tokens(chunk.text) dc = DocChunk(docId=chunk_entity.doc_id, docName=chunk_entity.doc_name, chunks=[chunk]) search_chunk_msg.doc_chunks.append(dc) + doc_entities = await DocumentManager.list_document_by_doc_ids( + [doc_chunk.doc_id for doc_chunk in search_chunk_msg.doc_chunks]) + doc_map = {doc_entity.id: doc_entity for doc_entity in doc_entities} for doc_chunk in search_chunk_msg.doc_chunks: - doc_chunk.doc_link = await DocumentService.generate_doc_download_url(doc_chunk.doc_id) + doc_entity = doc_map.get(doc_chunk.doc_id) + doc_chunk.doc_abstract = doc_entity.abstract if doc_entity else "" + doc_chunk.doc_extension = doc_entity.extension if doc_entity else "" + doc_chunk.doc_size = doc_entity.size if doc_entity else 0 return search_chunk_msg async def update_chunk_by_id(chunk_id: uuid.UUID, req: UpdateChunkRequest) -> uuid.UUID: diff --git a/data_chain/apps/service/document_service.py b/data_chain/apps/service/document_service.py index 2c43a07aa108a733b107740985a448beb2882c41..dfe65b9d1ad5a972f7da38ec682e2ff86f70534d 100644 --- a/data_chain/apps/service/document_service.py +++ b/data_chain/apps/service/document_service.py @@ -39,6 +39,8 @@ class DocumentService: """验证用户对文档的操作权限""" try: doc_entity = await DocumentManager.get_document_by_doc_id(doc_id) + if doc_entity.kb_id == DEFAULT_KNOWLEDGE_BASE_ID: + return True if doc_entity is None: err = f"文档不存在, 文档ID: {doc_id}" logging.error("[DocumentService] %s", err) diff --git a/data_chain/apps/service/knwoledge_base_service.py b/data_chain/apps/service/knwoledge_base_service.py index 8dcc3daf20950ae847218a79a6764a0a06749d0f..17b445dc2a603bdeb287ee48e370a57da9a66a0e 100644 --- a/data_chain/apps/service/knwoledge_base_service.py +++ b/data_chain/apps/service/knwoledge_base_service.py @@ -23,7 +23,7 @@ from data_chain.entities.response_data import ( from data_chain.apps.base.zip_handler import ZipHandler from data_chain.apps.service.task_queue_service import TaskQueueService from data_chain.entities.enum import Tokenizer, ParseMethod, TeamType, TeamStatus, KnowledgeBaseStatus, TaskType -from data_chain.entities.common import DEFAULT_DOC_TYPE_ID, default_roles, IMPORT_KB_PATH_IN_OS, EXPORT_KB_PATH_IN_MINIO, IMPORT_KB_PATH_IN_MINIO +from data_chain.entities.common import DEFAULT_KNOWLEDGE_BASE_ID, DEFAULT_DOC_TYPE_ID, default_roles, IMPORT_KB_PATH_IN_OS, EXPORT_KB_PATH_IN_MINIO, IMPORT_KB_PATH_IN_MINIO from data_chain.stores.database.database import TeamEntity, KnowledgeBaseEntity, DocumentTypeEntity from data_chain.stores.minio.minio import MinIO from data_chain.apps.base.convertor import Convertor @@ -46,6 +46,8 @@ class KnowledgeBaseService: async def validate_user_action_to_knowledge_base( user_sub: str, kb_id: uuid.UUID, action: str) -> bool: """验证用户在知识库中的操作权限""" + if kb_id == DEFAULT_KNOWLEDGE_BASE_ID: + return True try: kb_entity = await KnowledgeBaseManager.get_knowledge_base_by_kb_id(kb_id) if kb_entity is None: diff --git a/data_chain/entities/enum.py b/data_chain/entities/enum.py index 5866976ef61ba05ffdd9b50f13d323bcac93e40f..90d7c702f7be21caca42e5e8f2b18c2b2bb2441b 100644 --- a/data_chain/entities/enum.py +++ b/data_chain/entities/enum.py @@ -158,6 +158,7 @@ class SearchMethod(str, Enum): """搜索方法""" KEYWORD = "keyword" VECTOR = "vector" + DYNAMIC_WEIGHTED_KEYWORD_AND_VECTOR = "dynamic_weighted_keyword_and_vector" KEYWORD_AND_VECTOR = "keyword_and_vector" DOC2CHUNK = "doc2chunk" DOC2CHUNK_BFS = "doc2chunk_bfs" diff --git a/data_chain/entities/response_data.py b/data_chain/entities/response_data.py index 4b9c478c0e2f58d1a7f0c6b47b85e83c5f5dd9f6..5a3bf03b855be861587efb43f5d144a5007d5895 100644 --- a/data_chain/entities/response_data.py +++ b/data_chain/entities/response_data.py @@ -326,7 +326,9 @@ class DocChunk(BaseModel): """Post /chunk/search 数据结构""" doc_id: uuid.UUID = Field(description="文档ID", alias="docId") doc_name: str = Field(description="文档名称", alias="docName") - doc_link: str = Field(default="", description="文档链接", alias="docLink") + doc_abstract: str = Field(default="", description="文档摘要", alias="docAbstract") + doc_extension: str = Field(default="", description="文档扩展名", alias="docExtension") + doc_size: int = Field(default=0, description="文档大小,单位是KB", alias="docSize") chunks: list[Chunk] = Field(default=[], description="分片列表", alias="chunks") diff --git a/data_chain/manager/chunk_manager.py b/data_chain/manager/chunk_manager.py index a45c9fb3a552e9ba91e48a99ab1cd2c85cdda407..277abc3e41a18c3a8ea824c4808893ac8bcf36d0 100644 --- a/data_chain/manager/chunk_manager.py +++ b/data_chain/manager/chunk_manager.py @@ -224,7 +224,7 @@ class ChunkManager(): async def get_top_k_chunk_by_kb_id_keyword( kb_id: uuid.UUID, query: str, top_k: int, doc_ids: list[uuid.UUID] = None, banned_ids: list[uuid.UUID] = [], - chunk_to_type: str = None, pre_ids: list[uuid.UUID] = None) -> List[ChunkEntity]: + chunk_to_type: str = None, pre_ids: list[uuid.UUID] = None, is_tight: bool = True) -> List[ChunkEntity]: """根据知识库ID和向量查询文档解析结果""" try: async with await DataBase.get_session() as session: @@ -243,10 +243,21 @@ class ChunkManager(): tokenizer = 'zhparser' # 计算相似度分数并选择它 - similarity_score = func.ts_rank_cd( - func.to_tsvector(tokenizer, ChunkEntity.text), - func.plainto_tsquery(tokenizer, query) - ).label("similarity_score") + if is_tight: + similarity_score = func.ts_rank_cd( + func.to_tsvector(tokenizer, ChunkEntity.text), + func.plainto_tsquery(tokenizer, query) + ).label("similarity_score") + else: + similarity_score = func.ts_rank_cd( + func.to_tsvector(tokenizer, ChunkEntity.text), + func.to_tsquery( + func.replace( + func.text(func.plainto_tsquery(tokenizer, query)), + '&', '|' + ) + ) + ).label("similarity_score") stmt = ( select(ChunkEntity, similarity_score) diff --git a/data_chain/manager/document_manager.py b/data_chain/manager/document_manager.py index 973be2b0d5beccbaf42f42faeb37e77df324d462..55cb87a90397d1d0a6df5db6176b423311b301c2 100644 --- a/data_chain/manager/document_manager.py +++ b/data_chain/manager/document_manager.py @@ -97,10 +97,17 @@ class DocumentManager(): tokenizer = 'zhparser' elif kb_entity.tokenizer == Tokenizer.EN.value: tokenizer = 'english' + similarity_score = func.ts_rank_cd( func.to_tsvector(tokenizer, DocumentEntity.abstract), - func.plainto_tsquery(tokenizer, query) + func.to_tsquery( + func.replace( + func.text(func.plainto_tsquery(tokenizer, query)), + '&', '|' + ) + ) ).label("similarity_score") + stmt = ( select(DocumentEntity, similarity_score) .where(DocumentEntity.kb_id == kb_id) diff --git a/data_chain/rag/dynamic_weighted_keyword_and_vector_searcher.py b/data_chain/rag/dynamic_weighted_keyword_and_vector_searcher.py new file mode 100644 index 0000000000000000000000000000000000000000..5efe05ee080f65920d5fcdff60fe0ae80745ccc8 --- /dev/null +++ b/data_chain/rag/dynamic_weighted_keyword_and_vector_searcher.py @@ -0,0 +1,58 @@ +import asyncio +import uuid +from pydantic import BaseModel, Field +import random +from data_chain.logger.logger import logger as logging +from data_chain.stores.database.database import ChunkEntity +from data_chain.parser.tools.token_tool import TokenTool +from data_chain.manager.chunk_manager import ChunkManager +from data_chain.rag.base_searcher import BaseSearcher +from data_chain.embedding.embedding import Embedding +from data_chain.entities.enum import SearchMethod + + +class KeywordVectorSearcher(BaseSearcher): + """ + 关键词向量检索 + """ + name = SearchMethod.DYNAMIC_WEIGHTED_KEYWORD_AND_VECTOR.value + + @staticmethod + async def search( + query: str, kb_id: uuid.UUID, top_k: int = 5, doc_ids: list[uuid.UUID] = None, + banned_ids: list[uuid.UUID] = [] + ) -> list[ChunkEntity]: + """ + 向量检索 + :param query: 查询 + :param top_k: 返回的结果数量 + :return: 检索结果 + """ + vector = await Embedding.vectorize_embedding(query) + try: + chunk_entities_get_by_keyword = await ChunkManager.get_top_k_chunk_by_kb_id_keyword( + kb_id, query, max(top_k//3, 1), doc_ids, banned_ids) + banned_ids += [chunk_entity.id for chunk_entity in chunk_entities_get_by_keyword] + keywords, weights = TokenTool.get_top_k_keywords_and_weights(query) + logging.error(f"[KeywordVectorSearcher] keywords: {keywords}, weights: {weights}") + chunk_entities_get_by_dynamic_weighted_keyword = await ChunkManager.get_top_k_chunk_by_kb_id_dynamic_weighted_keyword(kb_id, keywords, weights, top_k//2, doc_ids, banned_ids) + banned_ids += [chunk_entity.id for chunk_entity in chunk_entities_get_by_dynamic_weighted_keyword] + chunk_entities_get_by_vector = [] + for _ in range(3): + try: + import time + start_time = time.time() + chunk_entities_get_by_vector = await asyncio.wait_for(ChunkManager.get_top_k_chunk_by_kb_id_vector(kb_id, vector, top_k-len(chunk_entities_get_by_keyword)-len(chunk_entities_get_by_dynamic_weighted_keyword), doc_ids, banned_ids), timeout=3) + end_time = time.time() + logging.info(f"[KeywordVectorSearcher] 向量检索成功完成,耗时: {end_time - start_time:.2f}秒") + break + except Exception as e: + err = f"[KeywordVectorSearcher] 向量检索失败,error: {e}" + logging.error(err) + continue + chunk_entities = chunk_entities_get_by_keyword + chunk_entities_get_by_dynamic_weighted_keyword + chunk_entities_get_by_vector + except Exception as e: + err = f"[KeywordVectorSearcher] 关键词向量检索失败,error: {e}" + logging.exception(err) + return [] + return chunk_entities diff --git a/data_chain/rag/keyword_and_vector_searcher.py b/data_chain/rag/keyword_and_vector_searcher.py index c3bbe937a70e6c5aa0f3ded47dd8dc22bc9b742e..86b3b4f5cfca9065c6318caa45ab39c2ae517f74 100644 --- a/data_chain/rag/keyword_and_vector_searcher.py +++ b/data_chain/rag/keyword_and_vector_searcher.py @@ -31,18 +31,16 @@ class KeywordVectorSearcher(BaseSearcher): vector = await Embedding.vectorize_embedding(query) try: chunk_entities_get_by_keyword = await ChunkManager.get_top_k_chunk_by_kb_id_keyword( - kb_id, query, max(top_k//3, 1), doc_ids, banned_ids) + kb_id, query, max(top_k//3, 1), doc_ids, banned_ids, is_tight=True) + banned_ids += [chunk_entity.id for chunk_entity in chunk_entities_get_by_keyword] + chunk_entities_get_by_keyword += await ChunkManager.get_top_k_chunk_by_kb_id_keyword( + kb_id, query, max(top_k//2, 1), doc_ids, banned_ids, is_tight=False) banned_ids += [chunk_entity.id for chunk_entity in chunk_entities_get_by_keyword] - keywords, weights = TokenTool.get_top_k_keywords_and_weights(query) - logging.error(f"[KeywordVectorSearcher] keywords: {keywords}, weights: {weights}") - chunk_entities_get_by_dynamic_weighted_keyword = await ChunkManager.get_top_k_chunk_by_kb_id_dynamic_weighted_keyword(kb_id, keywords, weights, top_k//2, doc_ids, banned_ids) - banned_ids += [chunk_entity.id for chunk_entity in chunk_entities_get_by_dynamic_weighted_keyword] - chunk_entities_get_by_vector = [] for _ in range(3): try: import time start_time = time.time() - chunk_entities_get_by_vector = await asyncio.wait_for(ChunkManager.get_top_k_chunk_by_kb_id_vector(kb_id, vector, top_k-len(chunk_entities_get_by_keyword)-len(chunk_entities_get_by_dynamic_weighted_keyword), doc_ids, banned_ids), timeout=3) + chunk_entities_get_by_vector = await asyncio.wait_for(ChunkManager.get_top_k_chunk_by_kb_id_vector(kb_id, vector, top_k-len(chunk_entities_get_by_keyword), doc_ids, banned_ids), timeout=3) end_time = time.time() logging.info(f"[KeywordVectorSearcher] 向量检索成功完成,耗时: {end_time - start_time:.2f}秒") break @@ -50,7 +48,7 @@ class KeywordVectorSearcher(BaseSearcher): err = f"[KeywordVectorSearcher] 向量检索失败,error: {e}" logging.error(err) continue - chunk_entities = chunk_entities_get_by_keyword + chunk_entities_get_by_dynamic_weighted_keyword + chunk_entities_get_by_vector + chunk_entities = chunk_entities_get_by_keyword + chunk_entities_get_by_vector except Exception as e: err = f"[KeywordVectorSearcher] 关键词向量检索失败,error: {e}" logging.exception(err) diff --git a/data_chain/rag/keyword_searcher.py b/data_chain/rag/keyword_searcher.py index a87cedc2634dd45d7cafed6c9f7ebeb6711d82ba..d622055fe7915b436aea67603f4e7200c6d43cb5 100644 --- a/data_chain/rag/keyword_searcher.py +++ b/data_chain/rag/keyword_searcher.py @@ -25,7 +25,9 @@ class KeyWordSearcher(BaseSearcher): :return: 检索结果 """ try: - chunk_entities = await ChunkManager.get_top_k_chunk_by_kb_id_keyword(kb_id, query, top_k, doc_ids, banned_ids) + chunk_entities = await ChunkManager.get_top_k_chunk_by_kb_id_keyword(kb_id, query, top_k//3, doc_ids, banned_ids) + banned_ids += [chunk_entity.id for chunk_entity in chunk_entities] + chunk_entities += await ChunkManager.get_top_k_chunk_by_kb_id_keyword(kb_id, query, top_k-len(chunk_entities), doc_ids, banned_ids, is_tight=False) except Exception as e: err = f"[KeyWordSearcher] 关键词检索失败,error: {e}" logging.exception(err) diff --git a/test/test_qa.py b/test/test_qa.py index 78145504a4178dabb894cf0bc35100a7e6753e02..ab3408d5aa1b1ee5aa49c16b2c5509ec332a732a 100644 --- a/test/test_qa.py +++ b/test/test_qa.py @@ -361,6 +361,7 @@ class QAScore(): prompt = prompt_dict['SCORE_QA'] llm_score_dict = await self.chat_with_llm(llm, prompt, QA['question'], QA['text'], QA['witChainD_source'], QA['answer'], QA['witChainD_answer']) print(llm_score_dict) + QA['context_relevancy'] = llm_score_dict['context_relevancy'] QA['context_recall'] = llm_score_dict['context_recall'] QA['faithfulness'] = llm_score_dict['faithfulness'] @@ -406,15 +407,21 @@ class QAScore(): - qa_pairs: list[dict] """ + required_metrics = { + "context_relevancy", + "context_recall", + "faithfulness", + "answer_relevancy", + } for i in range(5): try: - user_call = '''请对答案打分,并以下面形式返回结果{ + user_call = """请对答案打分,并以下面形式返回结果{ \"context_relevancy\": 分数, \"context_recall\": 分数, \"faithfulness\": 分数, \"answer_relevancy\": 分数 } -''' +注意:属性名必须使用双引号,分数为数字,保留两位小数。""" prompt = prompt.format(question=question, meta_chunk=meta_chunk, chunk=chunk, answer=answer, answer_text=answer_text) # print(prompt) @@ -423,10 +430,19 @@ class QAScore(): en = score_dict.rfind('}') if st != -1 and en != -1: score_dict = score_dict[st:en+1] - print(score_dict) + # print(score_dict) score_dict = json.loads(score_dict) # 提取问题、答案段落对的list,字符串格式为["问题","答案","段落对"] # print(score) + present_metrics = set(score_dict.keys()) + missing_metrics = required_metrics - present_metrics + if missing_metrics: + missing = ", ".join(missing_metrics) + print(f"评分结果缺少必要指标: {missing}") + for metric in required_metrics: + if metric not in score_dict: + score_dict[metric] = 0.00 + print(score_dict) return score_dict except Exception as e: continue @@ -579,11 +595,11 @@ if __name__ == '__main__': print(f"获取到{len(t_QAs)}个文档") for item in t_QAs[0]: single_item = { - "question": item['问题'], - "answer": item['标准答案'], - "witChainD_answer": item['llm的回答'], - "text": item['原始片段'], - "witChainD_source": item['检索片段'], + "question": item["问题"], + "answer": item["标准答案"], + "witChainD_answer": item["llm的回答"], + "text": item["原始片段"], + "witChainD_source": item["检索片段"], } # print(single_item) ttt_QAs = asyncio.run(QAScore().get_score(single_item)) @@ -638,7 +654,7 @@ if __name__ == '__main__': } else: ReOrderedQA = { - '领域': str(QA['type']), + # '领域': str(QA['type']), '问题': str(QA['question']), '标准答案': str(QA['answer']), 'llm的回答': str(QA['witChainD_answer']), @@ -682,7 +698,6 @@ if __name__ == '__main__': ) avg[metric] = avg_time_cost - print(f"生成测试结果: {avg}") excel_path = current_dir / 'answer.xlsx' with pd.ExcelWriter(excel_path, engine='xlsxwriter') as writer: @@ -696,9 +711,9 @@ if __name__ == '__main__': **{k: v for k, v in avg.items() if k != "time_cost"}, **{f"time_cost_{k}": v for k, v in filtered_time_cost.items()}, } + print(f"写入测试结果:{flat_avg}") avg_df = pd.DataFrame([flat_avg]) avg_df.to_excel(writer, sheet_name="测试结果", index=False) print(f'测试样例和结果已输出到{excel_path}') -