From 4af9ef70a123d408f1ee1a4008a85ccbda18f1f7 Mon Sep 17 00:00:00 2001 From: "Shine.Wang" Date: Fri, 4 Jul 2025 14:45:33 +0800 Subject: [PATCH 01/12] fix spell mistake --- .../base/task/worker/acc_testing_worker.py | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/data_chain/apps/base/task/worker/acc_testing_worker.py b/data_chain/apps/base/task/worker/acc_testing_worker.py index 45ae9274..08516666 100644 --- a/data_chain/apps/base/task/worker/acc_testing_worker.py +++ b/data_chain/apps/base/task/worker/acc_testing_worker.py @@ -180,31 +180,31 @@ class TestingWorker(BaseWorker): bac_info=bac_info ) llm_answer = await llm.nostream([], prompt, question) - sub_socres = [] + sub_scores = [] pre = await TokenTool.cal_precision(question, answer, llm) if pre != -1: - sub_socres.append(pre) + sub_scores.append(pre) rec = await TokenTool.cal_recall(answer, llm_answer, llm) if rec != -1: - sub_socres.append(rec) + sub_scores.append(rec) fai = await TokenTool.cal_faithfulness(question, llm_answer, bac_info, llm) if fai != -1: - sub_socres.append(fai) + sub_scores.append(fai) rel = await TokenTool.cal_relevance(question, llm_answer, llm) if rel != -1: - sub_socres.append(rel) + sub_scores.append(rel) lcs = TokenTool.cal_lcs(answer, llm_answer) if lcs != -1: - sub_socres.append(lcs) + sub_scores.append(lcs) leve = TokenTool.cal_leve(answer, llm_answer) if leve != -1: - sub_socres.append(leve) + sub_scores.append(leve) jac = TokenTool.cal_jac(answer, llm_answer) if jac != -1: - sub_socres.append(jac) + sub_scores.append(jac) score = -1 - if sub_socres: - score = sum(sub_socres) / len(sub_socres) + if sub_scores: + score = sum(sub_scores) / len(sub_scores) test_case_entity = TestCaseEntity( testing_id=testing_entity.id, question=question, -- Gitee From c2a86bbf9e5aec41cf1405d4a09c06153d93681c Mon Sep 17 00:00:00 2001 From: "Shine.Wang" Date: Wed, 16 Jul 2025 12:02:39 +0800 Subject: [PATCH 02/12] =?UTF-8?q?=E4=BB=BB=E5=8A=A1=E9=98=9F=E5=88=97Mongo?= =?UTF-8?q?DB=E8=BF=81=E7=A7=BBpg?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../base/task/worker/acc_testing_worker.py | 6 +- .../base/task/worker/export_dataset_worker.py | 6 +- .../worker/export_knowledge_base_worker.py | 6 +- .../task/worker/generate_dataset_worker.py | 6 +- .../base/task/worker/import_dataset_worker.py | 6 +- .../worker/import_knowledge_base_worker.py | 6 +- .../base/task/worker/parse_document_worker.py | 6 +- data_chain/apps/service/task_queue_service.py | 39 +++++++------ data_chain/manager/task_queue_mamanger.py | 58 +++++++++++-------- data_chain/stores/database/database.py | 17 ++++++ 10 files changed, 93 insertions(+), 63 deletions(-) diff --git a/data_chain/apps/base/task/worker/acc_testing_worker.py b/data_chain/apps/base/task/worker/acc_testing_worker.py index 08516666..48f74c95 100644 --- a/data_chain/apps/base/task/worker/acc_testing_worker.py +++ b/data_chain/apps/base/task/worker/acc_testing_worker.py @@ -28,7 +28,7 @@ from data_chain.manager.testing_manager import TestingManager from data_chain.manager.testcase_manager import TestCaseManager from data_chain.manager.qa_manager import QAManager from data_chain.manager.task_queue_mamanger import TaskQueueManager -from data_chain.stores.database.database import TaskEntity, QAEntity, DataSetEntity, DataSetDocEntity, TestingEntity, TestCaseEntity +from data_chain.stores.database.database import TaskEntity, QAEntity, DataSetEntity, DataSetDocEntity, TestingEntity, TestCaseEntity, TaskQueueEntity from data_chain.stores.minio.minio import MinIO from data_chain.stores.mongodb.mongodb import Task from data_chain.config.config import config @@ -431,11 +431,11 @@ class TestingWorker(BaseWorker): await TestingWorker.generate_report_and_upload_to_minio(dataset_entity, testing_entity, testcase_entities, tmp_path) current_stage += 1 await TestingWorker.report(task_id, "生成报告并上传到minio", current_stage, stage_cnt) - await TaskQueueManager.add_task(Task(_id=task_id, status=TaskStatus.SUCCESS.value)) + await TaskQueueManager.add_task(TaskQueueEntity(id=task_id, status=TaskStatus.SUCCESS.value)) except Exception as e: err = f"[TestingWorker] 任务失败,task_id: {task_id}, 错误信息: {e}" logging.exception(err) - await TaskQueueManager.add_task(Task(_id=task_id, status=TaskStatus.FAILED.value)) + await TaskQueueManager.add_task(TaskQueueEntity(id=task_id, status=TaskStatus.FAILED.value)) await TestingWorker.report(task_id, "任务失败", 0, 1) @staticmethod diff --git a/data_chain/apps/base/task/worker/export_dataset_worker.py b/data_chain/apps/base/task/worker/export_dataset_worker.py index 8d9b0812..082326a8 100644 --- a/data_chain/apps/base/task/worker/export_dataset_worker.py +++ b/data_chain/apps/base/task/worker/export_dataset_worker.py @@ -23,7 +23,7 @@ from data_chain.manager.chunk_manager import ChunkManager from data_chain.manager.dataset_manager import DatasetManager from data_chain.manager.qa_manager import QAManager from data_chain.manager.task_queue_mamanger import TaskQueueManager -from data_chain.stores.database.database import TaskEntity, DocumentEntity, DocumentTypeEntity, QAEntity, DataSetEntity, DataSetDocEntity +from data_chain.stores.database.database import TaskEntity, DocumentEntity, DocumentTypeEntity, QAEntity, DataSetEntity, DataSetDocEntity, TaskQueueEntity from data_chain.stores.minio.minio import MinIO from data_chain.stores.mongodb.mongodb import Task @@ -190,11 +190,11 @@ class ExportDataSetWorker(BaseWorker): await ExportDataSetWorker.upload_file_to_minio(task_id, zip_path) current_stage += 1 await ExportDataSetWorker.report(task_id, "上传文件到minio", current_stage, stage_cnt) - await TaskQueueManager.add_task(Task(_id=task_id, status=TaskStatus.SUCCESS.value)) + await TaskQueueManager.add_task(TaskQueueEntity(id=task_id, status=TaskStatus.SUCCESS.value)) except Exception as e: err = f"[ExportDataSetWorker] 任务失败,task_id: {task_id}, 错误信息: {e}" logging.exception(err) - await TaskQueueManager.add_task(Task(_id=task_id, status=TaskStatus.FAILED.value)) + await TaskQueueManager.add_task(TaskQueueEntity(id=task_id, status=TaskStatus.FAILED.value)) await ExportDataSetWorker.report(task_id, "任务失败", 0, 1) @staticmethod diff --git a/data_chain/apps/base/task/worker/export_knowledge_base_worker.py b/data_chain/apps/base/task/worker/export_knowledge_base_worker.py index 5995debc..2e50c555 100644 --- a/data_chain/apps/base/task/worker/export_knowledge_base_worker.py +++ b/data_chain/apps/base/task/worker/export_knowledge_base_worker.py @@ -13,7 +13,7 @@ from data_chain.manager.task_manager import TaskManager from data_chain.manager.knowledge_manager import KnowledgeBaseManager from data_chain.manager.document_manager import DocumentManager from data_chain.manager.task_queue_mamanger import TaskQueueManager -from data_chain.stores.database.database import TaskEntity, DocumentEntity +from data_chain.stores.database.database import TaskEntity, DocumentEntity, TaskQueueEntity from data_chain.stores.minio.minio import MinIO from data_chain.stores.mongodb.mongodb import Task @@ -197,11 +197,11 @@ class ExportKnowledgeBaseWorker(BaseWorker): await ExportKnowledgeBaseWorker.upload_zip_to_minio(zip_path, task_id) current_stage += 1 await ExportKnowledgeBaseWorker.report(task_id, "上传压缩包到minio", current_stage, stage_cnt) - await TaskQueueManager.add_task(Task(_id=task_id, status=TaskStatus.SUCCESS.value)) + await TaskQueueManager.add_task(TaskQueueEntity(id=task_id, status=TaskStatus.SUCCESS.value)) except Exception as e: err = f"[ExportKnowledgeBaseWorker] 运行任务失败,task_id: {task_id},错误信息: {e}" logging.exception(err) - await TaskQueueManager.add_task(Task(_id=task_id, status=TaskStatus.FAILED.value)) + await TaskQueueManager.add_task(TaskQueueEntity(id=task_id, status=TaskStatus.FAILED.value)) await ExportKnowledgeBaseWorker.report(task_id, err, 0, 1) @staticmethod diff --git a/data_chain/apps/base/task/worker/generate_dataset_worker.py b/data_chain/apps/base/task/worker/generate_dataset_worker.py index eb018501..ade19b6e 100644 --- a/data_chain/apps/base/task/worker/generate_dataset_worker.py +++ b/data_chain/apps/base/task/worker/generate_dataset_worker.py @@ -20,7 +20,7 @@ from data_chain.manager.chunk_manager import ChunkManager from data_chain.manager.dataset_manager import DatasetManager from data_chain.manager.qa_manager import QAManager from data_chain.manager.task_queue_mamanger import TaskQueueManager -from data_chain.stores.database.database import TaskEntity, DocumentEntity, DocumentTypeEntity, QAEntity, DataSetEntity, DataSetDocEntity +from data_chain.stores.database.database import TaskEntity, DocumentEntity, DocumentTypeEntity, QAEntity, DataSetEntity, DataSetDocEntity, TaskQueueEntity from data_chain.stores.minio.minio import MinIO from data_chain.stores.mongodb.mongodb import Task @@ -308,11 +308,11 @@ class GenerateDataSetWorker(BaseWorker): await GenerateDataSetWorker.add_qa_to_db(qa_entities) current_stage += 1 await GenerateDataSetWorker.report(task_id, "添加QA到数据库", current_stage, stage_cnt) - await TaskQueueManager.add_task(Task(_id=task_id, status=TaskStatus.SUCCESS.value)) + await TaskQueueManager.add_task(TaskQueueEntity(id=task_id, status=TaskStatus.SUCCESS.value)) except Exception as e: err = f"[GenerateDataSetWorker] 任务失败,task_id: {task_id},错误信息: {e}" logging.exception(err) - await TaskQueueManager.add_task(Task(_id=task_id, status=TaskStatus.FAILED.value)) + await TaskQueueManager.add_task(TaskQueueEntity(id=task_id, status=TaskStatus.FAILED.value)) await GenerateDataSetWorker.report(task_id, err, 0, 1) @staticmethod diff --git a/data_chain/apps/base/task/worker/import_dataset_worker.py b/data_chain/apps/base/task/worker/import_dataset_worker.py index 02451f56..e7e89e52 100644 --- a/data_chain/apps/base/task/worker/import_dataset_worker.py +++ b/data_chain/apps/base/task/worker/import_dataset_worker.py @@ -22,7 +22,7 @@ from data_chain.manager.chunk_manager import ChunkManager from data_chain.manager.dataset_manager import DatasetManager from data_chain.manager.qa_manager import QAManager from data_chain.manager.task_queue_mamanger import TaskQueueManager -from data_chain.stores.database.database import TaskEntity, DocumentEntity, DocumentTypeEntity, QAEntity, DataSetEntity, DataSetDocEntity +from data_chain.stores.database.database import TaskEntity, DocumentEntity, DocumentTypeEntity, QAEntity, DataSetEntity, DataSetDocEntity, TaskQueueEntity from data_chain.stores.minio.minio import MinIO from data_chain.stores.mongodb.mongodb import Task @@ -246,11 +246,11 @@ class ImportDataSetWorker(BaseWorker): await ImportDataSetWorker.update_dataset_score(dataset_entity.id, qa_entities, llm) current_stage += 1 await ImportDataSetWorker.report(task_id, "更新数据集分数", current_stage, stage_cnt) - await TaskQueueManager.add_task(Task(_id=task_id, status=TaskStatus.SUCCESS.value)) + await TaskQueueManager.add_task(TaskQueueEntity(id=task_id, status=TaskStatus.SUCCESS.value)) except Exception as e: err = f"[ImportDataSetWorker] 任务失败,task_id: {task_id},错误信息: {e}" logging.exception(err) - await TaskQueueManager.add_task(Task(_id=task_id, status=TaskStatus.FAILED.value)) + await TaskQueueManager.add_task(TaskQueueEntity(id=task_id, status=TaskStatus.FAILED.value)) await ImportDataSetWorker.report(task_id, "任务失败", 0, 1) @staticmethod diff --git a/data_chain/apps/base/task/worker/import_knowledge_base_worker.py b/data_chain/apps/base/task/worker/import_knowledge_base_worker.py index b7dcc5ea..43f665eb 100644 --- a/data_chain/apps/base/task/worker/import_knowledge_base_worker.py +++ b/data_chain/apps/base/task/worker/import_knowledge_base_worker.py @@ -15,7 +15,7 @@ from data_chain.manager.knowledge_manager import KnowledgeBaseManager from data_chain.manager.document_type_manager import DocumentTypeManager from data_chain.manager.document_manager import DocumentManager from data_chain.manager.task_queue_mamanger import TaskQueueManager -from data_chain.stores.database.database import TaskEntity, DocumentEntity, DocumentTypeEntity +from data_chain.stores.database.database import TaskEntity, DocumentEntity, DocumentTypeEntity, TaskQueueEntity from data_chain.stores.minio.minio import MinIO from data_chain.stores.mongodb.mongodb import Task @@ -223,11 +223,11 @@ class ImportKnowledgeBaseWorker(BaseWorker): await ImportKnowledgeBaseWorker.init_doc_parse_tasks(kb_id) current_stage += 1 await ImportKnowledgeBaseWorker.report(task_id, "初始化文档解析任务", current_stage, stage_cnt) - await TaskQueueManager.add_task(Task(_id=task_id, status=TaskStatus.SUCCESS.value)) + await TaskQueueManager.add_task(TaskQueueEntity(id=task_id, status=TaskStatus.SUCCESS.value)) except Exception as e: err = f"[ImportKnowledgeBaseWorker] 任务失败,task_id: {task_id},错误信息: {e}" logging.exception(err) - await TaskQueueManager.add_task(Task(_id=task_id, status=TaskStatus.FAILED.value)) + await TaskQueueManager.add_task(TaskQueueEntity(id=task_id, status=TaskStatus.FAILED.value)) await ImportKnowledgeBaseWorker.report(task_id, err, 0, 1) @staticmethod diff --git a/data_chain/apps/base/task/worker/parse_document_worker.py b/data_chain/apps/base/task/worker/parse_document_worker.py index a1b568a5..745ef1ca 100644 --- a/data_chain/apps/base/task/worker/parse_document_worker.py +++ b/data_chain/apps/base/task/worker/parse_document_worker.py @@ -29,7 +29,7 @@ from data_chain.manager.chunk_manager import ChunkManager from data_chain.manager.image_manager import ImageManager from data_chain.manager.task_report_manager import TaskReportManager from data_chain.manager.task_queue_mamanger import TaskQueueManager -from data_chain.stores.database.database import TaskEntity, DocumentEntity, DocumentTypeEntity, ChunkEntity, ImageEntity +from data_chain.stores.database.database import TaskEntity, DocumentEntity, DocumentTypeEntity, ChunkEntity, ImageEntity, TaskQueueEntity from data_chain.stores.minio.minio import MinIO from data_chain.stores.mongodb.mongodb import Task @@ -539,7 +539,7 @@ class ParseDocumentWorker(BaseWorker): await ParseDocumentWorker.add_parse_result_to_db(parse_result, doc_entity) current_stage += 1 await ParseDocumentWorker.report(task_id, '添加解析结果到数据库', current_stage, stage_cnt) - await TaskQueueManager.add_task(Task(_id=task_id, status=TaskStatus.SUCCESS.value)) + await TaskQueueManager.add_task(TaskQueueEntity(id=task_id, status=TaskStatus.SUCCESS.value)) task_report = await ParseDocumentWorker.assemble_task_report(task_id) report_path = os.path.join(tmp_path, 'task_report.txt') with open(report_path, 'w') as f: @@ -552,7 +552,7 @@ class ParseDocumentWorker(BaseWorker): except Exception as e: err = f"[DocParseWorker] 任务失败,task_id: {task_id},错误信息: {e}" logging.exception(err) - await TaskQueueManager.add_task(Task(_id=task_id, status=TaskStatus.FAILED.value)) + await TaskQueueManager.add_task(TaskQueueEntity(id=task_id, status=TaskStatus.FAILED.value)) await ParseDocumentWorker.report(task_id, err, 0, 1) task_report = await ParseDocumentWorker.assemble_task_report(task_id) report_path = os.path.join(tmp_path, 'task_report.txt') diff --git a/data_chain/apps/service/task_queue_service.py b/data_chain/apps/service/task_queue_service.py index 2a16ab14..1a3179b3 100644 --- a/data_chain/apps/service/task_queue_service.py +++ b/data_chain/apps/service/task_queue_service.py @@ -4,7 +4,8 @@ import uuid from typing import Optional from data_chain.entities.enum import TaskType, TaskStatus from data_chain.apps.base.task.worker.base_worker import BaseWorker -from data_chain.stores.mongodb.mongodb import MongoDB, Task +# from data_chain.stores.mongodb.mongodb import MongoDB, Task +from data_chain.stores.database.database import TaskQueueEntity from data_chain.manager.task_manager import TaskManager from data_chain.manager.task_queue_mamanger import TaskQueueManager from data_chain.logger.logger import logger as logging @@ -22,7 +23,7 @@ class TaskQueueService: if task_entity.status == TaskStatus.RUNNING.value: flag = await BaseWorker.reinit(task_entity.id) if flag: - task = Task(_id=task_entity.id, status=TaskStatus.PENDING.value) + task = TaskQueueEntity(id=task_entity.id, status=TaskStatus.PENDING.value) await TaskQueueManager.update_task_by_id(task_entity.id, task) else: await BaseWorker.stop(task_entity.id) @@ -30,11 +31,11 @@ class TaskQueueService: else: task = await TaskQueueManager.get_task_by_id(task_entity.id) if task is None: - task = Task(_id=task_entity.id, status=TaskStatus.PENDING.value) + task = TaskQueueEntity(id=task_entity.id, status=TaskStatus.PENDING.value) await TaskQueueManager.add_task(task) except Exception as e: - warining = f"[TaskQueueService] 初始化任务失败 {e}" - logging.warning(warining) + warning = f"[TaskQueueService] 初始化任务失败 {e}" + logging.warning(warning) @staticmethod async def init_task(task_type: str, op_id: uuid.UUID) -> uuid.UUID: @@ -42,7 +43,7 @@ class TaskQueueService: try: task_id = await BaseWorker.init(task_type, op_id) if task_id: - await TaskQueueManager.add_task(Task(_id=task_id, status=TaskStatus.PENDING.value)) + await TaskQueueManager.add_task(TaskQueueEntity(id=task_id, status=TaskStatus.PENDING.value)) return task_id except Exception as e: err = f"[TaskQueueService] 初始化任务失败 {e}" @@ -75,53 +76,53 @@ class TaskQueueService: async def handle_successed_tasks(): handle_successed_task_limit = 1024 for i in range(handle_successed_task_limit): - task = await TaskQueueManager.get_oldest_tasks_by_status(TaskStatus.SUCCESS.value) + task = await TaskQueueManager.get_oldest_tasks_by_status(TaskStatus.SUCCESS) if task is None: break try: - await BaseWorker.deinit(task.task_id) + await BaseWorker.deinit(task.id) except Exception as e: err = f"[TaskQueueService] 处理成功任务失败 {e}" logging.error(err) - await TaskQueueManager.delete_task_by_id(task.task_id) + await TaskQueueManager.delete_task_by_id(task.id) @staticmethod async def handle_failed_tasks(): handle_failed_task_limit = 1024 for i in range(handle_failed_task_limit): - task = await TaskQueueManager.get_oldest_tasks_by_status(TaskStatus.FAILED.value) + task = await TaskQueueManager.get_oldest_tasks_by_status(TaskStatus.FAILED) if task is None: break try: - flag = await BaseWorker.reinit(task.task_id) + flag = await BaseWorker.reinit(task.id) except Exception as e: err = f"[TaskQueueService] 处理失败任务失败 {e}" logging.error(err) - await TaskQueueManager.delete_task_by_id(task.task_id) + await TaskQueueManager.delete_task_by_id(task.id) continue if flag: - task = Task(_id=task.task_id, status=TaskStatus.PENDING.value) - await TaskQueueManager.update_task_by_id(task.task_id, task) + task.status = TaskStatus.PENDING.value + await TaskQueueManager.update_task_by_id(task.id, task) else: - await TaskQueueManager.delete_task_by_id(task.task_id) + await TaskQueueManager.delete_task_by_id(task.id) @staticmethod async def handle_pending_tasks(): handle_pending_task_limit = 128 for i in range(handle_pending_task_limit): - task = await TaskQueueManager.get_oldest_tasks_by_status(TaskStatus.PENDING.value) + task = await TaskQueueManager.get_oldest_tasks_by_status(TaskStatus.PENDING) if task is None: break try: - flag = await BaseWorker.run(task.task_id) + flag = await BaseWorker.run(task.id) except Exception as e: err = f"[TaskQueueService] 处理待处理任务失败 {e}" logging.error(err) - await TaskQueueManager.delete_task_by_id(task.task_id) + await TaskQueueManager.delete_task_by_id(task.id) continue if not flag: break - await TaskQueueManager.delete_task_by_id(task.task_id) + await TaskQueueManager.delete_task_by_id(task.id) @staticmethod async def handle_tasks(): diff --git a/data_chain/manager/task_queue_mamanger.py b/data_chain/manager/task_queue_mamanger.py index 8f4db40d..6c68db98 100644 --- a/data_chain/manager/task_queue_mamanger.py +++ b/data_chain/manager/task_queue_mamanger.py @@ -5,8 +5,8 @@ from sqlalchemy.orm import aliased import uuid from typing import Dict, List, Optional, Tuple from data_chain.logger.logger import logger as logging -from data_chain.stores.database.database import DataBase, TaskEntity -from data_chain.stores.mongodb.mongodb import MongoDB, Task +from data_chain.stores.database.database import DataBase, TaskEntity, TaskQueueEntity +# from data_chain.stores.mongodb.mongodb import MongoDB, Task from data_chain.entities.enum import TaskStatus @@ -14,60 +14,72 @@ class TaskQueueManager(): """任务队列管理类""" @staticmethod - async def add_task(task: Task): + async def add_task(task: TaskQueueEntity): try: - async with MongoDB.get_session() as session, await session.start_transaction(): - task_colletion = MongoDB.get_collection('witchiand_task') - await task_colletion.insert_one(task.model_dump(by_alias=True), session=session) + async with await DataBase.get_session() as session: + session.add(task) + await session.commit() except Exception as e: err = "添加任务到队列失败" logging.exception("[TaskQueueManager] %s", err) + raise e @staticmethod async def delete_task_by_id(task_id: uuid.UUID): """根据任务ID删除任务""" try: - async with MongoDB.get_session() as session, await session.start_transaction(): - task_colletion = MongoDB.get_collection('witchiand_task') - await task_colletion.delete_one({"_id": task_id}, session=session) + async with await DataBase.get_session() as session: + stmt = delete(TaskQueueEntity).where(TaskQueueEntity.id == task_id) + await session.execute(stmt) + await session.commit() except Exception as e: err = "删除任务失败" logging.exception("[TaskQueueManager] %s", err) raise e @staticmethod - async def get_oldest_tasks_by_status(status: TaskStatus) -> Task: + async def get_oldest_tasks_by_status(status: TaskStatus) -> Optional[TaskQueueEntity]: """根据任务状态获取最早的任务""" try: - async with MongoDB.get_session() as session: - task_colletion = MongoDB.get_collection('witchiand_task') - task = await task_colletion.find_one({"status": status}, sort=[("created_time", 1)], session=session) - return Task(**task) if task else None + async with await DataBase.get_session() as session: + stmt = ( + select(TaskQueueEntity) + .where(TaskQueueEntity.status == status.value) + .order_by(asc(TaskQueueEntity.created_time)) + .limit(1) + ) + result = await session.execute(stmt) + return result.scalars().first() except Exception as e: err = "获取最早的任务失败" logging.exception("[TaskQueueManager] %s", err) raise e @staticmethod - async def get_task_by_id(task_id: uuid.UUID) -> Task: + async def get_task_by_id(task_id: uuid.UUID) -> Optional[TaskQueueEntity]: """根据任务ID获取任务""" try: - async with MongoDB.get_session() as session: - task_colletion = MongoDB.get_collection('witchiand_task') - task = await task_colletion.find_one({"_id": task_id}, session=session) - return Task(**task) if task else None + async with await DataBase.get_session() as session: + stmt = select(TaskQueueEntity).where(TaskQueueEntity.id == task_id) + result = await session.execute(stmt) + return result.scalars().first() except Exception as e: err = "获取任务失败" logging.exception("[TaskQueueManager] %s", err) raise e @staticmethod - async def update_task_by_id(task_id: uuid.UUID, task: Task): + async def update_task_by_id(task_id: uuid.UUID, task: TaskQueueEntity): """根据任务ID更新任务""" try: - async with MongoDB.get_session() as session, await session.start_transaction(): - task_colletion = MongoDB.get_collection('witchiand_task') - await task_colletion.update_one({"_id": task_id}, {"$set": task.model_dump(by_alias=True)}, session=session) + async with await DataBase.get_session() as session: + stmt = ( + update(TaskQueueEntity) + .where(TaskQueueEntity.id == task_id) + .values(status=task.status) + ) + await session.execute(stmt) + await session.commit() except Exception as e: err = "更新任务失败" logging.exception("[TaskQueueManager] %s", err) diff --git a/data_chain/stores/database/database.py b/data_chain/stores/database/database.py index 4e8ae10d..9c459336 100644 --- a/data_chain/stores/database/database.py +++ b/data_chain/stores/database/database.py @@ -534,6 +534,23 @@ class TaskReportEntity(Base): ) +class TaskQueueEntity(Base): + __tablename__ = 'task_queue' + + id = Column(UUID, default=uuid4, primary_key=True) # 任务ID + status = Column(String) # 任务状态 + created_time = Column( + TIMESTAMP(timezone=True), + nullable=True, + server_default=func.current_timestamp() + ) + # 添加索引以提高查询性能 + __table_args__ = ( + Index('idx_task_queue_status', 'status'), + Index('idx_task_queue_created_time', 'created_time'), + ) + + class DataBase: # 对密码进行 URL 编码 -- Gitee From 09af3af47f3e59170ec7048fe7aba14c46bd21a1 Mon Sep 17 00:00:00 2001 From: "Shine.Wang" Date: Thu, 17 Jul 2025 16:59:01 +0800 Subject: [PATCH 03/12] =?UTF-8?q?=E6=A0=B9=E6=8D=AE=E5=AE=A1=E8=A7=86?= =?UTF-8?q?=E6=84=8F=E8=A7=81=E8=BF=9B=E8=A1=8C=E4=BF=AE=E6=94=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- data_chain/manager/task_queue_mamanger.py | 10 +++------- data_chain/stores/database/database.py | 13 +++++++------ 2 files changed, 10 insertions(+), 13 deletions(-) diff --git a/data_chain/manager/task_queue_mamanger.py b/data_chain/manager/task_queue_mamanger.py index 6c68db98..95024c55 100644 --- a/data_chain/manager/task_queue_mamanger.py +++ b/data_chain/manager/task_queue_mamanger.py @@ -1,12 +1,10 @@ # Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. from sqlalchemy import select, delete, update, desc, asc, func, exists, or_, and_ -from sqlalchemy.orm import aliased import uuid from typing import Dict, List, Optional, Tuple from data_chain.logger.logger import logger as logging -from data_chain.stores.database.database import DataBase, TaskEntity, TaskQueueEntity -# from data_chain.stores.mongodb.mongodb import MongoDB, Task +from data_chain.stores.database.database import DataBase, TaskQueueEntity from data_chain.entities.enum import TaskStatus @@ -48,8 +46,7 @@ class TaskQueueManager(): .order_by(asc(TaskQueueEntity.created_time)) .limit(1) ) - result = await session.execute(stmt) - return result.scalars().first() + return await session.scalars(stmt).first() except Exception as e: err = "获取最早的任务失败" logging.exception("[TaskQueueManager] %s", err) @@ -61,8 +58,7 @@ class TaskQueueManager(): try: async with await DataBase.get_session() as session: stmt = select(TaskQueueEntity).where(TaskQueueEntity.id == task_id) - result = await session.execute(stmt) - return result.scalars().first() + return await session.scalars(stmt).first() except Exception as e: err = "获取任务失败" logging.exception("[TaskQueueManager] %s", err) diff --git a/data_chain/stores/database/database.py b/data_chain/stores/database/database.py index 9c459336..6ddd2aad 100644 --- a/data_chain/stores/database/database.py +++ b/data_chain/stores/database/database.py @@ -1,13 +1,14 @@ # Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker from sqlalchemy import Index -from uuid import uuid4 +from datetime import datetime +import uuid import urllib.parse from data_chain.logger.logger import logger as logging from pgvector.sqlalchemy import Vector from sqlalchemy import Boolean, Column, ForeignKey, Integer, Float, String, func from sqlalchemy.types import TIMESTAMP, UUID -from sqlalchemy.orm import declarative_base +from sqlalchemy.orm import declarative_base, DeclarativeBase, MappedAsDataclass, Mapped, mapped_column from data_chain.config.config import config from data_chain.entities.enum import (Tokenizer, ParseMethod, @@ -534,12 +535,12 @@ class TaskReportEntity(Base): ) -class TaskQueueEntity(Base): +class TaskQueueEntity(DeclarativeBase, MappedAsDataclass): __tablename__ = 'task_queue' - id = Column(UUID, default=uuid4, primary_key=True) # 任务ID - status = Column(String) # 任务状态 - created_time = Column( + id: Mapped[uuid.UUID] = mapped_column(UUID, default_factory=uuid.uuid4, primary_key=True) # 任务ID + status: Mapped[str] = mapped_column(String) # 任务状态 + created_time: Mapped[datetime] = mapped_column( TIMESTAMP(timezone=True), nullable=True, server_default=func.current_timestamp() -- Gitee From e94b62009a378685e7b91e464dc00f37f4befdb9 Mon Sep 17 00:00:00 2001 From: "Shine.Wang" Date: Fri, 4 Jul 2025 14:45:33 +0800 Subject: [PATCH 04/12] fix spell mistake --- .../base/task/worker/acc_testing_worker.py | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/data_chain/apps/base/task/worker/acc_testing_worker.py b/data_chain/apps/base/task/worker/acc_testing_worker.py index 45ae9274..08516666 100644 --- a/data_chain/apps/base/task/worker/acc_testing_worker.py +++ b/data_chain/apps/base/task/worker/acc_testing_worker.py @@ -180,31 +180,31 @@ class TestingWorker(BaseWorker): bac_info=bac_info ) llm_answer = await llm.nostream([], prompt, question) - sub_socres = [] + sub_scores = [] pre = await TokenTool.cal_precision(question, answer, llm) if pre != -1: - sub_socres.append(pre) + sub_scores.append(pre) rec = await TokenTool.cal_recall(answer, llm_answer, llm) if rec != -1: - sub_socres.append(rec) + sub_scores.append(rec) fai = await TokenTool.cal_faithfulness(question, llm_answer, bac_info, llm) if fai != -1: - sub_socres.append(fai) + sub_scores.append(fai) rel = await TokenTool.cal_relevance(question, llm_answer, llm) if rel != -1: - sub_socres.append(rel) + sub_scores.append(rel) lcs = TokenTool.cal_lcs(answer, llm_answer) if lcs != -1: - sub_socres.append(lcs) + sub_scores.append(lcs) leve = TokenTool.cal_leve(answer, llm_answer) if leve != -1: - sub_socres.append(leve) + sub_scores.append(leve) jac = TokenTool.cal_jac(answer, llm_answer) if jac != -1: - sub_socres.append(jac) + sub_scores.append(jac) score = -1 - if sub_socres: - score = sum(sub_socres) / len(sub_socres) + if sub_scores: + score = sum(sub_scores) / len(sub_scores) test_case_entity = TestCaseEntity( testing_id=testing_entity.id, question=question, -- Gitee From c4463391a3eb77e465fa6acdb347079190b04c64 Mon Sep 17 00:00:00 2001 From: zxstty Date: Fri, 15 Aug 2025 11:31:56 +0800 Subject: [PATCH 05/12] =?UTF-8?q?=E6=96=87=E6=A1=A3=E6=A3=80=E7=B4=A2?= =?UTF-8?q?=E8=BF=94=E5=9B=9E=E4=BD=9C=E8=80=85=E5=92=8C=E5=88=9B=E5=BB=BA?= =?UTF-8?q?=E6=97=B6=E9=97=B4?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- data_chain/apps/service/chunk_service.py | 2 ++ data_chain/entities/response_data.py | 2 ++ 2 files changed, 4 insertions(+) diff --git a/data_chain/apps/service/chunk_service.py b/data_chain/apps/service/chunk_service.py index d1b70646..8d302c4c 100644 --- a/data_chain/apps/service/chunk_service.py +++ b/data_chain/apps/service/chunk_service.py @@ -142,6 +142,8 @@ class ChunkService: doc_map = {doc_entity.id: doc_entity for doc_entity in doc_entities} for doc_chunk in search_chunk_msg.doc_chunks: doc_entity = doc_map.get(doc_chunk.doc_id) + doc_chunk.doc_author = doc_entity.author_name if doc_entity else "" + doc_chunk.doc_created_at = doc_entity.created_time.strftime('%Y-%m-%d %H:%M') if doc_entity else "" doc_chunk.doc_abstract = doc_entity.abstract if doc_entity else "" doc_chunk.doc_extension = doc_entity.extension if doc_entity else "" doc_chunk.doc_size = doc_entity.size if doc_entity else 0 diff --git a/data_chain/entities/response_data.py b/data_chain/entities/response_data.py index 5a3bf03b..5cac2bf5 100644 --- a/data_chain/entities/response_data.py +++ b/data_chain/entities/response_data.py @@ -326,9 +326,11 @@ class DocChunk(BaseModel): """Post /chunk/search 数据结构""" doc_id: uuid.UUID = Field(description="文档ID", alias="docId") doc_name: str = Field(description="文档名称", alias="docName") + doc_author: str = Field(description="文档作者", alias="docAuthor") doc_abstract: str = Field(default="", description="文档摘要", alias="docAbstract") doc_extension: str = Field(default="", description="文档扩展名", alias="docExtension") doc_size: int = Field(default=0, description="文档大小,单位是KB", alias="docSize") + doc_created_at: str = Field(default="", description="文档创建时间", alias="docCreatedAt") chunks: list[Chunk] = Field(default=[], description="分片列表", alias="chunks") -- Gitee From d635b25e28a4bff5a7e9b25787900abedfbacf0c Mon Sep 17 00:00:00 2001 From: zxstty Date: Thu, 21 Aug 2025 20:32:59 +0800 Subject: [PATCH 06/12] =?UTF-8?q?=E9=80=9A=E8=BF=87=E5=A4=9A=E7=BA=BF?= =?UTF-8?q?=E7=A8=8B=E8=BF=9B=E8=A1=8Cocr?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../base/task/worker/parse_document_worker.py | 49 +++++++++++++------ data_chain/parser/tools/ocr_tool.py | 2 + 2 files changed, 35 insertions(+), 16 deletions(-) diff --git a/data_chain/apps/base/task/worker/parse_document_worker.py b/data_chain/apps/base/task/worker/parse_document_worker.py index fab00d00..7d61e7b3 100644 --- a/data_chain/apps/base/task/worker/parse_document_worker.py +++ b/data_chain/apps/base/task/worker/parse_document_worker.py @@ -8,6 +8,7 @@ import random import io import numpy as np from PIL import Image +import asyncio from data_chain.parser.tools.ocr_tool import OcrTool from data_chain.parser.tools.token_tool import TokenTool from data_chain.parser.tools.image_tool import ImageTool @@ -271,24 +272,40 @@ class ParseDocumentWorker(BaseWorker): index += 1024 @staticmethod - async def ocr_from_parse_image(parse_result: ParseResult, llm: LLM = None) -> list: + async def ocr_from_parse_image(parse_result: ParseResult, llm: LLM = None) -> None: '''从解析图片中获取ocr''' - for node in parse_result.nodes: + async def _ocr(node: ParseNode) -> None: + try: + img_blob = node.content + image = Image.open(io.BytesIO(img_blob)) + img_np = np.array(image) + image_related_text = '' + for related_node in node.link_nodes: + if related_node.type != ChunkType.IMAGE: + image_related_text += related_node.content + '\n' + ocr_result = (await OcrTool.image_to_text(img_np, image_related_text, llm)) + node.text_feature = ocr_result + node.content = ocr_result + except Exception as e: + err = f"[OCRTool] OCR识别失败: {e}" + logging.exception(err) + return None + + image_node_ids = [] + for i, node in enumerate(parse_result.nodes): if node.type == ChunkType.IMAGE: - try: - image_blob = node.content - image = Image.open(io.BytesIO(image_blob)) - img_np = np.array(image) - image_related_text = '' - for related_node in node.link_nodes: - if related_node.type != ChunkType.IMAGE: - image_related_text += related_node.content - node.content = await OcrTool.image_to_text(img_np, image_related_text, llm) - node.text_feature = node.content - except Exception as e: - err = f"[ParseDocumentWorker] OCR失败 error: {e}" - logging.exception(err) - continue + image_node_ids.append(i) + group_size = 5 + index = 0 + while index < len(image_node_ids): + sub_image_node_ids = image_node_ids[index:index + group_size] + task_list = [] + for node_id in sub_image_node_ids: + # 通过asyncio.create_task来异步执行OCR + node = parse_result.nodes[node_id] + task_list.append(asyncio.create_task(_ocr(node))) + await asyncio.gather(*task_list) + index += group_size @staticmethod async def merge_and_split_text(parse_result: ParseResult, doc_entity: DocumentEntity) -> None: diff --git a/data_chain/parser/tools/ocr_tool.py b/data_chain/parser/tools/ocr_tool.py index 858517da..115f6e9f 100644 --- a/data_chain/parser/tools/ocr_tool.py +++ b/data_chain/parser/tools/ocr_tool.py @@ -58,6 +58,8 @@ class OcrTool: async def merge_text_from_ocr_result(ocr_result: list) -> str: text = '' try: + if ocr_result[0] is None or len(ocr_result[0]) == 0: + return "" for _ in ocr_result[0]: text += str(_[1][0]) return text -- Gitee From 574e1acd522db94622c5b13e4c1b12365c34e467 Mon Sep 17 00:00:00 2001 From: zxstty Date: Thu, 21 Aug 2025 21:31:54 +0800 Subject: [PATCH 07/12] =?UTF-8?q?=E4=BF=AE=E5=A4=8D=E6=95=B0=E6=8D=AE?= =?UTF-8?q?=E9=9B=86=E9=80=9A=E8=BF=87=E7=8A=B6=E6=80=81=E8=BF=87=E6=BB=A4?= =?UTF-8?q?=E7=9A=84bug?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../apps/base/task/worker/import_knowledge_base_worker.py | 2 +- data_chain/manager/dataset_manager.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/data_chain/apps/base/task/worker/import_knowledge_base_worker.py b/data_chain/apps/base/task/worker/import_knowledge_base_worker.py index 43f665eb..c85f0c98 100644 --- a/data_chain/apps/base/task/worker/import_knowledge_base_worker.py +++ b/data_chain/apps/base/task/worker/import_knowledge_base_worker.py @@ -252,7 +252,7 @@ class ImportKnowledgeBaseWorker(BaseWorker): err = f"[ExportKnowledgeBaseWorker] 任务不存在,task_id: {task_id}" logging.exception(err) return None - if task_entity.status == TaskStatus.CANCLED or TaskStatus.FAILED.value: + if task_entity.status == TaskStatus.CANCLED.value or task_entity.status == TaskStatus.FAILED.value: await KnowledgeBaseManager.update_knowledge_base_by_kb_id(task_entity.op_id, {"status": KnowledgeBaseStatus.DELETED.value}) await MinIO.delete_object(IMPORT_KB_PATH_IN_MINIO, str(task_entity.op_id)) return task_id diff --git a/data_chain/manager/dataset_manager.py b/data_chain/manager/dataset_manager.py index cb61aa8a..89fcaafd 100644 --- a/data_chain/manager/dataset_manager.py +++ b/data_chain/manager/dataset_manager.py @@ -98,7 +98,8 @@ class DatasetManager: stmt = stmt.where(DataSetEntity.is_chunk_related == req.is_chunk_related) if req.generate_status is not None: status_list = [status.value for status in req.generate_status] - status_list += [DataSetStatus.DELETED.value] + if TaskStatus.SUCCESS in req.generate_status: + status_list += [TaskStatus.DELETED.value] stmt = stmt.where(subq.c.status.in_(status_list)) stmt = stmt.order_by(DataSetEntity.created_at.desc(), DataSetEntity.id.desc()) if req.score_order: -- Gitee From bbbb620ccfce9b60b015837686a43b2cdac8f23a Mon Sep 17 00:00:00 2001 From: zxstty Date: Thu, 21 Aug 2025 21:59:36 +0800 Subject: [PATCH 08/12] =?UTF-8?q?=E4=BF=AE=E5=A4=8Dtask=20queue=E7=9A=84bu?= =?UTF-8?q?g?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- data_chain/stores/database/database.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/data_chain/stores/database/database.py b/data_chain/stores/database/database.py index 6ddd2aad..e0af0a43 100644 --- a/data_chain/stores/database/database.py +++ b/data_chain/stores/database/database.py @@ -3,6 +3,7 @@ from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker from sqlalchemy import Index from datetime import datetime import uuid +from uuid import uuid4 import urllib.parse from data_chain.logger.logger import logger as logging from pgvector.sqlalchemy import Vector @@ -535,12 +536,12 @@ class TaskReportEntity(Base): ) -class TaskQueueEntity(DeclarativeBase, MappedAsDataclass): +class TaskQueueEntity(Base): __tablename__ = 'task_queue' - id: Mapped[uuid.UUID] = mapped_column(UUID, default_factory=uuid.uuid4, primary_key=True) # 任务ID - status: Mapped[str] = mapped_column(String) # 任务状态 - created_time: Mapped[datetime] = mapped_column( + id = Column(UUID, default=uuid4, primary_key=True) # 任务ID + status = Column(String) # 任务状态 + created_time = Column( TIMESTAMP(timezone=True), nullable=True, server_default=func.current_timestamp() -- Gitee From 117b9de20d764239e74be57646a8c2ca634cada5 Mon Sep 17 00:00:00 2001 From: zxstty Date: Thu, 21 Aug 2025 22:06:56 +0800 Subject: [PATCH 09/12] =?UTF-8?q?=E4=BF=AE=E5=A4=8Dtask=20queue=E7=9A=84bu?= =?UTF-8?q?g?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- chat2db/database/postgres.py | 3 +-- data_chain/manager/task_queue_mamanger.py | 6 ++++-- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/chat2db/database/postgres.py b/chat2db/database/postgres.py index ea4470d4..61ac542a 100644 --- a/chat2db/database/postgres.py +++ b/chat2db/database/postgres.py @@ -97,8 +97,6 @@ class PostgresDB: echo=False, pool_recycle=300, pool_pre_ping=True) - - Base.metadata.create_all(cls.engine) if config['DATABASE_TYPE'].lower() == 'opengauss': from sqlalchemy import event from opengauss_sqlalchemy.register_async import register_vector @@ -106,6 +104,7 @@ class PostgresDB: @event.listens_for(cls.engine.sync_engine, "connect") def connect(dbapi_connection, connection_record): dbapi_connection.run_async(register_vector) + Base.metadata.create_all(cls.engine) return cls._engine @classmethod diff --git a/data_chain/manager/task_queue_mamanger.py b/data_chain/manager/task_queue_mamanger.py index 95024c55..b0df886b 100644 --- a/data_chain/manager/task_queue_mamanger.py +++ b/data_chain/manager/task_queue_mamanger.py @@ -46,7 +46,8 @@ class TaskQueueManager(): .order_by(asc(TaskQueueEntity.created_time)) .limit(1) ) - return await session.scalars(stmt).first() + result = await session.execute(stmt) + return result.scalars().first() except Exception as e: err = "获取最早的任务失败" logging.exception("[TaskQueueManager] %s", err) @@ -58,7 +59,8 @@ class TaskQueueManager(): try: async with await DataBase.get_session() as session: stmt = select(TaskQueueEntity).where(TaskQueueEntity.id == task_id) - return await session.scalars(stmt).first() + result = await session.execute(stmt) + return result.scalars().first() except Exception as e: err = "获取任务失败" logging.exception("[TaskQueueManager] %s", err) -- Gitee From 3c97ff85e9a0110b29b0d55ec4d3d047cffd37bf Mon Sep 17 00:00:00 2001 From: zxstty Date: Thu, 21 Aug 2025 22:21:00 +0800 Subject: [PATCH 10/12] =?UTF-8?q?=E5=90=91=E9=87=8F=E5=8C=96=E4=B9=9F?= =?UTF-8?q?=E6=94=B9=E6=88=90=E5=B9=B6=E8=A1=8C?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../apps/base/task/worker/parse_document_worker.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/data_chain/apps/base/task/worker/parse_document_worker.py b/data_chain/apps/base/task/worker/parse_document_worker.py index 7d61e7b3..346aa70e 100644 --- a/data_chain/apps/base/task/worker/parse_document_worker.py +++ b/data_chain/apps/base/task/worker/parse_document_worker.py @@ -448,8 +448,18 @@ class ParseDocumentWorker(BaseWorker): @staticmethod async def embedding_chunk(parse_result: ParseResult) -> None: '''嵌入chunk''' - for node in parse_result.nodes: - node.vector = await Embedding.vectorize_embedding(node.text_feature) + def _embedding(node: ParseNode) -> None: + node.vector = Embedding.vectorize_embedding(node.text_feature) + group_size = 32 + index = 0 + while index < len(parse_result.nodes): + sub_nodes = parse_result.nodes[index:index + group_size] + task_list = [] + for node in sub_nodes: + # 通过asyncio.create_task来异步执行嵌入 + task_list.append(asyncio.create_task(_embedding(node))) + asyncio.run(asyncio.gather(*task_list)) + index += group_size @staticmethod async def add_parse_result_to_db(parse_result: ParseResult, doc_entity: DocumentEntity) -> None: -- Gitee From 606fe9431de0270088794c60255ff89693737edd Mon Sep 17 00:00:00 2001 From: zxstty Date: Thu, 21 Aug 2025 22:37:49 +0800 Subject: [PATCH 11/12] =?UTF-8?q?=E5=90=91=E9=87=8F=E5=8C=96=E4=B9=9F?= =?UTF-8?q?=E6=94=B9=E6=88=90=E5=B9=B6=E8=A1=8C?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- data_chain/apps/base/task/worker/parse_document_worker.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/data_chain/apps/base/task/worker/parse_document_worker.py b/data_chain/apps/base/task/worker/parse_document_worker.py index 346aa70e..7c8f369c 100644 --- a/data_chain/apps/base/task/worker/parse_document_worker.py +++ b/data_chain/apps/base/task/worker/parse_document_worker.py @@ -448,8 +448,8 @@ class ParseDocumentWorker(BaseWorker): @staticmethod async def embedding_chunk(parse_result: ParseResult) -> None: '''嵌入chunk''' - def _embedding(node: ParseNode) -> None: - node.vector = Embedding.vectorize_embedding(node.text_feature) + async def _embedding(node: ParseNode) -> None: + node.vector = await Embedding.vectorize_embedding(node.text_feature) group_size = 32 index = 0 while index < len(parse_result.nodes): -- Gitee From 1373945b70dd1cc055f705d9b1fd710dcbafbd4e Mon Sep 17 00:00:00 2001 From: zxstty Date: Thu, 21 Aug 2025 22:48:37 +0800 Subject: [PATCH 12/12] =?UTF-8?q?=E5=90=91=E9=87=8F=E5=8C=96=E4=B9=9F?= =?UTF-8?q?=E6=94=B9=E6=88=90=E5=B9=B6=E8=A1=8C?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- data_chain/apps/base/task/worker/parse_document_worker.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/data_chain/apps/base/task/worker/parse_document_worker.py b/data_chain/apps/base/task/worker/parse_document_worker.py index 7c8f369c..67d9e3d8 100644 --- a/data_chain/apps/base/task/worker/parse_document_worker.py +++ b/data_chain/apps/base/task/worker/parse_document_worker.py @@ -450,15 +450,17 @@ class ParseDocumentWorker(BaseWorker): '''嵌入chunk''' async def _embedding(node: ParseNode) -> None: node.vector = await Embedding.vectorize_embedding(node.text_feature) + group_size = 32 index = 0 while index < len(parse_result.nodes): sub_nodes = parse_result.nodes[index:index + group_size] task_list = [] for node in sub_nodes: - # 通过asyncio.create_task来异步执行嵌入 + # 与OCR代码风格保持一致 task_list.append(asyncio.create_task(_embedding(node))) - asyncio.run(asyncio.gather(*task_list)) + # 直接await任务集合 + await asyncio.gather(*task_list) index += group_size @staticmethod -- Gitee