diff --git a/data_chain/common/prompt.yaml b/data_chain/common/prompt.yaml index 43e5ea321a563cfed7cf7c274d019d35abee8203..691b2b4b7d0352464796dafe39024442ec3c3862 100644 --- a/data_chain/common/prompt.yaml +++ b/data_chain/common/prompt.yaml @@ -33,6 +33,10 @@ OCR_ENHANCED_PROMPT: '你是一个图片ocr内容总结专家,你的任务是 #9 不要输出坐标等信息,输出每个部分相对位置的描述即可 + #10 如果图片内容为空,请输出“图片内容为空” + + #11 如果图片本身就是一段文字,请直接输出文字内容 + 上下文:{image_related_text} 当前图片上一部分的ocr内容总结:{pre_part_description} diff --git a/data_chain/config/config.py b/data_chain/config/config.py index 742e87d236d4652ca82c844954accd0e17e19c71..4d9678d64259384ccbabace6d8d700fabf555481 100644 --- a/data_chain/config/config.py +++ b/data_chain/config/config.py @@ -64,9 +64,9 @@ class ConfigModel(DictBaseModel): HALF_KEY2: str = Field(None, description="两层密钥管理组件2") HALF_KEY3: str = Field(None, description="两层密钥管理组件3") # Prompt file - PROMPT_PATH: str = Field(None, description="prompt路径") + PROMPT_PATH: str = Field(default="./data_chain/common/prompt.yaml", description="prompt路径") # Stop Words PATH - STOP_WORDS_PATH: str = Field(None, description="停用词表存放位置") + STOP_WORDS_PATH: str = Field(default="./data_chain/common/stopwords.txt", description="停用词表存放位置") # CPU Limit USE_CPU_LIMIT: int = Field(default=64, description="文档解析器使用CPU核数") # Task Retry Time limit diff --git a/data_chain/parser/handler/deep_pdf_parser.py b/data_chain/parser/handler/deep_pdf_parser.py index c66fc4070c26f4be6341393df209093a8ac5e6a1..bf4db0867147f29cc1145acc4f0aab66db1ef370 100644 --- a/data_chain/parser/handler/deep_pdf_parser.py +++ b/data_chain/parser/handler/deep_pdf_parser.py @@ -150,103 +150,113 @@ class DeepPdfParser(BaseParser): @staticmethod async def detect_table(image_path: str) -> list[Bbox]: """ - 检测图像中的表格,返回表格区域及其内容 + 检测图像中的表格区域,返回列表[Bbox] """ - image = cv2.imread(image_path) - gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) + # 用原图直接提取绿色或蓝色通道,表格线常见于这些 + green = image[:, :, 1] # G通道 + blue = image[:, :, 0] # B通道 - # 使用改进的表格检测算法 - # 二值化 - binary = cv2.adaptiveThreshold( - ~gray, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY, 15, -10 - ) + # 选一个看起来表格线更清晰的通道来处理 + channel = green # 或 blue + + # 然后再二值化 + binary = cv2.adaptiveThreshold(channel, 255, + cv2.ADAPTIVE_THRESH_MEAN_C, + cv2.THRESH_BINARY_INV, 15, 10) - # 水平和垂直线检测 + # 提取水平和垂直线 horizontal = binary.copy() vertical = binary.copy() - # 定义水平线和垂直线的结构元素 - cols = horizontal.shape[1] - horizontal_size = cols // 30 - horizontalStructure = cv2.getStructuringElement(cv2.MORPH_RECT, (horizontal_size, 1)) - horizontal = cv2.erode(horizontal, horizontalStructure) - horizontal = cv2.dilate(horizontal, horizontalStructure) + scale = 30 # 控制提取结构元素大小,值越小越敏感 + h_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (horizontal.shape[1] // scale, 1)) + v_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (1, vertical.shape[0] // scale)) - rows = vertical.shape[0] - vertical_size = rows // 30 - verticalStructure = cv2.getStructuringElement(cv2.MORPH_RECT, (1, vertical_size)) - vertical = cv2.erode(vertical, verticalStructure) - vertical = cv2.dilate(vertical, verticalStructure) + horizontal = cv2.erode(horizontal, h_kernel) + horizontal = cv2.dilate(horizontal, h_kernel) - # 合并水平和垂直线 - mask = horizontal + vertical + vertical = cv2.erode(vertical, v_kernel) + vertical = cv2.dilate(vertical, v_kernel) - # 检测轮廓 + # 合并线条掩码 + mask = cv2.add(horizontal, vertical) + + # 轮廓检测 contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) table_bboxes = [] for contour in contours: x, y, w, h = cv2.boundingRect(contour) - # 基本过滤:过滤小区域 - if w < 100 or h < 100: + # 基础过滤(小块排除) + if w < 80 or h < 80: continue - # 计算轮廓面积与边界框面积的比率 - # area = cv2.contourArea(contour) - # rect_area = w * h - # area_ratio = area / rect_area - - # # 表格通常具有较高的面积比率和适当的宽高比 - # if area_ratio < 0.2: - # continue - - # 提取候选区域 - region = mask[y:y+h, x:x+w] - - # 计算网格密度 - 表格应该有明显的网格结构 - grid_density = np.count_nonzero(region) / (w * h) - - # 表格通常具有较高的网格密度 - # if grid_density < 0.05: - # continue - - # 计算轮廓的复杂度 + # 复杂度过滤 epsilon = 0.02 * cv2.arcLength(contour, True) approx = cv2.approxPolyDP(contour, epsilon, True) - complexity = len(approx) + if len(approx) < 4 or len(approx) > 20: + continue - # 表格轮廓通常较简单,而非表格图形可能更复杂 - if complexity > 20 or complexity < 4: + # 网格密度简单过滤 + region = mask[y:y+h, x:x+w] + density = np.count_nonzero(region) / (w * h) + if density < 0.02: continue table_bboxes.append(Bbox( - x0=float(x), - y0=float(y), - x1=float(x + w), - y1=float(y + h) + x0=float(x), y0=float(y), x1=float(x + w), y1=float(y + h) )) - # 合并相近的表格区域 - table_bboxes = sorted(table_bboxes, key=lambda bbox: (bbox.y0, bbox.x0)) + + # 按坐标排序 + table_bboxes = sorted(table_bboxes, key=lambda b: (b.y0, b.x0)) + + # 合并相邻或重叠的表格区域 merged_bboxes = [] for bbox in table_bboxes: if not merged_bboxes: merged_bboxes.append(bbox) continue - last_bbox = merged_bboxes[-1] - # 检查当前bbox是否与上一个bbox相邻或重叠 - if (abs(bbox.x0 - last_bbox.x1) < 10 and abs(bbox.y0 - last_bbox.y0) < 10) or \ - (abs(bbox.y0 - last_bbox.y1) < 10 and abs(bbox.x0 - last_bbox.x0) < 10): - # 合并两个bbox - merged_bboxes[-1] = Bbox( - x0=min(last_bbox.x0, bbox.x0), - y0=min(last_bbox.y0, bbox.y0), - x1=max(last_bbox.x1, bbox.x1), - y1=max(last_bbox.y1, bbox.y1) + last = merged_bboxes[-1] + overlap_x = min(last.x1, bbox.x1) - max(last.x0, bbox.x0) + overlap_y = min(last.y1, bbox.y1) - max(last.y0, bbox.y0) + + if overlap_x > -20 and overlap_y > -20: + merged = Bbox( + x0=min(last.x0, bbox.x0), + y0=min(last.y0, bbox.y0), + x1=max(last.x1, bbox.x1), + y1=max(last.y1, bbox.y1) ) + merged_bboxes[-1] = merged else: merged_bboxes.append(bbox) + + # 可视化调试:保存检测图 + # for b in merged_bboxes: + # cv2.rectangle(image, (int(b.x0), int(b.y0)), (int(b.x1), int(b.y1)), (0, 255, 0), 2) + # cv2.imwrite("debug_detected_tables.png", image) + # 自适应扩展每个表格区域的边界框 + for bbox in merged_bboxes: + width = bbox.x1 - bbox.x0 + height = bbox.y1 - bbox.y0 + # 计算长宽比例,按比例扩展 + if width > height: + bbox.x0 = max(0, bbox.x0 - 30) + bbox.y0 = max(0, bbox.y0 - 20) + bbox.x1 += 30 + bbox.y1 += 20 + elif width < height: + bbox.x0 = max(0, bbox.x0 - 20) + bbox.y0 = max(0, bbox.y0 - 30) + bbox.x1 += 20 + bbox.y1 += 30 + else: + bbox.x0 = max(0, bbox.x0 - 20) + bbox.y0 = max(0, bbox.y0 - 20) + bbox.x1 += 20 + bbox.y1 += 20 return merged_bboxes @staticmethod @@ -332,6 +342,10 @@ class DeepPdfParser(BaseParser): for y in all_y_coords: if not merged_y_coords or y - merged_y_coords[-1] > y_threshold: merged_y_coords.append(y) + if len(merged_x_coords) < 2 or len(merged_y_coords) < 2: + continue # 不是标准网格结构 + if len(cells) < 4: + continue # OCR单元格太少,跳过 def get_id(num, coords): """获取坐标在合并后的列表中的索引""" @@ -393,6 +407,9 @@ class DeepPdfParser(BaseParser): final_table.append(final_row) if not final_table: continue + # 计算信息熵,信息比较少的表格可能部署表格 + entropy = 0 + for row in final_table: node = ParseNode( id=uuid.uuid4(), @@ -600,21 +617,23 @@ class DeepPdfParser(BaseParser): text_nodes_with_bbox, table_nodes_with_bbox) sub_nodes_with_bbox = await DeepPdfParser.merge_nodes_with_bbox( sub_nodes_with_bbox, image_nodes_with_bbox) - + sub_nodes_with_bbox = sorted(sub_nodes_with_bbox, key=lambda x: (x.bbox.y0, x.bbox.x0)) + sub_nodes_with_bbox[-1].node.is_need_space = True # 最后一个节点后面需要空格 nodes_with_bbox.extend(sub_nodes_with_bbox) page_number += 1 for i in range(1, len(nodes_with_bbox)): '''根据bbox判断是否要进行换行''' vertical_distance = nodes_with_bbox[i].bbox.y0 - nodes_with_bbox[i-1].bbox.y1 height = nodes_with_bbox[i].bbox.y1 - nodes_with_bbox[i].bbox.y0 - if vertical_distance > 0 and vertical_distance > height*0.3: + if vertical_distance > 0 and (vertical_distance > height*0.3 or vertical_distance > 2): nodes_with_bbox[i-1].node.is_need_newline = True for i in range(1, len(nodes_with_bbox)): '''根据bbox判断是否要进行空格''' horizontal_distance = nodes_with_bbox[i].bbox.x0 - nodes_with_bbox[i-1].bbox.x1 width = nodes_with_bbox[i].bbox.x1 - nodes_with_bbox[i].bbox.x0 - if horizontal_distance > 0 and horizontal_distance > width*0.5: + if horizontal_distance > 0 and (horizontal_distance > width*0.3 or horizontal_distance > 2): nodes_with_bbox[i-1].node.is_need_space = True + nodes = [node_with_bbox.node for node_with_bbox in nodes_with_bbox] DeepPdfParser.image_related_node_in_link_nodes(nodes) # 假设这个方法在别处定义 parse_result = ParseResult( diff --git a/data_chain/parser/handler/pdf_parser.py b/data_chain/parser/handler/pdf_parser.py index bce9a9284ad9c58901837c357ca4c7f454d2d986..ecab3f565afc9cea01cd5cd7ae79a4435e77634b 100644 --- a/data_chain/parser/handler/pdf_parser.py +++ b/data_chain/parser/handler/pdf_parser.py @@ -303,13 +303,13 @@ class PdfParser(BaseParser): '''根据bbox判断是否要进行换行''' vertical_distance = nodes_with_bbox[i].bbox.y0 - nodes_with_bbox[i-1].bbox.y1 height = nodes_with_bbox[i].bbox.y1 - nodes_with_bbox[i].bbox.y0 - if vertical_distance > 0 and vertical_distance > height*0.3: + if vertical_distance > 0 and (vertical_distance > height*0.3 or vertical_distance > 2): nodes_with_bbox[i-1].node.is_need_newline = True for i in range(1, len(nodes_with_bbox)): '''根据bbox判断是否要进行空格''' horizontal_distance = nodes_with_bbox[i].bbox.x0 - nodes_with_bbox[i-1].bbox.x1 width = nodes_with_bbox[i].bbox.x1 - nodes_with_bbox[i].bbox.x0 - if horizontal_distance > 0 and horizontal_distance > width*0.3: + if horizontal_distance > 0 and (horizontal_distance > width*0.3 or horizontal_distance > 2): nodes_with_bbox[i-1].node.is_need_space = True nodes = [node_with_bbox.node for node_with_bbox in nodes_with_bbox] PdfParser.image_related_node_in_link_nodes(nodes) # 假设这个方法在别处定义 diff --git a/data_chain/parser/tools/ocr_tool.py b/data_chain/parser/tools/ocr_tool.py index da2b44935e2951877b35502e6389e083ea23435e..858517dab74091ccc2f6d9badcf86049021e17cb 100644 --- a/data_chain/parser/tools/ocr_tool.py +++ b/data_chain/parser/tools/ocr_tool.py @@ -109,6 +109,8 @@ class OcrTool: text = await OcrTool.merge_text_from_ocr_result(ocr_result) else: text = await OcrTool.enhance_ocr_result(ocr_result, image_related_text, llm) + if "图片内容为空" in text: + return "" return text except Exception as e: err = f"[OCRTool] 图片转文本失败 {e}"