diff --git a/Dockerfile-base b/Dockerfile-base index 7d2277e9576379d599d9d1c4dfaddc44de5aa06b..8222781d900b063371cb8413c696ce050c444213 100644 --- a/Dockerfile-base +++ b/Dockerfile-base @@ -1,4 +1,4 @@ -FROM openeuler/openeuler:22.03-lts-sp1 +FROM openeuler/openeuler:24.03-lts-sp2 # 设置环境变量 ENV PATH /rag-service/.local/bin:$PATH @@ -13,10 +13,6 @@ RUN sed -i 's|http://repo.openeuler.org/|https://mirrors.huaweicloud.com/openeul # 创建 /rag-service 目录并设置权限 RUN mkdir -p /rag-service && chown -R 1001:1001 /rag-service -COPY --chown=1001:1001 openGauss-sqlalchemy.tar.gz . -RUN tar -xvf openGauss-sqlalchemy.tar.gz && \ - cd openGauss-sqlalchemy && \ - pip3 install . --index-url https://pypi.tuna.tsinghua.edu.cn/simple # 切换到 eulercopilot 用户 USER eulercopilot diff --git a/data_chain/apps/app.py b/data_chain/apps/app.py index f5ea0b0e0560964e5146b167a10d4653f362a330..cb4bf7e5439c023c0f2c87d953d0cadc8c0b2689 100644 --- a/data_chain/apps/app.py +++ b/data_chain/apps/app.py @@ -46,7 +46,8 @@ from data_chain.parser.handler import ( xlsx_parser, yaml_parser, picture_parser, - deep_pdf_parser + deep_pdf_parser, + fine_pdf_parser ) from data_chain.rag import ( base_searcher, diff --git a/data_chain/apps/base/task/worker/generate_dataset_worker.py b/data_chain/apps/base/task/worker/generate_dataset_worker.py index b0dd0dd4430b0bd11874a3a67f8ed18118778f7a..eb01850191b85f987584b370898b704d9e56f832 100644 --- a/data_chain/apps/base/task/worker/generate_dataset_worker.py +++ b/data_chain/apps/base/task/worker/generate_dataset_worker.py @@ -157,7 +157,7 @@ class GenerateDataSetWorker(BaseWorker): l = j-1 r = j+1 tokens_sub = 0 - while TokenTool.get_tokens(chunk) < max(llm.max_tokens//2, 2048): + while TokenTool.get_tokens(chunk) < max(llm.max_tokens//4, 2048): if l < 0 and r >= len(doc_chunk.chunks): break if tokens_sub > 0: @@ -186,7 +186,7 @@ class GenerateDataSetWorker(BaseWorker): rd -= 1 try: sys_call = q_generate_prompt_template.format( - k=2*(qa_cnt-len(qs)), + k=5*(qa_cnt-len(qs)), content=TokenTool.get_k_tokens_words_from_content(chunk, llm.max_tokens) ) usr_call = '请输出问题的列表' diff --git a/data_chain/apps/base/task/worker/parse_document_worker.py b/data_chain/apps/base/task/worker/parse_document_worker.py index 1e9ddf8a0b8ad0713526c066d3d4fc12a30f5880..a1b568a504089ddeb0569215119c6a90de8590d5 100644 --- a/data_chain/apps/base/task/worker/parse_document_worker.py +++ b/data_chain/apps/base/task/worker/parse_document_worker.py @@ -126,8 +126,8 @@ class ParseDocumentWorker(BaseWorker): async def parse_doc(doc_entity: DocumentEntity, file_path: str) -> ParseResult: '''解析文档''' extension = doc_entity.extension - if doc_entity.parse_method == ParseMethod.DEEP.value: - extension += '.deep' + if doc_entity.parse_method == ParseMethod.DEEP.value or doc_entity.parse_method == ParseMethod.FINE.value: + extension += '.' + doc_entity.parse_method parse_result = await BaseParser.parser(extension, file_path) return parse_result @@ -380,11 +380,11 @@ class ParseDocumentWorker(BaseWorker): if len(node.title) == 0: if llm is not None: content = '' - for node in node.link_nodes: - if node.title: - content += node.title + '\n' + for cnode in node.link_nodes: + if cnode.title: + content += cnode.title + '\n' else: - sentences = TokenTool.get_top_k_keysentence(node.content, 1) + sentences = TokenTool.get_top_k_keysentence(cnode.content, 1) if sentences: content += sentences[0] + '\n' if content: @@ -444,6 +444,12 @@ class ParseDocumentWorker(BaseWorker): for node in parse_result.nodes: if not node.content: continue + new_content = node.content.strip() + # 剔除非new_content utf-8编码的内容 + new_content = new_content.encode('utf-8', 'ignore').decode('utf-8') + if not new_content: + continue + node.content = new_content chunk_entity = ChunkEntity( id=node.id, team_id=doc_entity.team_id, @@ -490,7 +496,7 @@ class ParseDocumentWorker(BaseWorker): raise Exception(err) await DocumentManager.update_document_by_doc_id(task_entity.op_id, {"status": DocumentStatus.RUNNING.value}) try: - if doc_entity.parse_method == ParseMethod.EHANCED.value or doc_entity.parse_method == ParseMethod.DEEP.value: + if doc_entity.parse_method == ParseMethod.EHANCED.value or doc_entity.parse_method == ParseMethod.DEEP.value or doc_entity.parse_method == ParseMethod.FINE.value: llm = LLM( openai_api_key=config['OPENAI_API_KEY'], openai_api_base=config['OPENAI_API_BASE'], diff --git a/data_chain/apps/base/zip_handler.py b/data_chain/apps/base/zip_handler.py index 3285eb8ea44e4a59b1fa56803fd79bda9b3707dd..c97e74e7dbed5cae114b2168e125acc5eb462ce7 100644 --- a/data_chain/apps/base/zip_handler.py +++ b/data_chain/apps/base/zip_handler.py @@ -8,6 +8,17 @@ from data_chain.logger.logger import logger as logging class ZipHandler(): + '''处理zip文件的类''' + @staticmethod + def is_zip_file(file_path: str) -> bool: + '''检查文件是否为zip文件''' + if not os.path.exists(file_path): + logging.error("[ZipHandler] 文件 %s 不存在", file_path) + return False + if not zipfile.is_zipfile(file_path): + logging.error("[ZipHandler] 文件 %s 不是一个有效的zip文件", file_path) + return False + return True @staticmethod def check_zip_file(zip_file_path: str, max_file_num: int = 4096, max_file_size: int = 10 * 1024 * 1024 * 1024) -> bool: diff --git a/data_chain/apps/service/session_service.py b/data_chain/apps/service/session_service.py index 320e48775a52769ec53fb9b45416cc976eae218f..1f62060aaebb9cb457533c2996bfc0450df87788 100644 --- a/data_chain/apps/service/session_service.py +++ b/data_chain/apps/service/session_service.py @@ -37,7 +37,7 @@ async def verify_user(request: HTTPConnection): except: raise UserHTTPException(status_code=status.HTTP_401_UNAUTHORIZED, retcode=401, rtmsg="Authentication Error.", data="") - if not SessionManager.verify_user(session_id): + if not (await SessionManager.verify_user(session_id)): raise UserHTTPException(status_code=status.HTTP_401_UNAUTHORIZED, retcode=401, rtmsg="Authentication Error.", data="") diff --git a/data_chain/common/prompt.yaml b/data_chain/common/prompt.yaml index 691b2b4b7d0352464796dafe39024442ec3c3862..70ba9e2b047707aac037375fef93b7d51956a42c 100644 --- a/data_chain/common/prompt.yaml +++ b/data_chain/common/prompt.yaml @@ -159,7 +159,8 @@ GENREATE_QUESTION_FROM_CONTENT_PROMPT: '你是一个文本分析专家,你的 #01 问题必须来源于文本中的内容 #02 单个问题长度不超过50个字 #03 不要输出重复的问题 - #04 请仅输出问题列表,不要输出其他内容 + #04 输出的问题要多样,覆盖文本中的不同方面 + #05 请仅输出问题列表,不要输出其他内容 例子: 输入:openEuler是一个开源的操作系统,旨在为云计算和边缘计算提供支持。它具有高性能、高安全性和高可靠性等特点。 输出: diff --git a/data_chain/entities/enum.py b/data_chain/entities/enum.py index 86cb2e85f3523ad83c223eca6cb5bcdd68bb1ed8..5866976ef61ba05ffdd9b50f13d323bcac93e40f 100644 --- a/data_chain/entities/enum.py +++ b/data_chain/entities/enum.py @@ -41,6 +41,7 @@ class ParseMethod(str, Enum): EHANCED = "enhanced" QA = "qa" DEEP = "deep" + FINE = "fine" class UserStatus(str, Enum): diff --git a/data_chain/parser/handler/fine_pdf_parser.py b/data_chain/parser/handler/fine_pdf_parser.py new file mode 100644 index 0000000000000000000000000000000000000000..707b729245bcb1e7419c7a204b6864ae2ad91027 --- /dev/null +++ b/data_chain/parser/handler/fine_pdf_parser.py @@ -0,0 +1,49 @@ +import asyncio +import os +import shutil +from uuid import uuid4 +from data_chain.entities.enum import DocParseRelutTopology, ChunkParseTopology, ChunkType +from data_chain.parser.parse_result import ParseNode, ParseResult +from data_chain.parser.handler.base_parser import BaseParser +from data_chain.parser.tools.ocr_tool import OcrTool +from data_chain.logger.logger import logger as logging +from data_chain.parser.handler.md_zip_parser import MdZipParser +from data_chain.parser.tools.instruct_scan_tool import InstructScanTool + + +class FinePdfParser(BaseParser): + name = 'pdf.fine' + + @staticmethod + async def parser(file_path: str) -> ParseResult: + if InstructScanTool.check_avx512_support(): + from marker.converters.pdf import PdfConverter + from marker.models import create_model_dict + from marker.config.parser import ConfigParser + from marker.output import text_from_rendered, save_output + fname_base = os.path.splitext(os.path.basename(file_path))[0] + output_dir = os.path.dirname(file_path) + md_path = os.path.join(output_dir, fname_base) + config = { + "output_format": "markdown", + "ADDITIONAL_KEY": "VALUE" + } + config_parser = ConfigParser(config) + + converter = PdfConverter( + config=config_parser.generate_config_dict(), + artifact_dict=create_model_dict(), + processor_list=config_parser.get_processors(), + renderer=config_parser.get_renderer(), + llm_service=config_parser.get_llm_service() + ) + rendered = converter(file_path) + if os.path.exists(md_path): + shutil.rmtree(md_path) + os.makedirs(md_path, exist_ok=True) + save_output(rendered, md_path, fname_base) + result = await MdZipParser.parser(md_path) + return result + else: + logging.error("[FinePdfParser] 当前机器不支持 AVX-512,无法进行PDF解析") + raise Exception("[FinePdfParser] 当前机器不支持 AVX-512,无法进行PDF解析") diff --git a/data_chain/parser/handler/md_zip_parser.py b/data_chain/parser/handler/md_zip_parser.py index 3804f9fe81615ac31a134065a222264f76d16f93..c80dc860ac98f89173766ed97d6374ba831f0382 100644 --- a/data_chain/parser/handler/md_zip_parser.py +++ b/data_chain/parser/handler/md_zip_parser.py @@ -166,7 +166,6 @@ class MdZipParser(BaseParser): code_text = element.get_text().strip() node = ParseNode( id=uuid.uuid4(), - lv=current_level, parse_topology_type=ChunkParseTopology.TREELEAF, content=code_text, @@ -179,7 +178,6 @@ class MdZipParser(BaseParser): if para_text: node = ParseNode( id=uuid.uuid4(), - lv=current_level, parse_topology_type=ChunkParseTopology.TREELEAF, content=para_text, @@ -193,7 +191,6 @@ class MdZipParser(BaseParser): if img_blob: node = ParseNode( id=uuid.uuid4(), - lv=current_level, parse_topology_type=ChunkParseTopology.TREELEAF, content=img_blob, @@ -206,7 +203,6 @@ class MdZipParser(BaseParser): for row in table_array: node = ParseNode( id=uuid.uuid4(), - lv=current_level, parse_topology_type=ChunkParseTopology.TREELEAF, content=row, @@ -226,7 +222,6 @@ class MdZipParser(BaseParser): @staticmethod async def markdown_to_tree(file_path: str, markdown_text: str) -> ParseNode: html = markdown.markdown(markdown_text, extensions=['tables']) - logging.error(html) root = ParseNode( id=uuid.uuid4(), title="", @@ -243,8 +238,13 @@ class MdZipParser(BaseParser): @staticmethod async def parser(file_path: str) -> ParseResult: - target_file_path = os.path.join(os.path.dirname(file_path), 'temp') - await ZipHandler.unzip_file(file_path, target_file_path) + if ZipHandler.is_zip_file(file_path): + target_file_path = os.path.join(os.path.dirname(file_path), 'temp') + await ZipHandler.unzip_file(file_path, target_file_path) + elif os.path.isdir(file_path): + target_file_path = file_path + else: + target_file_path = None # 递归查找markdown文件 markdown_file_path_list = [] for root, dirs, files in os.walk(target_file_path): diff --git a/data_chain/parser/tools/instruct_scan_tool.py b/data_chain/parser/tools/instruct_scan_tool.py index 9a849f40ae7b67e9473f0adb409241ebdeb7ed67..0dd6a4750f5675e15695e7dc2c9e7ac01b904459 100644 --- a/data_chain/parser/tools/instruct_scan_tool.py +++ b/data_chain/parser/tools/instruct_scan_tool.py @@ -13,6 +13,10 @@ class InstructScanTool: "Maybe": 无法确定是否支持 """ try: + # 当前环境为arm则返回 True + machine = platform.machine().lower() + if machine.startswith('arm') or machine.startswith('aarch64'): + return True # 优先使用 cpuinfo 库获取精确信息 info = cpuinfo.get_cpu_info() flags = info.get('flags', []) @@ -46,6 +50,9 @@ class InstructScanTool: """ 回退到基于平台命令的检测方法(原实现) """ + machine = platform.machine().lower() + if machine.startswith('arm') or machine.startswith('aarch64'): + return True system = platform.system() if system == "Linux": diff --git a/requirements.txt b/requirements.txt index 81e5e5c01eb7ceb028d9970c563f36df74fead7e..e2b1e3981564323a264e26273bca90b2c45f447d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -14,11 +14,11 @@ jieba==0.42.1 langchain==0.3.7 langchain-openai==0.2.5 minio==7.2.4 -markdown2==2.4.13 +markdown2==2.5.2 markdown==3.3.4 more-itertools==10.1.0 numpy==1.26.4 -openai==1.53.0 +openai==1.65.2 opencv-python==4.9.0.80 openpyxl==3.1.2 paddleocr==2.9.1 @@ -38,7 +38,7 @@ pyyaml==6.0.1 pymongo==4.12.1 redis==5.0.3 requests==2.32.2 -scikit-learn==1.5.0 +scikit-learn==1.6.1 sqlalchemy==2.0.23 starlette==0.37.2 tika==2.6.0 @@ -46,4 +46,6 @@ tiktoken==0.8.0 urllib3==2.2.1 uvicorn==0.21.0 xlrd==2.0.1 -py-cpuinfo==9.0.0 \ No newline at end of file +py-cpuinfo==9.0.0 +opengauss-sqlalchemy==2.4.0 +#marker-pdf==1.8.0 \ No newline at end of file