diff --git a/data_chain/apps/base/task/worker/generate_dataset_worker.py b/data_chain/apps/base/task/worker/generate_dataset_worker.py index 552cdac8b4ca8947e0f0e53dcbe9edafa40cceb1..fb39656cd92a32c2b33dd2a920719e1ba48647f7 100644 --- a/data_chain/apps/base/task/worker/generate_dataset_worker.py +++ b/data_chain/apps/base/task/worker/generate_dataset_worker.py @@ -135,18 +135,25 @@ class GenerateDataSetWorker(BaseWorker): chunk_cnt = len(chunk_index_list) division = data_cnt // chunk_cnt remainder = data_cnt % chunk_cnt - logging.error(f"数据集总条目 {dataset_entity.data_cnt}, 分块数量: {chunk_cnt}, 每块数据量: {division}, 余数: {remainder}") + logging.error( + f"数据集总条目 {dataset_entity.data_cnt}, 分块数量: {chunk_cnt}, 每块数据量: {division}, 余数: {remainder}") index = 0 d_index = 0 random.shuffle(doc_chunks) with open(config['PROMPT_PATH'], 'r', encoding='utf-8') as f: prompt_dict = yaml.load(f, Loader=yaml.SafeLoader) - q_generate_prompt_template = prompt_dict.get('GENERATE_QUESTION_FROM_CONTENT_PROMPT', {}) - q_generate_prompt_template = q_generate_prompt_template.get(language, '') - answer_generate_prompt_template = prompt_dict.get('GENERATE_ANSWER_FROM_QUESTION_AND_CONTENT_PROMPT', {}) - answer_generate_prompt_template = answer_generate_prompt_template.get(language, '') - cal_qa_score_prompt_template = prompt_dict.get('CAL_QA_SCORE_PROMPT', {}) - cal_qa_score_prompt_template = cal_qa_score_prompt_template.get(language, '') + q_generate_prompt_template = prompt_dict.get( + 'GENERATE_QUESTION_FROM_CONTENT_PROMPT', {}) + q_generate_prompt_template = q_generate_prompt_template.get( + language, '') + answer_generate_prompt_template = prompt_dict.get( + 'GENERATE_ANSWER_FROM_QUESTION_AND_CONTENT_PROMPT', {}) + answer_generate_prompt_template = answer_generate_prompt_template.get( + language, '') + cal_qa_score_prompt_template = prompt_dict.get( + 'CAL_QA_SCORE_PROMPT', {}) + cal_qa_score_prompt_template = cal_qa_score_prompt_template.get( + language, '') dataset_score = 0 logging.error(f"{chunk_index_list}") exist_q_set = set() @@ -168,20 +175,24 @@ class GenerateDataSetWorker(BaseWorker): break if tokens_sub > 0: if l >= 0: - tokens_sub -= TokenTool.get_tokens(doc_chunk.chunks[l].text) + tokens_sub -= TokenTool.get_tokens( + doc_chunk.chunks[l].text) chunk = doc_chunk.chunks[l].text+chunk l -= 1 else: - tokens_sub += TokenTool.get_tokens(doc_chunk.chunks[r].text) + tokens_sub += TokenTool.get_tokens( + doc_chunk.chunks[r].text) chunk += doc_chunk.chunks[r].text r += 1 else: if r < len(doc_chunk.chunks): - tokens_sub += TokenTool.get_tokens(doc_chunk.chunks[r].text) + tokens_sub += TokenTool.get_tokens( + doc_chunk.chunks[r].text) chunk += doc_chunk.chunks[r].text r += 1 else: - tokens_sub -= TokenTool.get_tokens(doc_chunk.chunks[l].text) + tokens_sub -= TokenTool.get_tokens( + doc_chunk.chunks[l].text) chunk = doc_chunk.chunks[l].text+chunk l -= 1 qa_cnt = division+(d_index <= remainder) @@ -193,11 +204,12 @@ class GenerateDataSetWorker(BaseWorker): try: sys_call = q_generate_prompt_template.format( k=5*(qa_cnt-len(qs)), - content=TokenTool.get_k_tokens_words_from_content(chunk, llm.max_tokens) + content=TokenTool.get_k_tokens_words_from_content( + chunk, llm.max_tokens) ) usr_call = '请输出问题的列表' sub_qs = await llm.nostream([], sys_call, usr_call, st_str='[', en_str=']') - sub_qs = json.loads(sub_qs) + sub_qs = TokenTool.loads_json_string(sub_qs) except Exception as e: err = f"[GenerateDataSetWorker] 生成问题失败,错误信息: {e}" logging.exception(err) @@ -214,8 +226,10 @@ class GenerateDataSetWorker(BaseWorker): try: for q in sub_qs: sys_call = answer_generate_prompt_template.format( - content=TokenTool.get_k_tokens_words_from_content(chunk, llm.max_tokens//8*7), - question=TokenTool.get_k_tokens_words_from_content(q, llm.max_tokens//8) + content=TokenTool.get_k_tokens_words_from_content( + chunk, llm.max_tokens//8*7), + question=TokenTool.get_k_tokens_words_from_content( + q, llm.max_tokens//8) ) usr_call = '请输出答案' sub_answer = await llm.nostream([], sys_call, usr_call) @@ -230,9 +244,12 @@ class GenerateDataSetWorker(BaseWorker): try: if dataset_entity.is_data_cleared: sys_call = cal_qa_score_prompt_template.format( - fragment=TokenTool.get_k_tokens_words_from_content(chunk, llm.max_tokens//9*4), - question=TokenTool.get_k_tokens_words_from_content(q, llm.max_tokens//9), - answer=TokenTool.get_k_tokens_words_from_content(answer, llm.max_tokens//9*4) + fragment=TokenTool.get_k_tokens_words_from_content( + chunk, llm.max_tokens//9*4), + question=TokenTool.get_k_tokens_words_from_content( + q, llm.max_tokens//9), + answer=TokenTool.get_k_tokens_words_from_content( + answer, llm.max_tokens//9*4) ) usr_call = '请输出分数' score = await llm.nostream([], sys_call, usr_call) diff --git a/data_chain/llm/llm.py b/data_chain/llm/llm.py index 031af5bbde694a5d4492174f55a3a65138c5bc3a..bad100b5815dc1d85dcf46feb22eff745fd6a26c 100644 --- a/data_chain/llm/llm.py +++ b/data_chain/llm/llm.py @@ -26,16 +26,28 @@ class LLM: async def create_stream( self, message): - return await self._client.chat.completions.create( - model=self.model_name, - messages=message, # type: ignore[] - max_completion_tokens=self.max_tokens, - temperature=self.temperature, - stream=True, - stream_options={"include_usage": True}, - timeout=300, - extra_body={"enable_thinking": False} - ) # type: ignore[] + try: + return await self._client.chat.completions.create( + model=self.model_name, + messages=message, # type: ignore[] + max_completion_tokens=self.max_tokens, + temperature=self.temperature, + stream=True, + stream_options={"include_usage": True}, + timeout=300, + extra_body={"enable_thinking": False} + ) # type: ignore[] + except Exception as e: + warning = f"[LLM] create_stream 出现异常: {e}" + logger.warning(warning) + return await self._client.chat.completions.create( + model=self.model_name, + messages=message, # type: ignore[] + max_completion_tokens=self.max_tokens, + temperature=self.temperature, + stream=True, + timeout=300 + ) # type: ignore[] async def data_producer(self, q: asyncio.Queue, history, system_call, user_call): message = self.assemble_chat(history, system_call, user_call) diff --git a/data_chain/manager/chunk_manager.py b/data_chain/manager/chunk_manager.py index 602eb165bbe017abd3fab9b41379039fde5c6efa..4ece74bba9378cdade6f27475cf76a1e2a78130d 100644 --- a/data_chain/manager/chunk_manager.py +++ b/data_chain/manager/chunk_manager.py @@ -466,12 +466,19 @@ class ChunkManager(): @staticmethod async def get_top_k_chunk_by_kb_id_jieba(kb_id: uuid.UUID, query: str, # 关键词列表改为单查询文本 - top_k: int, doc_ids: list[uuid.UUID] = None, banned_ids: list[uuid.UUID] = [], - chunk_to_type: str = None, pre_ids: list[uuid.UUID] = None) -> List[ChunkEntity]: + top_k: int, doc_ids: list[uuid.UUID] = None, banned_ids: list[uuid.UUID] = [], + chunk_to_type: str = None, pre_ids: list[uuid.UUID] = None) -> List[ChunkEntity]: """根据知识库ID和关键词权重查询文档解析结果(修复NoneType报错+强制索引)""" try: keywords, weights = TokenTool.get_top_k_keywords_and_weights( - query, top_k=20) + query) + if len(keywords) == 0: + return [] + if len(keywords) != len(weights): + return [] + if len(keywords) > 50: + keywords = keywords[:50] + weights = weights[:50] st = datetime.now() async with await DataBase.get_session() as session: # 1. 分词器选择(保留原逻辑) diff --git a/data_chain/manager/document_manager.py b/data_chain/manager/document_manager.py index ca271fbd84fa8b175d06d454294935d4191e3d73..567d5de7752a93cdb9b3e29e8267565339a5b255 100644 --- a/data_chain/manager/document_manager.py +++ b/data_chain/manager/document_manager.py @@ -280,6 +280,13 @@ class DocumentManager(): top_k: int, doc_ids: list[uuid.UUID] = None, banned_ids: list[uuid.UUID] = []) -> List[DocumentEntity]: try: keywords, weights = TokenTool.get_top_k_keywords_and_weights(query) + if len(keywords) == 0: + return [] + if len(keywords) != len(weights): + return [] + if len(keywords) > 50: + keywords = keywords[:50] + weights = weights[:50] st = datetime.now() # 新增计时日志 async with await DataBase.get_session() as session: # 1. 分词器选择(与第一个方法保持一致) diff --git a/data_chain/parser/tools/token_tool.py b/data_chain/parser/tools/token_tool.py index b71ef981a38216e4fd1b0de1553cc8db3e28034e..cca2c2df846983ce908b92c25b87d5defdabd95f 100644 --- a/data_chain/parser/tools/token_tool.py +++ b/data_chain/parser/tools/token_tool.py @@ -7,6 +7,7 @@ import json import re import uuid import numpy as np +import json_repair from pydantic import BaseModel, Field from data_chain.llm.llm import LLM from data_chain.embedding.embedding import Embedding @@ -24,6 +25,31 @@ class TokenTool: with open(stop_words_path, 'r', encoding='utf-8') as f: stopwords = set(line.strip() for line in f) + @staticmethod + # 基于json_repair修复json字符串 + def repair_json_string(json_string: str) -> str: + try: + repaired_json = json_repair.json_repair.repair_json( + json_string, ensure_ascii=False) + return repaired_json + except Exception as e: + err = f"[TokenTool] 修复json字符串失败 {e}" + logging.exception("[TokenTool] %s", err) + return json_string + + @staticmethod + def loads_json_string(json_string: str): + """ + 加载json字符串,若失败则尝试修复 + """ + try: + return json.loads(json_string) + except json.JSONDecodeError as e: + warning_msg = f"[TokenTool] 解析JSON字符串失败,尝试修复: {e}" + logging.warning(warning_msg) + repaired_json_string = TokenTool.repair_json_string(json_string) + return json.loads(repaired_json_string) + @staticmethod def filter_stopwords(content: str) -> str: """ @@ -316,6 +342,33 @@ class TokenTool: err = f"[TokenTool] 获取标题失败 {e}" logging.exception("[TokenTool] %s", err) + @staticmethod + def fullwidth_to_halfwidth(s: str) -> str: + result = "" + for char in s: + code_point = ord(char) + # 全角空格 + if code_point == 0x3000: + code_point = 0x0020 + # 其他全角字符 + elif 0xFF01 <= code_point <= 0xFF5E: + code_point -= 0xFEE0 + result += chr(code_point) + return result + + @staticmethod + def extract_number_from_string(s: str) -> float: + try: + match = re.search(r"[-+]?\d*\.\d+|\d+", s) + if match: + return float(match.group()) + else: + return -1 + except Exception as e: + err = f"[TokenTool] 提取数字失败 {e}" + logging.exception("[TokenTool] %s", err) + return -1 + @staticmethod async def cal_recall(answer: str, bac_info: str, llm: LLM, language: str) -> float: """ @@ -336,9 +389,15 @@ class TokenTool: bac_info, llm.max_tokens-llm.max_tokens//8) prompt = prompt_template.format(fragment=bac_info, answer=answer) sys_call = prompt - user_call = '请输出相似度' + if language == 'en': + user_call = 'Please output the recall score, do not output other content' + else: + user_call = '请输出分数,不要输出其他内容' similarity = await llm.nostream([], sys_call, user_call) - return eval(similarity) + similarity = TokenTool.fullwidth_to_halfwidth(similarity) + similarity = similarity.strip() + similarity = TokenTool.extract_number_from_string(similarity) + return similarity except Exception as e: err = f"[TokenTool] 计算recall失败 {e}" logging.exception("[TokenTool] %s", err) @@ -487,7 +546,7 @@ class TokenTool: user_call = '请结合文本输出问题列表' question_vector = await Embedding.vectorize_embedding(question) qs = await llm.nostream([], sys_call, user_call) - qs = json.loads(qs) + qs = TokenTool.loads_json_string(qs) if len(qs) == 0: return 0 score = 0 diff --git a/requirements.txt b/requirements.txt index f97dae85049f6f64d8260f48fc241e046838ac26..e6155d58a5cd0d82bb6e14cf09b99d21627331ba 100644 --- a/requirements.txt +++ b/requirements.txt @@ -47,4 +47,5 @@ xlrd==2.0.1 py-cpuinfo==9.0.0 opengauss-sqlalchemy==2.4.0 #marker-pdf==1.8.0 -motor==3.7.1 \ No newline at end of file +motor==3.7.1 +json_repair==0.52.5 \ No newline at end of file