diff --git a/data_chain/apps/base/task/worker/parse_document_worker.py b/data_chain/apps/base/task/worker/parse_document_worker.py index 8073d90da07a1f81aeae21d1d19f86a47d7ef20d..1e9ddf8a0b8ad0713526c066d3d4fc12a30f5880 100644 --- a/data_chain/apps/base/task/worker/parse_document_worker.py +++ b/data_chain/apps/base/task/worker/parse_document_worker.py @@ -304,6 +304,8 @@ class ParseDocumentWorker(BaseWorker): else: if node.is_need_newline: nodes[-1].content += '\n' + elif node.is_need_space: + nodes[-1].content += ' ' nodes[-1].content += node.content else: nodes.append(node) diff --git a/data_chain/manager/chunk_manager.py b/data_chain/manager/chunk_manager.py index 0ad089c63c8f44307e1f2bf017ce25cb9d744761..a45c9fb3a552e9ba91e48a99ab1cd2c85cdda407 100644 --- a/data_chain/manager/chunk_manager.py +++ b/data_chain/manager/chunk_manager.py @@ -176,7 +176,7 @@ class ChunkManager(): async with await DataBase.get_session() as session: fetch_cnt = top_k chunk_entities = [] - while fetch_cnt <= max(top_k, 8192): + while True: # 计算相似度分数 similarity_score = ChunkEntity.text_vector.cosine_distance(vector).label("similarity_score") @@ -213,6 +213,7 @@ class ChunkManager(): if chunk_entities: break fetch_cnt *= 2 + fetch_cnt = min(fetch_cnt, max(8192, top_k)+1) chunk_entities = chunk_entities[:top_k] # 确保返回的结果不超过 top_k return chunk_entities except Exception as e: diff --git a/data_chain/parser/handler/deep_pdf_parser.py b/data_chain/parser/handler/deep_pdf_parser.py index 8a3facf3340f0686038e3b7f1aa537faa92c9f2c..51c3ef25271fba2f58a9be6802acbd11cbc907d6 100644 --- a/data_chain/parser/handler/deep_pdf_parser.py +++ b/data_chain/parser/handler/deep_pdf_parser.py @@ -51,7 +51,16 @@ class ParseNodeWithBbox(BaseModel): class DeepPdfParser(BaseParser): name = 'pdf.deep' - ocr = PaddleOCR(use_angle_cls=True, lang="ch") # 使用中文语言模型 + det_model_dir = 'data_chain/parser/model/ocr/ch_PP-OCRv4_det_infer' + rec_model_dir = 'data_chain/parser/model/ocr/ch_PP-OCRv4_rec_infer' + cls_model_dir = 'data_chain/parser/model/ocr/ch_ppocr_mobile_v2.0_cls_infer' + ocr = PaddleOCR( + det_model_dir=det_model_dir, + rec_model_dir=rec_model_dir, + cls_model_dir=cls_model_dir, + use_angle_cls=True, + lang="ch" + ) # 使用中文语言模型 @staticmethod async def extract_text_from_page( @@ -587,10 +596,21 @@ class DeepPdfParser(BaseParser): exclude_regions = table_regions + image_regions # 提取文本时排除表格和图片区域 - text_nodes_with_bbox = await DeepPdfParser.extract_text_from_page(page, exclude_regions) - if not text_nodes_with_bbox: - text_nodes_with_bbox = await DeepPdfParser.extract_text_from_page_by_ocr( + text_nodes_with_bbox_1 = await DeepPdfParser.extract_text_from_page(page, exclude_regions) + text_nodes_with_bbox_2 = [] + text_len_1 = 0 + text_len_2 = 0 + for node in text_nodes_with_bbox_1: + text_len_1 += len(node.node.content) + if text_len_1 < 100: + text_nodes_with_bbox_2 = await DeepPdfParser.extract_text_from_page_by_ocr( image_path, exclude_regions) + for node in text_nodes_with_bbox_2: + text_len_2 += len(node.node.content) + if text_len_1 > text_len_2: + text_nodes_with_bbox = text_nodes_with_bbox_1 + else: + text_nodes_with_bbox = text_nodes_with_bbox_2 # 合并所有节点 sub_nodes_with_bbox = await DeepPdfParser.merge_nodes_with_bbox( text_nodes_with_bbox, table_nodes_with_bbox) @@ -602,8 +622,11 @@ class DeepPdfParser(BaseParser): for i in range(1, len(nodes_with_bbox)): '''根据bbox判断是否要进行换行''' if nodes_with_bbox[i].bbox.y0 > nodes_with_bbox[i-1].bbox.y1 + 1: - nodes_with_bbox[i].node.is_need_newline = True - + nodes_with_bbox[i-1].node.is_need_newline = True + for i in range(1, len(nodes_with_bbox)): + '''根据bbox判断是否要进行空格''' + if i > 0 and nodes_with_bbox[i].bbox.x0 > nodes_with_bbox[i-1].bbox.x1 + 1: + nodes_with_bbox[i-1].node.is_need_space = True nodes = [node_with_bbox.node for node_with_bbox in nodes_with_bbox] DeepPdfParser.image_related_node_in_link_nodes(nodes) # 假设这个方法在别处定义 parse_result = ParseResult( diff --git a/data_chain/parser/handler/pdf_parser.py b/data_chain/parser/handler/pdf_parser.py index 65248a01a3382f5c9a53350c114734d8fa302717..a492651c4732d3c044dae5844fceb4c3957f39ff 100644 --- a/data_chain/parser/handler/pdf_parser.py +++ b/data_chain/parser/handler/pdf_parser.py @@ -302,8 +302,11 @@ class PdfParser(BaseParser): for i in range(1, len(nodes_with_bbox)): '''根据bbox判断是否要进行换行''' if nodes_with_bbox[i].bbox.y0 > nodes_with_bbox[i-1].bbox.y1 + 1: - nodes_with_bbox[i].node.is_need_newline = True - + nodes_with_bbox[i-1].node.is_need_newline = True + for i in range(1, len(nodes_with_bbox)): + '''根据bbox判断是否要进行空格''' + if nodes_with_bbox[i].bbox.x0 > nodes_with_bbox[i-1].bbox.x1 + 1: + nodes_with_bbox[i-1].node.is_need_space = True nodes = [node_with_bbox.node for node_with_bbox in nodes_with_bbox] PdfParser.image_related_node_in_link_nodes(nodes) # 假设这个方法在别处定义 parse_result = ParseResult( diff --git a/data_chain/parser/parse_result.py b/data_chain/parser/parse_result.py index 965aa921b4221030a416025248cdb030d2c84d6b..b69e3f451ae5c3335e2a2c4578ae5d4535f023b1 100644 --- a/data_chain/parser/parse_result.py +++ b/data_chain/parser/parse_result.py @@ -18,6 +18,7 @@ class ParseNode(BaseModel): content: Any = Field(..., description="节点内容") type: ChunkType = Field(..., description="节点类型") link_nodes: list = Field(..., description="链接节点") + is_need_space: bool = Field(default=False, description="是否需要空格") is_need_newline: bool = Field(default=False, description="是否需要换行") diff --git a/data_chain/rag/doc2chunk_bfs_searcher.py b/data_chain/rag/doc2chunk_bfs_searcher.py index 7b5d9cc6c7e17bac749b02ff5eb877e6c7a24c48..c72e7bd1c1a05b2c9ea68e027dac522714871210 100644 --- a/data_chain/rag/doc2chunk_bfs_searcher.py +++ b/data_chain/rag/doc2chunk_bfs_searcher.py @@ -37,9 +37,7 @@ class Doc2ChunkBfsSearcher(BaseSearcher): root_chunk_entities_vector = [] for _ in range(3): try: - root_chunk_entities_vector = await asyncio.wait_for(ChunkManager.get_top_k_chunk_by_kb_id_vector(kb_id, vector, top_k-len(root_chunk_entities_keyword), doc_ids, banned_ids, ChunkParseTopology.TREEROOT.value), timeout=10) - if not root_chunk_entities_vector: - root_chunk_entities_vector = await asyncio.wait_for(ChunkManager.get_top_k_chunk_by_kb_id_vector(kb_id, vector, top_k-len(root_chunk_entities_keyword), doc_ids, banned_ids, ChunkParseTopology.TREEROOT.value), timeout=10) + root_chunk_entities_vector = await asyncio.wait_for(ChunkManager.get_top_k_chunk_by_kb_id_vector(kb_id, vector, top_k-len(root_chunk_entities_keyword), doc_ids, banned_ids, ChunkParseTopology.TREEROOT.value), timeout=3) break except Exception as e: err = f"[KeywordVectorSearcher] 向量检索失败,error: {e}" diff --git a/data_chain/rag/doc2chunk_searcher.py b/data_chain/rag/doc2chunk_searcher.py index f88d478c2175126bb9f69d8d284a727bafd66b0c..40510d513c9f2355a87a3ac3efe64f1fb008794c 100644 --- a/data_chain/rag/doc2chunk_searcher.py +++ b/data_chain/rag/doc2chunk_searcher.py @@ -53,9 +53,7 @@ class Doc2ChunkSearcher(BaseSearcher): chunk_entities_vector = [] for _ in range(3): try: - chunk_entities_vector = await asyncio.wait_for(ChunkManager.get_top_k_chunk_by_kb_id_vector(kb_id, vector, top_k-len(chunk_entities_keyword), use_doc_ids, banned_ids), timeout=10) - if not chunk_entities_vector: - chunk_entities_vector = await asyncio.wait_for(ChunkManager.get_top_k_chunk_by_kb_id_vector(kb_id, vector, top_k-len(chunk_entities_keyword), use_doc_ids, banned_ids), timeout=10) + chunk_entities_vector = await asyncio.wait_for(ChunkManager.get_top_k_chunk_by_kb_id_vector(kb_id, vector, top_k-len(chunk_entities_keyword), use_doc_ids, banned_ids), timeout=3) break except Exception as e: err = f"[KeywordVectorSearcher] 向量检索失败,error: {e}" diff --git a/data_chain/rag/keyword_and_vector_searcher.py b/data_chain/rag/keyword_and_vector_searcher.py index c4d14ac2c3d133c2b362b2e9f4d0d18fed251c02..c3bbe937a70e6c5aa0f3ded47dd8dc22bc9b742e 100644 --- a/data_chain/rag/keyword_and_vector_searcher.py +++ b/data_chain/rag/keyword_and_vector_searcher.py @@ -42,11 +42,9 @@ class KeywordVectorSearcher(BaseSearcher): try: import time start_time = time.time() - chunk_entities_get_by_vector = await asyncio.wait_for(ChunkManager.get_top_k_chunk_by_kb_id_vector(kb_id, vector, top_k-len(chunk_entities_get_by_keyword)-len(chunk_entities_get_by_dynamic_weighted_keyword), doc_ids, banned_ids), timeout=10) + chunk_entities_get_by_vector = await asyncio.wait_for(ChunkManager.get_top_k_chunk_by_kb_id_vector(kb_id, vector, top_k-len(chunk_entities_get_by_keyword)-len(chunk_entities_get_by_dynamic_weighted_keyword), doc_ids, banned_ids), timeout=3) end_time = time.time() logging.info(f"[KeywordVectorSearcher] 向量检索成功完成,耗时: {end_time - start_time:.2f}秒") - if not chunk_entities_get_by_vector: - chunk_entities_get_by_vector = await asyncio.wait_for(ChunkManager.get_top_k_chunk_by_kb_id_vector(kb_id, vector, top_k-len(chunk_entities_get_by_keyword)-len(chunk_entities_get_by_dynamic_weighted_keyword), doc_ids, banned_ids), timeout=10) break except Exception as e: err = f"[KeywordVectorSearcher] 向量检索失败,error: {e}" diff --git a/data_chain/rag/query_extend_searcher.py b/data_chain/rag/query_extend_searcher.py index 67b0681087d3a87517d8197d9cc7a9be986330d0..a09f660baeebd3e16ce61b31ae85e9723bb2fd3f 100644 --- a/data_chain/rag/query_extend_searcher.py +++ b/data_chain/rag/query_extend_searcher.py @@ -61,9 +61,7 @@ class QueryExtendSearcher(BaseSearcher): chunk_entities_get_by_vector = [] for _ in range(3): try: - chunk_entities_get_by_vector = await asyncio.wait_for(ChunkManager.get_top_k_chunk_by_kb_id_vector(kb_id, vector, top_k-len(chunk_entities_get_by_keyword), doc_ids, banned_ids), timeout=10) - if not chunk_entities_get_by_vector: - chunk_entities_get_by_vector = await asyncio.wait_for(ChunkManager.get_top_k_chunk_by_kb_id_vector(kb_id, vector, top_k-len(chunk_entities_get_by_keyword), doc_ids, banned_ids), timeout=10) + chunk_entities_get_by_vector = await asyncio.wait_for(ChunkManager.get_top_k_chunk_by_kb_id_vector(kb_id, vector, top_k-len(chunk_entities_get_by_keyword), doc_ids, banned_ids), timeout=3) break except Exception as e: err = f"[KeywordVectorSearcher] 向量检索失败,error: {e}" diff --git a/data_chain/rag/vector_searcher.py b/data_chain/rag/vector_searcher.py index 9716471af816f08cbbc25286e3f4e9b160720f70..dad5e8676792927fa28f27a0ec9b8ac0cb08a079 100644 --- a/data_chain/rag/vector_searcher.py +++ b/data_chain/rag/vector_searcher.py @@ -29,9 +29,7 @@ class VectorSearcher(BaseSearcher): chunk_entities = [] for _ in range(3): try: - chunk_entities = await asyncio.wait_for(ChunkManager.get_top_k_chunk_by_kb_id_vector(kb_id, vector, top_k, doc_ids, banned_ids), timeout=10) - if not chunk_entities: - chunk_entities = await asyncio.wait_for(ChunkManager.get_top_k_chunk_by_kb_id_vector(kb_id, vector, top_k, doc_ids, banned_ids), timeout=10) + chunk_entities = await asyncio.wait_for(ChunkManager.get_top_k_chunk_by_kb_id_vector(kb_id, vector, top_k, doc_ids, banned_ids), timeout=3) break except Exception as e: err = f"[VectorSearcher] 向量检索失败,error: {e}" diff --git a/download_model.py b/download_model.py index a18ef347208566b13fc2f62f0cd24e4043730faa..bcf7f39fdcfec3dcd9413d15aae9bbb8581852e5 100644 --- a/download_model.py +++ b/download_model.py @@ -1,5 +1,3 @@ import tiktoken -from paddleocr import PaddleOCR -ocr = PaddleOCR(use_angle_cls=True, lang="ch") # 使用中文语言模型 enc = tiktoken.encoding_for_model("gpt-4") print(len(enc.encode('hello world'))) \ No newline at end of file