diff --git a/data_chain/common/prompt.yaml b/data_chain/common/prompt.yaml index 62a88a6a51ba72eb248d813373742a6bddb45e17..a73bea326ebbac4b4ad8ea76b0497f7aa3cef871 100644 --- a/data_chain/common/prompt.yaml +++ b/data_chain/common/prompt.yaml @@ -18,11 +18,9 @@ OCR_ENHANCED_PROMPT: '你是一个图片ocr内容总结专家,你的任务是 #8 请仅输出图片的总结即可,不要输出其他内容 -上下文:{front_text} +上下文:{image_related_text} -先前图片组描述:{front_image_description} - -当前图片上一次的ocr内容总结:{front_part_description} +当前图片上一部分的ocr内容总结:{pre_part_description} 当前图片部分ocr的结果:{part}' diff --git a/data_chain/parser/handler/docx_parser.py b/data_chain/parser/handler/docx_parser.py index efff7ed8cf48a91413346279412c1e8df9ac2456..120b9931e7f3b1cdc69dc581d3c0ddcf654a9f71 100644 --- a/data_chain/parser/handler/docx_parser.py +++ b/data_chain/parser/handler/docx_parser.py @@ -17,18 +17,17 @@ from data_chain.parser.handler.base_parser import BaseService from data_chain.parser.tools.ocr import BaseOCR - class DocxService(BaseService): def __init__(self): super().__init__() - self.image_model = None + self.ocr_tool = None def open_file(self, file_path): try: doc = docx.Document(file_path) return doc except Exception as e: - logging.error(f"Error opening file {file_path} :{e}") + logging.error(f"Opening docx file {file_path} failed due to:{e}") raise e def is_image(self, graph: Paragraph, doc: Document): @@ -76,14 +75,15 @@ class DocxService(BaseService): if text_part: lines.append((text_part, 'para')) text_part = '' - # 处理图片 for image_part in image_parts: - image_blob = image_part.image.blob - content_type = image_part.content_type + try: + image_blob = image_part.image.blob + content_type = image_part.content_type + except Exception as e: + logging.error(f"Image blob and part get failed due to :{e}") extension = mimetypes.guess_extension(content_type).replace('.', '') - lines.append(([(Image.open(BytesIO(image_blob)), extension)], 'image')) + lines.append(((Image.open(BytesIO(image_blob)), extension), 'image')) else: - # 处理文字 text_part += run.text run_index += 1 @@ -93,7 +93,7 @@ class DocxService(BaseService): lines.append((paragraph.text, 'para')) elif isinstance(child, CT_Tbl): table = Table(child, parent) - rows=self.split_table(table) + rows = self.split_table(table) for row in rows: lines.append((row, 'table')) elif isinstance(child, CT_Picture): @@ -103,23 +103,39 @@ class DocxService(BaseService): try: image_blob = part.image.blob content_type = part.content_type - except: - logging.error(f'Error get image blob and content type due to: {img_id}') + except Exception as e: + logging.error(f'Get image blob and content type failed due to: {e}') continue extension = mimetypes.guess_extension(content_type).replace('.', '') - lines.append(([(Image.open(BytesIO(image_blob)), extension)], 'image')) - new_lines = [] + lines.append(((Image.open(BytesIO(image_blob)), extension), 'image')) + return lines + + async def ocr_from_images_in_lines(self, lines): + # 获取图像相邻文本 + last_para_pre = "" for i in range(len(lines)): - if lines[i][1] == 'image': - if len(new_lines) > 0 and new_lines[-1][1] == 'image': - new_lines[-1][0].append(lines[i][0][0]) - else: - new_lines.append(lines[i]) - else: - new_lines.append(lines[i]) - return new_lines + line = lines[i] + if line['type'] == 'image': + lines[i]['related_text'] = last_para_pre + elif line['type'] == 'para': + last_para_pre = line['text'] + elif line['type'] == 'table': + pass + last_para_bac = "" + for i in range(len(lines) - 1, -1, -1): + line = lines[i] + if line['type'] == 'image': + lines[i]['related_text'] += last_para_bac + elif line['type'] == 'para': + last_para_bac = line['text'] + elif line['type'] == 'table': + pass + for line in lines: + if line['type'] == 'image': + line['text'] = await self.ocr_tool.image_to_text(line['image'], text=line['related_text']) + return lines - async def solve_lines(self, lines, method): + async def change_lines(self, lines): """ 修整处理lines,根据不同的类型(图像、段落、表格)处理每一行,并根据method参数决定处理方式。 @@ -130,74 +146,53 @@ class DocxService(BaseService): 返回: - tuple: 包含处理后的句子列表和图像列表的元组。 """ - sentences = [] + new_lines = [] images = [] - last_para = "" last_para_id = None for line in lines: if line[1] == 'image': # 处理图像 - for image_tuple in line[0]: - image_id = self.get_uuid() - image = image_tuple[0] - image_bytes = image.tobytes() - image_extension = image_tuple[1] - await self.insert_image_to_tmp_folder(image_bytes, image_id, image_extension) - if method in ['ocr', 'enhanced']: - # 将图片关联到图片的描述chunk上 - chunk_id = self.get_uuid() - sentences.append({'id': chunk_id, - 'type': 'image'}) - sentences[-1]['near_text'] = last_para - sentences[-1]['image'] = np.array(image) - images.append({ - 'id': image_id, - 'chunk_id': chunk_id, - 'extension': image_extension, - }) - else: - # 将图片关联到上一个段落chunk上 - images.append({ - 'id': image_id, - 'chunk_id': last_para_id, - 'extension': image_extension, - }) + image_tuple = line[0] + image_id = self.get_uuid() + image = image_tuple[0] + image_bytes = image.tobytes() + image_extension = image_tuple[1] + await self.insert_image_to_tmp_folder(image_bytes, image_id, image_extension) + if self.parser_method in ['ocr', 'enhanced']: + # 将图片关联到图片的描述chunk上 + chunk_id = self.get_uuid() + new_lines.append({'id': chunk_id, + 'type': 'image'}) + new_lines[-1]['image'] = np.array(image) + images.append({ + 'id': image_id, + 'chunk_id': chunk_id, + 'extension': image_extension, + }) + else: + # 将图片关联到上一个段落chunk上 + images.append({ + 'id': image_id, + 'chunk_id': last_para_id, + 'extension': image_extension, + }) elif line[1] == 'para': # 处理段落 - sentences.append({'id': self.get_uuid(), + new_lines.append({'id': self.get_uuid(), 'text': line[0], 'type': line[1]}) - last_para = line[0] - last_para_id = sentences[-1]['id'] + last_para_id = new_lines[-1]['id'] elif line[1] == 'table': # 处理表格 - sentences.append({'id': self.get_uuid(), + new_lines.append({'id': self.get_uuid(), 'text': line[0], 'type': line[1]}) - if method in ['ocr', 'enhanced']: - sentences = await self.get_near_text(sentences) - return sentences, images - - async def get_near_text(self, sentences): - # 获取图像相邻文本 - last_para = "" - len_sentences = len(sentences) - for i in range(len_sentences - 1, -1, -1): - sentence = sentences[i] - if sentence['type'] == 'image': - sentences[i]['near_text'] = sentences[i]['near_text'] + last_para - elif sentence['type'] == 'para': - last_para = sentence['text'] - elif sentence['type'] == 'table': - pass - for sentence in sentences: - if sentence['type'] == 'image': - # 通过ocr/llm-Enhance进行强化 - sentence['text'] = await self.image_model.run(sentence['image'], text=sentence['near_text']) - return sentences + if self.parser_method in ['ocr', 'enhanced']: + new_lines = await self.ocr_from_images_in_lines(new_lines) + return new_lines, images async def parser(self, file_path): """ @@ -213,13 +208,12 @@ class DocxService(BaseService): doc = self.open_file(file_path) if not doc: return None - method = self.parser_method - if method != "general": - self.image_model = BaseOCR(llm=self.llm, llm_max_tokens=self.llm_max_tokens, method=method) + if self.parser_method != "general": + self.ocr_tool = BaseOCR(llm=self.llm, method=self.parser_method) lines = self.get_lines(doc) - sentences, images = await self.solve_lines(lines, method) + lines, images = await self.change_lines(lines) - chunks = self.build_chunks_by_lines(sentences) + chunks = self.build_chunks_by_lines(lines) chunk_links = self.build_chunk_links_by_line(chunks) return chunks, chunk_links, images diff --git a/data_chain/parser/handler/pdf_parser.py b/data_chain/parser/handler/pdf_parser.py index a08f059c666fee732550945d56d8646dfb56647a..4d9bd98673be6222d1a87b0f09e0cfe67d329817 100644 --- a/data_chain/parser/handler/pdf_parser.py +++ b/data_chain/parser/handler/pdf_parser.py @@ -8,8 +8,6 @@ from data_chain.parser.tools.ocr import BaseOCR from data_chain.parser.handler.base_parser import BaseService - - class PdfService(BaseService): def __init__(self): @@ -40,8 +38,7 @@ class PdfService(BaseService): 'text': text, 'type': 'para', }) - sorted_lines = sorted(lines, key=lambda x: (x['bbox'][1], x['bbox'][0])) - return sorted_lines + return lines def extract_table(self, page_number): """ @@ -90,13 +87,13 @@ class PdfService(BaseService): image = Image.open(io.BytesIO(image_bytes)) image_id = self.get_uuid() - await self.insert_image_to_tmp_folder(image_bytes, image_id,image_ext) + await self.insert_image_to_tmp_folder(image_bytes, image_id, image_ext) try: img_np = np.array(image) except Exception as e: logging.error(f"Error converting image to numpy array: {e}") continue - ocr_results = await self.image_model.run(img_np, text=near) + ocr_results = await self.image_model.image_to_text(img_np, text=near) # 获取OCR chunk_id = self.get_uuid() @@ -132,7 +129,6 @@ class PdfService(BaseService): def find_near_words(self, bbox, texts): """寻找相邻文本""" - nearby_text = [] image_x0, image_y0, image_x1, image_y1 = bbox threshold = 100 image_x0 -= threshold @@ -193,7 +189,7 @@ class PdfService(BaseService): sentences = [] all_image_chunks = [] if method != "general": - self.image_model = BaseOCR(llm=self.llm, llm_max_tokens=self.llm_max_tokens, + self.image_model = BaseOCR(llm=self.llm, method=self.parser_method) for page_num in range(self.page_numbers): tables = self.extract_table(page_num) @@ -206,7 +202,7 @@ class PdfService(BaseService): else: merge_list = temp_list sentences.extend(merge_list) - + sentences = sorted(sentences, key=lambda x: (x['bbox'][1], x['bbox'][0])) chunks = self.build_chunks_by_lines(sentences) chunk_links = self.build_chunk_links_by_line(chunks) return chunks, chunk_links, all_image_chunks diff --git a/data_chain/parser/tools/ocr.py b/data_chain/parser/tools/ocr.py index f553008027f5bf6d02eae3b3d5bc4e10faf121ff..baabfe584af1b5532b5363ac842e12d743e3a1f5 100644 --- a/data_chain/parser/tools/ocr.py +++ b/data_chain/parser/tools/ocr.py @@ -1,13 +1,14 @@ import yaml -from data_chain.logger.logger import logger as logging from paddleocr import PaddleOCR + +from data_chain.logger.logger import logger as logging from data_chain.config.config import config from data_chain.parser.tools.split import split_tools class BaseOCR: - def __init__(self, llm=None, llm_max_tokens=None, method='general'): + def __init__(self, llm=None, method='general'): # 指定模型文件的路径 det_model_dir = 'data_chain/parser/model/ocr/ch_PP-OCRv4_det_infer' rec_model_dir = 'data_chain/parser/model/ocr/ch_PP-OCRv4_rec_infer' @@ -25,148 +26,128 @@ class BaseOCR: if llm is None and method == 'enhanced': method = 'ocr' else: - self.max_tokens = llm_max_tokens + self.max_tokens = 1024 self.method = method - def ocr(self, image): + def ocr_from_image(self, image): """ - ocr识别文字 + 图片ocr接口 参数: - image:图像文件 - **kwargs:可选参数,如语言、gpu - 返回: - 一个list,包含了所有识别出的文字以及对应坐标 + image图片 """ try: - # get config - results = self.model.ocr(image) - return results + ocr_result = self.model.ocr(image) + if ocr_result is None or ocr_result[0] is None: + return None + return ocr_result except Exception as e: - logging.error(f"OCR job error {e}") - return [[None]] + logging.error(f"Ocr from image failed due to: {e}") + return None - @staticmethod - def get_text_from_ocr_results(ocr_results): - results = '' - if ocr_results[0] is None or len(ocr_results)==0: - return '' - if ocr_results[0][0] is None or len(ocr_results[0][0])==0: - return '' + def merge_text_from_ocr_result(ocr_result): + """ + ocr结果文字内容合并接口 + 参数: + ocr_result:ocr识别结果,包含了文字坐标、内容、置信度 + """ + text = '' try: - for result in ocr_results[0][0]: - results += result[1][0] - return results + for _ in ocr_result[0][0]: + text += _[1][0] + return text except Exception as e: - logging.error(f'Get text from ocr result failed with {e}') + logging.error(f'Get text from ocr result failed due to: {e}') return '' - @staticmethod - def split_list(image_result, max_tokens): + def cut_ocr_result_in_part(ocr_result, max_tokens=1024): """ - 分句,不超过Tokens数量 + ocr结果切割接口 + 参数: + ocr_result:ocr识别结果,包含了文字坐标、内容、置信度 + max_tokens:最大token数 """ - sum_tokens = 0 - result = [] - temp = [] - for sentences in image_result[0]: - if sentences is not None and len(sentences) > 0: - tokens = split_tools.get_tokens(sentences) - if sum_tokens + tokens > max_tokens: - result.append(temp) - temp = [sentences] - sum_tokens = tokens + tokens = 0 + ocr_result_part = None + ocr_result_parts = [] + for _ in ocr_result[0]: + if _ is not None and len(_) > 0: + sub_tokens = split_tools.get_tokens(str(_)) + if tokens + sub_tokens > max_tokens: + ocr_result_parts.append(ocr_result_part) + ocr_result_part = [_] + tokens += sub_tokens else: - temp.append(sentences) - sum_tokens += tokens - if temp: - result.append(temp) - return result - - @staticmethod - def get_prompt_dict(): - try: - with open(config['PROMPT_PATH'], 'r', encoding='utf-8') as f: - prompt_dict = yaml.load(f, Loader=yaml.SafeLoader) - return prompt_dict - except Exception as e: - logging.error(f'Get prompt failed : {e}') - raise e + ocr_result_part.append(_) + tokens += sub_tokens + if len(ocr_result_part) > 0: + ocr_result_parts.append(ocr_result_part) + return ocr_result_parts - async def improve(self, image_results, text): + async def enhance_ocr_result(self, ocr_result, image_related_text): """ - llm强化接口 + ocr结果强化接口 参数: - - image_results:ocr识别结果,包含了文字坐标、内容、置信度 - - text:图片组对应的前后文 + ocr_result:ocr识别结果,包含了文字坐标、内容、置信度 + image_related_text:图片组对应的前后文 """ try: + try: + with open(config['PROMPT_PATH'], 'r', encoding='utf-8') as f: + prompt_dict = yaml.load(f, Loader=yaml.SafeLoader) + prompt_template = prompt_dict.get('OCR_ENHANCED_PROMPT', '') + except Exception as e: + logging.error(f'Get prompt template failed due to :{e}') + return '' + pre_part_description = "" + ocr_result_parts = self.cut_ocr_result_in_part(ocr_result, self.max_tokens // 5*2) user_call = '请详细输出图片的摘要,不要输出其他内容' - split_images = [] - max_tokens = self.max_tokens // 5*2 - for image in image_results: - split_result = self.split_list(image, max_tokens) - split_images.append(split_result) - front_text = text - front_image_description = "" - front_part_description = "" - prompt_dict = self.get_prompt_dict() - for image in split_images: - for part in image: - prompt = prompt_dict.get('OCR_ENHANCED_PROMPT', '') - try: - prompt = prompt.format( - front_text=front_text, - front_image_description=front_image_description, - front_part_description=front_part_description, - part=part) - front_part_description = await self.llm.nostream([], prompt, user_call) - except Exception as e: - raise e - front_image_description = front_part_description - answer = front_image_description - return answer + for part in ocr_result_parts: + pre_part_description_cp = pre_part_description + try: + prompt = prompt_template.format( + image_related_text=image_related_text, + pre_part_description=pre_part_description, + part=part) + pre_part_description = await self.llm.nostream([], prompt, user_call) + except Exception as e: + logging.error(f'OCR resutl part enhance failed due to: {e}') + pre_part_description = pre_part_description_cp + return pre_part_description except Exception as e: - logging.error(f'OCR result improve error due to: {e}') + logging.error(f'OCR result enhance failed due to: {e}') return "" - async def run(self, image, text): + async def get_text_from_image(self, ocr_result, image_related_text): """ - 执行ocr的接口 + 从image中提取文字的接口 输入: - image:图像文件 + ocr_result: ocr结果 + image_related_text: 图片相关文字 """ - method = self.method - if not isinstance(image, list): - image = [image] - - image_results = self.process_images(image) - results = await self.generate_results(method, image_results, text) - - return results - - def process_images(self, images): - image_results = [] - for every_image in images: - try: - ocr_result = self.ocr(every_image) - image_results.append(ocr_result) - except Exception as e: - # 记录异常信息,可以选择日志记录或其他方式 - logging.error(f"Error processing image: {e}") - return image_results - - async def generate_results(self, method, image_results, text): - if method == 'ocr': - results = self.get_text_from_ocr_results(image_results) - return f'{results}' - elif method == 'enhanced': + if self.method == 'ocr': + text = self.merge_text_from_ocr_result(ocr_result) + return text + elif self.method == 'enhanced': try: - results = await self.improve(image_results, text) - if len(results.strip()) == 0: - return self.get_text_from_ocr_results(image_results) - return results + text = await self.enhance_ocr_result(ocr_result, image_related_text) + if len(text) == 0: + text = self.merge_text_from_ocr_result(ocr_result) except Exception as e: logging.error(f"LLM ERROR with: {e}") - return self.get_text_from_ocr_results(image_results) + text = self.merge_text_from_ocr_result(ocr_result) + return text else: return "" + + async def image_to_text(self, image, image_related_text=''): + """ + 执行ocr的接口 + 输入: + image:图像文件 + image_related_text:图像相关的文本 + """ + ocr_result = self.ocr_from_image(image) + if ocr_result is None: + return "" + text = await self.get_text_from_image(ocr_result, image_related_text) + return text diff --git a/data_chain/parser/tools/split.py b/data_chain/parser/tools/split.py index 4f21b9e2353ba3669ba96d1ee84f6872d91d2665..2f6010da54b48af59d158a91e61d82fb5009b73f 100644 --- a/data_chain/parser/tools/split.py +++ b/data_chain/parser/tools/split.py @@ -1,12 +1,18 @@ +import tiktoken import jieba +from data_chain.logger.logger import logger as logging + class SplitTools: def get_tokens(self, content): - sum_tokens = len(self.split_words(content)) - return sum_tokens - - @staticmethod + try: + enc = tiktoken.encoding_for_model("gpt-4") + return len(enc.encode(content)) + except Exception as e: + logging.error(f"Get tokens failed due to: {e}") + return 0 + def split_words(text): return list(jieba.cut(str(text))) diff --git a/utils/parser/handler/docx_parser.py b/utils/parser/handler/docx_parser.py index 5a229cc3632887762d15bb8c7cfe047815a17ea5..b1be756a80a7a62d6239a5bf013e7b34cf9e885f 100644 --- a/utils/parser/handler/docx_parser.py +++ b/utils/parser/handler/docx_parser.py @@ -90,7 +90,7 @@ class DocxService(BaseService): lines.append((paragraph.text, 'para')) elif isinstance(child, CT_Tbl): table = Table(child, parent) - rows=self.split_table(table) + rows = self.split_table(table) for row in rows: lines.append((row, 'table')) elif isinstance(child, CT_Picture): diff --git a/utils/parser/handler/pdf_parser.py b/utils/parser/handler/pdf_parser.py index 749c8f30c3344d6200d364894524207695868fec..61d54cc3d7a859f11f2044ec081feffc2457618f 100644 --- a/utils/parser/handler/pdf_parser.py +++ b/utils/parser/handler/pdf_parser.py @@ -7,8 +7,6 @@ from utils.parser.tools.ocr import BaseOCR from utils.parser.handler.base_parser import BaseService - - class PdfService(BaseService): def __init__(self): @@ -89,7 +87,7 @@ class PdfService(BaseService): image = Image.open(io.BytesIO(image_bytes)) image_id = self.get_uuid() - await self.insert_image_to_tmp_folder(image_bytes, image_id,image_ext) + await self.insert_image_to_tmp_folder(image_bytes, image_id, image_ext) img_np = np.array(image) ocr_results = await self.image_model.run(img_np, text=near) @@ -189,8 +187,7 @@ class PdfService(BaseService): sentences = [] all_image_chunks = [] if method != "general": - self.image_model = BaseOCR(llm=self.llm, llm_max_tokens=self.llm_max_tokens, - method=self.parser_method) + self.image_model = BaseOCR(llm=self.llm, llm_max_tokens=self.llm_max_tokens, method=self.parser_method) for page_num in range(self.page_numbers): tables = self.extract_table(page_num) text = self.extract_text(page_num)