From 145f08b06cdcab10755405e394f7f3ff8751d11c Mon Sep 17 00:00:00 2001 From: zxstty Date: Fri, 7 Nov 2025 15:44:18 +0800 Subject: [PATCH] =?UTF-8?q?=E5=AE=8C=E5=96=84=E5=85=A8=E8=A7=92=E5=AD=97?= =?UTF-8?q?=E7=AC=A6=E5=A4=84=E7=90=86=EF=BC=9B=E5=A4=A7=E6=A8=A1=E5=9E=8B?= =?UTF-8?q?=E9=A2=9D=E5=A4=96=E5=8F=82=E6=95=B0=E7=AE=A1=E7=90=86=EF=BC=9B?= =?UTF-8?q?json=E4=BF=AE=E5=A4=8D=E7=AD=89=E6=9C=BA=E5=88=B6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../task/worker/generate_dataset_worker.py | 53 ++++++++++----- data_chain/llm/llm.py | 32 ++++++--- data_chain/manager/chunk_manager.py | 13 +++- data_chain/manager/document_manager.py | 7 ++ data_chain/parser/tools/token_tool.py | 65 ++++++++++++++++++- requirements.txt | 3 +- 6 files changed, 138 insertions(+), 35 deletions(-) diff --git a/data_chain/apps/base/task/worker/generate_dataset_worker.py b/data_chain/apps/base/task/worker/generate_dataset_worker.py index 552cdac8..fb39656c 100644 --- a/data_chain/apps/base/task/worker/generate_dataset_worker.py +++ b/data_chain/apps/base/task/worker/generate_dataset_worker.py @@ -135,18 +135,25 @@ class GenerateDataSetWorker(BaseWorker): chunk_cnt = len(chunk_index_list) division = data_cnt // chunk_cnt remainder = data_cnt % chunk_cnt - logging.error(f"数据集总条目 {dataset_entity.data_cnt}, 分块数量: {chunk_cnt}, 每块数据量: {division}, 余数: {remainder}") + logging.error( + f"数据集总条目 {dataset_entity.data_cnt}, 分块数量: {chunk_cnt}, 每块数据量: {division}, 余数: {remainder}") index = 0 d_index = 0 random.shuffle(doc_chunks) with open(config['PROMPT_PATH'], 'r', encoding='utf-8') as f: prompt_dict = yaml.load(f, Loader=yaml.SafeLoader) - q_generate_prompt_template = prompt_dict.get('GENERATE_QUESTION_FROM_CONTENT_PROMPT', {}) - q_generate_prompt_template = q_generate_prompt_template.get(language, '') - answer_generate_prompt_template = prompt_dict.get('GENERATE_ANSWER_FROM_QUESTION_AND_CONTENT_PROMPT', {}) - answer_generate_prompt_template = answer_generate_prompt_template.get(language, '') - cal_qa_score_prompt_template = prompt_dict.get('CAL_QA_SCORE_PROMPT', {}) - cal_qa_score_prompt_template = cal_qa_score_prompt_template.get(language, '') + q_generate_prompt_template = prompt_dict.get( + 'GENERATE_QUESTION_FROM_CONTENT_PROMPT', {}) + q_generate_prompt_template = q_generate_prompt_template.get( + language, '') + answer_generate_prompt_template = prompt_dict.get( + 'GENERATE_ANSWER_FROM_QUESTION_AND_CONTENT_PROMPT', {}) + answer_generate_prompt_template = answer_generate_prompt_template.get( + language, '') + cal_qa_score_prompt_template = prompt_dict.get( + 'CAL_QA_SCORE_PROMPT', {}) + cal_qa_score_prompt_template = cal_qa_score_prompt_template.get( + language, '') dataset_score = 0 logging.error(f"{chunk_index_list}") exist_q_set = set() @@ -168,20 +175,24 @@ class GenerateDataSetWorker(BaseWorker): break if tokens_sub > 0: if l >= 0: - tokens_sub -= TokenTool.get_tokens(doc_chunk.chunks[l].text) + tokens_sub -= TokenTool.get_tokens( + doc_chunk.chunks[l].text) chunk = doc_chunk.chunks[l].text+chunk l -= 1 else: - tokens_sub += TokenTool.get_tokens(doc_chunk.chunks[r].text) + tokens_sub += TokenTool.get_tokens( + doc_chunk.chunks[r].text) chunk += doc_chunk.chunks[r].text r += 1 else: if r < len(doc_chunk.chunks): - tokens_sub += TokenTool.get_tokens(doc_chunk.chunks[r].text) + tokens_sub += TokenTool.get_tokens( + doc_chunk.chunks[r].text) chunk += doc_chunk.chunks[r].text r += 1 else: - tokens_sub -= TokenTool.get_tokens(doc_chunk.chunks[l].text) + tokens_sub -= TokenTool.get_tokens( + doc_chunk.chunks[l].text) chunk = doc_chunk.chunks[l].text+chunk l -= 1 qa_cnt = division+(d_index <= remainder) @@ -193,11 +204,12 @@ class GenerateDataSetWorker(BaseWorker): try: sys_call = q_generate_prompt_template.format( k=5*(qa_cnt-len(qs)), - content=TokenTool.get_k_tokens_words_from_content(chunk, llm.max_tokens) + content=TokenTool.get_k_tokens_words_from_content( + chunk, llm.max_tokens) ) usr_call = '请输出问题的列表' sub_qs = await llm.nostream([], sys_call, usr_call, st_str='[', en_str=']') - sub_qs = json.loads(sub_qs) + sub_qs = TokenTool.loads_json_string(sub_qs) except Exception as e: err = f"[GenerateDataSetWorker] 生成问题失败,错误信息: {e}" logging.exception(err) @@ -214,8 +226,10 @@ class GenerateDataSetWorker(BaseWorker): try: for q in sub_qs: sys_call = answer_generate_prompt_template.format( - content=TokenTool.get_k_tokens_words_from_content(chunk, llm.max_tokens//8*7), - question=TokenTool.get_k_tokens_words_from_content(q, llm.max_tokens//8) + content=TokenTool.get_k_tokens_words_from_content( + chunk, llm.max_tokens//8*7), + question=TokenTool.get_k_tokens_words_from_content( + q, llm.max_tokens//8) ) usr_call = '请输出答案' sub_answer = await llm.nostream([], sys_call, usr_call) @@ -230,9 +244,12 @@ class GenerateDataSetWorker(BaseWorker): try: if dataset_entity.is_data_cleared: sys_call = cal_qa_score_prompt_template.format( - fragment=TokenTool.get_k_tokens_words_from_content(chunk, llm.max_tokens//9*4), - question=TokenTool.get_k_tokens_words_from_content(q, llm.max_tokens//9), - answer=TokenTool.get_k_tokens_words_from_content(answer, llm.max_tokens//9*4) + fragment=TokenTool.get_k_tokens_words_from_content( + chunk, llm.max_tokens//9*4), + question=TokenTool.get_k_tokens_words_from_content( + q, llm.max_tokens//9), + answer=TokenTool.get_k_tokens_words_from_content( + answer, llm.max_tokens//9*4) ) usr_call = '请输出分数' score = await llm.nostream([], sys_call, usr_call) diff --git a/data_chain/llm/llm.py b/data_chain/llm/llm.py index 031af5bb..bad100b5 100644 --- a/data_chain/llm/llm.py +++ b/data_chain/llm/llm.py @@ -26,16 +26,28 @@ class LLM: async def create_stream( self, message): - return await self._client.chat.completions.create( - model=self.model_name, - messages=message, # type: ignore[] - max_completion_tokens=self.max_tokens, - temperature=self.temperature, - stream=True, - stream_options={"include_usage": True}, - timeout=300, - extra_body={"enable_thinking": False} - ) # type: ignore[] + try: + return await self._client.chat.completions.create( + model=self.model_name, + messages=message, # type: ignore[] + max_completion_tokens=self.max_tokens, + temperature=self.temperature, + stream=True, + stream_options={"include_usage": True}, + timeout=300, + extra_body={"enable_thinking": False} + ) # type: ignore[] + except Exception as e: + warning = f"[LLM] create_stream 出现异常: {e}" + logger.warning(warning) + return await self._client.chat.completions.create( + model=self.model_name, + messages=message, # type: ignore[] + max_completion_tokens=self.max_tokens, + temperature=self.temperature, + stream=True, + timeout=300 + ) # type: ignore[] async def data_producer(self, q: asyncio.Queue, history, system_call, user_call): message = self.assemble_chat(history, system_call, user_call) diff --git a/data_chain/manager/chunk_manager.py b/data_chain/manager/chunk_manager.py index 602eb165..4ece74bb 100644 --- a/data_chain/manager/chunk_manager.py +++ b/data_chain/manager/chunk_manager.py @@ -466,12 +466,19 @@ class ChunkManager(): @staticmethod async def get_top_k_chunk_by_kb_id_jieba(kb_id: uuid.UUID, query: str, # 关键词列表改为单查询文本 - top_k: int, doc_ids: list[uuid.UUID] = None, banned_ids: list[uuid.UUID] = [], - chunk_to_type: str = None, pre_ids: list[uuid.UUID] = None) -> List[ChunkEntity]: + top_k: int, doc_ids: list[uuid.UUID] = None, banned_ids: list[uuid.UUID] = [], + chunk_to_type: str = None, pre_ids: list[uuid.UUID] = None) -> List[ChunkEntity]: """根据知识库ID和关键词权重查询文档解析结果(修复NoneType报错+强制索引)""" try: keywords, weights = TokenTool.get_top_k_keywords_and_weights( - query, top_k=20) + query) + if len(keywords) == 0: + return [] + if len(keywords) != len(weights): + return [] + if len(keywords) > 50: + keywords = keywords[:50] + weights = weights[:50] st = datetime.now() async with await DataBase.get_session() as session: # 1. 分词器选择(保留原逻辑) diff --git a/data_chain/manager/document_manager.py b/data_chain/manager/document_manager.py index ca271fbd..567d5de7 100644 --- a/data_chain/manager/document_manager.py +++ b/data_chain/manager/document_manager.py @@ -280,6 +280,13 @@ class DocumentManager(): top_k: int, doc_ids: list[uuid.UUID] = None, banned_ids: list[uuid.UUID] = []) -> List[DocumentEntity]: try: keywords, weights = TokenTool.get_top_k_keywords_and_weights(query) + if len(keywords) == 0: + return [] + if len(keywords) != len(weights): + return [] + if len(keywords) > 50: + keywords = keywords[:50] + weights = weights[:50] st = datetime.now() # 新增计时日志 async with await DataBase.get_session() as session: # 1. 分词器选择(与第一个方法保持一致) diff --git a/data_chain/parser/tools/token_tool.py b/data_chain/parser/tools/token_tool.py index b71ef981..cca2c2df 100644 --- a/data_chain/parser/tools/token_tool.py +++ b/data_chain/parser/tools/token_tool.py @@ -7,6 +7,7 @@ import json import re import uuid import numpy as np +import json_repair from pydantic import BaseModel, Field from data_chain.llm.llm import LLM from data_chain.embedding.embedding import Embedding @@ -24,6 +25,31 @@ class TokenTool: with open(stop_words_path, 'r', encoding='utf-8') as f: stopwords = set(line.strip() for line in f) + @staticmethod + # 基于json_repair修复json字符串 + def repair_json_string(json_string: str) -> str: + try: + repaired_json = json_repair.json_repair.repair_json( + json_string, ensure_ascii=False) + return repaired_json + except Exception as e: + err = f"[TokenTool] 修复json字符串失败 {e}" + logging.exception("[TokenTool] %s", err) + return json_string + + @staticmethod + def loads_json_string(json_string: str): + """ + 加载json字符串,若失败则尝试修复 + """ + try: + return json.loads(json_string) + except json.JSONDecodeError as e: + warning_msg = f"[TokenTool] 解析JSON字符串失败,尝试修复: {e}" + logging.warning(warning_msg) + repaired_json_string = TokenTool.repair_json_string(json_string) + return json.loads(repaired_json_string) + @staticmethod def filter_stopwords(content: str) -> str: """ @@ -316,6 +342,33 @@ class TokenTool: err = f"[TokenTool] 获取标题失败 {e}" logging.exception("[TokenTool] %s", err) + @staticmethod + def fullwidth_to_halfwidth(s: str) -> str: + result = "" + for char in s: + code_point = ord(char) + # 全角空格 + if code_point == 0x3000: + code_point = 0x0020 + # 其他全角字符 + elif 0xFF01 <= code_point <= 0xFF5E: + code_point -= 0xFEE0 + result += chr(code_point) + return result + + @staticmethod + def extract_number_from_string(s: str) -> float: + try: + match = re.search(r"[-+]?\d*\.\d+|\d+", s) + if match: + return float(match.group()) + else: + return -1 + except Exception as e: + err = f"[TokenTool] 提取数字失败 {e}" + logging.exception("[TokenTool] %s", err) + return -1 + @staticmethod async def cal_recall(answer: str, bac_info: str, llm: LLM, language: str) -> float: """ @@ -336,9 +389,15 @@ class TokenTool: bac_info, llm.max_tokens-llm.max_tokens//8) prompt = prompt_template.format(fragment=bac_info, answer=answer) sys_call = prompt - user_call = '请输出相似度' + if language == 'en': + user_call = 'Please output the recall score, do not output other content' + else: + user_call = '请输出分数,不要输出其他内容' similarity = await llm.nostream([], sys_call, user_call) - return eval(similarity) + similarity = TokenTool.fullwidth_to_halfwidth(similarity) + similarity = similarity.strip() + similarity = TokenTool.extract_number_from_string(similarity) + return similarity except Exception as e: err = f"[TokenTool] 计算recall失败 {e}" logging.exception("[TokenTool] %s", err) @@ -487,7 +546,7 @@ class TokenTool: user_call = '请结合文本输出问题列表' question_vector = await Embedding.vectorize_embedding(question) qs = await llm.nostream([], sys_call, user_call) - qs = json.loads(qs) + qs = TokenTool.loads_json_string(qs) if len(qs) == 0: return 0 score = 0 diff --git a/requirements.txt b/requirements.txt index f97dae85..e6155d58 100644 --- a/requirements.txt +++ b/requirements.txt @@ -47,4 +47,5 @@ xlrd==2.0.1 py-cpuinfo==9.0.0 opengauss-sqlalchemy==2.4.0 #marker-pdf==1.8.0 -motor==3.7.1 \ No newline at end of file +motor==3.7.1 +json_repair==0.52.5 \ No newline at end of file -- Gitee