From e329cdaca8bcf6b6005b58e1f89060ffd55fa868 Mon Sep 17 00:00:00 2001 From: zxstty Date: Fri, 27 Jun 2025 10:40:21 +0800 Subject: [PATCH] =?UTF-8?q?=E5=AE=8C=E5=96=84deep=20pdf=E8=A7=A3=E6=9E=90?= =?UTF-8?q?=E6=A8=A1=E5=BC=8F=E4=B8=8B=E6=A8=A1=E5=9E=8B=E8=B7=AF=E5=BE=84?= =?UTF-8?q?=E9=85=8D=E7=BD=AE?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../base/task/worker/parse_document_worker.py | 2 ++ data_chain/manager/chunk_manager.py | 3 +- data_chain/parser/handler/deep_pdf_parser.py | 35 +++++++++++++++---- data_chain/parser/handler/pdf_parser.py | 7 ++-- data_chain/parser/parse_result.py | 1 + data_chain/rag/doc2chunk_bfs_searcher.py | 4 +-- data_chain/rag/doc2chunk_searcher.py | 4 +-- data_chain/rag/keyword_and_vector_searcher.py | 4 +-- data_chain/rag/query_extend_searcher.py | 4 +-- data_chain/rag/vector_searcher.py | 4 +-- download_model.py | 2 -- 11 files changed, 44 insertions(+), 26 deletions(-) diff --git a/data_chain/apps/base/task/worker/parse_document_worker.py b/data_chain/apps/base/task/worker/parse_document_worker.py index 8073d90d..1e9ddf8a 100644 --- a/data_chain/apps/base/task/worker/parse_document_worker.py +++ b/data_chain/apps/base/task/worker/parse_document_worker.py @@ -304,6 +304,8 @@ class ParseDocumentWorker(BaseWorker): else: if node.is_need_newline: nodes[-1].content += '\n' + elif node.is_need_space: + nodes[-1].content += ' ' nodes[-1].content += node.content else: nodes.append(node) diff --git a/data_chain/manager/chunk_manager.py b/data_chain/manager/chunk_manager.py index 0ad089c6..a45c9fb3 100644 --- a/data_chain/manager/chunk_manager.py +++ b/data_chain/manager/chunk_manager.py @@ -176,7 +176,7 @@ class ChunkManager(): async with await DataBase.get_session() as session: fetch_cnt = top_k chunk_entities = [] - while fetch_cnt <= max(top_k, 8192): + while True: # 计算相似度分数 similarity_score = ChunkEntity.text_vector.cosine_distance(vector).label("similarity_score") @@ -213,6 +213,7 @@ class ChunkManager(): if chunk_entities: break fetch_cnt *= 2 + fetch_cnt = min(fetch_cnt, max(8192, top_k)+1) chunk_entities = chunk_entities[:top_k] # 确保返回的结果不超过 top_k return chunk_entities except Exception as e: diff --git a/data_chain/parser/handler/deep_pdf_parser.py b/data_chain/parser/handler/deep_pdf_parser.py index 8a3facf3..51c3ef25 100644 --- a/data_chain/parser/handler/deep_pdf_parser.py +++ b/data_chain/parser/handler/deep_pdf_parser.py @@ -51,7 +51,16 @@ class ParseNodeWithBbox(BaseModel): class DeepPdfParser(BaseParser): name = 'pdf.deep' - ocr = PaddleOCR(use_angle_cls=True, lang="ch") # 使用中文语言模型 + det_model_dir = 'data_chain/parser/model/ocr/ch_PP-OCRv4_det_infer' + rec_model_dir = 'data_chain/parser/model/ocr/ch_PP-OCRv4_rec_infer' + cls_model_dir = 'data_chain/parser/model/ocr/ch_ppocr_mobile_v2.0_cls_infer' + ocr = PaddleOCR( + det_model_dir=det_model_dir, + rec_model_dir=rec_model_dir, + cls_model_dir=cls_model_dir, + use_angle_cls=True, + lang="ch" + ) # 使用中文语言模型 @staticmethod async def extract_text_from_page( @@ -587,10 +596,21 @@ class DeepPdfParser(BaseParser): exclude_regions = table_regions + image_regions # 提取文本时排除表格和图片区域 - text_nodes_with_bbox = await DeepPdfParser.extract_text_from_page(page, exclude_regions) - if not text_nodes_with_bbox: - text_nodes_with_bbox = await DeepPdfParser.extract_text_from_page_by_ocr( + text_nodes_with_bbox_1 = await DeepPdfParser.extract_text_from_page(page, exclude_regions) + text_nodes_with_bbox_2 = [] + text_len_1 = 0 + text_len_2 = 0 + for node in text_nodes_with_bbox_1: + text_len_1 += len(node.node.content) + if text_len_1 < 100: + text_nodes_with_bbox_2 = await DeepPdfParser.extract_text_from_page_by_ocr( image_path, exclude_regions) + for node in text_nodes_with_bbox_2: + text_len_2 += len(node.node.content) + if text_len_1 > text_len_2: + text_nodes_with_bbox = text_nodes_with_bbox_1 + else: + text_nodes_with_bbox = text_nodes_with_bbox_2 # 合并所有节点 sub_nodes_with_bbox = await DeepPdfParser.merge_nodes_with_bbox( text_nodes_with_bbox, table_nodes_with_bbox) @@ -602,8 +622,11 @@ class DeepPdfParser(BaseParser): for i in range(1, len(nodes_with_bbox)): '''根据bbox判断是否要进行换行''' if nodes_with_bbox[i].bbox.y0 > nodes_with_bbox[i-1].bbox.y1 + 1: - nodes_with_bbox[i].node.is_need_newline = True - + nodes_with_bbox[i-1].node.is_need_newline = True + for i in range(1, len(nodes_with_bbox)): + '''根据bbox判断是否要进行空格''' + if i > 0 and nodes_with_bbox[i].bbox.x0 > nodes_with_bbox[i-1].bbox.x1 + 1: + nodes_with_bbox[i-1].node.is_need_space = True nodes = [node_with_bbox.node for node_with_bbox in nodes_with_bbox] DeepPdfParser.image_related_node_in_link_nodes(nodes) # 假设这个方法在别处定义 parse_result = ParseResult( diff --git a/data_chain/parser/handler/pdf_parser.py b/data_chain/parser/handler/pdf_parser.py index 65248a01..a492651c 100644 --- a/data_chain/parser/handler/pdf_parser.py +++ b/data_chain/parser/handler/pdf_parser.py @@ -302,8 +302,11 @@ class PdfParser(BaseParser): for i in range(1, len(nodes_with_bbox)): '''根据bbox判断是否要进行换行''' if nodes_with_bbox[i].bbox.y0 > nodes_with_bbox[i-1].bbox.y1 + 1: - nodes_with_bbox[i].node.is_need_newline = True - + nodes_with_bbox[i-1].node.is_need_newline = True + for i in range(1, len(nodes_with_bbox)): + '''根据bbox判断是否要进行空格''' + if nodes_with_bbox[i].bbox.x0 > nodes_with_bbox[i-1].bbox.x1 + 1: + nodes_with_bbox[i-1].node.is_need_space = True nodes = [node_with_bbox.node for node_with_bbox in nodes_with_bbox] PdfParser.image_related_node_in_link_nodes(nodes) # 假设这个方法在别处定义 parse_result = ParseResult( diff --git a/data_chain/parser/parse_result.py b/data_chain/parser/parse_result.py index 965aa921..b69e3f45 100644 --- a/data_chain/parser/parse_result.py +++ b/data_chain/parser/parse_result.py @@ -18,6 +18,7 @@ class ParseNode(BaseModel): content: Any = Field(..., description="节点内容") type: ChunkType = Field(..., description="节点类型") link_nodes: list = Field(..., description="链接节点") + is_need_space: bool = Field(default=False, description="是否需要空格") is_need_newline: bool = Field(default=False, description="是否需要换行") diff --git a/data_chain/rag/doc2chunk_bfs_searcher.py b/data_chain/rag/doc2chunk_bfs_searcher.py index 7b5d9cc6..c72e7bd1 100644 --- a/data_chain/rag/doc2chunk_bfs_searcher.py +++ b/data_chain/rag/doc2chunk_bfs_searcher.py @@ -37,9 +37,7 @@ class Doc2ChunkBfsSearcher(BaseSearcher): root_chunk_entities_vector = [] for _ in range(3): try: - root_chunk_entities_vector = await asyncio.wait_for(ChunkManager.get_top_k_chunk_by_kb_id_vector(kb_id, vector, top_k-len(root_chunk_entities_keyword), doc_ids, banned_ids, ChunkParseTopology.TREEROOT.value), timeout=10) - if not root_chunk_entities_vector: - root_chunk_entities_vector = await asyncio.wait_for(ChunkManager.get_top_k_chunk_by_kb_id_vector(kb_id, vector, top_k-len(root_chunk_entities_keyword), doc_ids, banned_ids, ChunkParseTopology.TREEROOT.value), timeout=10) + root_chunk_entities_vector = await asyncio.wait_for(ChunkManager.get_top_k_chunk_by_kb_id_vector(kb_id, vector, top_k-len(root_chunk_entities_keyword), doc_ids, banned_ids, ChunkParseTopology.TREEROOT.value), timeout=3) break except Exception as e: err = f"[KeywordVectorSearcher] 向量检索失败,error: {e}" diff --git a/data_chain/rag/doc2chunk_searcher.py b/data_chain/rag/doc2chunk_searcher.py index f88d478c..40510d51 100644 --- a/data_chain/rag/doc2chunk_searcher.py +++ b/data_chain/rag/doc2chunk_searcher.py @@ -53,9 +53,7 @@ class Doc2ChunkSearcher(BaseSearcher): chunk_entities_vector = [] for _ in range(3): try: - chunk_entities_vector = await asyncio.wait_for(ChunkManager.get_top_k_chunk_by_kb_id_vector(kb_id, vector, top_k-len(chunk_entities_keyword), use_doc_ids, banned_ids), timeout=10) - if not chunk_entities_vector: - chunk_entities_vector = await asyncio.wait_for(ChunkManager.get_top_k_chunk_by_kb_id_vector(kb_id, vector, top_k-len(chunk_entities_keyword), use_doc_ids, banned_ids), timeout=10) + chunk_entities_vector = await asyncio.wait_for(ChunkManager.get_top_k_chunk_by_kb_id_vector(kb_id, vector, top_k-len(chunk_entities_keyword), use_doc_ids, banned_ids), timeout=3) break except Exception as e: err = f"[KeywordVectorSearcher] 向量检索失败,error: {e}" diff --git a/data_chain/rag/keyword_and_vector_searcher.py b/data_chain/rag/keyword_and_vector_searcher.py index c4d14ac2..c3bbe937 100644 --- a/data_chain/rag/keyword_and_vector_searcher.py +++ b/data_chain/rag/keyword_and_vector_searcher.py @@ -42,11 +42,9 @@ class KeywordVectorSearcher(BaseSearcher): try: import time start_time = time.time() - chunk_entities_get_by_vector = await asyncio.wait_for(ChunkManager.get_top_k_chunk_by_kb_id_vector(kb_id, vector, top_k-len(chunk_entities_get_by_keyword)-len(chunk_entities_get_by_dynamic_weighted_keyword), doc_ids, banned_ids), timeout=10) + chunk_entities_get_by_vector = await asyncio.wait_for(ChunkManager.get_top_k_chunk_by_kb_id_vector(kb_id, vector, top_k-len(chunk_entities_get_by_keyword)-len(chunk_entities_get_by_dynamic_weighted_keyword), doc_ids, banned_ids), timeout=3) end_time = time.time() logging.info(f"[KeywordVectorSearcher] 向量检索成功完成,耗时: {end_time - start_time:.2f}秒") - if not chunk_entities_get_by_vector: - chunk_entities_get_by_vector = await asyncio.wait_for(ChunkManager.get_top_k_chunk_by_kb_id_vector(kb_id, vector, top_k-len(chunk_entities_get_by_keyword)-len(chunk_entities_get_by_dynamic_weighted_keyword), doc_ids, banned_ids), timeout=10) break except Exception as e: err = f"[KeywordVectorSearcher] 向量检索失败,error: {e}" diff --git a/data_chain/rag/query_extend_searcher.py b/data_chain/rag/query_extend_searcher.py index 67b06810..a09f660b 100644 --- a/data_chain/rag/query_extend_searcher.py +++ b/data_chain/rag/query_extend_searcher.py @@ -61,9 +61,7 @@ class QueryExtendSearcher(BaseSearcher): chunk_entities_get_by_vector = [] for _ in range(3): try: - chunk_entities_get_by_vector = await asyncio.wait_for(ChunkManager.get_top_k_chunk_by_kb_id_vector(kb_id, vector, top_k-len(chunk_entities_get_by_keyword), doc_ids, banned_ids), timeout=10) - if not chunk_entities_get_by_vector: - chunk_entities_get_by_vector = await asyncio.wait_for(ChunkManager.get_top_k_chunk_by_kb_id_vector(kb_id, vector, top_k-len(chunk_entities_get_by_keyword), doc_ids, banned_ids), timeout=10) + chunk_entities_get_by_vector = await asyncio.wait_for(ChunkManager.get_top_k_chunk_by_kb_id_vector(kb_id, vector, top_k-len(chunk_entities_get_by_keyword), doc_ids, banned_ids), timeout=3) break except Exception as e: err = f"[KeywordVectorSearcher] 向量检索失败,error: {e}" diff --git a/data_chain/rag/vector_searcher.py b/data_chain/rag/vector_searcher.py index 9716471a..dad5e867 100644 --- a/data_chain/rag/vector_searcher.py +++ b/data_chain/rag/vector_searcher.py @@ -29,9 +29,7 @@ class VectorSearcher(BaseSearcher): chunk_entities = [] for _ in range(3): try: - chunk_entities = await asyncio.wait_for(ChunkManager.get_top_k_chunk_by_kb_id_vector(kb_id, vector, top_k, doc_ids, banned_ids), timeout=10) - if not chunk_entities: - chunk_entities = await asyncio.wait_for(ChunkManager.get_top_k_chunk_by_kb_id_vector(kb_id, vector, top_k, doc_ids, banned_ids), timeout=10) + chunk_entities = await asyncio.wait_for(ChunkManager.get_top_k_chunk_by_kb_id_vector(kb_id, vector, top_k, doc_ids, banned_ids), timeout=3) break except Exception as e: err = f"[VectorSearcher] 向量检索失败,error: {e}" diff --git a/download_model.py b/download_model.py index a18ef347..bcf7f39f 100644 --- a/download_model.py +++ b/download_model.py @@ -1,5 +1,3 @@ import tiktoken -from paddleocr import PaddleOCR -ocr = PaddleOCR(use_angle_cls=True, lang="ch") # 使用中文语言模型 enc = tiktoken.encoding_for_model("gpt-4") print(len(enc.encode('hello world'))) \ No newline at end of file -- Gitee