From 2ecdbd111158cc804568e26267757f804f8a3b54 Mon Sep 17 00:00:00 2001 From: zxstty Date: Tue, 4 Nov 2025 17:00:53 +0800 Subject: [PATCH] =?UTF-8?q?=E6=9C=8D=E5=8A=A1=E9=87=8D=E5=90=AF=E5=88=9D?= =?UTF-8?q?=E5=A7=8B=E5=8C=96=E4=BB=BB=E5=8A=A1=E6=94=B9=E6=88=90=E5=B9=B6?= =?UTF-8?q?=E8=A1=8C?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- data_chain/apps/base/task/process_handler.py | 2 +- data_chain/apps/service/document_service.py | 12 ++++ data_chain/apps/service/task_queue_service.py | 71 ++++++++++++------- data_chain/manager/task_queue_mamanger.py | 39 ++++++++++ data_chain/stores/database/database.py | 1 - 5 files changed, 98 insertions(+), 27 deletions(-) diff --git a/data_chain/apps/base/task/process_handler.py b/data_chain/apps/base/task/process_handler.py index 3492bd41..afdedac1 100644 --- a/data_chain/apps/base/task/process_handler.py +++ b/data_chain/apps/base/task/process_handler.py @@ -16,7 +16,7 @@ class ProcessHandler: lock = multiprocessing.Lock() # 创建一个锁对象 max_processes = min( max((os.cpu_count() or 1) // 2, 1), - config['USE_CPU_LIMIT']) # 获取CPU核心数作为最大进程数,默认为1 + config['USE_CPU_LIMIT']) time_out = 10 @staticmethod diff --git a/data_chain/apps/service/document_service.py b/data_chain/apps/service/document_service.py index 8f95f190..b9ec1d10 100644 --- a/data_chain/apps/service/document_service.py +++ b/data_chain/apps/service/document_service.py @@ -327,16 +327,28 @@ class DocumentService: shutil.rmtree(tmp_path) os.makedirs(tmp_path) document_file_path = os.path.join(tmp_path, doc.filename) + st = time.time() async with aiofiles.open(document_file_path, "wb") as f: content = await doc.read() await f.write(content) + end_time = time.time() + logging.error( + "[DocumentService] 写入文档到本地耗时: %.2f 秒", end_time - st) + st = time.time() await MinIO.put_object( bucket_name=DOC_PATH_IN_MINIO, file_index=str(id), file_path=document_file_path ) + en = time.time() + logging.error( + "[DocumentService] 上传文档到MinIO耗时: %.2f 秒", en - st) if os.path.exists(tmp_path): + st = time.time() shutil.rmtree(tmp_path) + en = time.time() + logging.error( + "[DocumentService] 删除临时文件夹耗时: %.2f 秒", en - st) doc_entity = DocumentEntity( id=id, team_id=kb_entity.team_id, diff --git a/data_chain/apps/service/task_queue_service.py b/data_chain/apps/service/task_queue_service.py index 259cbac6..d9905abd 100644 --- a/data_chain/apps/service/task_queue_service.py +++ b/data_chain/apps/service/task_queue_service.py @@ -16,48 +16,69 @@ class TaskQueueService: @staticmethod async def init_task_queue(): + task_need_pending_ids = [] + task_need_delete_ids = [] + task_entities_need_add = [] import time st = time.time() - task_entities = await TaskManager.list_task_by_task_status(TaskStatus.PENDING.value) + pending_task_entities = await TaskManager.list_task_by_task_status(TaskStatus.PENDING.value) en = time.time() logging.info(f"[TaskQueueService] 获取待处理任务耗时 {en-st} 秒") + pending_task_ids = [ + task_entity.id for task_entity in pending_task_entities] + pending_task_entities_in_db = await TaskQueueManager.get_tasks_by_ids(pending_task_ids) + pending_task_ids_in_db = [ + task_entity.id for task_entity in pending_task_entities_in_db] + pending_task_ids_not_in_db = list( + set(pending_task_ids) - set(pending_task_ids_in_db)) + for task_id in pending_task_ids_not_in_db: + task_entities_need_add.append(TaskQueueEntity( + id=task_id, status=TaskStatus.PENDING.value)) + st = time.time() - task_entities += await TaskManager.list_task_by_task_status(TaskStatus.RUNNING.value) + running_task_entities = await TaskManager.list_task_by_task_status(TaskStatus.RUNNING.value) en = time.time() logging.info(f"[TaskQueueService] 获取运行中任务耗时 {en-st} 秒") - for task_entity in task_entities: + for task_entity in running_task_entities: # 将所有任务取消 - try: - if task_entity.status == TaskStatus.RUNNING.value: + st = time.time() + flag = await BaseWorker.reinit(task_entity.id) + en = time.time() + logging.info(f"[TaskQueueService] 重新初始化任务耗时 {en-st} 秒") + if flag: st = time.time() - flag = await BaseWorker.reinit(task_entity.id) + task_need_pending_ids.append(task_entity.id) en = time.time() - logging.info(f"[TaskQueueService] 重新初始化任务耗时 {en-st} 秒") - if flag: - st = time.time() - await TaskQueueManager.update_task_by_id(task_entity.id, TaskStatus.PENDING) - en = time.time() - else: - st = time.time() - await BaseWorker.stop(task_entity.id) - await TaskQueueManager.delete_task_by_id(task_entity.id) - en = time.time() else: st = time.time() - task = await TaskQueueManager.get_task_by_id(task_entity.id) + task_need_delete_ids.append(task_entity.id) en = time.time() - logging.info(f"[TaskQueueService] 获取任务耗时 {en-st} 秒") - if task is None: - st = time.time() - task = TaskQueueEntity( - id=task_entity.id, status=TaskStatus.PENDING.value) - await TaskQueueManager.add_task(task) - en = time.time() - logging.info(f"[TaskQueueService] 添加任务耗时 {en-st} 秒") except Exception as e: warning = f"[TaskQueueService] 初始化任务失败 {e}" logging.warning(warning) + batch_size = 1024 + if len(task_need_pending_ids) > 0: + st = time.time() + for i in range(0, len(task_need_pending_ids), batch_size): + await TaskQueueManager.update_task_by_ids( + task_need_pending_ids[i:i+batch_size], TaskStatus.PENDING) + en = time.time() + logging.info(f"[TaskQueueService] 更新待处理任务状态耗时 {en-st} 秒") + if len(task_need_delete_ids) > 0: + st = time.time() + for i in range(0, len(task_need_delete_ids), batch_size): + await TaskQueueManager.delete_tasks_by_ids( + task_need_delete_ids[i:i+batch_size]) + en = time.time() + logging.info(f"[TaskQueueService] 删除任务耗时 {en-st} 秒") + if len(task_entities_need_add) > 0: + st = time.time() + for i in range(0, len(task_entities_need_add), batch_size): + await TaskQueueManager.add_tasks( + task_entities_need_add[i:i+batch_size]) + en = time.time() + logging.info(f"[TaskQueueService] 批量添加任务耗时 {en-st} 秒") @staticmethod async def init_task(task_type: str, op_id: uuid.UUID) -> uuid.UUID: diff --git a/data_chain/manager/task_queue_mamanger.py b/data_chain/manager/task_queue_mamanger.py index 2d1c33f0..95b1271e 100644 --- a/data_chain/manager/task_queue_mamanger.py +++ b/data_chain/manager/task_queue_mamanger.py @@ -22,6 +22,17 @@ class TaskQueueManager(): logging.exception("[TaskQueueManager] %s", err) raise e + @staticmethod + async def add_tasks(tasks: List[TaskQueueEntity]): + try: + async with await DataBase.get_session() as session: + session.add_all(tasks) + await session.commit() + except Exception as e: + err = "批量添加任务到队列失败" + logging.exception("[TaskQueueManager] %s", err) + raise e + @staticmethod async def delete_task_by_id(task_id: uuid.UUID): """根据任务ID删除任务""" @@ -36,6 +47,20 @@ class TaskQueueManager(): logging.exception("[TaskQueueManager] %s", err) raise e + @staticmethod + async def delete_tasks_by_ids(task_ids: List[uuid.UUID]): + """根据任务ID列表批量删除任务""" + try: + async with await DataBase.get_session() as session: + stmt = delete(TaskQueueEntity).where( + TaskQueueEntity.id.in_(task_ids)) + await session.execute(stmt) + await session.commit() + except Exception as e: + err = "批量删除任务失败" + logging.exception("[TaskQueueManager] %s", err) + raise e + @staticmethod async def get_oldest_tasks_by_status(status: TaskStatus) -> Optional[TaskQueueEntity]: """根据任务状态获取最早的任务""" @@ -68,6 +93,20 @@ class TaskQueueManager(): logging.exception("[TaskQueueManager] %s", err) raise e + @staticmethod + async def get_tasks_by_ids(task_ids: List[uuid.UUID]) -> List[TaskQueueEntity]: + """根据任务ID列表批量获取任务""" + try: + async with await DataBase.get_session() as session: + stmt = select(TaskQueueEntity).where( + TaskQueueEntity.id.in_(task_ids)) + result = await session.execute(stmt) + return result.scalars().all() + except Exception as e: + err = "批量获取任务失败" + logging.exception("[TaskQueueManager] %s", err) + raise e + @staticmethod async def update_task_by_id(task_id: uuid.UUID, status: TaskStatus): """根据任务ID更新任务""" diff --git a/data_chain/stores/database/database.py b/data_chain/stores/database/database.py index 429a862e..96efb585 100644 --- a/data_chain/stores/database/database.py +++ b/data_chain/stores/database/database.py @@ -735,7 +735,6 @@ class DataBase: pool_size = os.cpu_count() if pool_size is None: pool_size = 5 - logging.error(f"Database pool size set to: {pool_size}") engine = create_async_engine( database_url, echo=False, -- Gitee