From 52ddab28c0d8f2e5b0b61c958d0f5b2593636ac6 Mon Sep 17 00:00:00 2001 From: zxstty Date: Fri, 31 Oct 2025 11:55:12 +0800 Subject: [PATCH 1/4] =?UTF-8?q?=E5=B9=B6=E8=A1=8C=E6=A3=80=E7=B4=A2?= =?UTF-8?q?=E5=A4=9A=E4=B8=AA=E7=9F=A5=E8=AF=86=E5=BA=93?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- data_chain/apps/service/chunk_service.py | 61 ++++++++++++++---------- 1 file changed, 35 insertions(+), 26 deletions(-) diff --git a/data_chain/apps/service/chunk_service.py b/data_chain/apps/service/chunk_service.py index f6c51d46..1d0fee5a 100644 --- a/data_chain/apps/service/chunk_service.py +++ b/data_chain/apps/service/chunk_service.py @@ -1,5 +1,5 @@ # Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. -import aiofiles +import asyncio from fastapi import APIRouter, Depends, Query, Body, File, UploadFile import uuid import traceback @@ -28,7 +28,7 @@ from data_chain.manager.chunk_manager import ChunkManager from data_chain.manager.role_manager import RoleManager from data_chain.manager.task_manager import TaskManager from data_chain.manager.task_report_manager import TaskReportManager -from data_chain.stores.database.database import DocumentEntity +from data_chain.stores.database.database import ChunkEntity from data_chain.stores.minio.minio import MinIO from data_chain.entities.enum import ParseMethod, DataSetStatus, DocumentStatus, TaskType from data_chain.entities.common import DOC_PATH_IN_OS, DOC_PATH_IN_MINIO, DEFAULT_KNOWLEDGE_BASE_ID, DEFAULT_DOC_TYPE_ID @@ -75,35 +75,44 @@ class ChunkService: logging.exception("[ChunkService] %s", err) raise e + async def search_chunks_from_kb(user_sub: str, action: str, search_method: str, kb_id: uuid.UUID, query: str, top_k: int, doc_ids: list[uuid.UUID] = None, banned_ids: list[uuid.UUID] = [], is_rerank: bool = False) -> list[ChunkEntity]: + """从知识库搜索分片""" + kb_entity = await KnowledgeBaseManager.get_knowledge_base_by_kb_id(kb_id) + if kb_entity is None: + err = f"知识库不存在,知识库ID: {kb_id}" + logging.error("[ChunkService] %s", err) + return [] + if kb_id != DEFAULT_KNOWLEDGE_BASE_ID and not await KnowledgeBaseService.validate_user_action_to_knowledge_base(user_sub, kb_id, action): + err = f"用户没有权限访问该知识库,知识库ID: {kb_id}" + logging.error("[ChunkService] %s", err) + return [] + top_k_search = top_k + if is_rerank: + top_k_search = top_k * 3 + try: + chunk_entities = await BaseSearcher.search(search_method, kb_id, query, top_k_search, doc_ids, banned_ids) + except Exception as e: + err = f"搜索分片失败,error: {e}" + logging.exception("[ChunkService] %s", err) + return [] + if is_rerank: + chunk_indexs = await BaseSearcher.rerank(chunk_entities, kb_entity.rerank_method, query) + chunk_entities = [chunk_entities[i] for i in chunk_indexs] + chunk_entities = chunk_entities[:top_k] + return chunk_entities + async def search_chunks(user_sub: str, action: str, req: SearchChunkRequest) -> SearchChunkMsg: """根据查询条件搜索分片""" logging.error("[ChunkService] 搜索分片,查询条件: %s", req) chunk_entities = [] + search_tasks = [] for kb_id in req.kb_ids: - try: - kb_entity = await KnowledgeBaseManager.get_knowledge_base_by_kb_id(kb_id) - if kb_entity is None: - err = f"知识库不存在,知识库ID: {kb_id}" - logging.warning("[ChunkService] %s", err) - continue - if kb_id != DEFAULT_KNOWLEDGE_BASE_ID and not await KnowledgeBaseService.validate_user_action_to_knowledge_base(user_sub, kb_id, action): - err = f"用户没有权限访问该知识库,知识库ID: {kb_id}" - logging.warning("[ChunkService] %s", err) - continue - top_k = req.top_k - if req.is_rerank: - top_k = req.top_k*3 - sub_chunk_entities = await BaseSearcher.search(req.search_method.value, kb_id, req.query, top_k, req.doc_ids, req.banned_ids) - if req.is_rerank: - sub_chunk_indexs = await BaseSearcher.rerank(sub_chunk_entities, kb_entity.rerank_method, req.query) - sub_chunk_entities = [sub_chunk_entities[i] - for i in sub_chunk_indexs] - sub_chunk_entities = sub_chunk_entities[:req.top_k] - chunk_entities += sub_chunk_entities - except Exception as e: - err = f"[ChunkService] 搜索分片失败,error: {e}" - logging.exception(err) - return SearchChunkMsg(docChunks=[]) + search_task = ChunkService.search_chunks_from_kb( + user_sub, action, req.search_method, kb_id, req.query, req.top_k, req.doc_ids, req.banned_ids, req.is_rerank) + search_tasks.append(search_task) + search_results = await asyncio.gather(*search_tasks) + for search_result in search_results: + chunk_entities += search_result if len(chunk_entities) == 0: return SearchChunkMsg(docChunks=[]) if req.is_rerank: -- Gitee From dd37ae1cfd6c6c21e5e04cdb3f375f1fb2aa19cb Mon Sep 17 00:00:00 2001 From: zxstty Date: Sun, 2 Nov 2025 22:59:01 +0800 Subject: [PATCH 2/4] =?UTF-8?q?=E4=BD=BF=E7=94=A8bm25=E4=BB=A3=E6=9B=BFcha?= =?UTF-8?q?rper=E5=88=86=E8=AF=8D=E5=99=A8?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- data_chain/manager/chunk_manager.py | 75 +++---- data_chain/manager/document_manager.py | 80 ++++--- ...ic_weighted_keyword_and_vector_searcher.py | 19 +- data_chain/stores/database/database.py | 198 +++++++++--------- 4 files changed, 185 insertions(+), 187 deletions(-) diff --git a/data_chain/manager/chunk_manager.py b/data_chain/manager/chunk_manager.py index d0c9ac67..8cbb4f4a 100644 --- a/data_chain/manager/chunk_manager.py +++ b/data_chain/manager/chunk_manager.py @@ -1,5 +1,5 @@ # Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. -from sqlalchemy import select, update, func, text, or_, and_, Float, literal_column +from sqlalchemy import select, update, func, text, or_, and_, Float, literal_column, true from typing import List, Tuple, Dict, Optional import uuid from datetime import datetime @@ -197,7 +197,7 @@ class ChunkManager(): else: # PostgreSQL需要将向量转换为字符串格式 vector_param = str(vector) - + params = { "vector": vector_param, "kb_id": kb_id, @@ -395,22 +395,11 @@ class ChunkManager(): kb_id: uuid.UUID, keywords: List[str], weights: List[float], top_k: int, doc_ids: list[uuid.UUID] = None, banned_ids: list[uuid.UUID] = [], chunk_to_type: str = None, pre_ids: list[uuid.UUID] = None) -> List[ChunkEntity]: - """根据知识库ID和关键词权重查询文档解析结果(修复NoneType报错+强制索引)""" + """根据知识库ID和关键词权重查询文档解析结果(使用BM25打分,修复CTE定义顺序)""" try: st = datetime.now() async with await DataBase.get_session() as session: - # 1. 分词器选择(保留原逻辑) - kb_entity = await KnowledgeBaseManager.get_knowledge_base_by_kb_id(kb_id) - if kb_entity.tokenizer == Tokenizer.ZH.value: - tokenizer = 'chparser' if config['DATABASE_TYPE'].lower( - ) == 'opengauss' else 'zhparser' - elif kb_entity.tokenizer == Tokenizer.EN.value: - tokenizer = 'english' - else: - tokenizer = 'chparser' if config['DATABASE_TYPE'].lower( - ) == 'opengauss' else 'zhparser' - - # 2. 构建加权关键词CTE(保留原逻辑) + # 1. 构建加权关键词CTE(提前定义,确保后续能引用) params = {} values_clause = [] for idx, (term, weight) in enumerate(zip(keywords, weights)): @@ -425,25 +414,25 @@ class ChunkManager(): literal_column("t.weight").cast(Float).label("weight") ) .select_from(text(values_text)) - .cte("weighted_terms") + .cte("weighted_terms") # 定义为CTE,供后续查询使用 ) - # 3. 初始化查询(确保stmt始终是Select对象,不直接赋值None) + # 2. 初始化查询(此时weighted_terms已定义,可正常引用) stmt = ( select( ChunkEntity, func.sum( - func.ts_rank_cd(ChunkEntity.text_ts_vector, func.to_tsquery( - tokenizer, weighted_terms.c.term)) + # 使用BM25运算符<&>计算相似度,乘以权重后求和 + ChunkEntity.text.op('<&>')(weighted_terms.c.term) * weighted_terms.c.weight ).label("similarity_score") ) # 关联文档表 .join(DocumentEntity, DocumentEntity.id == ChunkEntity.doc_id) - .join( # 关联CTE+强制触发GIN索引(核心优化) + # 关联加权关键词CTE(通过BM25匹配有分数的结果) + .join( weighted_terms, - ChunkEntity.text_ts_vector.op( - '@@')(func.to_tsquery(tokenizer, weighted_terms.c.term)), + ChunkEntity.text.op('<&>')(weighted_terms.c.term) > 0, isouter=False ) # 基础过滤条件 @@ -454,55 +443,57 @@ class ChunkManager(): .where(ChunkEntity.status != ChunkStatus.DELETED.value) ) - # 4. 动态条件:禁用ID(修复关键:用if-else确保stmt不被赋值为None) + # 3. 动态条件:禁用ID if banned_ids: stmt = stmt.where(ChunkEntity.id.notin_(banned_ids)) - # 5. 其他动态条件(同样用if-else确保链式调用不中断) + # 4. 其他动态条件 if doc_ids is not None: stmt = stmt.where(DocumentEntity.id.in_(doc_ids)) if chunk_to_type is not None: - stmt = stmt.where( - ChunkEntity.parse_topology_type == chunk_to_type) + stmt = stmt.where(ChunkEntity.parse_topology_type == chunk_to_type) if pre_ids is not None: - stmt = stmt.where( - ChunkEntity.pre_id_in_parse_topology.in_(pre_ids)) + stmt = stmt.where(ChunkEntity.pre_id_in_parse_topology.in_(pre_ids)) - # 6. 分组、过滤分数、排序、限制行数(链式调用安全) + # 5. 分组、过滤、排序、限制(使用BM25加权总分) stmt = (stmt - .group_by(ChunkEntity.id) # 按chunk分组计算总权重 - .having( # 过滤总分数>0的结果 + .group_by(ChunkEntity.id) + .having( func.sum( - func.ts_rank_cd(ChunkEntity.text_ts_vector, func.to_tsquery( - tokenizer, weighted_terms.c.term)) + ChunkEntity.text.op('<&>')(weighted_terms.c.term) * weighted_terms.c.weight ) > 0 ) - .order_by( # 按总分数降序 + .order_by( func.sum( - func.ts_rank_cd(ChunkEntity.text_ts_vector, func.to_tsquery( - tokenizer, weighted_terms.c.term)) + ChunkEntity.text.op('<&>')(weighted_terms.c.term) * weighted_terms.c.weight ).desc() ) - .limit(top_k) # 限制返回数量 + .limit(top_k) ) - # 7. 执行查询与结果处理(保留原逻辑) + # 6. 强制使用BM25索引(关键修正:匹配实际索引名+正确Hint语法) + if config['DATABASE_TYPE'].lower() == 'opengauss': + # 实际BM25索引名:text_bm25_index;表名:chunk(从\d+结果确认) + table_name = ChunkEntity.__tablename__ # 若映射正确,应为'chunk' + bm25_index_name = "text_bm25_index" # 与\d+结果中的索引名完全一致 + stmt = stmt.prefix_with(f"/*+ INDEX({table_name} {bm25_index_name}) */") + + # 7. 执行查询与结果处理 result = await session.execute(stmt, params=params) chunk_entities = result.scalars().all() # 8. 日志输出 cost = (datetime.now() - st).total_seconds() logging.warning( - f"[ChunkManager] get_top_k_chunk_by_kb_id_dynamic_weighted_keyword cost: {cost}s " - f"| kb_id: {kb_id} | keywords: {keywords[:2]}... | match_count: {len(chunk_entities)}" + f"[ChunkManager] BM25查询耗时: {cost}s " + f"| kb_id: {kb_id} | keywords: {keywords[:2]}... | 匹配数量: {len(chunk_entities)}" ) return chunk_entities except Exception as e: - # 异常日志补充关键上下文 - err = f"根据知识库ID和关键词权重查询失败: kb_id={kb_id}, keywords={keywords[:2]}..., error={str(e)[:150]}" + err = f"BM25查询失败: kb_id={kb_id}, keywords={keywords[:2]}..., error={str(e)[:150]}" logging.exception("[ChunkManager] %s", err) return [] diff --git a/data_chain/manager/document_manager.py b/data_chain/manager/document_manager.py index f03433b6..1ab38966 100644 --- a/data_chain/manager/document_manager.py +++ b/data_chain/manager/document_manager.py @@ -208,26 +208,15 @@ class DocumentManager(): err = "获取前K个文档失败" logging.exception("[DocumentManager] %s", err) raise e - + @staticmethod async def get_top_k_document_by_kb_id_dynamic_weighted_keyword( - kb_id: uuid.UUID, keywords: List[str], weights: List[float], - top_k: int, doc_ids: list[uuid.UUID] = None, banned_ids: list[uuid.UUID] = []) -> List[DocumentEntity]: - """根据知识库ID和关键词权重查询文档(修复NoneType报错+强制索引)""" + kb_id: uuid.UUID, keywords: List[str], weights: List[float], + top_k: int, doc_ids: list[uuid.UUID] = None, banned_ids: list[uuid.UUID] = []) -> List[DocumentEntity]: + """根据知识库ID和关键词权重查询文档(BM25检索版,匹配abstract_bm25_index索引)""" try: - st = datetime.now() # 新增计时日志 + st = datetime.now() async with await DataBase.get_session() as session: - # 1. 分词器选择(与第一个方法保持一致) - kb_entity = await KnowledgeBaseManager.get_knowledge_base_by_kb_id(kb_id) - if kb_entity.tokenizer == Tokenizer.ZH.value: - tokenizer = 'chparser' if config['DATABASE_TYPE'].lower( - ) == 'opengauss' else 'zhparser' - elif kb_entity.tokenizer == Tokenizer.EN.value: - tokenizer = 'english' - else: - tokenizer = 'chparser' if config['DATABASE_TYPE'].lower( - ) == 'opengauss' else 'zhparser' - - # 2. 构建加权关键词CTE(保留原逻辑) + # 1. 构建加权关键词CTE(保留原权重逻辑,移除分词器依赖) params = {} values_clause = [] for idx, (term, weight) in enumerate(zip(keywords, weights)): @@ -245,72 +234,81 @@ class DocumentManager(): .cte("weighted_terms") ) - # 3. 初始化查询(确保stmt始终是Select对象) + # 2. 初始化查询(替换为BM25评分,关联BM25索引字段abstract) stmt = ( select( DocumentEntity, func.sum( - func.ts_rank_cd(DocumentEntity.abstract_ts_vector, func.to_tsquery( - tokenizer, weighted_terms.c.term)) + # BM25相似度计算:abstract <&> 关键词,再乘以关键词权重 + DocumentEntity.abstract.op('<&>')(weighted_terms.c.term) * weighted_terms.c.weight ).label("similarity_score") ) - # 关联CTE+强制触发GIN索引(核心优化) + # 关联加权关键词CTE:仅保留BM25有匹配的文档(避免无效计算) .join( weighted_terms, - DocumentEntity.abstract_ts_vector.op( - '@@')(func.to_tsquery(tokenizer, weighted_terms.c.term)), + # 条件:BM25相似度结果非空(确保触发BM25索引) + DocumentEntity.abstract.op('<&>')(weighted_terms.c.term).isnot(None), isouter=False ) - # 基础过滤条件 + # 基础过滤条件(与原逻辑一致) .where(DocumentEntity.enabled == True) .where(DocumentEntity.status != DocumentStatus.DELETED.value) .where(DocumentEntity.kb_id == kb_id) ) - # 4. 动态条件:禁用ID(确保stmt链式调用不中断) + # 3. 动态条件:禁用ID(保持原逻辑) if banned_ids: stmt = stmt.where(DocumentEntity.id.notin_(banned_ids)) - # 5. 其他动态条件 + # 4. 其他动态条件:指定文档ID过滤(保持原逻辑) if doc_ids is not None: stmt = stmt.where(DocumentEntity.id.in_(doc_ids)) - # 6. 分组、过滤分数、排序、限制行数(链式调用安全) + # 5. 分组、过滤、排序、限制(基于BM25加权总分) stmt = (stmt - .group_by(DocumentEntity.id) # 按文档ID分组计算总权重 - .having( # 过滤总分数>0的结果 + .group_by(DocumentEntity.id) # 按文档ID分组,计算单文档总相似度 + .having( + # 过滤总相似度>0的文档(排除无匹配的结果) func.sum( - func.ts_rank_cd(DocumentEntity.abstract_ts_vector, func.to_tsquery( - tokenizer, weighted_terms.c.term)) + DocumentEntity.abstract.op('<&>')(weighted_terms.c.term) * weighted_terms.c.weight ) > 0 ) - .order_by( # 按总分数降序 + .order_by( + # 按BM25总相似度降序,取Top-K func.sum( - func.ts_rank_cd(DocumentEntity.abstract_ts_vector, func.to_tsquery( - tokenizer, weighted_terms.c.term)) + DocumentEntity.abstract.op('<&>')(weighted_terms.c.term) * weighted_terms.c.weight ).desc() ) - .limit(top_k) # 限制返回数量 + .limit(top_k) ) - # 7. 执行查询与结果处理 + # 6. 强制使用BM25索引(关键:匹配实际索引名abstract_bm25_index) + if config['DATABASE_TYPE'].lower() == 'opengauss': + # 表名:DocumentEntity对应的表名(通常为'document',需确保ORM映射正确) + table_name = DocumentEntity.__tablename__ + # 实际BM25索引名:从__table_args__确认是abstract_bm25_index + bm25_index_name = "abstract_bm25_index" + # 使用openGauss推荐的Hint语法,强制索引扫描 + stmt = stmt.prefix_with(f"/*+ INDEX({table_name} {bm25_index_name}) */") + + # 7. 执行查询与结果处理(保持原逻辑) result = await session.execute(stmt, params=params) doc_entities = result.scalars().all() - # 8. 新增执行时间日志 + # 8. 日志输出(补充BM25标识,便于排查) cost = (datetime.now() - st).total_seconds() logging.warning( - f"[DocumentManager] get_top_k_document_by_kb_id_dynamic_weighted_keyword cost: {cost}s " - f"| kb_id: {kb_id} | keywords: {keywords[:2]}... | match_count: {len(doc_entities)}" + f"[DocumentManager] BM25检索文档耗时: {cost}s " + f"| kb_id: {kb_id} | keywords: {keywords[:2]}... | 匹配数量: {len(doc_entities)}" ) return doc_entities except Exception as e: - # 异常日志补充关键上下文 - err = f"根据知识库ID和关键词权重查询文档失败: kb_id={kb_id}, keywords={keywords[:2]}..., error={str(e)[:150]}" + # 异常日志补充BM25检索上下文 + err = f"BM25检索文档失败: kb_id={kb_id}, keywords={keywords[:2]}..., error={str(e)[:150]}" logging.exception("[DocumentManager] %s", err) return [] diff --git a/data_chain/rag/dynamic_weighted_keyword_and_vector_searcher.py b/data_chain/rag/dynamic_weighted_keyword_and_vector_searcher.py index a5467d41..500288ef 100644 --- a/data_chain/rag/dynamic_weighted_keyword_and_vector_searcher.py +++ b/data_chain/rag/dynamic_weighted_keyword_and_vector_searcher.py @@ -29,35 +29,42 @@ class DynamicKeywordVectorSearcher(BaseSearcher): :return: 检索结果 """ vector = await Embedding.vectorize_embedding(query) + logging.error(f"[DynamicKeywordVectorSearcher] vector {vector}") try: chunk_entities_get_by_keyword = await ChunkManager.get_top_k_chunk_by_kb_id_keyword( kb_id, query, max(top_k//3, 1), doc_ids, banned_ids) banned_ids += [chunk_entity.id for chunk_entity in chunk_entities_get_by_keyword] keywords, weights = TokenTool.get_top_k_keywords_and_weights(query) - logging.error(f"[DynamicKeywordVectorSearcher] keywords: {keywords}, weights: {weights}") + logging.error( + f"[DynamicKeywordVectorSearcher] keywords: {keywords}, weights: {weights}") import time start_time = time.time() chunk_entities_get_by_dynamic_weighted_keyword = await ChunkManager.get_top_k_chunk_by_kb_id_dynamic_weighted_keyword(kb_id, keywords, weights, top_k//2, doc_ids, banned_ids) end_time = time.time() - logging.info(f"[DynamicKeywordVectorSearcher] 动态关键字检索成功完成,耗时: {end_time - start_time:.2f}秒") + logging.info( + f"[DynamicKeywordVectorSearcher] 动态关键字检索成功完成,耗时: {end_time - start_time:.2f}秒") banned_ids += [chunk_entity.id for chunk_entity in chunk_entities_get_by_dynamic_weighted_keyword] chunk_entities_get_by_vector = [] for _ in range(3): try: import time start_time = time.time() - logging.error(f"[DynamicKeywordVectorSearcher] 开始进行向量检索,top_k: {top_k-len(chunk_entities_get_by_keyword)-len(chunk_entities_get_by_dynamic_weighted_keyword)}") + logging.error( + f"[DynamicKeywordVectorSearcher] 开始进行向量检索,top_k: {top_k-len(chunk_entities_get_by_keyword)-len(chunk_entities_get_by_dynamic_weighted_keyword)}") chunk_entities_get_by_vector = await asyncio.wait_for(ChunkManager.get_top_k_chunk_by_kb_id_vector(kb_id, vector, top_k-len(chunk_entities_get_by_keyword)-len(chunk_entities_get_by_dynamic_weighted_keyword), doc_ids, banned_ids), timeout=20) end_time = time.time() - logging.info(f"[DynamicKeywordVectorSearcher] 向量检索成功完成,耗时: {end_time - start_time:.2f}秒") + logging.info( + f"[DynamicKeywordVectorSearcher] 向量检索成功完成,耗时: {end_time - start_time:.2f}秒") break except Exception as e: import traceback err = f"[DynamicKeywordVectorSearcher] 向量检索失败,error: {e}, traceback: {traceback.format_exc()}" logging.error(err) continue - logging.error(f"[DynamicKeywordVectorSearcher] chunk_entities_get_by_keyword: {len(chunk_entities_get_by_keyword)}, chunk_entities_get_by_dynamic_weighted_keyword: {len(chunk_entities_get_by_dynamic_weighted_keyword)}, chunk_entities_get_by_vector: {len(chunk_entities_get_by_vector)}") - chunk_entities = chunk_entities_get_by_keyword + chunk_entities_get_by_dynamic_weighted_keyword + chunk_entities_get_by_vector + logging.error( + f"[DynamicKeywordVectorSearcher] chunk_entities_get_by_keyword: {len(chunk_entities_get_by_keyword)}, chunk_entities_get_by_dynamic_weighted_keyword: {len(chunk_entities_get_by_dynamic_weighted_keyword)}, chunk_entities_get_by_vector: {len(chunk_entities_get_by_vector)}") + chunk_entities = chunk_entities_get_by_keyword + \ + chunk_entities_get_by_dynamic_weighted_keyword + chunk_entities_get_by_vector except Exception as e: err = f"[DynamicKeywordVectorSearcher] 关键词向量检索失败,error: {e}" logging.exception(err) diff --git a/data_chain/stores/database/database.py b/data_chain/stores/database/database.py index 7d159913..daae2620 100644 --- a/data_chain/stores/database/database.py +++ b/data_chain/stores/database/database.py @@ -7,7 +7,7 @@ from uuid import uuid4 import urllib.parse from data_chain.logger.logger import logger as logging from pgvector.sqlalchemy import Vector -from sqlalchemy import Boolean, Column, ForeignKey, BigInteger, Float, String, func +from sqlalchemy import Boolean, Column, ForeignKey, BigInteger, Float, Text, func from sqlalchemy.types import TIMESTAMP, UUID from sqlalchemy.dialects.postgresql import TSVECTOR from sqlalchemy.orm import declarative_base, DeclarativeBase, MappedAsDataclass, Mapped, mapped_column @@ -45,13 +45,13 @@ class TeamEntity(Base): __tablename__ = 'team' id = Column(UUID, default=uuid4, primary_key=True) - author_id = Column(String) - author_name = Column(String) - name = Column(String) - description = Column(String) + author_id = Column(Text) + author_name = Column(Text) + name = Column(Text) + description = Column(Text) member_cnt = Column(BigInteger, default=0) is_public = Column(Boolean, default=True) - status = Column(String, default=TeamStatus.EXISTED.value) + status = Column(Text, default=TeamStatus.EXISTED.value) created_time = Column( TIMESTAMP(timezone=True), nullable=True, @@ -69,12 +69,12 @@ class TeamMessageEntity(Base): id = Column(UUID, default=uuid4, primary_key=True) team_id = Column(UUID, ForeignKey('team.id')) - author_id = Column(String) - author_name = Column(String) - message_level = Column(String) - zh_message = Column(String, default='') - en_message = Column(String, default='') - status = Column(String, default=TeamMessageStatus.EXISTED.value) + author_id = Column(Text) + author_name = Column(Text) + message_level = Column(Text) + zh_message = Column(Text, default='') + en_message = Column(Text, default='') + status = Column(Text, default=TeamMessageStatus.EXISTED.value) created_time = Column( TIMESTAMP(timezone=True), nullable=True, @@ -92,10 +92,10 @@ class RoleEntity(Base): id = Column(UUID, default=uuid4, primary_key=True) team_id = Column(UUID, ForeignKey('team.id')) - name = Column(String) + name = Column(Text) is_unique = Column(Boolean, default=False) editable = Column(Boolean, default=True) - status = Column(String, default=RoleStatus.EXISTED.value) # 角色状态 + status = Column(Text, default=RoleStatus.EXISTED.value) # 角色状态 created_time = Column( TIMESTAMP(timezone=True), nullable=True, @@ -111,9 +111,9 @@ class RoleEntity(Base): class ActionEntity(Base): __tablename__ = 'action' - action = Column(String, primary_key=True) - name = Column(String) - type = Column(String) + action = Column(Text, primary_key=True) + name = Column(Text) + type = Column(Text) created_time = Column( TIMESTAMP(timezone=True), nullable=True, @@ -131,8 +131,8 @@ class RoleActionEntity(Base): id = Column(UUID, default=uuid4, primary_key=True) role_id = Column(UUID, ForeignKey('role.id', ondelete="CASCADE")) - action = Column(String) - status = Column(String, default=RoleActionStatus.EXISTED.value) + action = Column(Text) + status = Column(Text, default=RoleActionStatus.EXISTED.value) created_time = Column( TIMESTAMP(timezone=True), nullable=True, @@ -148,9 +148,9 @@ class RoleActionEntity(Base): class UserEntity(Base): __tablename__ = 'users' - id = Column(String, primary_key=True) - name = Column(String) - status = Column(String, default=UserStatus.ACTIVE.value) + id = Column(Text, primary_key=True) + name = Column(Text) + status = Column(Text, default=UserStatus.ACTIVE.value) created_time = Column( TIMESTAMP(timezone=True), nullable=True, @@ -168,17 +168,17 @@ class UserMessageEntity(Base): id = Column(UUID, default=uuid4, primary_key=True) team_id = Column(UUID) - team_name = Column(String) + team_name = Column(Text) role_id = Column(UUID) - sender_id = Column(String) - sender_name = Column(String) - status_to_sender = Column(String, default=UserMessageStatus.UNREAD.value) - receiver_id = Column(String) - receiver_name = Column(String) + sender_id = Column(Text) + sender_name = Column(Text) + status_to_sender = Column(Text, default=UserMessageStatus.UNREAD.value) + receiver_id = Column(Text) + receiver_name = Column(Text) is_to_all = Column(Boolean, default=False) - status_to_receiver = Column(String, default=UserMessageStatus.UNREAD.value) - message = Column(String) - type = Column(String) + status_to_receiver = Column(Text, default=UserMessageStatus.UNREAD.value) + message = Column(Text) + type = Column(Text) created_time = Column( TIMESTAMP(timezone=True), nullable=True, @@ -191,8 +191,8 @@ class TeamUserEntity(Base): id = Column(UUID, default=uuid4, primary_key=True) team_id = Column(UUID, ForeignKey('team.id', ondelete="CASCADE")) # 团队id - user_id = Column(String) # 用户id - status = Column(String, default=TeamUserStaus.EXISTED.value) # 用户在团队中的状态 + user_id = Column(Text) # 用户id + status = Column(Text, default=TeamUserStaus.EXISTED.value) # 用户在团队中的状态 created_time = Column( TIMESTAMP(timezone=True), nullable=True, @@ -209,9 +209,9 @@ class UserRoleEntity(Base): __tablename__ = 'user_role' id = Column(UUID, default=uuid4, primary_key=True) team_id = Column(UUID, ForeignKey('team.id', ondelete="CASCADE")) # 团队id - user_id = Column(String) # 用户id + user_id = Column(Text) # 用户id role_id = Column(UUID) # 角色id - status = Column(String, default=UserRoleStatus.EXISTED.value) # 用户角色状态 + status = Column(Text, default=UserRoleStatus.EXISTED.value) # 用户角色状态 created_time = Column( TIMESTAMP(timezone=True), nullable=True, @@ -230,23 +230,23 @@ class KnowledgeBaseEntity(Base): id = Column(UUID, default=uuid4, primary_key=True) team_id = Column(UUID, ForeignKey( 'team.id', ondelete="CASCADE"), nullable=True) # 团队id - author_id = Column(String) # 作者id - author_name = Column(String) # 作者名称 - name = Column(String, default='') # 知识库名资产名 - tokenizer = Column(String, default=Tokenizer.ZH.value) # 分词器 - description = Column(String, default='') # 资产描述 - embedding_model = Column(String) # 资产向量化模型 - rerank_method = Column(String) - rerank_name = Column(String) - spearating_characters = Column(String) # 资产分块的分隔符 + author_id = Column(Text) # 作者id + author_name = Column(Text) # 作者名称 + name = Column(Text, default='') # 知识库名资产名 + tokenizer = Column(Text, default=Tokenizer.ZH.value) # 分词器 + description = Column(Text, default='') # 资产描述 + embedding_model = Column(Text) # 资产向量化模型 + rerank_method = Column(Text) + rerank_name = Column(Text) + spearating_characters = Column(Text) # 资产分块的分隔符 doc_cnt = Column(BigInteger, default=0) # 资产文档个数 doc_size = Column(BigInteger, default=0) # 资产下所有文档大小(TODO: 单位kb或者字节) upload_count_limit = Column(BigInteger, default=128) # 更新次数限制 upload_size_limit = Column(BigInteger, default=512) # 更新大小限制 default_parse_method = Column( - String, default=ParseMethod.GENERAL.value) # 默认解析方法 + Text, default=ParseMethod.GENERAL.value) # 默认解析方法 default_chunk_size = Column(BigInteger, default=1024) # 默认分块大小 - status = Column(String, default=KnowledgeBaseStatus.IDLE.value) + status = Column(Text, default=KnowledgeBaseStatus.IDLE.value) created_time = Column( TIMESTAMP(timezone=True), nullable=True, @@ -265,7 +265,7 @@ class DocumentTypeEntity(Base): id = Column(UUID, default=uuid4, primary_key=True) kb_id = Column(UUID, ForeignKey('knowledge_base.id', ondelete="CASCADE"), nullable=True) - name = Column(String) + name = Column(Text) created_time = Column( TIMESTAMP(timezone=True), nullable=True, @@ -285,20 +285,20 @@ class DocumentEntity(Base): team_id = Column(UUID) # 文档所属团队id kb_id = Column(UUID, ForeignKey( 'knowledge_base.id', ondelete="CASCADE")) # 文档所属资产id - author_id = Column(String) # 文档作者id - author_name = Column(String) # 文档作者名称 - name = Column(String) # 文档名 - extension = Column(String) # 文件后缀 + author_id = Column(Text) # 文档作者id + author_name = Column(Text) # 文档作者名称 + name = Column(Text) # 文档名 + extension = Column(Text) # 文件后缀 size = Column(BigInteger) # 文档大小 - parse_method = Column(String, default=ParseMethod.GENERAL.value) # 文档解析方法 + parse_method = Column(Text, default=ParseMethod.GENERAL.value) # 文档解析方法 parse_relut_topology = Column( - String, default=DocParseRelutTopology.LIST.value) # 文档解析结果拓扑结构 + Text, default=DocParseRelutTopology.LIST.value) # 文档解析结果拓扑结构 chunk_size = Column(BigInteger) # 文档分块大小 type_id = Column(UUID) # 文档类别 enabled = Column(Boolean) # 文档是否启用 - status = Column(String, default=DocumentStatus.IDLE.value) # 文档状态 - full_text = Column(String) # 文档全文 - abstract = Column(String) # 文档摘要 + status = Column(Text, default=DocumentStatus.IDLE.value) # 文档状态 + full_text = Column(Text) # 文档全文 + abstract = Column(Text) # 文档摘要 abstract_ts_vector = Column(TSVECTOR) # 文档摘要词向量 abstract_vector = Column(Vector(1024)) # 文档摘要向量 created_time = Column( @@ -318,9 +318,10 @@ class DocumentEntity(Base): 'abstract_vector_index', abstract_vector, postgresql_using='hnsw', - postgresql_with={'m': 16, 'ef_construction': 200}, + postgresql_with={'m': 32, 'ef_construction': 200}, postgresql_ops={'abstract_vector': 'vector_cosine_ops'} ), + Index('abstract_bm25_index', abstract, postgresql_using='bm25') ) @@ -332,21 +333,21 @@ class ChunkEntity(Base): kb_id = Column(UUID) # 知识库id doc_id = Column(UUID, ForeignKey( 'document.id', ondelete="CASCADE")) # 片段所属文档id - doc_name = Column(String) # 片段所属文档名称 - text = Column(String) # 片段文本内容 + doc_name = Column(Text) # 片段所属文档名称 + text = Column(Text) # 片段文本内容 text_ts_vector = Column(TSVECTOR) # 片段文本词向量 text_vector = Column(Vector(1024)) # 文本向量 tokens = Column(BigInteger) # 片段文本token数 - type = Column(String, default=ChunkType.TEXT.value) # 片段类型 + type = Column(Text, default=ChunkType.TEXT.value) # 片段类型 # 前一个chunk的id(假如解析结果为链表,那么这里是前一个节点的id,如果文档解析结果为树,那么这里是父节点的id) pre_id_in_parse_topology = Column(UUID) # chunk的在解析结果中的拓扑类型(假如解析结果为链表,那么这里为链表头、中间和尾;假如解析结果为树,那么这里为树根、树的中间节点和叶子节点) parse_topology_type = Column( - String, default=ChunkParseTopology.LISTHEAD.value) + Text, default=ChunkParseTopology.LISTHEAD.value) global_offset = Column(BigInteger) # chunk在文档中的相对偏移 local_offset = Column(BigInteger) # chunk在块中的相对偏移 enabled = Column(Boolean) # chunk是否启用 - status = Column(String, default=ChunkStatus.EXISTED.value) # chunk状态 + status = Column(Text, default=ChunkStatus.EXISTED.value) # chunk状态 created_time = Column( TIMESTAMP(timezone=True), nullable=True, @@ -362,9 +363,10 @@ class ChunkEntity(Base): 'text_vector_index', text_vector, postgresql_using='hnsw', - postgresql_with={'m': 16, 'ef_construction': 200}, + postgresql_with={'m': 32, 'ef_construction': 200}, postgresql_ops={'text_vector': 'vector_cosine_ops'} ), + Index('text_bm25_index', text, postgresql_using='bm25') ) @@ -374,8 +376,8 @@ class ImageEntity(Base): team_id = Column(UUID) # 团队id doc_id = Column(UUID) # 图片所属文档id chunk_id = Column(UUID) # 图片所属chunk的id - extension = Column(String) # 图片后缀 - status = Column(String, default=ImageStatus.EXISTED.value) # 图片状态 + extension = Column(Text) # 图片后缀 + status = Column(Text, default=ImageStatus.EXISTED.value) # 图片状态 created_time = Column( TIMESTAMP(timezone=True), nullable=True, @@ -395,16 +397,16 @@ class DataSetEntity(Base): team_id = Column(UUID) # 数据集所属团队id kb_id = Column(UUID, ForeignKey('knowledge_base.id', ondelete="CASCADE")) # 数据集所属资产id - author_id = Column(String) # 数据的创建者id - author_name = Column(String) # 数据的创建者名称 - llm_id = Column(String) # 数据的生成使用的大模型的id - name = Column(String, nullable=False) # 数据集名称 - description = Column(String) # 数据集描述 + author_id = Column(Text) # 数据的创建者id + author_name = Column(Text) # 数据的创建者名称 + llm_id = Column(Text) # 数据的生成使用的大模型的id + name = Column(Text, nullable=False) # 数据集名称 + description = Column(Text) # 数据集描述 data_cnt = Column(BigInteger) # 数据集数据量 is_data_cleared = Column(Boolean, default=False) # 数据集是否清洗 is_chunk_related = Column(Boolean, default=False) # 数据集是否关联上下文 is_imported = Column(Boolean, default=False) # 数据集是否导入 - status = Column(String, default=DataSetStatus.IDLE) # 数据集状态 + status = Column(Text, default=DataSetStatus.IDLE) # 数据集状态 score = Column(Float, default=-1) # 数据集得分 created_at = Column( TIMESTAMP(timezone=True), @@ -444,12 +446,12 @@ class QAEntity(Base): dataset_id = Column(UUID, ForeignKey( 'dataset.id', ondelete="CASCADE")) # 数据所属数据集id doc_id = Column(UUID) # 数据关联的文档id - doc_name = Column(String, default="未知文档") # 数据关联的文档名称 - question = Column(String) # 数据的问题 - answer = Column(String) # 数据的答案 - chunk = Column(String) # 数据的片段 - chunk_type = Column(String, default="未知片段类型") # 数据的片段类型 - status = Column(String, default=QAStatus.EXISTED.value) # 数据的状态 + doc_name = Column(Text, default="未知文档") # 数据关联的文档名称 + question = Column(Text) # 数据的问题 + answer = Column(Text) # 数据的答案 + chunk = Column(Text) # 数据的片段 + chunk_type = Column(Text, default="未知片段类型") # 数据的片段类型 + status = Column(Text, default=QAStatus.EXISTED.value) # 数据的状态 created_at = Column( TIMESTAMP(timezone=True), nullable=True, @@ -470,15 +472,15 @@ class TestingEntity(Base): kb_id = Column(UUID) # 测试任务所属资产id dataset_id = Column(UUID, ForeignKey( 'dataset.id', ondelete="CASCADE")) # 测试任务使用数据集的id - author_id = Column(String) # 测试任务的创建者id - author_name = Column(String) # 测试任务的创建者名称 - name = Column(String) # 测试任务的名称 - description = Column(String) # 测试任务的描述 - llm_id = Column(String) # 测试任务的使用的大模型 + author_id = Column(Text) # 测试任务的创建者id + author_name = Column(Text) # 测试任务的创建者名称 + name = Column(Text) # 测试任务的名称 + description = Column(Text) # 测试任务的描述 + llm_id = Column(Text) # 测试任务的使用的大模型 search_method = Column( - String, default=SearchMethod.KEYWORD_AND_VECTOR.value) # 测试任务的使用的检索增强模式类型 + Text, default=SearchMethod.KEYWORD_AND_VECTOR.value) # 测试任务的使用的检索增强模式类型 top_k = Column(BigInteger, default=5) # 测试任务的检索增强模式的top_k - status = Column(String, default=TestingStatus.IDLE.value) # 测试任务的状态 + status = Column(Text, default=TestingStatus.IDLE.value) # 测试任务的状态 ave_score = Column(Float, default=-1) # 测试任务的综合得分 ave_pre = Column(Float, default=-1) # 测试任务的平均召回率 ave_rec = Column(Float, default=-1) # 测试任务的平均精确率 @@ -505,12 +507,12 @@ class TestCaseEntity(Base): id = Column(UUID, default=uuid4, primary_key=True) # 测试case的id testing_id = Column(UUID, ForeignKey( 'testing.id', ondelete="CASCADE")) # 测试 - question = Column(String) # 数据的问题 - answer = Column(String) # 数据的答案 - chunk = Column(String) # 数据的片段 - llm_answer = Column(String) # 测试答案 - related_chunk = Column(String) # 测试关联到的chunk - doc_name = Column(String) # 测试关联的文档名称 + question = Column(Text) # 数据的问题 + answer = Column(Text) # 数据的答案 + chunk = Column(Text) # 数据的片段 + llm_answer = Column(Text) # 测试答案 + related_chunk = Column(Text) # 测试关联到的chunk + doc_name = Column(Text) # 测试关联的文档名称 score = Column(Float) # 测试得分 pre = Column(Float) # 召回率 rec = Column(Float) # 精确率 @@ -519,7 +521,7 @@ class TestCaseEntity(Base): lcs = Column(Float) # 最长公共子序列得分 leve = Column(Float) # 编辑距离得分 jac = Column(Float) # 杰卡德相似系数 - status = Column(String, default=TestCaseStatus.EXISTED.value) # 测试状态 + status = Column(Text, default=TestCaseStatus.EXISTED.value) # 测试状态 created_at = Column( TIMESTAMP(timezone=True), nullable=True, @@ -537,13 +539,13 @@ class TaskEntity(Base): id = Column(UUID, default=uuid4, primary_key=True) team_id = Column(UUID) # 团队id - user_id = Column(String, ForeignKey( + user_id = Column(Text, ForeignKey( 'users.id', ondelete="CASCADE")) # 创建者id op_id = Column(UUID) # 任务关联的实体id, 资产或者文档id - op_name = Column(String) # 任务关联的实体名称 - type = Column(String) # 任务类型 + op_name = Column(Text) # 任务关联的实体名称 + type = Column(Text) # 任务类型 retry = Column(BigInteger) # 重试次数 - status = Column(String) # 任务状态 + status = Column(Text) # 任务状态 created_time = Column( TIMESTAMP(timezone=True), nullable=True, @@ -561,7 +563,7 @@ class TaskReportEntity(Base): id = Column(UUID, default=uuid4, primary_key=True) # 任务报告的id task_id = Column(UUID, ForeignKey('task.id', ondelete="CASCADE")) # 任务id - message = Column(String) # 任务报告信息 + message = Column(Text) # 任务报告信息 current_stage = Column(BigInteger) # 任务当前阶段 stage_cnt = Column(BigInteger) # 任务总的阶段 created_time = Column( @@ -580,7 +582,7 @@ class TaskQueueEntity(Base): __tablename__ = 'task_queue' id = Column(UUID, default=uuid4, primary_key=True) # 任务ID - status = Column(String) # 任务状态 + status = Column(Text) # 任务状态 created_time = Column( TIMESTAMP(timezone=True), nullable=True, @@ -609,7 +611,7 @@ class DataBase: pool_size = 5 engine = create_async_engine( database_url, - echo=False, + echo=True, pool_recycle=300, pool_pre_ping=True, pool_size=pool_size -- Gitee From 93074e4f905f66c887106b3398d2d40a6a26fe30 Mon Sep 17 00:00:00 2001 From: zxstty Date: Sun, 2 Nov 2025 23:34:34 +0800 Subject: [PATCH 3/4] =?UTF-8?q?=E5=AE=8C=E5=96=84=20post=20/chunk/search?= =?UTF-8?q?=E6=8E=A5=E5=8F=A3=EF=BC=8C=E9=80=82=E9=85=8D=E6=B5=8B=E8=AF=95?= =?UTF-8?q?=E6=A8=A1=E5=BC=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- data_chain/apps/service/chunk_service.py | 175 +++++++++++++++-------- data_chain/entities/request_data.py | 2 + data_chain/entities/response_data.py | 10 ++ data_chain/parser/tools/token_tool.py | 19 +++ 4 files changed, 144 insertions(+), 62 deletions(-) diff --git a/data_chain/apps/service/chunk_service.py b/data_chain/apps/service/chunk_service.py index 1d0fee5a..1a64a2eb 100644 --- a/data_chain/apps/service/chunk_service.py +++ b/data_chain/apps/service/chunk_service.py @@ -1,4 +1,5 @@ # Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. +import time import asyncio from fastapi import APIRouter, Depends, Query, Body, File, UploadFile import uuid @@ -39,7 +40,9 @@ from data_chain.embedding.embedding import Embedding class ChunkService: + """Chunk Service""" + @staticmethod async def validate_user_action_to_chunk(user_sub: str, chunk_id: uuid.UUID, action: str) -> bool: """验证用户对分片的操作权限""" try: @@ -58,6 +61,7 @@ class ChunkService: logging.exception("[ChunkService] %s", err) raise e + @staticmethod async def list_chunks_by_document_id(req: ListChunkRequest) -> ListChunkMsg: """根据文档ID列出分片""" try: @@ -75,17 +79,9 @@ class ChunkService: logging.exception("[ChunkService] %s", err) raise e + @staticmethod async def search_chunks_from_kb(user_sub: str, action: str, search_method: str, kb_id: uuid.UUID, query: str, top_k: int, doc_ids: list[uuid.UUID] = None, banned_ids: list[uuid.UUID] = [], is_rerank: bool = False) -> list[ChunkEntity]: """从知识库搜索分片""" - kb_entity = await KnowledgeBaseManager.get_knowledge_base_by_kb_id(kb_id) - if kb_entity is None: - err = f"知识库不存在,知识库ID: {kb_id}" - logging.error("[ChunkService] %s", err) - return [] - if kb_id != DEFAULT_KNOWLEDGE_BASE_ID and not await KnowledgeBaseService.validate_user_action_to_knowledge_base(user_sub, kb_id, action): - err = f"用户没有权限访问该知识库,知识库ID: {kb_id}" - logging.error("[ChunkService] %s", err) - return [] top_k_search = top_k if is_rerank: top_k_search = top_k * 3 @@ -95,88 +91,143 @@ class ChunkService: err = f"搜索分片失败,error: {e}" logging.exception("[ChunkService] %s", err) return [] - if is_rerank: - chunk_indexs = await BaseSearcher.rerank(chunk_entities, kb_entity.rerank_method, query) - chunk_entities = [chunk_entities[i] for i in chunk_indexs] - chunk_entities = chunk_entities[:top_k] return chunk_entities + @staticmethod + async def rerank_chunks(chunk_entities: list[ChunkEntity], rerank_method: str, query: str, top_k: int) -> list[ChunkEntity]: + """对分片进行重排序""" + chunk_indexs = await BaseSearcher.rerank(chunk_entities, rerank_method, query) + chunk_entities = [chunk_entities[i] for i in chunk_indexs] + chunk_entities = chunk_entities[:top_k] + return chunk_entities + # 关联上下文 + + @staticmethod + async def relate_surrounding_chunks(chunk_entities: list[ChunkEntity], tokens_limit: int) -> list[ChunkEntity]: + """关联上下文到搜索分片结果中""" + chunk_ids = [chunk_entity.id for chunk_entity in chunk_entities] + tokens_limit_every_chunk = tokens_limit // len( + chunk_entities) if len(chunk_entities) > 0 else tokens_limit + leave_tokens = 0 + related_chunk_entities = [] + token_sum = 0 + for chunk_entity in chunk_entities: + token_sum += chunk_entity.tokens + for chunk_entity in chunk_entities: + leave_tokens = tokens_limit_every_chunk+leave_tokens + try: + sub_related_chunk_entities = await BaseSearcher.related_surround_chunk(chunk_entity, leave_tokens-chunk_entity.tokens, chunk_ids) + except Exception as e: + leave_tokens += tokens_limit_every_chunk + err = f"[ChunkService] 关联上下文失败,error: {e}" + logging.exception(err) + continue + for related_chunk_entity in sub_related_chunk_entities: + token_sum += related_chunk_entity.tokens + leave_tokens -= related_chunk_entity.tokens + if leave_tokens < 0: + leave_tokens = 0 + chunk_ids += [chunk_entity.id for chunk_entity in sub_related_chunk_entities] + related_chunk_entities += sub_related_chunk_entities + if token_sum >= tokens_limit: + break + return related_chunk_entities + # 补全文档信息 + + @staticmethod + async def enrich_doc_info_to_search_chunks(search_chunk_msg: SearchChunkMsg) -> None: + """补全文档信息到搜索分片结果中""" + doc_entities = await DocumentManager.list_document_by_doc_ids( + [doc_chunk.doc_id for doc_chunk in search_chunk_msg.doc_chunks]) + doc_map = {doc_entity.id: doc_entity for doc_entity in doc_entities} + for doc_chunk in search_chunk_msg.doc_chunks: + doc_entity = doc_map.get(doc_chunk.doc_id) + doc_chunk.doc_author = doc_entity.author_name if doc_entity else "" + doc_chunk.doc_created_at = doc_entity.created_time.strftime( + '%Y-%m-%d %H:%M') if doc_entity else "" + doc_chunk.doc_abstract = doc_entity.abstract if doc_entity else "" + doc_chunk.doc_extension = doc_entity.extension if doc_entity else "" + doc_chunk.doc_size = doc_entity.size if doc_entity else 0 + + @staticmethod async def search_chunks(user_sub: str, action: str, req: SearchChunkRequest) -> SearchChunkMsg: """根据查询条件搜索分片""" + search_chunk_msg = SearchChunkMsg(docChunks=[]) + kb_ids_after_validate = [] + for kb_id in req.kb_ids: + if kb_id == DEFAULT_KNOWLEDGE_BASE_ID or await KnowledgeBaseService.validate_user_action_to_knowledge_base(user_sub, kb_id, action): + kb_ids_after_validate.append(kb_id) + else: + logging.error( + "[ChunkService] 用户没有权限访问该知识库,知识库ID: %s", str(kb_id)) + req.kb_ids = kb_ids_after_validate logging.error("[ChunkService] 搜索分片,查询条件: %s", req) chunk_entities = [] search_tasks = [] + st = time.time() for kb_id in req.kb_ids: search_task = ChunkService.search_chunks_from_kb( user_sub, action, req.search_method, kb_id, req.query, req.top_k, req.doc_ids, req.banned_ids, req.is_rerank) search_tasks.append(search_task) search_results = await asyncio.gather(*search_tasks) - for search_result in search_results: - chunk_entities += search_result + en = time.time() + if req.is_testing_scene: + search_chunk_msg.t_used_in_search = round(en - st, 3) + if req.is_rerank: + st = time.time() + for i in range(len(req.kb_ids)): + kb_entity = await KnowledgeBaseManager.get_knowledge_base_by_kb_id(req.kb_ids[i]) + search_results[i] = await ChunkService.rerank_chunks( + search_results[i], kb_entity.rerank_method, req.query, req.top_k) + chunk_entities += search_results[i] + en = time.time() + if req.is_testing_scene: + search_chunk_msg.t_used_in_rerank = round(en - st, 3) if len(chunk_entities) == 0: return SearchChunkMsg(docChunks=[]) if req.is_rerank: + st = time.time() chunk_indexs = await BaseSearcher.rerank(chunk_entities, None, req.query) chunk_entities = [chunk_entities[i] for i in chunk_indexs] + en = time.time() + if req.is_testing_scene: + search_chunk_msg.t_used_in_rerank = round( + (search_chunk_msg.t_used_in_rerank or 0) + (en - st), 3) chunk_entities = chunk_entities[:req.top_k] - chunk_ids = [chunk_entity.id for chunk_entity in chunk_entities] logging.error("[ChunkService] 搜索分片,查询结果数量: %s", len(chunk_entities)) if req.is_related_surrounding: # 关联上下文 - tokens_limit = req.tokens_limit - tokens_limit_every_chunk = tokens_limit // len( - chunk_entities) if len(chunk_entities) > 0 else tokens_limit - leave_tokens = 0 - related_chunk_entities = [] - token_sum = 0 - for chunk_entity in chunk_entities: - token_sum += chunk_entity.tokens - for chunk_entity in chunk_entities: - leave_tokens = tokens_limit_every_chunk+leave_tokens - try: - sub_related_chunk_entities = await BaseSearcher.related_surround_chunk(chunk_entity, leave_tokens-chunk_entity.tokens, chunk_ids) - except Exception as e: - leave_tokens += tokens_limit_every_chunk - err = f"[ChunkService] 关联上下文失败,error: {e}" - logging.exception(err) - continue - for related_chunk_entity in sub_related_chunk_entities: - token_sum += related_chunk_entity.tokens - leave_tokens -= related_chunk_entity.tokens - if leave_tokens < 0: - leave_tokens = 0 - chunk_ids += [chunk_entity.id for chunk_entity in sub_related_chunk_entities] - related_chunk_entities += sub_related_chunk_entities - if token_sum >= tokens_limit: - break - chunk_entities += related_chunk_entities - search_chunk_msg = SearchChunkMsg(docChunks=[]) + st = time.time() + chunk_entities_related = await ChunkService.relate_surrounding_chunks( + chunk_entities, req.tokens_limit) + chunk_entities += chunk_entities_related + en = time.time() + if req.is_testing_scene: + search_chunk_msg.t_used_in_surrounding_text_relation = round( + en - st, 3) if req.is_classify_by_doc: doc_chunks = await BaseSearcher.classify_by_doc_id(chunk_entities) - for doc_chunk in doc_chunks: - for chunk in doc_chunk.chunks: - if req.is_compress: - chunk.text = TokenTool.compress_tokens(chunk.text) - search_chunk_msg.doc_chunks.append(doc_chunk) + search_chunk_msg.doc_chunks = doc_chunks else: for chunk_entity in chunk_entities: chunk = await Convertor.convert_chunk_entity_to_chunk(chunk_entity) - if req.is_compress: - chunk.text = TokenTool.compress_tokens(chunk.text) dc = DocChunk(docId=chunk_entity.doc_id, docName=chunk_entity.doc_name, chunks=[chunk]) search_chunk_msg.doc_chunks.append(dc) - doc_entities = await DocumentManager.list_document_by_doc_ids( - [doc_chunk.doc_id for doc_chunk in search_chunk_msg.doc_chunks]) - doc_map = {doc_entity.id: doc_entity for doc_entity in doc_entities} - for doc_chunk in search_chunk_msg.doc_chunks: - doc_entity = doc_map.get(doc_chunk.doc_id) - doc_chunk.doc_author = doc_entity.author_name if doc_entity else "" - doc_chunk.doc_created_at = doc_entity.created_time.strftime( - '%Y-%m-%d %H:%M') if doc_entity else "" - doc_chunk.doc_abstract = doc_entity.abstract if doc_entity else "" - doc_chunk.doc_extension = doc_entity.extension if doc_entity else "" - doc_chunk.doc_size = doc_entity.size if doc_entity else 0 + if req.is_compress: + st = time.time() + for doc_chunk in search_chunk_msg.doc_chunks: + for chunk in doc_chunk.chunks: + chunk.text = await TokenTool.compress_tokens(chunk.text) + en = time.time() + if req.is_testing_scene: + search_chunk_msg.t_used_in_text_compression = round( + en - st, 3) + if req.is_testing_scene: + for doc_chunk in search_chunk_msg.doc_chunks: + for chunk in doc_chunk.chunks: + chunk.score = await TokenTool.cal_jac(req.query, chunk.text) + await ChunkService.enrich_doc_info_to_search_chunks(search_chunk_msg) return search_chunk_msg async def update_chunk_by_id(chunk_id: uuid.UUID, req: UpdateChunkRequest) -> uuid.UUID: diff --git a/data_chain/entities/request_data.py b/data_chain/entities/request_data.py index 3f1c1299..67a40ee8 100644 --- a/data_chain/entities/request_data.py +++ b/data_chain/entities/request_data.py @@ -263,6 +263,8 @@ class SearchChunkRequest(BaseModel): default=False, description="是否压缩", alias="isCompress") tokens_limit: int = Field( default=8192, description="token限制", alias="tokensLimit") + is_testing_scene: Optional[bool] = Field( + default=False, description="是否是评测场景", alias="isTestingScene") class ListDatasetRequest(BaseModel): diff --git a/data_chain/entities/response_data.py b/data_chain/entities/response_data.py index cb636eed..f8cdffe5 100644 --- a/data_chain/entities/response_data.py +++ b/data_chain/entities/response_data.py @@ -349,6 +349,8 @@ class Chunk(BaseModel): chunk_type: ChunkType = Field(description="分片类型", alias="chunkType") text: str = Field(description="分片文本") enabled: bool = Field(description="分片是否启用", alias="enabled") + score: Optional[float] = Field( + default=None, description="分片得分", alias="score") class ListChunkMsg(BaseModel): @@ -392,6 +394,14 @@ class SearchChunkMsg(BaseModel): """Post /chunk/search 数据结构""" doc_chunks: list[DocChunk] = Field( default=[], description="文档分片列表", alias="docChunks") + t_used_in_search: Optional[float] = Field( + default=None, description="搜索中使用的时间", alias="tUsedInSearch") + t_used_in_rerank: Optional[float] = Field( + default=None, description="重排序中使用的时间", alias="tUsedInRerank") + t_used_in_surrounding_text_relation: Optional[float] = Field( + default=None, description="上下文关系中使用的时间", alias="tUsedInSurroundingTextRelation") + t_used_in_text_compression: Optional[float] = Field( + default=None, description="文本压缩中使用的时间", alias="tUsedInTextCompression") class SearchChunkResponse(ResponseData): diff --git a/data_chain/parser/tools/token_tool.py b/data_chain/parser/tools/token_tool.py index 1d4765c5..b71ef981 100644 --- a/data_chain/parser/tools/token_tool.py +++ b/data_chain/parser/tools/token_tool.py @@ -447,6 +447,25 @@ class TokenTool: # 计算余弦距离 cosine_dist = 1 - cosine_similarity return cosine_dist + # 通过向量计算语义相似度 + + @staticmethod + async def cal_semantic_similarity(content1: str, content2: str) -> float: + """ + 计算语义相似度 + 参数: + content1:内容1 + content2:内容2 + """ + try: + vector1 = await Embedding.vectorize_embedding(content1) + vector2 = await Embedding.vectorize_embedding(content2) + similarity = TokenTool.cosine_distance_numpy(vector1, vector2) + return (1 - similarity) * 100 + except Exception as e: + err = f"[TokenTool] 计算语义相似度失败 {e}" + logging.exception("[TokenTool] %s", err) + return -1 @staticmethod async def cal_relevance(question: str, answer: str, llm: LLM, language: str) -> float: -- Gitee From 9b109fc44e6410f2b45b5611aad86bf579c2b197 Mon Sep 17 00:00:00 2001 From: zxstty Date: Tue, 4 Nov 2025 11:26:44 +0800 Subject: [PATCH 4/4] fix bug --- data_chain/apps/app.py | 20 +- data_chain/apps/service/chunk_service.py | 4 + data_chain/apps/service/document_service.py | 22 ++- data_chain/apps/service/task_queue_service.py | 31 ++- data_chain/manager/chunk_manager.py | 90 +++------ data_chain/manager/document_manager.py | 97 +++------- data_chain/manager/task_queue_mamanger.py | 27 ++- data_chain/rag/doc2chunk_searcher.py | 4 +- ...ic_weighted_keyword_and_vector_searcher.py | 5 +- .../rag/dynamic_weighted_keyword_searcher.py | 3 +- data_chain/rag/enhanced_by_llm_searcher.py | 3 +- data_chain/stores/database/database.py | 181 +++++++++++++++--- 12 files changed, 305 insertions(+), 182 deletions(-) diff --git a/data_chain/apps/app.py b/data_chain/apps/app.py index d072df7d..f8cd6fa6 100644 --- a/data_chain/apps/app.py +++ b/data_chain/apps/app.py @@ -103,15 +103,33 @@ app.add_exception_handler(HTTPException, http_exception_handler) app.add_exception_handler(StarletteHTTPException, starlette_http_exception_handler) app.add_exception_handler(Exception, general_exception_handler) - +import time @app.on_event("startup") async def startup_event(): + st= time.time() await configure() + en = time.time() + logging.info(f"[App] configure completed, time used: {en - st:.2f} seconds") + st= time.time() await add_acitons() + en = time.time() + logging.info(f"[App] add_actions completed, time used: {en - st:.2f} seconds") + st= time.time() await TaskQueueService.init_task_queue() + en = time.time() + logging.info(f"[App] init_task_queue completed, time used: {en - st:.2f} seconds") + st= time.time() await add_knowledge_base() + en = time.time() + logging.info(f"[App] add_knowledge_base completed, time used: {en - st:.2f} seconds") + st= time.time() await add_document_type() + en = time.time() + logging.info(f"[App] add_document_type completed, time used: {en - st:.2f} seconds") + st= time.time() await init_path() + en = time.time() + logging.info(f"[App] init_path completed, time used: {en - st:.2f} seconds") scheduler.add_job(TaskQueueService.handle_tasks, 'interval', seconds=5) scheduler.start() diff --git a/data_chain/apps/service/chunk_service.py b/data_chain/apps/service/chunk_service.py index 1a64a2eb..300843b4 100644 --- a/data_chain/apps/service/chunk_service.py +++ b/data_chain/apps/service/chunk_service.py @@ -183,6 +183,9 @@ class ChunkService: en = time.time() if req.is_testing_scene: search_chunk_msg.t_used_in_rerank = round(en - st, 3) + else: + for result in search_results: + chunk_entities += result if len(chunk_entities) == 0: return SearchChunkMsg(docChunks=[]) if req.is_rerank: @@ -228,6 +231,7 @@ class ChunkService: for chunk in doc_chunk.chunks: chunk.score = await TokenTool.cal_jac(req.query, chunk.text) await ChunkService.enrich_doc_info_to_search_chunks(search_chunk_msg) + logging.error("f{search_chunk_msg}") return search_chunk_msg async def update_chunk_by_id(chunk_id: uuid.UUID, req: UpdateChunkRequest) -> uuid.UUID: diff --git a/data_chain/apps/service/document_service.py b/data_chain/apps/service/document_service.py index 7ec05556..8f95f190 100644 --- a/data_chain/apps/service/document_service.py +++ b/data_chain/apps/service/document_service.py @@ -1,4 +1,5 @@ # Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. +import time import aiofiles from fastapi import APIRouter, Depends, Query, Body, File, UploadFile import uuid @@ -67,7 +68,8 @@ class DocumentService: doc_ids = [doc_entity.id for doc_entity in doc_entities] task_entities = await TaskManager.list_current_tasks_by_op_ids(doc_ids) task_ids = [task_entity.id for task_entity in task_entities] - task_dict = {task_entity.op_id: task_entity for task_entity in task_entities} + task_dict = { + task_entity.op_id: task_entity for task_entity in task_entities} task_report_entities = await TaskReportManager.list_current_task_report_by_task_ids(task_ids) task_report_dict = {task_report_entity.task_id: task_report_entity for task_report_entity in task_report_entities} @@ -82,7 +84,8 @@ class DocumentService: task = await Convertor.convert_task_entity_to_task(task_entity, task_report) document.parse_task = task documents.append(document) - list_document_msg = ListDocumentMsg(total=total, documents=documents) + list_document_msg = ListDocumentMsg( + total=total, documents=documents) return list_document_msg except Exception as e: err = "列出文档失败" @@ -166,7 +169,8 @@ class DocumentService: continue doc_ids.append(doc_entity.id) task_entities = await TaskManager.list_current_tasks_by_op_ids(doc_ids) - task_dict = {task_entity.op_id: task_entity for task_entity in task_entities} + task_dict = { + task_entity.op_id: task_entity for task_entity in task_entities} doc_status_list = [] for doc_id in doc_ids: task_entity = task_dict.get(doc_id, None) @@ -313,6 +317,7 @@ class DocumentService: err = f"上传文档数量或大小超过限制, 知识库ID: {kb_id}, 上传文档数量: {doc_cnt}, 上传文档大小: {doc_sz}" logging.error("[DocumentService] %s", err) raise ValueError(err) + st = time.time() doc_entities = [] for doc in docs: try: @@ -356,7 +361,9 @@ class DocumentService: err = f"上传文档失败, 文档名: {doc.filename}, 错误信息: {e}" logging.error("[DocumentService] %s", err) continue + logging.error("[DocumentService] 上传文档总耗时: %.2f 秒", time.time() - st) index = 0 + st = time.time() while index < len(doc_entities): try: await DocumentManager.add_documents(doc_entities[index:index+1024]) @@ -365,10 +372,16 @@ class DocumentService: err = f"上传文档失败, 文档名: {doc_entity.name}, 错误信息: {e}" logging.error("[DocumentService] %s", err) continue + logging.error("[DocumentService] 入库文档总耗时: %.2f 秒", time.time() - st) + st = time.time() for doc_entity in doc_entities: await TaskQueueService.init_task(TaskType.DOC_PARSE.value, doc_entity.id) + logging.error("[DocumentService] 初始化任务总耗时: %.2f 秒", time.time() - st) doc_ids = [doc_entity.id for doc_entity in doc_entities] + st = time.time() await KnowledgeBaseManager.update_doc_cnt_and_doc_size(kb_id=kb_entity.id) + logging.error( + "[DocumentService] 更新知识库文档数量和大小总耗时: %.2f 秒", time.time() - st) return doc_ids @staticmethod @@ -455,7 +468,8 @@ class DocumentService: doc_entities = await DocumentManager.update_document_by_doc_ids( doc_ids, {"status": DocumentStatus.DELETED.value}) doc_ids = [doc_entity.id for doc_entity in doc_entities] - kb_ids = [doc_entity.kb_id for doc_entity in doc_entities if doc_entity.kb_id is not None] + kb_ids = [ + doc_entity.kb_id for doc_entity in doc_entities if doc_entity.kb_id is not None] kb_ids = list(set(kb_ids)) for kb_id in kb_ids: await KnowledgeBaseManager.update_doc_cnt_and_doc_size(kb_id=kb_id) diff --git a/data_chain/apps/service/task_queue_service.py b/data_chain/apps/service/task_queue_service.py index 042d0c3d..259cbac6 100644 --- a/data_chain/apps/service/task_queue_service.py +++ b/data_chain/apps/service/task_queue_service.py @@ -16,23 +16,45 @@ class TaskQueueService: @staticmethod async def init_task_queue(): + import time + st = time.time() task_entities = await TaskManager.list_task_by_task_status(TaskStatus.PENDING.value) + en = time.time() + logging.info(f"[TaskQueueService] 获取待处理任务耗时 {en-st} 秒") + st = time.time() task_entities += await TaskManager.list_task_by_task_status(TaskStatus.RUNNING.value) + en = time.time() + logging.info(f"[TaskQueueService] 获取运行中任务耗时 {en-st} 秒") for task_entity in task_entities: + # 将所有任务取消 + try: if task_entity.status == TaskStatus.RUNNING.value: + st = time.time() flag = await BaseWorker.reinit(task_entity.id) + en = time.time() + logging.info(f"[TaskQueueService] 重新初始化任务耗时 {en-st} 秒") if flag: - task = TaskQueueEntity(id=task_entity.id, status=TaskStatus.PENDING.value) - await TaskQueueManager.update_task_by_id(task_entity.id, task) + st = time.time() + await TaskQueueManager.update_task_by_id(task_entity.id, TaskStatus.PENDING) + en = time.time() else: + st = time.time() await BaseWorker.stop(task_entity.id) await TaskQueueManager.delete_task_by_id(task_entity.id) + en = time.time() else: + st = time.time() task = await TaskQueueManager.get_task_by_id(task_entity.id) + en = time.time() + logging.info(f"[TaskQueueService] 获取任务耗时 {en-st} 秒") if task is None: - task = TaskQueueEntity(id=task_entity.id, status=TaskStatus.PENDING.value) + st = time.time() + task = TaskQueueEntity( + id=task_entity.id, status=TaskStatus.PENDING.value) await TaskQueueManager.add_task(task) + en = time.time() + logging.info(f"[TaskQueueService] 添加任务耗时 {en-st} 秒") except Exception as e: warning = f"[TaskQueueService] 初始化任务失败 {e}" logging.warning(warning) @@ -103,8 +125,7 @@ class TaskQueueService: await TaskQueueManager.delete_task_by_id(task.id) continue if flag: - task.status = TaskStatus.PENDING.value - await TaskQueueManager.update_task_by_id(task.id, task) + await TaskQueueManager.update_task_by_id(task.id, TaskStatus.PENDING) else: await TaskQueueManager.delete_task_by_id(task.id) diff --git a/data_chain/manager/chunk_manager.py b/data_chain/manager/chunk_manager.py index 8cbb4f4a..7f83ee4d 100644 --- a/data_chain/manager/chunk_manager.py +++ b/data_chain/manager/chunk_manager.py @@ -1,5 +1,5 @@ # Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. -from sqlalchemy import select, update, func, text, or_, and_, Float, literal_column, true +from sqlalchemy import select, update, func, text, or_, and_, bindparam from typing import List, Tuple, Dict, Optional import uuid from datetime import datetime @@ -246,7 +246,8 @@ class ChunkManager(): # 方案1:临时关闭全表扫描(优先推荐,简单有效,会话级生效,不影响其他查询) # 执行SET命令:关闭全表扫描后,数据库会优先选择可用索引(text_vector_index) await session.execute(text("SET enable_seqscan = off;")) - + # 增加searcher数量,提升向量检索性能(可选) + await session.execute(text("SET hnsw_ef_search = 1000;")) # 方案2(备选):使用OpenGauss查询计划hints(需确保数据库开启hints支持) # 在SELECT后添加 /*+ IndexScan(chunk text_vector_index) */ 强制索引扫描 # 若用方案2,需将下面SELECT行改为:SELECT /*+ IndexScan(chunk text_vector_index) */ @@ -259,12 +260,12 @@ class ChunkManager(): chunk.pre_id_in_parse_topology, chunk.parse_topology_type, chunk.global_offset, chunk.local_offset, chunk.enabled, chunk.status, chunk.created_time, chunk.updated_time, - chunk.text_vector <#> :vector AS similarity_score + chunk.text_vector <=> :vector AS similarity_score FROM chunk JOIN document ON document.id = chunk.doc_id WHERE {where_clause} - AND (chunk.text_vector <#> :vector) IS NOT NULL - AND (chunk.text_vector <#> :vector) = (chunk.text_vector <#> :vector) + AND (chunk.text_vector <=> :vector) IS NOT NULL + AND (chunk.text_vector <=> :vector) = (chunk.text_vector <=> :vector) ORDER BY similarity_score ASC NULLS LAST LIMIT :limit """ @@ -392,55 +393,35 @@ class ChunkManager(): @staticmethod async def get_top_k_chunk_by_kb_id_dynamic_weighted_keyword( - kb_id: uuid.UUID, keywords: List[str], weights: List[float], + kb_id: uuid.UUID, query: str, # 关键词列表改为单查询文本 top_k: int, doc_ids: list[uuid.UUID] = None, banned_ids: list[uuid.UUID] = [], chunk_to_type: str = None, pre_ids: list[uuid.UUID] = None) -> List[ChunkEntity]: - """根据知识库ID和关键词权重查询文档解析结果(使用BM25打分,修复CTE定义顺序)""" + """根据知识库ID和查询文本查询文档解析结果(使用BM25直接打分)""" try: st = datetime.now() async with await DataBase.get_session() as session: - # 1. 构建加权关键词CTE(提前定义,确保后续能引用) - params = {} - values_clause = [] - for idx, (term, weight) in enumerate(zip(keywords, weights)): - params[f"term_{idx}"] = term - params[f"weight_{idx}"] = weight - values_clause.append( - f"(CAST(:term_{idx} AS TEXT), CAST(:weight_{idx} AS FLOAT8))") - values_text = f"(VALUES {', '.join(values_clause)}) AS t(term, weight)" - weighted_terms = ( - select( - literal_column("t.term").label("term"), - literal_column("t.weight").cast(Float).label("weight") - ) - .select_from(text(values_text)) - .cte("weighted_terms") # 定义为CTE,供后续查询使用 - ) + # 1. 构建查询文本参数(单文本,无需CTE列表) + params = {"query": query} - # 2. 初始化查询(此时weighted_terms已定义,可正常引用) + # 2. 初始化查询(直接使用查询文本计算BM25分数) + # 使用bindparam定义参数,避免混合使用占位符和美元符号引用 + query_param = bindparam("query") stmt = ( select( ChunkEntity, - func.sum( - # 使用BM25运算符<&>计算相似度,乘以权重后求和 - ChunkEntity.text.op('<&>')(weighted_terms.c.term) - * weighted_terms.c.weight - ).label("similarity_score") + # 计算查询文本与chunk的BM25分数 + ChunkEntity.text.op('<&>')(query_param).label("similarity_score") ) # 关联文档表 .join(DocumentEntity, DocumentEntity.id == ChunkEntity.doc_id) - # 关联加权关键词CTE(通过BM25匹配有分数的结果) - .join( - weighted_terms, - ChunkEntity.text.op('<&>')(weighted_terms.c.term) > 0, - isouter=False - ) # 基础过滤条件 .where(DocumentEntity.enabled == True) .where(DocumentEntity.status != DocumentStatus.DELETED.value) .where(ChunkEntity.kb_id == kb_id) .where(ChunkEntity.enabled == True) .where(ChunkEntity.status != ChunkStatus.DELETED.value) + # 过滤BM25分数大于0的结果(确保有相关性) + .where(ChunkEntity.text.op('<&>')(query_param) > 0) ) # 3. 动态条件:禁用ID @@ -451,52 +432,37 @@ class ChunkManager(): if doc_ids is not None: stmt = stmt.where(DocumentEntity.id.in_(doc_ids)) if chunk_to_type is not None: - stmt = stmt.where(ChunkEntity.parse_topology_type == chunk_to_type) + stmt = stmt.where( + ChunkEntity.parse_topology_type == chunk_to_type) if pre_ids is not None: - stmt = stmt.where(ChunkEntity.pre_id_in_parse_topology.in_(pre_ids)) + stmt = stmt.where( + ChunkEntity.pre_id_in_parse_topology.in_(pre_ids)) - # 5. 分组、过滤、排序、限制(使用BM25加权总分) + # 5. 排序、限制(直接使用BM25分数排序) stmt = (stmt - .group_by(ChunkEntity.id) - .having( - func.sum( - ChunkEntity.text.op('<&>')(weighted_terms.c.term) - * weighted_terms.c.weight - ) > 0 - ) .order_by( - func.sum( - ChunkEntity.text.op('<&>')(weighted_terms.c.term) - * weighted_terms.c.weight - ).desc() + ChunkEntity.text.op('<&>')(query_param).desc() ) .limit(top_k) ) - # 6. 强制使用BM25索引(关键修正:匹配实际索引名+正确Hint语法) - if config['DATABASE_TYPE'].lower() == 'opengauss': - # 实际BM25索引名:text_bm25_index;表名:chunk(从\d+结果确认) - table_name = ChunkEntity.__tablename__ # 若映射正确,应为'chunk' - bm25_index_name = "text_bm25_index" # 与\d+结果中的索引名完全一致 - stmt = stmt.prefix_with(f"/*+ INDEX({table_name} {bm25_index_name}) */") - - # 7. 执行查询与结果处理 + # 6. 执行查询与结果处理 result = await session.execute(stmt, params=params) chunk_entities = result.scalars().all() - # 8. 日志输出 + # 7. 日志输出 cost = (datetime.now() - st).total_seconds() logging.warning( f"[ChunkManager] BM25查询耗时: {cost}s " - f"| kb_id: {kb_id} | keywords: {keywords[:2]}... | 匹配数量: {len(chunk_entities)}" + f"| kb_id: {kb_id} | query: {query[:50]}... | 匹配数量: {len(chunk_entities)}" ) return chunk_entities except Exception as e: - err = f"BM25查询失败: kb_id={kb_id}, keywords={keywords[:2]}..., error={str(e)[:150]}" + err = f"BM25查询失败: kb_id={kb_id}, query={query[:50]}..., error={str(e)[:150]}" logging.exception("[ChunkManager] %s", err) return [] - + @staticmethod async def fetch_surrounding_chunk_by_doc_id_and_global_offset( doc_id: uuid.UUID, global_offset: int, diff --git a/data_chain/manager/document_manager.py b/data_chain/manager/document_manager.py index 1ab38966..205ca937 100644 --- a/data_chain/manager/document_manager.py +++ b/data_chain/manager/document_manager.py @@ -1,5 +1,5 @@ # Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. -from sqlalchemy import select, delete, update, func, between, asc, desc, and_, Float, literal_column, text +from sqlalchemy import select, delete, update, func, between, asc, desc, and_, Float, literal_column, text, bindparam from datetime import datetime, timezone import uuid from typing import Dict, List, Tuple @@ -105,15 +105,17 @@ class DocumentManager(): document.chunk_size, document.type_id, document.enabled, document.status, document.full_text, document.abstract, document.abstract_vector, document.created_time, document.updated_time, - document.abstract_vector <#> :vector AS similarity_score + document.abstract_vector <=> :vector AS similarity_score FROM document WHERE {where_clause} - AND (document.abstract_vector <#> :vector) IS NOT NULL - AND (document.abstract_vector <#> :vector) = (document.abstract_vector <#> :vector) + AND (document.abstract_vector <=> :vector) IS NOT NULL + AND (document.abstract_vector <=> :vector) = (document.abstract_vector <=> :vector) ORDER BY similarity_score ASC NULLS LAST LIMIT :limit """ + # 增加searcher数量,提升向量检索性能(可选) async with await DataBase.get_session() as session: + await session.execute(text("SET hnsw_ef_search = 1000;")) params["limit"] = top_k result = await session.execute(text(base_sql), params) rows = result.fetchall() @@ -208,107 +210,66 @@ class DocumentManager(): err = "获取前K个文档失败" logging.exception("[DocumentManager] %s", err) raise e + @staticmethod async def get_top_k_document_by_kb_id_dynamic_weighted_keyword( - kb_id: uuid.UUID, keywords: List[str], weights: List[float], - top_k: int, doc_ids: list[uuid.UUID] = None, banned_ids: list[uuid.UUID] = []) -> List[DocumentEntity]: - """根据知识库ID和关键词权重查询文档(BM25检索版,匹配abstract_bm25_index索引)""" + kb_id: uuid.UUID, query: str, # 关键词列表改为单查询文本,移除weights参数 + top_k: int, doc_ids: list[uuid.UUID] = None, banned_ids: list[uuid.UUID] = []) -> List[DocumentEntity]: + """根据知识库ID和查询文本查询文档(BM25检索版,匹配abstract_bm25_index索引)""" try: st = datetime.now() async with await DataBase.get_session() as session: - # 1. 构建加权关键词CTE(保留原权重逻辑,移除分词器依赖) - params = {} - values_clause = [] - for idx, (term, weight) in enumerate(zip(keywords, weights)): - params[f"term_{idx}"] = term - params[f"weight_{idx}"] = weight - values_clause.append( - f"(CAST(:term_{idx} AS TEXT), CAST(:weight_{idx} AS FLOAT8))") - values_text = f"(VALUES {', '.join(values_clause)}) AS t(term, weight)" - weighted_terms = ( - select( - literal_column("t.term").label("term"), - literal_column("t.weight").cast(Float).label("weight") - ) - .select_from(text(values_text)) - .cte("weighted_terms") - ) + # 1. 构建查询文本参数(单文本,无需CTE列表) + params = {"query": query} - # 2. 初始化查询(替换为BM25评分,关联BM25索引字段abstract) + # 2. 初始化查询(直接使用查询文本计算BM25分数) + query_param = bindparam("query") stmt = ( select( DocumentEntity, - func.sum( - # BM25相似度计算:abstract <&> 关键词,再乘以关键词权重 - DocumentEntity.abstract.op('<&>')(weighted_terms.c.term) - * weighted_terms.c.weight - ).label("similarity_score") + # 计算查询文本与文档abstract的BM25分数 + DocumentEntity.abstract.op('<&>')( + query_param).label("similarity_score") ) - # 关联加权关键词CTE:仅保留BM25有匹配的文档(避免无效计算) - .join( - weighted_terms, - # 条件:BM25相似度结果非空(确保触发BM25索引) - DocumentEntity.abstract.op('<&>')(weighted_terms.c.term).isnot(None), - isouter=False - ) - # 基础过滤条件(与原逻辑一致) + # 基础过滤条件 .where(DocumentEntity.enabled == True) .where(DocumentEntity.status != DocumentStatus.DELETED.value) .where(DocumentEntity.kb_id == kb_id) + # 过滤BM25分数大于0的结果(确保有相关性) + .where(DocumentEntity.abstract.op('<&>')(query_param) > 0) ) - # 3. 动态条件:禁用ID(保持原逻辑) + # 3. 动态条件:禁用ID if banned_ids: stmt = stmt.where(DocumentEntity.id.notin_(banned_ids)) - # 4. 其他动态条件:指定文档ID过滤(保持原逻辑) + # 4. 其他动态条件:指定文档ID过滤 if doc_ids is not None: stmt = stmt.where(DocumentEntity.id.in_(doc_ids)) - # 5. 分组、过滤、排序、限制(基于BM25加权总分) + # 5. 排序、限制(直接使用BM25分数排序) stmt = (stmt - .group_by(DocumentEntity.id) # 按文档ID分组,计算单文档总相似度 - .having( - # 过滤总相似度>0的文档(排除无匹配的结果) - func.sum( - DocumentEntity.abstract.op('<&>')(weighted_terms.c.term) - * weighted_terms.c.weight - ) > 0 - ) .order_by( - # 按BM25总相似度降序,取Top-K - func.sum( - DocumentEntity.abstract.op('<&>')(weighted_terms.c.term) - * weighted_terms.c.weight - ).desc() + DocumentEntity.abstract.op( + '<&>')(query_param).desc() ) .limit(top_k) ) - # 6. 强制使用BM25索引(关键:匹配实际索引名abstract_bm25_index) - if config['DATABASE_TYPE'].lower() == 'opengauss': - # 表名:DocumentEntity对应的表名(通常为'document',需确保ORM映射正确) - table_name = DocumentEntity.__tablename__ - # 实际BM25索引名:从__table_args__确认是abstract_bm25_index - bm25_index_name = "abstract_bm25_index" - # 使用openGauss推荐的Hint语法,强制索引扫描 - stmt = stmt.prefix_with(f"/*+ INDEX({table_name} {bm25_index_name}) */") - - # 7. 执行查询与结果处理(保持原逻辑) + # 6. 执行查询与结果处理 result = await session.execute(stmt, params=params) doc_entities = result.scalars().all() - # 8. 日志输出(补充BM25标识,便于排查) + # 7. 日志输出 cost = (datetime.now() - st).total_seconds() logging.warning( f"[DocumentManager] BM25检索文档耗时: {cost}s " - f"| kb_id: {kb_id} | keywords: {keywords[:2]}... | 匹配数量: {len(doc_entities)}" + f"| kb_id: {kb_id} | query: {query[:50]}... | 匹配数量: {len(doc_entities)}" ) return doc_entities except Exception as e: - # 异常日志补充BM25检索上下文 - err = f"BM25检索文档失败: kb_id={kb_id}, keywords={keywords[:2]}..., error={str(e)[:150]}" + err = f"BM25检索文档失败: kb_id={kb_id}, query={query[:50]}..., error={str(e)[:150]}" logging.exception("[DocumentManager] %s", err) return [] diff --git a/data_chain/manager/task_queue_mamanger.py b/data_chain/manager/task_queue_mamanger.py index b0df886b..2d1c33f0 100644 --- a/data_chain/manager/task_queue_mamanger.py +++ b/data_chain/manager/task_queue_mamanger.py @@ -27,7 +27,8 @@ class TaskQueueManager(): """根据任务ID删除任务""" try: async with await DataBase.get_session() as session: - stmt = delete(TaskQueueEntity).where(TaskQueueEntity.id == task_id) + stmt = delete(TaskQueueEntity).where( + TaskQueueEntity.id == task_id) await session.execute(stmt) await session.commit() except Exception as e: @@ -58,7 +59,8 @@ class TaskQueueManager(): """根据任务ID获取任务""" try: async with await DataBase.get_session() as session: - stmt = select(TaskQueueEntity).where(TaskQueueEntity.id == task_id) + stmt = select(TaskQueueEntity).where( + TaskQueueEntity.id == task_id) result = await session.execute(stmt) return result.scalars().first() except Exception as e: @@ -67,14 +69,14 @@ class TaskQueueManager(): raise e @staticmethod - async def update_task_by_id(task_id: uuid.UUID, task: TaskQueueEntity): + async def update_task_by_id(task_id: uuid.UUID, status: TaskStatus): """根据任务ID更新任务""" try: async with await DataBase.get_session() as session: stmt = ( update(TaskQueueEntity) .where(TaskQueueEntity.id == task_id) - .values(status=task.status) + .values(status=status.value) ) await session.execute(stmt) await session.commit() @@ -82,3 +84,20 @@ class TaskQueueManager(): err = "更新任务失败" logging.exception("[TaskQueueManager] %s", err) raise e + + @staticmethod + async def update_task_by_ids(task_ids: List[uuid.UUID], status: TaskStatus): + """根据任务ID列表批量更新任务状态""" + try: + async with await DataBase.get_session() as session: + stmt = ( + update(TaskQueueEntity) + .where(TaskQueueEntity.id.in_(task_ids)) + .values(status=status.value) + ) + await session.execute(stmt) + await session.commit() + except Exception as e: + err = "批量更新任务状态失败" + logging.exception("[TaskQueueManager] %s", err) + raise e diff --git a/data_chain/rag/doc2chunk_searcher.py b/data_chain/rag/doc2chunk_searcher.py index 6b3aca75..774048d5 100644 --- a/data_chain/rag/doc2chunk_searcher.py +++ b/data_chain/rag/doc2chunk_searcher.py @@ -46,9 +46,7 @@ class Doc2ChunkSearcher(BaseSearcher): use_doc_ids += [doc_entity.id for doc_entity in doc_entities_vector] chunk_entities_keyword = await ChunkManager.get_top_k_chunk_by_kb_id_keyword(kb_id, query, top_k//3, use_doc_ids, banned_ids) banned_ids += [chunk_entity.id for chunk_entity in chunk_entities_keyword] - keywords, weights = TokenTool.get_top_k_keywords_and_weights(query) - logging.error(f"[KeywordVectorSearcher] keywords: {keywords}, weights: {weights}") - chunk_entities_get_by_dynamic_weighted_keyword = await ChunkManager.get_top_k_chunk_by_kb_id_dynamic_weighted_keyword(kb_id, keywords, weights, top_k//2, use_doc_ids, banned_ids) + chunk_entities_get_by_dynamic_weighted_keyword = await ChunkManager.get_top_k_chunk_by_kb_id_dynamic_weighted_keyword(kb_id, query, top_k//2, use_doc_ids, banned_ids) banned_ids += [chunk_entity.id for chunk_entity in chunk_entities_get_by_dynamic_weighted_keyword] chunk_entities_vector = [] for _ in range(3): diff --git a/data_chain/rag/dynamic_weighted_keyword_and_vector_searcher.py b/data_chain/rag/dynamic_weighted_keyword_and_vector_searcher.py index 500288ef..aed0ffcc 100644 --- a/data_chain/rag/dynamic_weighted_keyword_and_vector_searcher.py +++ b/data_chain/rag/dynamic_weighted_keyword_and_vector_searcher.py @@ -34,12 +34,9 @@ class DynamicKeywordVectorSearcher(BaseSearcher): chunk_entities_get_by_keyword = await ChunkManager.get_top_k_chunk_by_kb_id_keyword( kb_id, query, max(top_k//3, 1), doc_ids, banned_ids) banned_ids += [chunk_entity.id for chunk_entity in chunk_entities_get_by_keyword] - keywords, weights = TokenTool.get_top_k_keywords_and_weights(query) - logging.error( - f"[DynamicKeywordVectorSearcher] keywords: {keywords}, weights: {weights}") import time start_time = time.time() - chunk_entities_get_by_dynamic_weighted_keyword = await ChunkManager.get_top_k_chunk_by_kb_id_dynamic_weighted_keyword(kb_id, keywords, weights, top_k//2, doc_ids, banned_ids) + chunk_entities_get_by_dynamic_weighted_keyword = await ChunkManager.get_top_k_chunk_by_kb_id_dynamic_weighted_keyword(kb_id, query, top_k//2, doc_ids, banned_ids) end_time = time.time() logging.info( f"[DynamicKeywordVectorSearcher] 动态关键字检索成功完成,耗时: {end_time - start_time:.2f}秒") diff --git a/data_chain/rag/dynamic_weighted_keyword_searcher.py b/data_chain/rag/dynamic_weighted_keyword_searcher.py index 53cbb962..7860ae36 100644 --- a/data_chain/rag/dynamic_weighted_keyword_searcher.py +++ b/data_chain/rag/dynamic_weighted_keyword_searcher.py @@ -28,8 +28,7 @@ class DynamicWeightKeyWordSearcher(BaseSearcher): :return: 检索结果 """ try: - keywords, weights = TokenTool.get_top_k_keywords_and_weights(query) - chunk_entities = await ChunkManager.get_top_k_chunk_by_kb_id_dynamic_weighted_keyword(kb_id, keywords, weights, top_k, doc_ids, banned_ids) + chunk_entities = await ChunkManager.get_top_k_chunk_by_kb_id_dynamic_weighted_keyword(kb_id, query, top_k, doc_ids, banned_ids) except Exception as e: err = f"[KeywordVectorSearcher] 关键词向量检索失败,error: {e}" logging.exception(err) diff --git a/data_chain/rag/enhanced_by_llm_searcher.py b/data_chain/rag/enhanced_by_llm_searcher.py index 738eaac9..b27cf729 100644 --- a/data_chain/rag/enhanced_by_llm_searcher.py +++ b/data_chain/rag/enhanced_by_llm_searcher.py @@ -49,11 +49,10 @@ class EnhancedByLLMSearcher(BaseSearcher): model_name=config['MODEL_NAME'], max_tokens=config['MAX_TOKENS'], ) - keywords, weights = TokenTool.get_top_k_keywords_and_weights(query) while len(chunk_entities) < top_k and rd < max_retry: rd += 1 sub_chunk_entities_keyword = await ChunkManager.get_top_k_chunk_by_kb_id_dynamic_weighted_keyword( - kb_id, keywords, weights, top_k, doc_ids, banned_ids) + kb_id, query, top_k, doc_ids, banned_ids) chunk_ids = [chunk_entity.id for chunk_entity in sub_chunk_entities_keyword] banned_ids += chunk_ids sub_chunk_entities_vector = [] diff --git a/data_chain/stores/database/database.py b/data_chain/stores/database/database.py index daae2620..429a862e 100644 --- a/data_chain/stores/database/database.py +++ b/data_chain/stores/database/database.py @@ -62,13 +62,19 @@ class TeamEntity(Base): server_default=func.current_timestamp(), onupdate=func.current_timestamp() ) + # 添加索引 + __table_args__ = ( + Index('team_id_index', id), + Index('team_name_index', name), + Index('team_author_id_index', author_id) + ) class TeamMessageEntity(Base): __tablename__ = 'team_message' id = Column(UUID, default=uuid4, primary_key=True) - team_id = Column(UUID, ForeignKey('team.id')) + team_id = Column(UUID) author_id = Column(Text) author_name = Column(Text) message_level = Column(Text) @@ -86,12 +92,19 @@ class TeamMessageEntity(Base): onupdate=func.current_timestamp() ) + # 添加索引 + __table_args__ = ( + Index('team_message_id_index', id), + Index('team_message_team_id_index', team_id), + Index('team_message_author_id_index', author_id) + ) + class RoleEntity(Base): __tablename__ = 'role' id = Column(UUID, default=uuid4, primary_key=True) - team_id = Column(UUID, ForeignKey('team.id')) + team_id = Column(UUID) name = Column(Text) is_unique = Column(Boolean, default=False) editable = Column(Boolean, default=True) @@ -106,6 +119,12 @@ class RoleEntity(Base): server_default=func.current_timestamp(), onupdate=func.current_timestamp() ) + # 添加索引 + __table_args__ = ( + Index('role_id_index', id), + Index('role_team_id_index', team_id), + Index('role_name_index', name) + ) class ActionEntity(Base): @@ -130,7 +149,7 @@ class RoleActionEntity(Base): __tablename__ = 'role_action' id = Column(UUID, default=uuid4, primary_key=True) - role_id = Column(UUID, ForeignKey('role.id', ondelete="CASCADE")) + role_id = Column(UUID) action = Column(Text) status = Column(Text, default=RoleActionStatus.EXISTED.value) created_time = Column( @@ -144,6 +163,13 @@ class RoleActionEntity(Base): onupdate=func.current_timestamp() ) + # 添加索引 + __table_args__ = ( + Index('role_action_id_index', id), + Index('role_action_role_id_index', role_id), + Index('role_action_action_index', action) + ) + class UserEntity(Base): __tablename__ = 'users' @@ -162,6 +188,12 @@ class UserEntity(Base): onupdate=func.current_timestamp() ) + # 添加索引 + __table_args__ = ( + Index('user_id_index', id), + Index('user_name_index', name) + ) + class UserMessageEntity(Base): __tablename__ = 'user_message' @@ -185,12 +217,19 @@ class UserMessageEntity(Base): server_default=func.current_timestamp() ) + # 添加索引 + __table_args__ = ( + Index('user_message_id_index', id), + Index('user_message_sender_id_index', sender_id), + Index('user_message_receiver_id_index', receiver_id) + ) + class TeamUserEntity(Base): __tablename__ = 'team_user' id = Column(UUID, default=uuid4, primary_key=True) - team_id = Column(UUID, ForeignKey('team.id', ondelete="CASCADE")) # 团队id + team_id = Column(UUID) # 团队id user_id = Column(Text) # 用户id status = Column(Text, default=TeamUserStaus.EXISTED.value) # 用户在团队中的状态 created_time = Column( @@ -204,11 +243,18 @@ class TeamUserEntity(Base): onupdate=func.current_timestamp() ) + # 添加索引 + __table_args__ = ( + Index('team_user_id_index', id), + Index('team_user_team_id_index', team_id), + Index('team_user_user_id_index', user_id) + ) + class UserRoleEntity(Base): __tablename__ = 'user_role' id = Column(UUID, default=uuid4, primary_key=True) - team_id = Column(UUID, ForeignKey('team.id', ondelete="CASCADE")) # 团队id + team_id = Column(UUID) # 团队id user_id = Column(Text) # 用户id role_id = Column(UUID) # 角色id status = Column(Text, default=UserRoleStatus.EXISTED.value) # 用户角色状态 @@ -223,13 +269,19 @@ class UserRoleEntity(Base): onupdate=func.current_timestamp() ) + # 添加索引 + __table_args__ = ( + Index('user_role_id_index', id), + Index('user_role_team_id_index', team_id), + Index('user_role_user_id_index', user_id) + ) + class KnowledgeBaseEntity(Base): __tablename__ = 'knowledge_base' id = Column(UUID, default=uuid4, primary_key=True) - team_id = Column(UUID, ForeignKey( - 'team.id', ondelete="CASCADE"), nullable=True) # 团队id + team_id = Column(UUID, nullable=True) # 团队id author_id = Column(Text) # 作者id author_name = Column(Text) # 作者名称 name = Column(Text, default='') # 知识库名资产名 @@ -258,13 +310,19 @@ class KnowledgeBaseEntity(Base): onupdate=func.current_timestamp() ) + # 添加索引 + __table_args__ = ( + Index('knowledge_base_id_index', id), + Index('knowledge_base_team_id_index', team_id), + Index('knowledge_base_name_index', name) + ) + class DocumentTypeEntity(Base): __tablename__ = 'document_type' id = Column(UUID, default=uuid4, primary_key=True) - kb_id = Column(UUID, ForeignKey('knowledge_base.id', - ondelete="CASCADE"), nullable=True) + kb_id = Column(UUID, nullable=True) name = Column(Text) created_time = Column( TIMESTAMP(timezone=True), @@ -277,14 +335,20 @@ class DocumentTypeEntity(Base): onupdate=func.current_timestamp() ) + # 添加索引 + __table_args__ = ( + Index('document_type_id_index', id), + Index('document_type_kb_id_index', kb_id), + Index('document_type_name_index', name) + ) + class DocumentEntity(Base): __tablename__ = 'document' id = Column(UUID, default=uuid4, primary_key=True) team_id = Column(UUID) # 文档所属团队id - kb_id = Column(UUID, ForeignKey( - 'knowledge_base.id', ondelete="CASCADE")) # 文档所属资产id + kb_id = Column(UUID) # 文档所属资产id author_id = Column(Text) # 文档作者id author_name = Column(Text) # 文档作者名称 name = Column(Text) # 文档名 @@ -312,6 +376,12 @@ class DocumentEntity(Base): onupdate=func.current_timestamp() ) __table_args__ = ( + Index("document_id_index", id), + Index("document_team_id_index", team_id), + Index("document_kb_id_index", kb_id), + Index("document_author_id_index", author_id), + Index("document_author_name_index", author_name), + Index("document_name_index", name), Index('abstract_ts_vector_index', abstract_ts_vector, postgresql_using='gin'), Index( @@ -331,8 +401,7 @@ class ChunkEntity(Base): id = Column(UUID, default=uuid4, primary_key=True) # chunk id team_id = Column(UUID) # 团队id kb_id = Column(UUID) # 知识库id - doc_id = Column(UUID, ForeignKey( - 'document.id', ondelete="CASCADE")) # 片段所属文档id + doc_id = Column(UUID) # 片段所属文档id doc_name = Column(Text) # 片段所属文档名称 text = Column(Text) # 片段文本内容 text_ts_vector = Column(TSVECTOR) # 片段文本词向量 @@ -358,6 +427,10 @@ class ChunkEntity(Base): server_default=func.current_timestamp(), onupdate=func.current_timestamp()) __table_args__ = ( + Index("chunk_id_index", id), + Index("chunk_team_id_index", team_id), + Index("chunk_kb_id_index", kb_id), + Index("chunk_doc_id_index", doc_id), Index('text_ts_vector_index', text_ts_vector, postgresql_using='gin'), Index( 'text_vector_index', @@ -389,14 +462,21 @@ class ImageEntity(Base): onupdate=func.current_timestamp() ) + # 添加索引 + __table_args__ = ( + Index('image_id_index', id), + Index('image_team_id_index', team_id), + Index('image_doc_id_index', doc_id), + Index('image_chunk_id_index', chunk_id) + ) + class DataSetEntity(Base): __tablename__ = 'dataset' id = Column(UUID, default=uuid4, primary_key=True) # 数据集id team_id = Column(UUID) # 数据集所属团队id - kb_id = Column(UUID, ForeignKey('knowledge_base.id', - ondelete="CASCADE")) # 数据集所属资产id + kb_id = Column(UUID) # 数据集所属资产id author_id = Column(Text) # 数据的创建者id author_name = Column(Text) # 数据的创建者名称 llm_id = Column(Text) # 数据的生成使用的大模型的id @@ -419,13 +499,20 @@ class DataSetEntity(Base): onupdate=func.current_timestamp() ) + # 添加索引 + __table_args__ = ( + Index('dataset_id_index', id), + Index('dataset_team_id_index', team_id), + Index('dataset_kb_id_index', kb_id), + Index('dataset_name_index', name) + ) + class DataSetDocEntity(Base): __tablename__ = 'dataset_doc' id = Column(UUID, default=uuid4, primary_key=True) # 数据集文档id - dataset_id = Column(UUID, ForeignKey( - 'dataset.id', ondelete="CASCADE")) # 数据集id + dataset_id = Column(UUID) # 数据集id doc_id = Column(UUID) # 文档id created_at = Column( TIMESTAMP(timezone=True), @@ -438,13 +525,19 @@ class DataSetDocEntity(Base): onupdate=func.current_timestamp() ) + # 添加索引 + __table_args__ = ( + Index('dataset_doc_id_index', id), + Index('dataset_doc_dataset_id_index', dataset_id), + Index('dataset_doc_doc_id_index', doc_id) + ) + class QAEntity(Base): __tablename__ = 'qa' id = Column(UUID, default=uuid4, primary_key=True) # 数据id - dataset_id = Column(UUID, ForeignKey( - 'dataset.id', ondelete="CASCADE")) # 数据所属数据集id + dataset_id = Column(UUID) # 数据所属数据集id doc_id = Column(UUID) # 数据关联的文档id doc_name = Column(Text, default="未知文档") # 数据关联的文档名称 question = Column(Text) # 数据的问题 @@ -462,6 +555,12 @@ class QAEntity(Base): server_default=func.current_timestamp(), onupdate=func.current_timestamp() ) + # 添加索引 + __table_args__ = ( + Index('qa_id_index', id), + Index('qa_dataset_id_index', dataset_id), + Index('qa_doc_id_index', doc_id) + ) class TestingEntity(Base): @@ -470,8 +569,7 @@ class TestingEntity(Base): id = Column(UUID, default=uuid4, primary_key=True) # 测试任务的id team_id = Column(UUID) # 测试任务所属团队id kb_id = Column(UUID) # 测试任务所属资产id - dataset_id = Column(UUID, ForeignKey( - 'dataset.id', ondelete="CASCADE")) # 测试任务使用数据集的id + dataset_id = Column(UUID) # 测试任务使用数据集的id author_id = Column(Text) # 测试任务的创建者id author_name = Column(Text) # 测试任务的创建者名称 name = Column(Text) # 测试任务的名称 @@ -500,13 +598,20 @@ class TestingEntity(Base): onupdate=func.current_timestamp() ) + # 添加索引 + __table_args__ = ( + Index('testing_id_index', id), + Index('testing_team_id_index', team_id), + Index('testing_kb_id_index', kb_id), + Index('testing_dataset_id_index', dataset_id) + ) + class TestCaseEntity(Base): __tablename__ = 'testcase' id = Column(UUID, default=uuid4, primary_key=True) # 测试case的id - testing_id = Column(UUID, ForeignKey( - 'testing.id', ondelete="CASCADE")) # 测试 + testing_id = Column(UUID) # 测试 question = Column(Text) # 数据的问题 answer = Column(Text) # 数据的答案 chunk = Column(Text) # 数据的片段 @@ -533,14 +638,19 @@ class TestCaseEntity(Base): onupdate=func.current_timestamp() ) + # 添加索引 + __table_args__ = ( + Index('testcase_id_index', id), + Index('testcase_testing_id_index', testing_id) + ) + class TaskEntity(Base): __tablename__ = 'task' id = Column(UUID, default=uuid4, primary_key=True) team_id = Column(UUID) # 团队id - user_id = Column(Text, ForeignKey( - 'users.id', ondelete="CASCADE")) # 创建者id + user_id = Column(Text) # 创建者id op_id = Column(UUID) # 任务关联的实体id, 资产或者文档id op_name = Column(Text) # 任务关联的实体名称 type = Column(Text) # 任务类型 @@ -557,12 +667,22 @@ class TaskEntity(Base): onupdate=func.current_timestamp() ) + # 添加索引 + __table_args__ = ( + Index('task_id_index', id), + Index('task_team_id_index', team_id), + Index('task_user_id_index', user_id), + Index('task_op_id_index', op_id), + Index('task_type_index', type), + Index('task_status_index', status) + ) + class TaskReportEntity(Base): __tablename__ = 'task_report' id = Column(UUID, default=uuid4, primary_key=True) # 任务报告的id - task_id = Column(UUID, ForeignKey('task.id', ondelete="CASCADE")) # 任务id + task_id = Column(UUID) # 任务id message = Column(Text) # 任务报告信息 current_stage = Column(BigInteger) # 任务当前阶段 stage_cnt = Column(BigInteger) # 任务总的阶段 @@ -577,6 +697,12 @@ class TaskReportEntity(Base): onupdate=func.current_timestamp() ) + # 添加索引 + __table_args__ = ( + Index('task_report_id_index', id), + Index('task_report_task_id_index', task_id) + ) + class TaskQueueEntity(Base): __tablename__ = 'task_queue' @@ -609,9 +735,10 @@ class DataBase: pool_size = os.cpu_count() if pool_size is None: pool_size = 5 + logging.error(f"Database pool size set to: {pool_size}") engine = create_async_engine( database_url, - echo=True, + echo=False, pool_recycle=300, pool_pre_ping=True, pool_size=pool_size -- Gitee