diff --git a/Dockerfile b/Dockerfile index 3e3de77938d02721e2cc28084dd76b37698b2aca..8f1bbc4f0c9135d2f8bd3c07fee315126e13a4f8 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,18 +1,14 @@ -FROM hub.oepkgs.net/neocopilot/data_chain_back_end_base:0.9.6-x86 +FROM hub.oepkgs.net/neocopilot/data_chain_back_end_base:0.10.0-x86 + +COPY --chmod=750 ./data_chain /rag-service/data_chain +COPY --chmod=750 ./chat2db /rag-service/chat2db +COPY --chmod=750 ./run.sh /rag-service/ -COPY --chown=1001:1001 --chmod=750 ./ /rag-service/ WORKDIR /rag-service ENV PYTHONPATH /rag-service USER root -RUN sed -i 's/umask 002/umask 027/g' /etc/bashrc && \ - sed -i 's/umask 022/umask 027/g' /etc/bashrc && \ - # yum remove -y python3-pip gdb-gdbserver && \ - sh -c "find /usr /etc \( -name *yum* -o -name *dnf* -o -name *vi* \) -exec rm -rf {} + || true" && \ - sh -c "find /usr /etc \( -name ps -o -name top \) -exec rm -rf {} + || true" && \ - sh -c "rm -f /usr/bin/find /usr/bin/oldfind || true" -USER eulercopilot CMD ["/bin/bash", "run.sh"] diff --git a/Dockerfile-base b/Dockerfile-base index 8222781d900b063371cb8413c696ce050c444213..58fbae5fee55a7cc9d9a718d310bab589a9f1536 100644 --- a/Dockerfile-base +++ b/Dockerfile-base @@ -8,19 +8,15 @@ RUN sed -i 's|http://repo.openeuler.org/|https://mirrors.huaweicloud.com/openeul yum makecache &&\ yum update -y &&\ yum install -y mesa-libGL java python3 python3-pip shadow-utils &&\ - yum clean all && \ - groupadd -g 1001 eulercopilot && useradd -u 1001 -g eulercopilot eulercopilot + yum clean all -# 创建 /rag-service 目录并设置权限 -RUN mkdir -p /rag-service && chown -R 1001:1001 /rag-service - -# 切换到 eulercopilot 用户 -USER eulercopilot +# 创建 /rag-service +RUN mkdir -p /rag-service # 复制 requirements.txt 文件到 /rag-service 目录 -COPY --chown=1001:1001 requirements.txt /rag-service/ -COPY --chown=1001:1001 tika-server-standard-2.9.2.jar /rag-service/ -COPY --chown=1001:1001 download_model.py /rag-service/ +COPY requirements.txt /rag-service/ +COPY tika-server-standard-2.9.2.jar /rag-service/ +COPY download_model.py /rag-service/ # 安装 Python 依赖 RUN pip3 install --no-cache-dir -r /rag-service/requirements.txt --index-url https://pypi.tuna.tsinghua.edu.cn/simple && \ diff --git a/chat2db/.gitignore b/chat2db/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..3040d646d234ca4d2ec099a0378f110f2e7f3775 --- /dev/null +++ b/chat2db/.gitignore @@ -0,0 +1,2 @@ +__pycache__/ +.vscode/ \ No newline at end of file diff --git a/chat2db/app/__init__.py b/chat2db/app/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/chat2db/app/app.py b/chat2db/app/app.py deleted file mode 100644 index 71be6ed2b862a04647bfb08845a7424e8adf6041..0000000000000000000000000000000000000000 --- a/chat2db/app/app.py +++ /dev/null @@ -1,36 +0,0 @@ -import uvicorn -from fastapi import FastAPI -import sys -from chat2db.app.router import sql_example -from chat2db.app.router import sql_generate -from chat2db.app.router import database -from chat2db.app.router import table -from chat2db.config.config import config -import logging - - -logging.basicConfig(stream=sys.stdout, level=logging.INFO, - format='%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s') - -app = FastAPI() - -app.include_router(sql_example.router) -app.include_router(sql_generate.router) -app.include_router(database.router) -app.include_router(table.router) - -if __name__ == '__main__': - try: - ssl_enable = config["SSL_ENABLE"] - if ssl_enable: - uvicorn.run(app, host=config["UVICORN_IP"], port=int(config["UVICORN_PORT"]), - proxy_headers=True, forwarded_allow_ips='*', - ssl_certfile=config["SSL_CERTFILE"], - ssl_keyfile=config["SSL_KEYFILE"], - ) - else: - uvicorn.run(app, host=config["UVICORN_IP"], port=int(config["UVICORN_PORT"]), - proxy_headers=True, forwarded_allow_ips='*' - ) - except Exception as e: - exit(1) diff --git a/chat2db/app/base/ac_automation.py b/chat2db/app/base/ac_automation.py deleted file mode 100644 index 3012f2bb73d599771f63ef0cd2617e3f43d73dbb..0000000000000000000000000000000000000000 --- a/chat2db/app/base/ac_automation.py +++ /dev/null @@ -1,87 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. -import copy -import logging -import sys - -class Node: - def __init__(self, dep, pre_id): - self.dep = dep - self.pre_id = pre_id - self.pre_nearest_children_id = {} - self.children_id = {} - self.data_frame = None - - -logging.basicConfig(stream=sys.stdout, level=logging.INFO, - format='%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s') - - -class DictTree: - def __init__(self): - self.root = 0 - self.node_list = [Node(0, -1)] - - def load_data(self, data_dict): - for key in data_dict: - self.insert_data(key, data_dict[key]) - self.init_pre() - - def insert_data(self, keyword, data_frame): - if not isinstance(keyword,str): - return - if len(keyword) == 0: - return - node_index = self.root - try: - for i in range(len(keyword)): - if keyword[i] not in self.node_list[node_index].children_id.keys(): - self.node_list.append(Node(self.node_list[node_index].dep+1, 0)) - self.node_list[node_index].children_id[keyword[i]] = len(self.node_list)-1 - node_index = self.node_list[node_index].children_id[keyword[i]] - except Exception as e: - logging.error(f'关键字插入失败由于:{e}') - return - self.node_list[node_index].data_frame = data_frame - - def init_pre(self): - q = [self.root] - l = 0 - r = 1 - try: - while l < r: - node_index = q[l] - self.node_list[node_index].pre_nearest_children_id = self.node_list[self.node_list[node_index].pre_id].children_id.copy() - l += 1 - for key, val in self.node_list[node_index].children_id.items(): - q.append(val) - r += 1 - if key in self.node_list[node_index].pre_nearest_children_id.keys(): - pre_id = self.node_list[node_index].pre_nearest_children_id[key] - self.node_list[val].pre_id = pre_id - self.node_list[node_index].pre_nearest_children_id[key] = val - except Exception as e: - logging.error(f'字典树前缀构建失败由于:{e}') - return - - def get_results(self, content: str): - content = content.lower() - pre_node_index = self.root - nex_node_index = None - results = [] - logging.info(f'当前问题{content}') - try: - for i in range(len(content)): - if content[i] in self.node_list[pre_node_index].pre_nearest_children_id.keys(): - nex_node_index = self.node_list[pre_node_index].pre_nearest_children_id[content[i]] - else: - nex_node_index = 0 - if self.node_list[pre_node_index].dep >= self.node_list[nex_node_index].dep: - if self.node_list[pre_node_index].data_frame is not None: - results.extend(copy.deepcopy(self.node_list[pre_node_index].data_frame)) - pre_node_index = nex_node_index - logging.info(f'当前深度{self.node_list[pre_node_index].dep}') - if self.node_list[pre_node_index].data_frame is not None: - results.extend(copy.deepcopy(self.node_list[pre_node_index].data_frame)) - except Exception as e: - logging.error(f'结果获取失败由于:{e}') - return results diff --git a/chat2db/app/base/mysql.py b/chat2db/app/base/mysql.py deleted file mode 100644 index b47322bc4dec4b254e86af04ea95d9f47a63457c..0000000000000000000000000000000000000000 --- a/chat2db/app/base/mysql.py +++ /dev/null @@ -1,217 +0,0 @@ - -import asyncio -import aiomysql -import concurrent.futures -import logging -from sqlalchemy.orm import sessionmaker -from sqlalchemy import create_engine, text -import sys -from concurrent.futures import ThreadPoolExecutor -from urllib.parse import urlparse -from chat2db.app.base.meta_databbase import MetaDatabase -logging.basicConfig(stream=sys.stdout, level=logging.INFO, - format='%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s') - - -class Mysql(MetaDatabase): - executor = ThreadPoolExecutor(max_workers=10) - - async def test_database_connection(database_url): - try: - with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor: - future = executor.submit(Mysql._connect_and_query, database_url) - result = future.result(timeout=5) - return result - except concurrent.futures.TimeoutError: - logging.error('mysql数据库连接超时') - return False - except Exception as e: - logging.error(f'mysql数据库连接失败由于{e}') - return False - - @staticmethod - def _connect_and_query(database_url): - try: - engine = create_engine( - database_url, - pool_size=20, - max_overflow=80, - pool_recycle=300, - pool_pre_ping=True - ) - session = sessionmaker(bind=engine)() - session.execute(text("SELECT 1")) - session.close() - return True - except Exception as e: - raise e - - @staticmethod - async def drop_table(database_url, table_name): - engine = create_engine( - database_url, - pool_size=20, - max_overflow=80, - pool_recycle=300, - pool_pre_ping=True - ) - with sessionmaker(engine)() as session: - sql_str = f"DROP TABLE IF EXISTS {table_name};" - session.execute(text(sql_str)) - - @staticmethod - async def select_primary_key_and_keyword_from_table(database_url, table_name, keyword): - try: - url = urlparse(database_url) - db_config = { - 'host': url.hostname or 'localhost', - 'port': int(url.port or 3306), - 'user': url.username or 'root', - 'password': url.password or '', - 'db': url.path.strip('/') - } - - async with aiomysql.create_pool(**db_config) as pool: - async with pool.acquire() as conn: - async with conn.cursor() as cur: - primary_key_query = """ - SELECT - COLUMNS.column_name - FROM - information_schema.tables AS TABLES - INNER JOIN information_schema.columns AS COLUMNS ON TABLES.table_name = COLUMNS.table_name - WHERE - TABLES.table_schema = %s AND TABLES.table_name = %s AND COLUMNS.column_key = 'PRI'; - """ - - # 尝试执行查询 - await cur.execute(primary_key_query, (db_config['db'], table_name)) - primary_key_list = await cur.fetchall() - if not primary_key_list: - return [] - primary_key_names = ', '.join([record[0] for record in primary_key_list]) - columns = f'{primary_key_names}, {keyword}' - query = f'SELECT {columns} FROM {table_name};' - await cur.execute(query) - results = await cur.fetchall() - - def _process_results(results, primary_key_list): - tmp_dict = {} - for row in results: - key = str(row[-1]) - if key not in tmp_dict: - tmp_dict[key] = [] - pk_values = [str(row[i]) for i in range(len(primary_key_list))] - tmp_dict[key].append(pk_values) - - return { - 'primary_key_list': [record[0] for record in primary_key_list], - 'keyword_value_dict': tmp_dict - } - result = await asyncio.get_event_loop().run_in_executor( - Mysql.executor, - _process_results, - results, - primary_key_list - ) - return result - - except Exception as e: - logging.error(f'mysql数据检索失败由于 {e}') - - @staticmethod - async def assemble_sql_query_base_on_primary_key(table_name, primary_key_list, primary_key_value_list): - sql_str = f'SELECT * FROM {table_name} where ' - for i in range(len(primary_key_list)): - sql_str += primary_key_list[i]+'= \''+primary_key_value_list[i]+'\'' - if i != len(primary_key_list)-1: - sql_str += ' and ' - sql_str += ';' - return sql_str - - @staticmethod - async def get_table_info(database_url, table_name): - engine = create_engine( - database_url, - pool_size=20, - max_overflow=80, - pool_recycle=300, - pool_pre_ping=True - ) - with sessionmaker(engine)() as session: - sql_str = f"""SELECT TABLE_COMMENT FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_NAME = '{table_name}';""" - table_note = session.execute(text(sql_str)).one()[0] - if table_note == '': - table_note = table_name - table_note = { - 'table_name': table_name, - 'table_note': table_note - } - return table_note - - @staticmethod - async def get_column_info(database_url, table_name): - engine = create_engine( - database_url, - pool_size=20, - max_overflow=80, - pool_recycle=300, - pool_pre_ping=True - ) - with engine.connect() as conn: - sql_str = f""" - SELECT column_name, column_type, column_comment FROM information_schema.columns where TABLE_NAME='{table_name}'; - """ - results = conn.execute(text(sql_str), {'table_name': table_name}).all() - column_info_list = [] - for result in results: - column_info_list.append({'column_name': result[0], 'column_type': result[1], 'column_note': result[2]}) - return column_info_list - - @staticmethod - async def get_all_table_name_from_database_url(database_url): - engine = create_engine( - database_url, - pool_size=20, - max_overflow=80, - pool_recycle=300, - pool_pre_ping=True - ) - with engine.connect() as connection: - result = connection.execute(text("SHOW TABLES")) - table_name_list = [row[0] for row in result] - return table_name_list - - @staticmethod - async def get_rand_data(database_url, table_name, cnt=10): - engine = create_engine( - database_url, - pool_size=20, - max_overflow=80, - pool_recycle=300, - pool_pre_ping=True - ) - try: - with sessionmaker(engine)() as session: - sql_str = f'''SELECT * - FROM {table_name} - ORDER BY RAND() - LIMIT {cnt};''' - dataframe = str(session.execute(text(sql_str)).all()) - except Exception as e: - dataframe = '' - logging.error(f'随机从数据库中获取数据失败由于{e}') - return dataframe - - @staticmethod - async def try_excute(database_url, sql_str): - engine = create_engine( - database_url, - pool_size=20, - max_overflow=80, - pool_recycle=300, - pool_pre_ping=True - ) - with sessionmaker(engine)() as session: - result = session.execute(text(sql_str)).all() - return Mysql.result_to_json(result) diff --git a/chat2db/app/base/postgres.py b/chat2db/app/base/postgres.py deleted file mode 100644 index a29a4427f34e689f632cbae79fd2337740ede824..0000000000000000000000000000000000000000 --- a/chat2db/app/base/postgres.py +++ /dev/null @@ -1,236 +0,0 @@ -import asyncio -import asyncpg -import concurrent.futures -import logging -from sqlalchemy.orm import sessionmaker -from sqlalchemy import create_engine, text -import sys -from concurrent.futures import ThreadPoolExecutor -from chat2db.app.base.meta_databbase import MetaDatabase -logging.basicConfig(stream=sys.stdout, level=logging.INFO, - format='%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s') - - -def handler(signum, frame): - raise TimeoutError("超时") - - -class Postgres(MetaDatabase): - executor = ThreadPoolExecutor(max_workers=10) - - async def test_database_connection(database_url): - try: - with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor: - future = executor.submit(Postgres._connect_and_query, database_url) - result = future.result(timeout=5) - return result - except concurrent.futures.TimeoutError: - logging.error('postgres数据库连接超时') - return False - except Exception as e: - logging.error(f'postgres数据库连接失败由于{e}') - return False - - @staticmethod - def _connect_and_query(database_url): - try: - engine = create_engine( - database_url, - pool_size=20, - max_overflow=80, - pool_recycle=300, - pool_pre_ping=True - ) - session = sessionmaker(bind=engine)() - session.execute(text("SELECT 1")) - session.close() - return True - except Exception as e: - raise e - - @staticmethod - async def drop_table(database_url, table_name): - engine = create_engine( - database_url, - pool_size=20, - max_overflow=80, - pool_recycle=300, - pool_pre_ping=True - ) - with sessionmaker(engine)() as session: - sql_str = f"DROP TABLE IF EXISTS {table_name};" - session.execute(text(sql_str)) - - @staticmethod - async def select_primary_key_and_keyword_from_table(database_url, table_name, keyword): - try: - dsn = database_url.replace('+psycopg2', '') - conn = await asyncpg.connect(dsn=dsn) - primary_key_query = """ - SELECT - kcu.column_name - FROM - information_schema.table_constraints AS tc - JOIN information_schema.key_column_usage AS kcu - ON tc.constraint_name = kcu.constraint_name - WHERE - tc.constraint_type = 'PRIMARY KEY' - AND tc.table_name = $1; - """ - primary_key_list = await conn.fetch(primary_key_query, table_name) - if not primary_key_list: - return [] - columns = ', '.join([record['column_name'] for record in primary_key_list]) + f', {keyword}' - query = f'SELECT {columns} FROM {table_name};' - results = await conn.fetch(query) - - def _process_results(results, primary_key_list): - tmp_dict = {} - for row in results: - key = str(row[-1]) - if key not in tmp_dict: - tmp_dict[key] = [] - pk_values = [str(row[i]) for i in range(len(primary_key_list))] - tmp_dict[key].append(pk_values) - - return { - 'primary_key_list': [record['column_name'] for record in primary_key_list], - 'keyword_value_dict': tmp_dict - } - result = await asyncio.get_event_loop().run_in_executor( - Postgres.executor, - _process_results, - results, - primary_key_list - ) - await conn.close() - - return result - except Exception as e: - logging.error(f'postgres数据检索失败由于 {e}') - return None - - @staticmethod - async def assemble_sql_query_base_on_primary_key(table_name, primary_key_list, primary_key_value_list): - sql_str = f'SELECT * FROM {table_name} where ' - for i in range(len(primary_key_list)): - sql_str += primary_key_list[i]+'='+'\''+primary_key_value_list[i]+'\'' - if i != len(primary_key_list)-1: - sql_str += ' and ' - sql_str += ';' - return sql_str - - @staticmethod - async def get_table_info(database_url, table_name): - engine = create_engine( - database_url, - pool_size=20, - max_overflow=80, - pool_recycle=300, - pool_pre_ping=True - ) - with engine.connect() as conn: - sql_str = """ - SELECT - d.description AS table_description - FROM - pg_class t - JOIN - pg_description d ON t.oid = d.objoid - WHERE - t.relkind = 'r' AND - d.objsubid = 0 AND - t.relname = :table_name; """ - result = conn.execute(text(sql_str), {'table_name': table_name}).one_or_none() - if result is None: - table_note = table_name - else: - table_note = result[0] - table_note = { - 'table_name': table_name, - 'table_note': table_note - } - return table_note - - @staticmethod - async def get_column_info(database_url, table_name): - engine = create_engine( - database_url, - pool_size=20, - max_overflow=80, - pool_recycle=300, - pool_pre_ping=True - ) - with engine.connect() as conn: - sql_str = """ - SELECT - a.attname as 字段名, - format_type(a.atttypid,a.atttypmod) as 类型, - col_description(a.attrelid,a.attnum) as 注释 - FROM - pg_class as c,pg_attribute as a - where - a.attrelid = c.oid - and - a.attnum>0 - and - c.relname = :table_name; - """ - results = conn.execute(text(sql_str), {'table_name': table_name}).all() - column_info_list = [] - for result in results: - column_info_list.append({'column_name': result[0], 'column_type': result[1], 'column_note': result[2]}) - return column_info_list - - @staticmethod - async def get_all_table_name_from_database_url(database_url): - engine = create_engine( - database_url, - pool_size=20, - max_overflow=80, - pool_recycle=300, - pool_pre_ping=True - ) - with engine.connect() as connection: - sql_str = ''' - SELECT table_name - FROM information_schema.tables - WHERE table_schema = 'public'; - ''' - result = connection.execute(text(sql_str)) - table_name_list = [row[0] for row in result] - return table_name_list - - @staticmethod - async def get_rand_data(database_url, table_name, cnt=10): - engine = create_engine( - database_url, - pool_size=20, - max_overflow=80, - pool_recycle=300, - pool_pre_ping=True - ) - try: - with sessionmaker(engine)() as session: - sql_str = f'''SELECT * - FROM {table_name} - ORDER BY RANDOM() - LIMIT {cnt};''' - dataframe = str(session.execute(text(sql_str)).all()) - except Exception as e: - dataframe = '' - logging.error(f'随机从数据库中获取数据失败由于{e}') - return dataframe - - @staticmethod - async def try_excute(database_url, sql_str): - engine = create_engine( - database_url, - pool_size=20, - max_overflow=80, - pool_recycle=300, - pool_pre_ping=True - ) - with sessionmaker(engine)() as session: - result=session.execute(text(sql_str)).all() - return Postgres.result_to_json(result) diff --git a/chat2db/app/router/database.py b/chat2db/app/router/database.py deleted file mode 100644 index 37aacca406d2de39aaa56b983bf7d65b3d29f2e3..0000000000000000000000000000000000000000 --- a/chat2db/app/router/database.py +++ /dev/null @@ -1,191 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. - -import logging -import uuid -from fastapi import APIRouter, status -from typing import Optional -import sys -from chat2db.model.request import DatabaseAddRequest, DatabaseDelRequest, DatabaseSqlGenerateRequest -from chat2db.model.response import ResponseData -from chat2db.manager.database_info_manager import DatabaseInfoManager -from chat2db.manager.table_info_manager import TableInfoManager -from chat2db.manager.column_info_manager import ColumnInfoManager -from chat2db.app.service.diff_database_service import DiffDatabaseService -from chat2db.app.service.sql_generate_service import SqlGenerateService -from chat2db.app.service.keyword_service import keyword_service -from chat2db.app.base.vectorize import Vectorize - -logging.basicConfig(stream=sys.stdout, level=logging.INFO, - format='%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s') - -router = APIRouter( - prefix="/database" -) - - -@router.post("/add", response_model=ResponseData) -async def add_database_info(request: DatabaseAddRequest): - database_url = request.database_url - database_type = DiffDatabaseService.get_database_type_from_url(database_url) - if not DiffDatabaseService.is_database_type_allow(database_type): - return ResponseData( - code=status.HTTP_422_UNPROCESSABLE_ENTITY, - message="不支持当前数据库", - result={} - ) - flag = await DiffDatabaseService.get_database_service(database_type).test_database_connection(database_url) - if not flag: - return ResponseData( - code=status.HTTP_422_UNPROCESSABLE_ENTITY, - message="无法连接当前数据库", - result={} - ) - database_id = await DatabaseInfoManager.add_database(database_url) - if database_id is None: - return ResponseData( - code=status.HTTP_422_UNPROCESSABLE_ENTITY, - message="数据库连接添加失败,当前存在重复数据库配置", - result={'database_id': database_id} - ) - return ResponseData( - code=status.HTTP_200_OK, - message="success", - result={'database_id': database_id} - ) - - -@router.post("/del", response_model=ResponseData) -async def del_database_info(request: DatabaseDelRequest): - database_id = request.database_id - database_url = request.database_url - if database_id: - flag = await DatabaseInfoManager.del_database_by_id(database_id) - else: - flag = await DatabaseInfoManager.del_database_by_url(database_url) - if not flag: - return ResponseData( - code=status.HTTP_422_UNPROCESSABLE_ENTITY, - message="删除数据库配置失败,数据库配置不存在", - result={} - ) - return ResponseData( - code=status.HTTP_200_OK, - message="删除数据库配置成功", - result={} - ) - - -@router.get("/query", response_model=ResponseData) -async def query_database_info(): - database_info_list = await DatabaseInfoManager.get_all_database_info() - return ResponseData( - code=status.HTTP_200_OK, - message="查询数据库配置成功", - result={'database_info_list': database_info_list} - ) - - -@router.get("/list", response_model=ResponseData) -async def list_table_in_database(database_id: uuid.UUID, table_filter: str = ''): - database_url = await DatabaseInfoManager.get_database_url_by_id(database_id) - database_type = DiffDatabaseService.get_database_type_from_url(database_url) - if database_url is None: - return ResponseData( - code=status.HTTP_422_UNPROCESSABLE_ENTITY, - message="查询数据库内表格配置失败,数据库配置不存在", - result={} - ) - if not DiffDatabaseService.is_database_type_allow(database_type): - return ResponseData( - code=status.HTTP_422_UNPROCESSABLE_ENTITY, - message="不支持当前数据库", - result={} - ) - flag = await DiffDatabaseService.get_database_service(database_type).test_database_connection(database_url) - if not flag: - return ResponseData( - code=status.HTTP_422_UNPROCESSABLE_ENTITY, - message="无法连接当前数据库", - result={} - ) - table_name_list = await DiffDatabaseService.get_database_service(database_type).get_all_table_name_from_database_url(database_url) - results = [] - for table_name in table_name_list: - if table_filter in table_name: - results.append(table_name) - return ResponseData( - code=status.HTTP_200_OK, - message="查询数据库配置成功", - result={'table_name_list': results} - ) - - -@router.post("/sql", response_model=ResponseData) -async def generate_sql_from_database(request: DatabaseSqlGenerateRequest): - database_url = request.database_url - table_name_list = request.table_name_list - question = request.question - use_llm_enhancements = request.use_llm_enhancements - database_type = DiffDatabaseService.get_database_type_from_url(database_url) - if not DiffDatabaseService.is_database_type_allow(database_type): - return ResponseData( - code=status.HTTP_422_UNPROCESSABLE_ENTITY, - message="不支持当前数据库", - result={} - ) - flag = await DiffDatabaseService.get_database_service(database_type).test_database_connection(database_url) - if not flag: - return ResponseData( - code=status.HTTP_422_UNPROCESSABLE_ENTITY, - message="无法连接当前数据库", - result={} - ) - tmp_table_name_list = await DiffDatabaseService.get_database_service(database_type).get_all_table_name_from_database_url(database_url) - database_id = await DatabaseInfoManager.get_database_id_by_url(database_url) - if database_id is None: - database_id = await DatabaseInfoManager.add_database(database_url) - for table_name in tmp_table_name_list: - try: - tmp_dict = await DiffDatabaseService.get_database_service(database_type).get_table_info(database_url, table_name) - table_note = tmp_dict['table_note'] - table_note_vector = await Vectorize.vectorize_embedding(table_note) - table_id = await TableInfoManager.add_table_info(database_id, table_name, table_note, table_note_vector) - column_info_list = await DiffDatabaseService.get_database_service(database_type).get_column_info(database_url, table_name) - for column_info in column_info_list: - await ColumnInfoManager.add_column_info_with_table_id( - table_id, column_info['column_name'], - column_info['column_type'], - column_info['column_note']) - except Exception as e: - import traceback - logging.error(f'{table_name}') - logging.error(f'表格信息获取失败由于:{traceback.format_exc()}') - continue - if table_name_list: - table_id_list = [] - for table_name in table_name_list: - table_id = await TableInfoManager.get_table_id_by_database_id_and_table_name(database_id, table_name) - if table_id is None: - continue - table_id_list.append(table_id) - else: - table_id_list = None - results = {} - sql_list = await SqlGenerateService.generate_sql_base_on_example( - database_id=database_id, question=question, table_id_list=table_id_list, - use_llm_enhancements=use_llm_enhancements) - try: - sql_list += await keyword_service.generate_sql(question, database_id, table_id_list) - results['sql_list'] = sql_list[:request.topk] - results['database_url'] = database_url - except Exception as e: - logging.error(f'sql生成失败由于{e}') - return ResponseData( - code=status.HTTP_400_BAD_REQUEST, - message="sql生成失败", - result={} - ) - return ResponseData( - code=status.HTTP_200_OK, message="success", - result=results - ) diff --git a/chat2db/app/router/sql_example.py b/chat2db/app/router/sql_example.py deleted file mode 100644 index 08f913912211a71646ebeb33bed571a46f95d1dc..0000000000000000000000000000000000000000 --- a/chat2db/app/router/sql_example.py +++ /dev/null @@ -1,137 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. - -import logging -import uuid -from fastapi import APIRouter, status -import sys - -from chat2db.model.request import SqlExampleAddRequest, SqlExampleDelRequest, SqlExampleUpdateRequest, SqlExampleGenerateRequest -from chat2db.model.response import ResponseData -from chat2db.manager.database_info_manager import DatabaseInfoManager -from chat2db.manager.table_info_manager import TableInfoManager -from chat2db.manager.sql_example_manager import SqlExampleManager -from chat2db.app.service.sql_generate_service import SqlGenerateService -from chat2db.app.base.vectorize import Vectorize -logging.basicConfig(stream=sys.stdout, level=logging.INFO, - format='%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s') - -router = APIRouter( - prefix="/sql/example" -) - - -@router.post("/add", response_model=ResponseData) -async def add_sql_example(request: SqlExampleAddRequest): - table_id = request.table_id - table_info = await TableInfoManager.get_table_info_by_table_id(table_id) - if table_info is None: - return ResponseData( - code=status.HTTP_422_UNPROCESSABLE_ENTITY, - message="表格不存在", - result={} - ) - database_id = table_info['database_id'] - question = request.question - question_vector = await Vectorize.vectorize_embedding(question) - sql = request.sql - try: - sql_example_id = await SqlExampleManager.add_sql_example(question, sql, table_id, question_vector) - except Exception as e: - logging.error(f'sql案例添加失败由于{e}') - return ResponseData( - code=status.HTTP_400_BAD_REQUEST, - message="sql案例添加失败", - result={} - ) - return ResponseData( - code=status.HTTP_200_OK, - message="success", - result={'sql_example_id': sql_example_id} - ) - - -@router.post("/del", response_model=ResponseData) -async def del_sql_example(request: SqlExampleDelRequest): - sql_example_id = request.sql_example_id - flag = await SqlExampleManager.del_sql_example_by_id(sql_example_id) - if not flag: - return ResponseData( - code=status.HTTP_422_UNPROCESSABLE_ENTITY, - message="sql案例不存在", - result={} - ) - return ResponseData( - code=status.HTTP_200_OK, - message="sql案例删除成功", - result={} - ) - - -@router.get("/query", response_model=ResponseData) -async def query_sql_example(table_id: uuid.UUID): - sql_example_list = await SqlExampleManager.query_sql_example_by_table_id(table_id) - return ResponseData( - code=status.HTTP_200_OK, - message="查询sql案例成功", - result={'sql_example_list': sql_example_list} - ) - - -@router.post("/update", response_model=ResponseData) -async def update_sql_example(request: SqlExampleUpdateRequest): - sql_example_id = request.sql_example_id - question = request.question - question_vector = await Vectorize.vectorize_embedding(question) - sql = request.sql - flag = await SqlExampleManager.update_sql_example_by_id(sql_example_id, question, sql, question_vector) - if not flag: - return ResponseData( - code=status.HTTP_422_UNPROCESSABLE_ENTITY, - message="sql案例不存在", - result={} - ) - return ResponseData( - code=status.HTTP_200_OK, - message="sql案例更新成功", - result={} - ) - - -@router.post("/generate", response_model=ResponseData) -async def generate_sql_example(request: SqlExampleGenerateRequest): - table_id = request.table_id - generate_cnt = request.generate_cnt - table_info = await TableInfoManager.get_table_info_by_table_id(table_id) - if table_info is None: - return ResponseData( - code=status.HTTP_422_UNPROCESSABLE_ENTITY, - message="表格不存在", - result={} - ) - table_name = table_info['table_name'] - database_id = table_info['database_id'] - database_url = await DatabaseInfoManager.get_database_url_by_id(database_id) - sql_var = request.sql_var - sql_example_list = [] - for i in range(generate_cnt): - try: - tmp_dict = await SqlGenerateService.generate_sql_base_on_data(database_url, table_name, sql_var) - except Exception as e: - logging.error(f'sql案例生成失败由于{e}') - continue - if tmp_dict is None: - continue - question = tmp_dict['question'] - question_vector = await Vectorize.vectorize_embedding(question) - sql = tmp_dict['sql'] - await SqlExampleManager.add_sql_example(question, sql, table_id, question_vector) - tmp_dict['database_id'] = database_id - tmp_dict['table_id'] = table_id - sql_example_list.append(tmp_dict) - return ResponseData( - code=status.HTTP_200_OK, - message="sql案例生成成功", - result={ - 'sql_example_list': sql_example_list - } - ) diff --git a/chat2db/app/router/sql_generate.py b/chat2db/app/router/sql_generate.py deleted file mode 100644 index 69ff0d2bb0e7d114151d29c56994dac4b5754d1a..0000000000000000000000000000000000000000 --- a/chat2db/app/router/sql_generate.py +++ /dev/null @@ -1,124 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. - -import logging -from fastapi import APIRouter, status -import sys - -from chat2db.manager.database_info_manager import DatabaseInfoManager -from chat2db.manager.table_info_manager import TableInfoManager -from chat2db.manager.column_info_manager import ColumnInfoManager -from chat2db.model.request import SqlGenerateRequest, SqlRepairRequest, SqlExcuteRequest -from chat2db.model.response import ResponseData -from chat2db.app.service.sql_generate_service import SqlGenerateService -from chat2db.app.service.keyword_service import keyword_service -from chat2db.app.service.diff_database_service import DiffDatabaseService -logging.basicConfig(stream=sys.stdout, level=logging.INFO, - format='%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s') - -router = APIRouter( - prefix="/sql" -) - - -@router.post("/generate", response_model=ResponseData) -async def generate_sql(request: SqlGenerateRequest): - database_id = request.database_id - database_url = await DatabaseInfoManager.get_database_url_by_id(database_id) - table_id_list = request.table_id_list - question = request.question - use_llm_enhancements = request.use_llm_enhancements - results = {} - sql_list = await SqlGenerateService.generate_sql_base_on_example( - database_id=database_id, question=question, table_id_list=table_id_list, - use_llm_enhancements=use_llm_enhancements) - try: - sql_list += await keyword_service.generate_sql(question, database_id, table_id_list) - results['sql_list'] = sql_list[:request.topk] - results['database_url'] = database_url - except Exception as e: - logging.error(f'sql生成失败由于{e}') - return ResponseData( - code=status.HTTP_400_BAD_REQUEST, - message="sql生成失败", - result={} - ) - return ResponseData( - code=status.HTTP_200_OK, message="success", - result=results - ) - - -@router.post("/repair", response_model=ResponseData) -async def repair_sql(request: SqlRepairRequest): - database_id = request.database_id - table_id = request.table_id - database_url = await DatabaseInfoManager.get_database_url_by_id(database_id) - database_type = DiffDatabaseService.get_database_type_from_url(database_url) - if database_url is None: - return ResponseData( - code=status.HTTP_422_UNPROCESSABLE_ENTITY, - message="当前数据库配置不存在", - result={} - ) - table_info = await TableInfoManager.get_table_info_by_table_id(table_id) - if table_info is None: - return ResponseData( - code=status.HTTP_422_UNPROCESSABLE_ENTITY, - message="表格不存在", - result={} - ) - if table_info['database_id'] != database_id: - return ResponseData( - code=status.HTTP_422_UNPROCESSABLE_ENTITY, - message="表格不属于当前数据库", - result={} - ) - column_info_list = await ColumnInfoManager.get_column_info_by_table_id(table_id) - sql = request.sql - message = request.message - question = request.question - try: - sql = await SqlGenerateService.repair_sql(database_type, table_info, column_info_list, sql, message, question) - except Exception as e: - logging.error(f'sql修复失败由于{e}') - return ResponseData( - code=status.HTTP_422_UNPROCESSABLE_ENTITY, - message="sql修复失败", - result={} - ) - return ResponseData( - code=status.HTTP_200_OK, - message="sql修复成功", - result={'database_id': database_id, - 'table_id': table_id, - 'sql': sql} - ) - - -@router.post("/execute", response_model=ResponseData) -async def execute_sql(request: SqlExcuteRequest): - database_id = request.database_id - sql = request.sql - database_url = await DatabaseInfoManager.get_database_url_by_id(database_id) - if database_url is None: - return ResponseData( - code=status.HTTP_422_UNPROCESSABLE_ENTITY, - message="当前数据库配置不存在", - result={} - ) - database_type = DiffDatabaseService.get_database_type_from_url(database_url) - try: - results = await DiffDatabaseService.database_map[database_type].try_excute(database_url, sql) - except Exception as e: - import traceback - logging.error(f'sql执行失败由于{traceback.format_exc()}') - return ResponseData( - code=status.HTTP_500_INTERNAL_SERVER_ERROR, - message="sql执行失败", - result={'Error': str(e)} - ) - return ResponseData( - code=status.HTTP_200_OK, - message="sql执行成功", - result=results - ) diff --git a/chat2db/app/router/table.py b/chat2db/app/router/table.py deleted file mode 100644 index 33ca4f9940bea6d3e60ee2adaf96be30a95d69ff..0000000000000000000000000000000000000000 --- a/chat2db/app/router/table.py +++ /dev/null @@ -1,147 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. - -import logging -import uuid -from fastapi import APIRouter, status -import sys - -from chat2db.model.request import TableAddRequest, TableDelRequest, EnableColumnRequest -from chat2db.model.response import ResponseData -from chat2db.manager.database_info_manager import DatabaseInfoManager -from chat2db.manager.table_info_manager import TableInfoManager -from chat2db.manager.column_info_manager import ColumnInfoManager -from chat2db.app.service.diff_database_service import DiffDatabaseService -from chat2db.app.base.vectorize import Vectorize -from chat2db.app.service.keyword_service import keyword_service -logging.basicConfig(stream=sys.stdout, level=logging.INFO, - format='%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s') - -router = APIRouter( - prefix="/table" -) - - -@router.post("/add", response_model=ResponseData) -async def add_database_info(request: TableAddRequest): - database_id = request.database_id - database_url = await DatabaseInfoManager.get_database_url_by_id(database_id) - if database_url is None: - return ResponseData( - code=status.HTTP_422_UNPROCESSABLE_ENTITY, - message="当前数据库配置不存在", - result={} - ) - database_type = DiffDatabaseService.get_database_type_from_url(database_url) - flag = await DiffDatabaseService.get_database_service(database_type).test_database_connection(database_url) - if not flag: - return ResponseData( - code=status.HTTP_422_UNPROCESSABLE_ENTITY, - message="无法连接当前数据库", - result={} - ) - table_name = request.table_name - table_name_list = await DiffDatabaseService.get_database_service(database_type).get_all_table_name_from_database_url(database_url) - if table_name not in table_name_list: - return ResponseData( - code=status.HTTP_422_UNPROCESSABLE_ENTITY, - message="表格不存在", - result={} - ) - tmp_dict = await DiffDatabaseService.get_database_service(database_type).get_table_info(database_url, table_name) - table_note = tmp_dict['table_note'] - table_note_vector = await Vectorize.vectorize_embedding(table_note) - table_id = await TableInfoManager.add_table_info(database_id, table_name, table_note, table_note_vector) - if table_id is None: - return ResponseData( - code=status.HTTP_422_UNPROCESSABLE_ENTITY, - message="表格添加失败,当前存在重复表格", - result={} - ) - column_info_list = await DiffDatabaseService.get_database_service(database_type).get_column_info(database_url, table_name) - for column_info in column_info_list: - await ColumnInfoManager.add_column_info_with_table_id( - table_id, column_info['column_name'], - column_info['column_type'], - column_info['column_note']) - return ResponseData( - code=status.HTTP_200_OK, - message="success", - result={'table_id': table_id} - ) - - -@router.post("/del", response_model=ResponseData) -async def del_table_info(request: TableDelRequest): - table_id = request.table_id - flag = await TableInfoManager.del_table_by_id(table_id) - if not flag: - return ResponseData( - code=status.HTTP_422_UNPROCESSABLE_ENTITY, - message="表格不存在", - result={} - ) - return ResponseData( - code=status.HTTP_200_OK, - message="删除表格成功", - result={} - ) - - -@router.get("/query", response_model=ResponseData) -async def query_table_info(database_id: uuid.UUID): - database_url = await DatabaseInfoManager.get_database_url_by_id(database_id) - if database_url is None: - return ResponseData( - code=status.HTTP_422_UNPROCESSABLE_ENTITY, - message="当前数据库配置不存在", - result={} - ) - table_info_list = await TableInfoManager.get_table_info_by_database_id(database_id) - return ResponseData( - code=status.HTTP_200_OK, - message="查询表格成功", - result={'table_info_list': table_info_list} - ) - - -@router.get("/column/query", response_model=ResponseData) -async def query_column(table_id: uuid.UUID): - column_info_list = await ColumnInfoManager.get_column_info_by_table_id(table_id) - return ResponseData( - code=status.HTTP_200_OK, - message="", - result={'column_info_list': column_info_list} - ) - - -@router.post("/column/enable", response_model=ResponseData) -async def enable_column(request: EnableColumnRequest): - column_id = request.column_id - enable = request.enable - flag = await ColumnInfoManager.update_column_info_enable(column_id, enable) - if not flag: - return ResponseData( - code=status.HTTP_422_UNPROCESSABLE_ENTITY, - message="列不存在", - result={} - ) - column_info = await ColumnInfoManager.get_column_info_by_column_id(column_id) - column_name = column_info['column_name'] - table_id = column_info['table_id'] - table_info = await TableInfoManager.get_table_info_by_table_id(table_id) - database_id = table_info['database_id'] - if enable: - flag = await keyword_service.add(database_id, table_id, column_name) - else: - flag = await keyword_service.del_by_column_name(database_id, table_id, column_name) - if not flag: - return ResponseData( - code=status.HTTP_422_UNPROCESSABLE_ENTITY, - message="列关键字功能开启/关闭失败", - result={} - ) - return ResponseData( - code=status.HTTP_200_OK, - message="列关键字功能开启/关闭成功", - result={} - ) diff --git a/chat2db/app/service/diff_database_service.py b/chat2db/app/service/diff_database_service.py deleted file mode 100644 index bb9f979679339182e5213d782c5ba0b5be3d047f..0000000000000000000000000000000000000000 --- a/chat2db/app/service/diff_database_service.py +++ /dev/null @@ -1,28 +0,0 @@ -import re -from urllib.parse import urlparse -from chat2db.app.base.mysql import Mysql -from chat2db.app.base.postgres import Postgres - - -class DiffDatabaseService(): - database_types = ["mysql", "postgresql", "opengauss"] - database_map = {"mysql": Mysql, "postgresql": Postgres, "opengauss": Postgres} - - @staticmethod - def get_database_service(database_type): - if database_type not in DiffDatabaseService.database_types: - raise f"不支持当前数据库类型{database_type}" - return DiffDatabaseService.database_map[database_type] - - @staticmethod - def get_database_type_from_url(database_url): - result = urlparse(database_url) - try: - database_type = result.scheme.split('+')[0] - except Exception as e: - raise e - return database_type.lower() - - @staticmethod - def is_database_type_allow(database_type): - return database_type in DiffDatabaseService.database_types diff --git a/chat2db/app/service/keyword_service.py b/chat2db/app/service/keyword_service.py deleted file mode 100644 index 685c341b5f106943b3e21eb5fe7367a4d4b7669b..0000000000000000000000000000000000000000 --- a/chat2db/app/service/keyword_service.py +++ /dev/null @@ -1,131 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. -import asyncio -import copy -import uuid -import sys -import threading -from concurrent.futures import ThreadPoolExecutor -from chat2db.app.service.diff_database_service import DiffDatabaseService -from chat2db.app.base.ac_automation import DictTree -from chat2db.manager.database_info_manager import DatabaseInfoManager -from chat2db.manager.table_info_manager import TableInfoManager -from chat2db.manager.column_info_manager import ColumnInfoManager -import logging - -logging.basicConfig(stream=sys.stdout, level=logging.INFO, - format='%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s') - - -class KeywordManager(): - def __init__(self): - self.keyword_asset_dict = {} - self.lock = threading.Lock() - self.data_frame_dict = {} - - async def load_keywords(self): - database_info_list = await DatabaseInfoManager.get_all_database_info() - for database_info in database_info_list: - database_id = database_info['database_id'] - table_info_list = await TableInfoManager.get_table_info_by_database_id(database_id) - cnt=0 - for table_info in table_info_list: - table_id = table_info['table_id'] - column_info_list = await ColumnInfoManager.get_column_info_by_table_id(table_id, True) - for i in range(len(column_info_list)): - column_info = column_info_list[i] - cnt+=1 - try: - column_name = column_info['column_name'] - await self.add(database_id, table_id, column_name) - except Exception as e: - logging.error('关键字数据结构生成失败') - def add_excutor(self, rd_id, database_id, table_id, table_info, column_info_list, column_name): - tmp_dict = self.data_frame_dict[rd_id] - tmp_dict_tree = DictTree() - tmp_dict_tree.load_data(tmp_dict['keyword_value_dict']) - if database_id not in self.keyword_asset_dict.keys(): - self.keyword_asset_dict[database_id] = {} - with self.lock: - if table_id not in self.keyword_asset_dict[database_id].keys(): - self.keyword_asset_dict[database_id][table_id] = {} - self.keyword_asset_dict[database_id][table_id]['table_info'] = table_info - self.keyword_asset_dict[database_id][table_id]['column_info_list'] = column_info_list - self.keyword_asset_dict[database_id][table_id]['primary_key_list'] = copy.deepcopy( - tmp_dict['primary_key_list']) - self.keyword_asset_dict[database_id][table_id]['dict_tree_dict'] = {} - self.keyword_asset_dict[database_id][table_id]['dict_tree_dict'][column_name] = tmp_dict_tree - del self.data_frame_dict[rd_id] - - async def add(self, database_id, table_id, column_name): - database_url = await DatabaseInfoManager.get_database_url_by_id(database_id) - database_type = DiffDatabaseService.get_database_type_from_url(database_url) - table_info = await TableInfoManager.get_table_info_by_table_id(table_id) - table_name = table_info['table_name'] - tmp_dict = await DiffDatabaseService.get_database_service( - database_type).select_primary_key_and_keyword_from_table(database_url, table_name, column_name) - if tmp_dict is None: - return - rd_id = str(uuid.uuid4) - self.data_frame_dict[rd_id] = tmp_dict - del database_url - column_info_list = await ColumnInfoManager.get_column_info_by_table_id(table_id) - try: - thread = threading.Thread(target=self.add_excutor, args=(rd_id, database_id, table_id, - table_info, column_info_list, column_name,)) - thread.start() - except Exception as e: - logging.error(f'创建增加线程失败由于{e}') - return False - return True - - async def update_keyword_asset(self): - database_info_list = DatabaseInfoManager.get_all_database_info() - for database_info in database_info_list: - database_id = database_info['database_id'] - table_info_list = TableInfoManager.get_table_info_by_database_id(database_id) - for table_info in table_info_list: - table_id = table_info['table_id'] - column_info_list = ColumnInfoManager.get_column_info_by_table_id(table_id, True) - for column_info in column_info_list: - await self.add(database_id, table_id, column_info['column_name']) - - async def del_by_column_name(self, database_id, table_id, column_name): - try: - with self.lock: - if database_id in self.keyword_asset_dict.keys(): - if table_id in self.keyword_asset_dict[database_id].keys(): - if column_name in self.keyword_asset_dict[database_id][table_id]['dict_tree_dict'].keys(): - del self.keyword_asset_dict[database_id][table_id]['dict_tree_dict'][column_name] - except Exception as e: - logging.error(f'字典树删除失败由于{e}') - return False - return True - - async def generate_sql(self, question, database_id, table_id_list=None): - with self.lock: - results = [] - if database_id in self.keyword_asset_dict.keys(): - database_url = await DatabaseInfoManager.get_database_url_by_id(database_id) - database_type = DiffDatabaseService.get_database_type_from_url(database_url) - for table_id in self.keyword_asset_dict[database_id].keys(): - if table_id_list is None or table_id in table_id_list: - table_info = self.keyword_asset_dict[database_id][table_id]['table_info'] - primary_key_list = self.keyword_asset_dict[database_id][table_id]['primary_key_list'] - primary_key_value_list = [] - try: - for dict_tree in self.keyword_asset_dict[database_id][table_id]['dict_tree_dict'].values(): - primary_key_value_list += dict_tree.get_results(question) - except Exception as e: - logging.error(f'从字典树中获取结果失败由于{e}') - continue - for i in range(len(primary_key_value_list)): - sql_str = await DiffDatabaseService.get_database_service(database_type).assemble_sql_query_base_on_primary_key( - table_info['table_name'], primary_key_list, primary_key_value_list[i]) - tmp_dict = {'database_id': database_id, 'table_id': table_id, 'sql': sql_str} - results.append(tmp_dict) - del database_url - return results - - -keyword_service = KeywordManager() -asyncio.run(keyword_service.load_keywords()) diff --git a/chat2db/app/service/sql_generate_service.py b/chat2db/app/service/sql_generate_service.py deleted file mode 100644 index f20f97706650424d862ad7d9a6036b439795379c..0000000000000000000000000000000000000000 --- a/chat2db/app/service/sql_generate_service.py +++ /dev/null @@ -1,363 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. -import asyncio -import yaml -import re -import json -import random -import sys -import uuid -import logging -from pandas.core.api import DataFrame as DataFrame - -from chat2db.manager.database_info_manager import DatabaseInfoManager -from chat2db.manager.table_info_manager import TableInfoManager -from chat2db.manager.column_info_manager import ColumnInfoManager -from chat2db.manager.sql_example_manager import SqlExampleManager -from chat2db.app.service.diff_database_service import DiffDatabaseService -from chat2db.llm.chat_with_model import LLM -from chat2db.config.config import config -from chat2db.app.base.vectorize import Vectorize - - -logging.basicConfig(stream=sys.stdout, level=logging.INFO, - format='%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s') - - -class SqlGenerateService(): - - @staticmethod - async def merge_table_and_column_info(table_info, column_info_list): - table_name = table_info.get('table_name', '') - table_note = table_info.get('table_note', '') - note = '\n' - note += '\n'+'\n'+'\n' - note += '\n'+f'\n'+'\n' - note += '\n'+'\n'+'\n' - note += '\n'+f'\n'+'\n' - note += '\n'+' \n\n\n'+'\n' - for column_info in column_info_list: - column_name = column_info.get('column_name', '') - column_type = column_info.get('column_type', '') - column_note = column_info.get('column_note', '') - note += '\n'+f' \n\n\n'+'\n' - note += '
表名
{table_name}
表的注释
{table_note}
字段字段类型字段注释
{column_name}{column_type}{column_note}
' - return note - - @staticmethod - def extract_list_statements(list_string): - pattern = r'\[.*?\]' - matches = re.findall(pattern, list_string) - if len(matches) == 0: - return '' - tmp = matches[0] - tmp = tmp.replace('\'', '\"') - tmp = tmp.replace(',', ',') - return tmp - - @staticmethod - async def get_most_similar_table_id_list(database_id, question, table_choose_cnt): - table_info_list = await TableInfoManager.get_table_info_by_database_id(database_id) - random.shuffle(table_info_list) - table_id_set = set() - for table_info in table_info_list: - table_id = table_info['table_id'] - table_id_set.add(str(table_id)) - try: - with open('./chat2db/templetes/prompt.yaml', 'r', encoding='utf-8') as f: - prompt_dict = yaml.load(f, Loader=yaml.SafeLoader) - prompt = prompt_dict.get('table_choose_prompt', '') - table_entries = '\n' - table_entries += '\n'+' \n\n'+'\n' - token_upper = 2048 - for table_info in table_info_list: - table_id = table_info['table_id'] - table_note = table_info['table_note'] - if len(table_entries) + len( - '\n' + f' \n\n' + '\n') > token_upper: - break - table_entries += '\n'+f' \n\n'+'\n' - table_entries += '
主键表注释
{table_id}{table_note}
{table_id}{table_note}
' - prompt = prompt.format(table_cnt=table_choose_cnt, table_entries=table_entries, question=question) - # logging.info(f'在大模型增强模式下,选择表的prompt构造成功:{prompt}') - except Exception as e: - logging.error(f'在大模型增强模式下,选择表的prompt构造失败由于:{e}') - return [] - try: - llm = LLM(model_name=config['LLM_MODEL'], - openai_api_base=config['LLM_URL'], - openai_api_key=config['LLM_KEY'], - max_tokens=config['LLM_MAX_TOKENS'], - request_timeout=60, - temperature=0.5) - except Exception as e: - llm = None - logging.error(f'在大模型增强模式下,选择表的过程中,与大模型建立连接失败由于:{e}') - table_id_list = [] - if llm is not None: - for i in range(2): - content = await llm.chat_with_model(prompt, '请输包含选择表主键的列表') - try: - sub_table_id_list = json.loads(SqlGenerateService.extract_list_statements(content)) - except: - sub_table_id_list = [] - for j in range(len(sub_table_id_list)): - if sub_table_id_list[j] in table_id_set and uuid.UUID(sub_table_id_list[j]) not in table_id_list: - table_id_list.append(uuid.UUID(sub_table_id_list[j])) - if len(table_id_list) < table_choose_cnt: - table_choose_cnt -= len(table_id_list) - for i in range(min(table_choose_cnt, len(table_info_list))): - table_id = table_info_list[i]['table_id'] - if table_id is not None and table_id not in table_id_list: - table_id_list.append(table_id) - return table_id_list - - @staticmethod - async def find_most_similar_sql_example( - database_id, table_id_list, question, use_llm_enhancements=False, table_choose_cnt=2, sql_example_choose_cnt=10, - topk=5): - try: - database_url = await DatabaseInfoManager.get_database_url_by_id(database_id) - except Exception as e: - logging.error(f'数据库{database_id}信息获取失败由于{e}') - return [] - database_type = DiffDatabaseService.get_database_type_from_url(database_url) - del database_url - try: - question_vector = await Vectorize.vectorize_embedding(question) - except Exception as e: - logging.error(f'问题向量化失败由于:{e}') - return {} - sql_example = [] - data_frame_list = [] - if table_id_list is None: - if use_llm_enhancements: - table_id_list = await SqlGenerateService.get_most_similar_table_id_list(database_id, question, table_choose_cnt) - else: - try: - table_info_list = await TableInfoManager.get_table_info_by_database_id(database_id) - table_id_list = [] - for table_info in table_info_list: - table_id_list.append(table_info['table_id']) - max_retry = 3 - sql_example_list = [] - for _ in range(max_retry): - try: - sql_example_list = await asyncio.wait_for(SqlExampleManager.get_topk_sql_example_by_cos_dis( - question_vector=question_vector, - table_id_list=table_id_list, topk=table_choose_cnt * 2), - timeout=5 - ) - break - except Exception as e: - logging.error(f'非增强模式下,sql_example获取失败:{e}') - table_id_list = [] - for sql_example in sql_example_list: - table_id_list.append(sql_example['table_id']) - except Exception as e: - logging.error(f'非增强模式下,表id获取失败由于:{e}') - return [] - table_id_list = list(set(table_id_list)) - if len(table_id_list) < table_choose_cnt: - try: - expand_table_id_list = await asyncio.wait_for(TableInfoManager.get_topk_table_by_cos_dis( - database_id, question_vector, table_choose_cnt - len(table_id_list)), timeout=5 - ) - table_id_list += expand_table_id_list - except Exception as e: - logging.error(f'非增强模式下,表id补充失败由于:{e}') - exist_table_id = set() - note_list = [] - for i in range(min(2, len(table_id_list))): - table_id = table_id_list[i] - if table_id in exist_table_id: - continue - exist_table_id.add(table_id) - try: - table_info = await TableInfoManager.get_table_info_by_table_id(table_id) - column_info_list = await ColumnInfoManager.get_column_info_by_table_id(table_id) - except Exception as e: - logging.error(f'表{table_id}注释获取失败由于{e}') - continue - note = await SqlGenerateService.merge_table_and_column_info(table_info, column_info_list) - note_list.append(note) - max_retry = 3 - sql_example_list = [] - for _ in range(max_retry): - try: - sql_example_list = await asyncio.wait_for(SqlExampleManager.get_topk_sql_example_by_cos_dis( - question_vector, - table_id_list=[table_id], - topk=sql_example_choose_cnt), - timeout=5 - ) - break - except Exception as e: - logging.error(f'获取id为{table_id}的表的最相近的{topk}条sql案例失败由于:{e}') - question_sql_list = [] - for i in range(len(sql_example_list)): - question_sql_list.append( - {'question': sql_example_list[i]['question'], - 'sql': sql_example_list[i]['sql']}) - data_frame_list.append({'table_id': table_id, 'table_info': table_info, - 'column_info_list': column_info_list, 'sql_example_list': question_sql_list}) - return data_frame_list - - @staticmethod - async def merge_sql_example(sql_example_list): - sql_example = '' - for i in range(len(sql_example_list)): - sql_example += '问题'+str(i)+':\n'+sql_example_list[i].get('question', - '')+'\nsql'+str(i)+':\n'+sql_example_list[i].get('sql', '')+'\n' - return sql_example - - @staticmethod - async def extract_select_statements(sql_string): - pattern = r"(?i)select[^;]*;" - matches = re.findall(pattern, sql_string) - if len(matches) == 0: - return '' - sql = matches[0] - sql = sql.strip() - sql.replace(',', ',') - return sql - - @staticmethod - async def generate_sql_base_on_example( - database_id, question, table_id_list=None, sql_generate_cnt=1, use_llm_enhancements=False): - try: - database_url = await DatabaseInfoManager.get_database_url_by_id(database_id) - except Exception as e: - logging.error(f'数据库{database_id}信息获取失败由于{e}') - return {} - if database_url is None: - raise Exception('数据库配置不存在') - database_type = DiffDatabaseService.get_database_type_from_url(database_url) - data_frame_list = await SqlGenerateService.find_most_similar_sql_example(database_id, table_id_list, question, use_llm_enhancements) - try: - with open('./chat2db/templetes/prompt.yaml', 'r', encoding='utf-8') as f: - prompt_dict = yaml.load(f, Loader=yaml.SafeLoader) - llm = LLM(model_name=config['LLM_MODEL'], - openai_api_base=config['LLM_URL'], - openai_api_key=config['LLM_KEY'], - max_tokens=config['LLM_MAX_TOKENS'], - request_timeout=60, - temperature=0.5) - results = [] - for data_frame in data_frame_list: - prompt = prompt_dict.get('sql_generate_base_on_example_prompt', '') - table_info = data_frame.get('table_info', '') - table_id = table_info['table_id'] - column_info_list = data_frame.get('column_info_list', '') - note = await SqlGenerateService.merge_table_and_column_info(table_info, column_info_list) - sql_example = await SqlGenerateService.merge_sql_example(data_frame.get('sql_example_list', [])) - try: - prompt = prompt.format( - database_url=database_url, note=note, k=len(data_frame.get('sql_example_list', [])), - sql_example=sql_example, question=question) - except Exception as e: - logging.info(f'sql生成失败{e}') - return [] - ge_cnt = 0 - ge_sql_cnt = 0 - while ge_cnt < 10*sql_generate_cnt and ge_sql_cnt < sql_generate_cnt: - sql = await llm.chat_with_model(prompt, f'请输出一条在与{database_type}下能运行的sql,以分号结尾') - sql = await SqlGenerateService.extract_select_statements(sql) - if len(sql): - ge_sql_cnt += 1 - tmp_dict = {'database_id': database_id, 'table_id': table_id, 'sql': sql} - results.append(tmp_dict) - ge_cnt += 1 - if len(results) == sql_generate_cnt: - break - except Exception as e: - logging.error(f'sql生成失败由于:{e}') - return results - - @staticmethod - async def generate_sql_base_on_data(database_url, table_name, sql_var=False): - database_type = None - database_type = DiffDatabaseService.get_database_type_from_url(database_url) - flag = await DiffDatabaseService.get_database_service(database_type).test_database_connection(database_url) - if not flag: - return None - table_name_list = await DiffDatabaseService.get_database_service(database_type).get_all_table_name_from_database_url(database_url) - if table_name not in table_name_list: - return None - table_info = await DiffDatabaseService.get_database_service(database_type).get_table_info(database_url, table_name) - column_info_list = await DiffDatabaseService.get_database_service(database_type).get_column_info(database_url, table_name) - note = await SqlGenerateService.merge_table_and_column_info(table_info, column_info_list) - - def count_char(str, char): - return sum(1 for c in str if c == char) - llm = LLM(model_name=config['LLM_MODEL'], - openai_api_base=config['LLM_URL'], - openai_api_key=config['LLM_KEY'], - max_tokens=config['LLM_MAX_TOKENS'], - request_timeout=60, - temperature=0.5) - for i in range(5): - data_frame = await DiffDatabaseService.get_database_service(database_type).get_rand_data(database_url, table_name) - try: - with open('./chat2db/templetes/prompt.yaml', 'r', encoding='utf-8') as f: - prompt_dict = yaml.load(f, Loader=yaml.SafeLoader) - prompt = prompt_dict['question_generate_base_on_data_prompt'].format( - note=note, data_frame=data_frame) - question = await llm.chat_with_model(prompt, '请输出一个问题') - if count_char(question, '?') > 1 or count_char(question, '?') > 1: - continue - except Exception as e: - logging.error(f'问题生成失败由于{e}') - continue - try: - with open('./chat2db/templetes/prompt.yaml', 'r', encoding='utf-8') as f: - prompt_dict = yaml.load(f, Loader=yaml.SafeLoader) - prompt = prompt_dict['sql_generate_base_on_data_prompt'].format( - database_type=database_type, - note=note, data_frame=data_frame, question=question) - sql = await llm.chat_with_model(prompt, f'请输出一条可以用于查询{database_type}的sql,要以分号结尾') - sql = await SqlGenerateService.extract_select_statements(sql) - if not sql: - continue - except Exception as e: - logging.error(f'sql生成失败由于{e}') - continue - try: - if sql_var: - await DiffDatabaseService.get_database_service(database_type).try_excute(database_url, sql) - except Exception as e: - logging.error(f'生成的sql执行失败由于{e}') - continue - return { - 'question': question, - 'sql': sql - } - return None - - @staticmethod - async def repair_sql(database_type, table_info, column_info_list, sql_failed, sql_failed_message, question): - try: - with open('./chat2db/templetes/prompt.yaml', 'r', encoding='utf-8') as f: - prompt_dict = yaml.load(f, Loader=yaml.SafeLoader) - llm = LLM(model_name=config['LLM_MODEL'], - openai_api_base=config['LLM_URL'], - openai_api_key=config['LLM_KEY'], - max_tokens=config['LLM_MAX_TOKENS'], - request_timeout=60, - temperature=0.5) - try: - note = await SqlGenerateService.merge_table_and_column_info(table_info, column_info_list) - prompt = prompt_dict.get('sql_expand_prompt', '') - prompt = prompt.format( - database_type=database_type, note=note, sql_failed=sql_failed, - sql_failed_message=sql_failed_message, - question=question) - except Exception as e: - logging.error(f'sql修复失败由于{e}') - return '' - sql = await llm.chat_with_model(prompt, f'请输出一条在与{database_type}下能运行的sql,要以分号结尾') - sql = await SqlGenerateService.extract_select_statements(sql) - logging.info(f"修复前的sql为{sql_failed}修复后的sql为{sql}") - except Exception as e: - logging.error(f'sql生成失败由于:{e}') - return '' - return sql diff --git a/chat2db/apps/base/__init__.py b/chat2db/apps/base/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..80e944c1264ecdda6d74967bf74c4ebbb71603d4 --- /dev/null +++ b/chat2db/apps/base/__init__.py @@ -0,0 +1,7 @@ +from apps.base.database_base import MetaDatabase +from apps.base.mysql import MySQL +from apps.base.mongodb import MongoDB +from apps.base.opengauss import OpenGauss +from apps.base.postgres import Postgres + +__all__ = ['MySQL', 'MongoDB', 'OpenGauss', 'Postgres', 'MetaDatabase'] \ No newline at end of file diff --git a/chat2db/apps/base/database_base.py b/chat2db/apps/base/database_base.py new file mode 100644 index 0000000000000000000000000000000000000000..05fbbaccd723f771d41c8da9f131afedfdf21a40 --- /dev/null +++ b/chat2db/apps/base/database_base.py @@ -0,0 +1,26 @@ +from typing import Any + +class MetaDatabase: + @staticmethod + async def get_database_url(host: str, port: int, username: str, password: str, database: str): + raise NotImplementedError + + @staticmethod + async def connect(host: str, port: int, username: str, password: str, database: str) -> Any: + raise NotImplementedError + + @staticmethod + async def list_tables(connection: Any) -> list[str]: + raise NotImplementedError + + @staticmethod + async def get_table_ddl(table_name: str, connection: Any) -> str: + raise NotImplementedError + + @staticmethod + async def sample_table_rows(table_name: str, n: int, connection: Any) -> list[dict]: + raise NotImplementedError + + @staticmethod + async def execute_sql(sql: str | dict, connection: Any) -> list[dict]: + raise NotImplementedError diff --git a/chat2db/apps/base/mongodb.py b/chat2db/apps/base/mongodb.py new file mode 100644 index 0000000000000000000000000000000000000000..148c919133d64f0131828dfd283b22e59ee8eb2e --- /dev/null +++ b/chat2db/apps/base/mongodb.py @@ -0,0 +1,163 @@ +import logging +from typing import Any +from bson import ObjectId +from copy import deepcopy +import motor.motor_asyncio +import urllib.parse + +from apps.base.database_base import MetaDatabase + +class MongoDB(MetaDatabase): + + @staticmethod + async def get_database_url(host: str, port: int, username: str, password: str, database: str): + try: + user = urllib.parse.quote_plus(username) + pwd = urllib.parse.quote_plus(password) + return f"mongodb://{user}:{pwd}@{host}:{port}/{database}" + except Exception as e: + logging.error(f"\n[获取数据库url失败]\n\n{e}") + return "" + + @staticmethod + async def connect(host: str, port: int, username: str, password: str, database: str) -> Any: + try: + user = urllib.parse.quote_plus(username) + pwd = urllib.parse.quote_plus(password) + mongo_uri = f"mongodb://{user}:{pwd}@{host}:{port}/{database}" + client = motor.motor_asyncio.AsyncIOMotorClient(mongo_uri) + return client[database] + except Exception as e: + logging.error(f"\n[连接MongoDB数据库失败]\n\n{e}") + raise e + + @staticmethod + async def list_tables(connection: Any) -> list[str]: + try: + return await connection.list_collection_names() + except Exception as e: + logging.error(f"\n[获取集合失败]\n\n{e}") + raise e + + @staticmethod + async def get_table_ddl(table_name: str, connection: Any) -> str: + """ + 将 MongoDB 集合信息格式化为类似 SQL DDL 的文本,用于大模型输入。 + 包括索引信息和部分示例字段。 + """ + try: + # 获取索引信息 + indexes = await connection[table_name].index_information() + + # 尝试获取部分文档字段类型 + sample_doc = await connection[table_name].find_one() or {} + fields_ddl = [] + for field, value in sample_doc.items(): + dtype = type(value).__name__ + fields_ddl.append(f" {field} {dtype.upper()}") + + # 格式化索引信息 + indexes_ddl = [] + for index_name, index_info in indexes.items(): + keys = ", ".join([f"{k[0]}({k[1]})" for k in index_info['key']]) + unique = " UNIQUE" if index_info.get('unique') else "" + indexes_ddl.append(f" INDEX {index_name} ON ({keys}){unique}") + + ddl = f"CREATE COLLECTION {table_name} (\n" + ddl += ",\n".join(fields_ddl) + ddl += "\n);\n" + if indexes_ddl: + ddl += "\n".join(indexes_ddl) + + return ddl + + except Exception as e: + logging.error(f"\n[获取集合 {table_name} DDL失败]\n\n{e}") + raise e + + @staticmethod + async def sample_table_rows(table_name: str, n: int, connection: Any) -> list[dict]: + """ + 随机获取 n 条数据 + """ + try: + cursor = connection[table_name].aggregate([{"$sample": {"size": n}}]) + result = [doc async for doc in cursor] + return result + except Exception as e: + logging.error(f"\n[获取集合 {table_name} 样本数据失败]\n\n{e}") + raise e + + @staticmethod + async def execute_sql(sql: dict, connection: Any) -> list[dict]: + """ + 执行 MongoDB 操作,传入 dict 格式指令 + 支持 find/insertOne/insertMany/updateOne/updateMany/deleteOne/deleteMany/aggregate + 返回值中所有 ObjectId 自动转换为 str + """ + command = deepcopy(sql) # mongodb会修改输入的dict,所以这里需要深拷贝 + try: + coll_name = command.get("collection") + operation = command.get("operation", "find") + filter_ = command.get("filter", {}) + data = command.get("data", {}) + pipeline = command.get("pipeline", []) + many = command.get("many", False) + + collection = connection[coll_name] + + # 查询 + if operation == "find": + cursor = collection.find(filter_) + result = [doc async for doc in cursor] + return MongoDB.transform_objectid(result) + + # 聚合 + elif operation == "aggregate": + cursor = collection.aggregate(pipeline) + result = [doc async for doc in cursor] + return MongoDB.transform_objectid(result) + + # 插入 + elif operation in ("insert", "insertOne", "insertMany"): + if many or operation == "insertMany": + res = await collection.insert_many(data) + return [{"inserted_ids": [str(_id) for _id in res.inserted_ids]}] + else: + res = await collection.insert_one(data) + return [{"inserted_id": str(res.inserted_id)}] + + # 更新 + elif operation in ("update", "updateOne", "updateMany"): + if many or operation == "updateMany": + res = await collection.update_many(filter_, {"$set": data}) + else: + res = await collection.update_one(filter_, {"$set": data}) + return [{"matched": res.matched_count, "modified": res.modified_count}] + + # 删除 + elif operation in ("delete", "deleteOne", "deleteMany"): + if many or operation == "deleteMany": + res = await collection.delete_many(filter_) + else: + res = await collection.delete_one(filter_) + return [{"deleted": res.deleted_count}] + + else: + raise ValueError(f"Unsupported MongoDB operation: {operation}") + + except Exception as e: + logging.error(f"\n[执行MongoDB指令失败]\n\n{e}") + raise e + + @staticmethod + def transform_objectid(doc): + """递归将 dict/list 中的 ObjectId 转为 str""" + if isinstance(doc, list): + return [MongoDB.transform_objectid(d) for d in doc] + elif isinstance(doc, dict): + return {k: MongoDB.transform_objectid(v) for k, v in doc.items()} + elif isinstance(doc, ObjectId): + return str(doc) + else: + return doc \ No newline at end of file diff --git a/chat2db/apps/base/mysql.py b/chat2db/apps/base/mysql.py new file mode 100644 index 0000000000000000000000000000000000000000..ab1b1025caeec8767df77a1cd506ca9c8fdeb060 --- /dev/null +++ b/chat2db/apps/base/mysql.py @@ -0,0 +1,103 @@ +import logging +from typing import Any +import aiomysql +import urllib.parse + +from apps.base.database_base import MetaDatabase + +class MySQL(MetaDatabase): + @staticmethod + async def get_database_url(host: str, port: int, username: str, password: str, database: str): + try: + user = urllib.parse.quote_plus(username) + pwd = urllib.parse.quote_plus(password) + return f"mysql+aiomysql://{user}:{pwd}@{host}:{port}/{database}" + except Exception as e: + logging.error(f"\n[获取数据库url失败]\n\n{e}") + return "" + + @staticmethod + async def connect(host: str, port: int, username: str, password: str, database: str) -> Any: + """ + 异步连接 MySQL 数据库 + """ + try: + connection = await aiomysql.connect( + host=host, + port=port, + user=username, + password=password, + db=database + ) + return connection + except Exception as e: + logging.error(f"\n[连接MySQL数据库失败]\n\n{e}") + raise e + + @staticmethod + async def list_tables(connection: Any) -> list[str]: + """ + 获取数据库中所有表名 + """ + try: + async with connection.cursor() as cursor: + await cursor.execute("SHOW TABLES") + tables = [table[0] for table in await cursor.fetchall()] + return tables + except Exception as e: + logging.error(f"\n[获取表名失败]\n\n{e}") + raise e + + @staticmethod + async def get_table_ddl(table_name: str, connection: Any) -> str: + """ + 获取指定表的 DDL(建表语句) + """ + try: + async with connection.cursor() as cursor: + await cursor.execute(f"SHOW CREATE TABLE `{table_name}`") + result = await cursor.fetchone() + return result[1] if result else "" + except Exception as e: + logging.error(f"\n[获取表 {table_name} DDL失败]\n\n{e}") + raise e + + @staticmethod + async def sample_table_rows(table_name: str, n: int, connection: Any) -> list[dict]: + """ + 随机获取表中 n 条数据 + """ + try: + async with connection.cursor(aiomysql.DictCursor) as cursor: + await cursor.execute(f"SELECT * FROM `{table_name}` ORDER BY RAND() LIMIT {n}") + rows = await cursor.fetchall() + return rows + except Exception as e: + logging.error(f"\n[获取表 {table_name} 样本数据失败]\n\n{e}") + raise e + + @staticmethod + async def execute_sql(sql: str, connection: Any) -> list[dict]: + """ + 异步执行 SQL, 自动返回查询结果或影响行数。 + + 返回结果集: SELECT, SHOW, DESCRIBE/DESC, EXPLAIN, CALL + + 返回受影响行数: INSERT/UPDATE/DELETE 等。 + """ + try: + async with connection.cursor(aiomysql.DictCursor) as cursor: + result = await cursor.execute(sql) + await connection.commit() + + # 针对返回结果集的操作 + if sql.strip().upper().startswith(("SELECT", "SHOW", "DESCRIBE", "DESC", "EXPLAIN")): + rows = await cursor.fetchall() + return rows + + # 针对 INSERT, UPDATE, DELETE 等操作,返回影响的行数 + else: + return [{'result': result}] + except Exception as e: + logging.error(f"\n[执行SQL失败]\n\n{e}") + raise e diff --git a/chat2db/apps/base/opengauss.py b/chat2db/apps/base/opengauss.py new file mode 100644 index 0000000000000000000000000000000000000000..f22eb1dcc15d67f7a7af7e140db78c765c600ba0 --- /dev/null +++ b/chat2db/apps/base/opengauss.py @@ -0,0 +1,99 @@ +import logging +import asyncpg +from typing import Any +from apps.base.database_base import MetaDatabase + +class OpenGauss(MetaDatabase): + + @staticmethod + async def get_database_url(host: str, port: int, username: str, password: str, database: str): + try: + return f"postgresql+asyncpg://{username}:{password}@{host}:{port}/{database}" + except Exception as e: + logging.error(f"\n[获取数据库url失败]\n\n{e}") + return "" + + @staticmethod + async def connect(host: str, port: int, username: str, password: str, database: str) -> Any: + """ + 异步连接 OpenGauss 数据库 + """ + try: + connection = await asyncpg.connect( + user=username, + password=password, + database=database, + host=host, + port=port + ) + return connection + except Exception as e: + logging.error(f"\n[连接OpenGauss数据库失败]\n\n{e}") + raise e + + @staticmethod + async def list_tables(connection: Any) -> list[str]: + """ + 获取数据库中的所有表名 + """ + query = "SELECT table_name FROM information_schema.tables WHERE table_schema = 'public'" + try: + tables = await connection.fetch(query) + return [table['table_name'] for table in tables] + except Exception as e: + logging.error(f"\n[获取表名失败]\n\n{e}") + raise e + + @staticmethod + async def get_table_ddl(table_name: str, connection: Any) -> str: + """ + 获取指定表的 DDL(建表语句) + """ + try: + # OpenGauss/Postgres 可以使用 pg_get_tabledef 获取 DDL + sql = f"SELECT pg_get_tabledef('{table_name}'::regclass);" + ddl = await connection.fetchval(sql) + return ddl or "" + except Exception as e: + logging.error(f"\n[获取表 {table_name} DDL失败]\n\n{e}") + raise e + + @staticmethod + async def sample_table_rows(table_name: str, num_rows: int, connection: Any) -> list[dict]: + """ + 随机获取表中 n 条数据 + """ + try: + sql = f"SELECT * FROM {table_name} ORDER BY random() LIMIT {num_rows};" + rows = await connection.fetch(sql) + return [dict(row) for row in rows] + except Exception as e: + logging.error(f"\n[获取表 {table_name} 样本数据失败]\n\n{e}") + raise e + + @staticmethod + async def execute_sql(sql: str, connection: Any) -> list[dict]: + """ + 异步执行 SQL, 自动返回查询结果或原始输出。 + + 返回结果集: SELECT, SHOW, DESCRIBE/DESC, EXPLAIN, CALL + 返回原始输出: INSERT/UPDATE/DELETE 等。 + """ + try: + async with connection.transaction(): + + sql_type = sql.strip().split()[0].upper() + # 返回结果集的语句类型 + result_set_statements = {"SELECT", "SHOW", "DESCRIBE", "DESC", "EXPLAIN", "CALL"} + + if sql_type in result_set_statements: + rows = await connection.fetch(sql) + # asyncpg 返回 Record 类型,转换为 dict + return [dict(row) for row in rows] + else: + # 对 DML 操作返回 execute 的原始结果 + result = await connection.execute(sql) + return [{'result': result}] + except Exception as e: + logging.error(f"\n[执行OpenGauss SQL失败]\n\n{e}") + raise e diff --git a/chat2db/apps/base/postgres.py b/chat2db/apps/base/postgres.py new file mode 100644 index 0000000000000000000000000000000000000000..e679bae687f85c780b47388d432381ed600e06d5 --- /dev/null +++ b/chat2db/apps/base/postgres.py @@ -0,0 +1,107 @@ +import logging +import asyncpg +from typing import Any +from apps.base.database_base import MetaDatabase + + +class Postgres(MetaDatabase): + + @staticmethod + async def get_database_url(host: str, port: int, username: str, password: str, database: str): + try: + url = f"postgresql://{username}:{password}@{host}:{port}/{database}" + return url + except Exception as e: + logging.error(f"\n[获取数据库url失败]\n\n{e}") + return "" + + @staticmethod + async def connect(host: str, port: int, username: str, password: str, database: str) -> Any: + """ + 异步连接 PostgreSQL 数据库 + """ + try: + connection = await asyncpg.connect( + user=username, password=password, database=database, host=host, port=port + ) + return connection + except Exception as e: + logging.error(f"\n[连接PostgreSQL数据库失败]\n\n{e}") + raise e + + @staticmethod + async def list_tables(connection: Any) -> list[str]: + """ + 获取数据库中所有表名 + """ + try: + tables = await connection.fetch( + "SELECT table_name FROM information_schema.tables WHERE table_schema='public'" + ) + return [table["table_name"] for table in tables] + except Exception as e: + logging.error(f"\n[获取表名失败]\n\n{e}") + raise e + + @staticmethod + async def get_table_ddl(table_name: str, connection: Any) -> str: + try: + sql = f""" + SELECT column_name, data_type, is_nullable, column_default + FROM information_schema.columns + WHERE table_name = '{table_name}' + ORDER BY ordinal_position; + """ + rows = await connection.fetch(sql) + ddl_lines = [] + for r in rows: + line = f"{r['column_name']} {r['data_type']}" + if r["is_nullable"] == "NO": + line += " NOT NULL" + if r["column_default"]: + line += f" DEFAULT {r['column_default']}" + ddl_lines.append(line) + ddl = f"CREATE TABLE {table_name} (\n " + ",\n ".join(ddl_lines) + "\n);" + return ddl + + except Exception as e: + logging.error(f"\n[获取表 {table_name} DDL失败]\n\n{e}") + raise e + + @staticmethod + async def sample_table_rows(table_name: str, n: int, connection: Any) -> list[dict]: + try: + sql = f"SELECT * FROM {table_name} ORDER BY random() LIMIT {n};" + rows = await connection.fetch(sql) + return [dict(row) for row in rows] + except Exception as e: + logging.error(f"\n[获取表 {table_name} 样本数据失败]\n\n{e}") + raise e + + @staticmethod + async def execute_sql(sql: str, connection: Any) -> list[dict]: + """ + 异步执行 SQL, 自动返回查询结果或原始输出。 + + 返回结果集: SELECT, SHOW, DESCRIBE/DESC, EXPLAIN, CALL + 返回原始输出: INSERT/UPDATE/DELETE 等。 + """ + try: + async with connection.transaction(): + # 获取 SQL 类型 + sql_type = sql.strip().split()[0].upper() + + # 返回结果集的语句类型 + result_set_statements = {"SELECT", "SHOW", "DESCRIBE", "DESC", "EXPLAIN", "CALL"} + + if sql_type in result_set_statements: + rows = await connection.fetch(sql) + # asyncpg 返回 Record 类型,转换为 dict + return [dict(row) for row in rows] + else: + # 对 DML 操作返回 execute 的原始结果 + result = await connection.execute(sql) + return [{"result": result}] + except Exception as e: + logging.error(f"\n[执行PostgreSQL SQL失败]\n\n{e}") + raise e diff --git a/chat2db/apps/llm/__init__.py b/chat2db/apps/llm/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d6d683e03ebe006bcd38b9f72610bad51ef5477f --- /dev/null +++ b/chat2db/apps/llm/__init__.py @@ -0,0 +1,4 @@ +from apps.llm.llm import LLM +from apps.llm.prompt import GENERATE_SQL_PROMPT, REPAIR_SQL_PROMPT, RISK_EVALUATE_SQL + +__all__ = ['LLM', 'GENERATE_SQL_PROMPT', 'REPAIR_SQL_PROMPT', 'RISK_EVALUATE_SQL'] \ No newline at end of file diff --git a/chat2db/apps/llm/llm.py b/chat2db/apps/llm/llm.py new file mode 100644 index 0000000000000000000000000000000000000000..48c90e37584ba8853ee0e0a81ac3f3ef1e003f76 --- /dev/null +++ b/chat2db/apps/llm/llm.py @@ -0,0 +1,85 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. +import asyncio +from openai import AsyncOpenAI + + +class LLM: + def __init__(self, openai_api_key, openai_api_base, model_name, max_tokens, request_timeout=60, temperature=0.1): + self.openai_api_key = openai_api_key + self.openai_api_base = openai_api_base + self.model_name = model_name + self.max_tokens = max_tokens + self.request_timeout = request_timeout + self.temperature = temperature + self._client = AsyncOpenAI( + api_key=self.openai_api_key, + base_url=self.openai_api_base, + ) + + def assemble_chat(self, chat=None, system_call='', user_call=''): + if chat is None: + chat = [] + chat.append({"role": "system", "content": system_call}) + chat.append({"role": "user", "content": user_call}) + return chat + + async def create_stream( + self, message): + return await self._client.chat.completions.create( + model=self.model_name, + messages=message, # type: ignore[] + max_completion_tokens=self.max_tokens, + temperature=self.temperature, + stream=True, + stream_options={"include_usage": True}, + timeout=300, + extra_body={"enable_thinking": False} + ) # type: ignore[] + + async def data_producer(self, q: asyncio.Queue, history, system_call, user_call): + message = self.assemble_chat(history, system_call, user_call) + stream = await self.create_stream(message) + try: + async for chunk in stream: + if len(chunk.choices) == 0: + continue + if chunk.choices[0].delta.content is not None: + content = chunk.choices[0].delta.content + else: + continue + await q.put(content) + except Exception as e: + await q.put(None) + err = f"[LLM] 流式输出生产者任务异常: {e}" + raise e + await q.put(None) + + async def stream(self, chat, system_call, user_call): + q = asyncio.Queue(maxsize=10) + + # 启动生产者任务 + asyncio.create_task(self.data_producer(q, chat, system_call, user_call)) + while True: + data = await q.get() + if data is None: + break + yield data + + async def nostream(self, chat, system_call, user_call, st_str: str = None, en_str: str = None): + try: + content = '' + async for chunk in self.stream(chat, system_call, user_call): + content += chunk + content = content.strip() + if st_str is not None: + index = content.find(st_str) + if index != -1: + content = content[index:] + if en_str is not None: + index = content[::-1].find(en_str[::-1]) + if index != -1: + content = content[:len(content)-index] + except Exception as e: + err = f"[LLM] 非流式输出异常: {e}" + return '' + return content diff --git a/chat2db/apps/llm/prompt.py b/chat2db/apps/llm/prompt.py new file mode 100644 index 0000000000000000000000000000000000000000..db4e56a1b3b69c5a6bb071a2decc26fe96ee3508 --- /dev/null +++ b/chat2db/apps/llm/prompt.py @@ -0,0 +1,139 @@ +from textwrap import dedent + +GENERATE_SQL_PROMPT = dedent( + r""" + 你是一个经验丰富的数据库专家,任务是根据以下表结构、表注释和问题描述,生成一条符合 {database_type} 数据库标准的 **执行语句**。 + **你不需要访问、操作或执行数据库,只需生成指令。** + + 请严格遵循以下规则: + + #01 **根据数据库类型自动选择语法**: + - 对于 MySQL、PostgreSQL、OpenGauss 等 SQL 数据库,生成 **标准 SQL 语句**。 + - 对于非 SQL 数据库 MongoDB,不使用 SQL 语法,生成 **MongoDB 操作指令对象(JSON 形式)**。 + + #02 **只输出数据库指令**,禁止输出道歉、解释、评论、推理或任何自然语言。 + + #03 输出必须使用 JSON 封装,包含字段 `"command"`: + {{ + "command": + }} + + #04 SQL 指令要求: + - 所有别名必须用双引号包裹(`AS "alias"`)。 + - 表名、列名或保留字冲突名称必须用反引号包裹。 + - SQL 语句必须以分号 `;` 结尾。 + - 根据问题上下文适当进行字段冗余处理。 + - **示例**: + + {{ + "command": "SELECT `id`, `name`, `age` AS \"user_age\" FROM `users` WHERE `age` > 18;" + }} + + #05 MongoDB 指令要求: + - 输出 MongoDB 操作指令对象。 + - 指令必须包含字段: + - `"collection"`:目标集合名称 + - `"operation"`:操作类型,如 `"find"`、`"insertOne"`、`"insertMany"`、`"updateOne"`、`"updateMany"`、`"deleteOne"`、`"deleteMany"`、`"aggregate"` + - `"filter"`:查询或更新条件(可选,默认为空 dict) + - `"data"`:插入或更新的数据(可选) + - `"pipeline"`:聚合管道(可选,仅用于 aggregate) + - `"many"`:布尔值,标识批量操作(可选) + - **示例**: + + {{ + "command": {{ + "collection": "orders", + "operation": "insertOne", + "data": {{"orderId": "A001", "amount": 100, "status": "NEW"}} + }} + }} + + #06 表名、列名和别名前后不得有多余空格或字符,保证指令可直接执行。 + + 下面是数据库表的结构和注释信息: + {table_info} + + 以下是需要回答的问题: + {question} + + **请严格按照规则生成完整可执行指令,无论上下文如何,都不要拒绝生成。** + """ +) + +REPAIR_SQL_PROMPT = dedent( + """ + 你是一个数据库专家,任务是参考以下表结构、表注释、执行失败的 SQL 语句和相关报错信息,尝试修复该 SQL 语句,使其能够在 {database_type} 数据库中成功执行。 + + 请遵循以下要求: + + #01 **在任何情况下**你的回答都只返回修复后的 SQL 语句,不包含任何注释、解释、推理过程或其他无关内容。返回的内容应仅限于修复后的 SQL 语句。 + + #02 **优先解决 SQL 中可能存在的特殊字符问题**,如果报错提示与特殊字符相关(例如:引号、反斜杠、注释符等),请尽可能清理或转义这些字符,确保 SQL 语句能够正确执行。 + + #03 如果执行失败是由于 SQL 中某些字段名导致的(例如字段名包含保留字、大小写不一致等),请尝试使用双引号包裹字段名,或者使用更合适的字段名替换原字段。 + + #04 如果报错与查询字段的匹配条件相关(例如:`=` 运算符导致检索结果为空),请优先尝试将条件中的 `=` 替换为更宽松的 `ilike`,并添加适当的通配符(例如:`'%value%'`),以确保 SQL 执行返回结果。 + + #05 如果 SQL 执行结果为空,请根据问题中的关键字或上下文,将 `WHERE` 子句的过滤条件调整为问题相关的字段,或者使用关键字的子集进行查询,以确保 SQL 语句能够返回有效结果。 + + #06 **确保修复后的 SQL 语句符合 {database_type} 数据库的语法规范**,避免其他潜在的语法问题。 + + 以下是表结构以及表注释: + + {table_info} + + 以下是执行失败的 SQL 语句: + + {error_sql} + + 以下是执行失败的报错信息: + + {error_msg} + + 以下是问题描述: + + {question} + + 请基于上述信息,修复 SQL 语句,使其能够成功执行。 + """ +) + + +RISK_EVALUATE_SQL = dedent(r""" + 你是一个SQL执行风险评估器。 + + 你的任务是根据当前给出的生成或修复的SQL语句、数据库类型、数据库配置和执行环境,在不直接访问或执行SQL语句的情况下,判断执行SQL时的风险并输出提示。 + + 严格遵守以下要求: + #00 你不需要执行任何实际的指令和访问或操作任何数据库,只需要对指令运行的风险进行预测和评估。 + + #01 **在任何情况下**你的回答中都只有 json 形式的风险等级评估结果,不要包含任何**评估理由、推理过程或其他无关的内容**。 + + #02 JSON 内容**必须**包含两个字段: + - "risk":取值为 "low"、"medium" 或 "high" + - "message":风险提示信息 + + #03 对于 MongoDB 数据库,其不是标准 SQL 数据库,不要输出对 SQL 的解释或说明,同样仅分析指令运行风险。 + + #04 你的工作是仅分析 SQL 语句的风险等级,不涉及任何具体数据库访问、执行操作以及获取结果。 + + #05 数据库类型: {database_type} + + #06 语句执行的目标是:{goal} + + #07 需要执行或修复的SQL语句是:{sql} + + #08 目标的表信息是:{table_info} + + #09 如果生成SQL,可能涉及数据库中的敏感表/数据 + + #10 如果是修复SQL,错误SQL语句是:{error_sql},错误信息是:{error_msg} + + #11 结果格式如下 + {{ + "risk": "low/medium/high", + "message": "提示信息" + }} + + """ +) diff --git a/chat2db/apps/routers/sql.py b/chat2db/apps/routers/sql.py new file mode 100644 index 0000000000000000000000000000000000000000..9a6df2d5b7611db9efb705c6926faa38fa66877f --- /dev/null +++ b/chat2db/apps/routers/sql.py @@ -0,0 +1,160 @@ +import logging +from fastapi import APIRouter, status +import sys + +from chat2db.apps.schemas.enum_var import RiskLevel, DatabaseType +from chat2db.apps.schemas.request import SqlGenerateRequest, SqlExecuteRequest, SqlRepairRequest +from chat2db.apps.schemas.response import ResponseData, SqlGenerateRsp, SqlExecuteRsp, SqlRepairRsp +from chat2db.apps.services import database_service +from chat2db.apps.services.sql_service import SqlService + +router = APIRouter(prefix="/sql") + + +@router.post("/generate", response_model=ResponseData) +async def generate_sql(request: SqlGenerateRequest): + try: + _, table_info = await SqlService.get_connection_and_table_info( + database_type=request.type, + host=request.host, + port=request.port, + username=request.username, + password=request.password, + database=request.database, + table_list=request.table_list, + ) + + sql = await SqlService.generator( + database_type=request.type, + goal=request.goal, + table_info=table_info, + ) + + risk = await SqlService.risk_analysis( + database_type=request.type, goal=request.goal, sql=sql, table_info=table_info + ) + + except Exception as e: + logging.error(f"[SQL 生成失败]") + return ResponseData(code=status.HTTP_400_BAD_REQUEST, message="SQL 生成失败", result={}) + + return ResponseData( + code=status.HTTP_200_OK, + message="success", + result=SqlGenerateRsp( + risk=risk, + sql=sql, + ), + ) + + +@router.post("/repair", response_model=ResponseData) +async def repair_sql(request: SqlRepairRequest): + try: + _, table_info = await SqlService.get_connection_and_table_info( + database_type=request.type, + host=request.host, + port=request.port, + username=request.username, + password=request.password, + database=request.database, + table_list=request.table_list, + ) + + repair_sql = await SqlService.repairer( + database_type=request.type, + goal=request.goal, + table_info=table_info, + error_sql=request.error_sql, + error_msg=request.error_msg, + ) + + risk = await SqlService.risk_analysis( + database_type=request.type, + goal=request.goal, + sql=repair_sql, + table_info=table_info, + error_sql=request.error_sql, + error_msg=request.error_msg, + ) + + except Exception as e: + logging.error(f"[SQL 修复失败]") + return ResponseData( + code=status.HTTP_400_BAD_REQUEST, message="SQL 修复失败", result={"Error": str(e)} + ) + + return ResponseData( + code=status.HTTP_200_OK, + message="success", + result=SqlRepairRsp( + risk=risk, + sql=repair_sql, + ), + ) + + +@router.post("/execute", response_model=ResponseData) +async def execute_sql(request: SqlExecuteRequest): + try: + connection = await database_service.connect_database( + database_type=request.type, + host=request.host, + port=request.port, + username=request.username, + password=request.password, + database=request.database, + ) + execute_result = await SqlService.executer( + database_type=request.type, + sql=request.sql, + connection=connection, + ) + + except Exception as e: + logging.error(f"[SQL 执行失败]") + return ResponseData( + code=status.HTTP_400_BAD_REQUEST, message="SQL 执行失败", result={"Error": str(e)} + ) + + return ResponseData( + code=status.HTTP_200_OK, + message="success", + result=SqlExecuteRsp( + execute_result=execute_result, + ), + ) + + +@router.post("/handler", response_model=ResponseData) +async def sql_handler(request: SqlGenerateRequest): + try: + connection, table_info = await SqlService.get_connection_and_table_info( + database_type=request.type, + host=request.host, + port=request.port, + username=request.username, + password=request.password, + database=request.database, + table_list=request.table_list, + ) + + execute_result, sql, risk = await SqlService.sql_handler( + database_type=request.type, + goal=request.goal, + table_info=table_info, + connection=connection, + ) + except Exception as e: + logging.error(f"[查询失败]") + return ResponseData(code=status.HTTP_400_BAD_REQUEST, message="查询失败", result={"Error": str(e)}) + + return ResponseData( + code=status.HTTP_200_OK, + message="success", + result=SqlExecuteRsp( + sql=sql, + execute_result=execute_result, + risk=risk, + ), + ) diff --git a/chat2db/apps/schemas/enum_var.py b/chat2db/apps/schemas/enum_var.py new file mode 100644 index 0000000000000000000000000000000000000000..ff5edb76f7b7506d4ab058ad3044efb0d9ecd068 --- /dev/null +++ b/chat2db/apps/schemas/enum_var.py @@ -0,0 +1,18 @@ +from enum import Enum + + +class RiskLevel(str, Enum): + LOW = "low" + MEDIUM = "medium" + HIGH = "high" + + +class DatabaseType(str, Enum): + MYSQL = "mysql" + POSTGRES = "postgres" + OPENGAUSS = "opengauss" + MONGODB = "mongodb" + +if __name__ == "__main__": + print(DatabaseType.MYSQL) + print(DatabaseType.MYSQL.value) \ No newline at end of file diff --git a/chat2db/apps/schemas/request.py b/chat2db/apps/schemas/request.py new file mode 100644 index 0000000000000000000000000000000000000000..fc01bdbe667608775d5606b3cc55f234578f5456 --- /dev/null +++ b/chat2db/apps/schemas/request.py @@ -0,0 +1,50 @@ +import uuid +from pydantic import BaseModel, Field +from typing import Optional + +from chat2db.apps.schemas.enum_var import DatabaseType + +class SqlGenerateRequest(BaseModel): + """ + 生成SQL请求 + """ + type: DatabaseType = Field(..., description="数据库类型") + host: str = Field(..., description="数据库地址") + port: int = Field(..., description="数据库端口") + username: str = Field(..., description="数据库用户名") + password: str = Field(..., description="数据库密码") + database: str = Field(..., description="数据库名称") + goal: str = Field(..., description="生成目标") + + table_list: list[str] = Field(None, description="表名列表") + +class SqlRepairRequest(BaseModel): + """ + 修复SQL请求 + """ + type: DatabaseType = Field(..., description="数据库类型") + + host: str = Field(..., description="数据库地址") + port: int = Field(..., description="数据库端口") + username: str = Field(..., description="数据库用户名") + password: str = Field(..., description="数据库密码") + database: str = Field(..., description="数据库名称") + goal: str = Field(..., description="生成目标") + + error_sql: str = Field(..., description="错误 SQL 语句") + error_msg: str = Field(..., description="错误信息") + table_list: list[str] = Field(None, description="表名列表") + + +class SqlExecuteRequest(BaseModel): + """ + 执行SQL请求 + """ + type: DatabaseType = Field(..., description="数据库类型") + + host: str = Field(..., description="数据库地址") + port: int = Field(..., description="数据库端口") + username: str = Field(..., description="数据库用户名") + password: str = Field(..., description="数据库密码") + database: str = Field(..., description="数据库名称") + sql: str = Field(..., description="执行SQL") diff --git a/chat2db/apps/schemas/response.py b/chat2db/apps/schemas/response.py new file mode 100644 index 0000000000000000000000000000000000000000..af4b2a568b7a9de1117878396fb553e242f15d72 --- /dev/null +++ b/chat2db/apps/schemas/response.py @@ -0,0 +1,38 @@ +from pydantic import BaseModel, Field +from typing import Any + +from chat2db.apps.schemas.enum_var import RiskLevel + +class ResponseData(BaseModel): + code: int + message: str + result: Any + +class RiskInfo(BaseModel): + risk: RiskLevel = Field(..., description="风险等级") + message: str = Field(..., description="风险提示信息") + +class SqlGenerateRsp(BaseModel): + """ + SQL生成请求 + """ + sql: str | dict = Field(..., description="生成的SQL") + risk: RiskInfo = Field(..., description="SQL 风险等级") + + +class SqlRepairRsp(BaseModel): + """ + 修复SQL请求 + """ + sql: str | dict = Field(..., description="修复的SQL") + risk: RiskInfo = Field(..., description="SQL 风险等级") + + +class SqlExecuteRsp(BaseModel): + """ + 执行SQL请求 + """ + execute_result: list[dict[str, Any]] = Field(..., description="执行结果") + sql: str | dict = Field(..., description="执行的SQL") + risk: RiskInfo = Field(..., description="SQL 风险等级") + diff --git a/chat2db/apps/services/database_service.py b/chat2db/apps/services/database_service.py new file mode 100644 index 0000000000000000000000000000000000000000..8ca956a7d2c2378ae7038d64bf6b6d2cc8943647 --- /dev/null +++ b/chat2db/apps/services/database_service.py @@ -0,0 +1,125 @@ +from typing import Any, Type +from chat2db.apps.schemas.enum_var import DatabaseType +from chat2db.apps.base import MySQL, MongoDB, OpenGauss, Postgres, MetaDatabase + +class DatabaseService: + + DatabaseMap: dict[DatabaseType, Type[MetaDatabase]] = { + DatabaseType.MYSQL: MySQL, + DatabaseType.MONGODB: MongoDB, + DatabaseType.OPENGAUSS: OpenGauss, + DatabaseType.POSTGRES: Postgres, + } + + @staticmethod + async def get_database_url( + database_type: DatabaseType, host: str, port: int, username: str, password: str, database: str + ): + """ + 根据数据库类型和连接信息生成数据库 URL。 + + :return: 数据库连接 URL 字符串 + """ + db_class = DatabaseService.DatabaseMap[database_type] + return await db_class.get_database_url(host, port, username, password, database) + + @staticmethod + async def connect_database( + database_type: DatabaseType, host: str, port: int, username: str, password: str, database: str + ): + """ + 根据数据库类型和连接信息建立数据库连接。 + + :return: 数据库连接对象 + """ + db_class = DatabaseService.DatabaseMap[database_type] + return await db_class.connect(host, port, username, password, database) + + @staticmethod + async def list_tables(database_type: DatabaseType, connection: Any) -> list[str]: + """ + 获取指定数据库中所有表名。 + + :param database_type: 数据库类型枚举 + :param connection: 数据库连接对象 + :return: 表名列表 + """ + db_module = DatabaseService.DatabaseMap[database_type] + return await db_module.list_tables(connection) + + @staticmethod + async def get_table_ddl(database_type: DatabaseType, table_name: str, connection: Any) -> str: + """ + 获取指定表的建表语句 DDL。 + + :param database_type: 数据库类型枚举 + :param table_name: 表名 + :param connection: 数据库连接对象 + + :return: 表的 DDL 字符串 + """ + db_module = DatabaseService.DatabaseMap[database_type] + return await db_module.get_table_ddl(table_name, connection) + + @staticmethod + async def sample_table_rows( + database_type: DatabaseType, table_name: str, num_rows: int, connection: Any + ) -> list[dict]: + """ + 获取指定表的前 n 条示例数据。 + + :param database_type: 数据库类型枚举 + :param table_name: 表名 + :param n: 返回的行数 + :param connection: 数据库连接对象 + + :return: 示例行列表,每行为字典 + """ + db_module = DatabaseService.DatabaseMap[database_type] + return await db_module.sample_table_rows(table_name, num_rows, connection) + + @staticmethod + async def execute_sql(database_type: DatabaseType, sql: str | dict, connection: Any) -> list[dict]: + """ + 执行 SQL 语句或 MongoDB 指令。 + + :param database_type: 数据库类型枚举 + :param sql: SQL 语句字符串(非 MongoDB)或 MongoDB dict 指令 + :param connection: 数据库连接对象 + + :return: 执行结果列表,每条记录为字典 + """ + db_module = DatabaseService.DatabaseMap[database_type] + return await db_module.execute_sql(sql, connection) + + +if __name__ == "__main__": + import asyncio + + async def main(): + type = "mysql" + conn = await DatabaseService.connect_database( + type, + host="localhost", + port=3306, + username="chat2db", + password="123456", + database="chat2db", + ) + print("\n[Connection]\n:", conn) + + tables = await DatabaseService.list_tables(type, conn) + print("\n[Tables]:\n", tables) + + ddl = await DatabaseService.get_table_ddl(type, tables[0], conn) + print("\n[DDL]\n:", ddl) + + sql = "SELECT DISTINCT `TABLE_NAME` FROM `information_schema`.`TABLES` WHERE `TABLE_SCHEMA` = DATABASE();", + execute_res = await DatabaseService.execute_sql( + type, + sql, + conn, + ) + print("\n[Execute]:\n", execute_res) + + asyncio.run(main()) diff --git a/chat2db/apps/services/sql_service.py b/chat2db/apps/services/sql_service.py new file mode 100644 index 0000000000000000000000000000000000000000..bdd30d0cf1db5d108434f8dee3fb49b07c5f413a --- /dev/null +++ b/chat2db/apps/services/sql_service.py @@ -0,0 +1,240 @@ +from typing import Any +import logging +import re +import json + +from chat2db.apps.llm import LLM, GENERATE_SQL_PROMPT, REPAIR_SQL_PROMPT, RISK_EVALUATE_SQL +from chat2db.apps.services.database_service import DatabaseService +from chat2db.apps.schemas.enum_var import DatabaseType + +from config.config import config + + +class SqlService: + + @staticmethod + async def get_connection_and_table_info( + database_type: DatabaseType, + host: str, + port: int, + username: str, + password: str, + database: str, + table_list: list[str] | None = None + ) -> str: + try: + conn = await DatabaseService.connect_database(database_type, host, port, username, password, database) + + if table_list is None or len(table_list) == 0: + table_list = await DatabaseService.list_tables(database_type, conn) + table_ddls = {} + for table in table_list: + ddl = await DatabaseService.get_table_ddl(database_type, table, conn) + table_ddls[table] = ddl + + table_info = "\n".join([f"表: {table}\nDDL:\n{ddl}" for table, ddl in table_ddls.items()]) + + return conn, table_info + except Exception as e: + logging.error(f"\n[获取数据库连接和表信息失败]\n\n{e}") + raise e + + @staticmethod + async def generator( + database_type: DatabaseType, + goal: str, + table_info: str, + llm: LLM | None = None, + ) -> str: + """ + 核心业务逻辑:生成 SQL + - 传入 table_info 作为表信息。 + - 或提供数据库连接信息 host, port, username, password, database + """ + logging.info(f"\n[生成目标]\n\n{goal}") + + if llm == None: + llm = LLM( + model_name=config["LLM_MODEL"], + openai_api_base=config["LLM_URL"], + openai_api_key=config["LLM_KEY"], + max_tokens=config["LLM_MAX_TOKENS"], + request_timeout=60, + temperature=0.5, + ) + + prompt = GENERATE_SQL_PROMPT.format( + database_type=database_type.value, table_info=table_info, question=goal + ) + + try: + result = await llm.nostream([], prompt, "请给出你生成的 SQL 语句") + sql = (await SqlService._extract_json(result))['command'] + logging.info(f"\n[生成SQL成功]\n\n{sql}") + return sql + + except Exception as e: + logging.error(f"\n[生成SQL失败]\n\n{e}") + raise e + + @staticmethod + async def repairer( + database_type: DatabaseType, + goal: str, + table_info: str, + error_sql: str, + error_msg: str, + llm: LLM | None = None, + ) -> str: + """ + 核心业务逻辑:生成修复 SQL + - 传入 table_info 作为表信息。 + - 或提供数据库连接信息 host, port, username, password, database + """ + if llm == None: + llm = LLM( + model_name=config["LLM_MODEL"], + openai_api_base=config["LLM_URL"], + openai_api_key=config["LLM_KEY"], + max_tokens=config["LLM_MAX_TOKENS"], + request_timeout=60, + temperature=0.5, + ) + + prompt = REPAIR_SQL_PROMPT.format( + database_type=database_type.value, + table_info=table_info, + error_sql=error_sql, + error_msg=error_msg, + question=goal, + ) + try: + repair_sql = await llm.nostream([], prompt, "请给出你修复的 SQL 语句") + logging.info(f"\n[修复SQL成功]\n\n{repair_sql}") + return repair_sql + + except Exception as e: + logging.error(f"\n[修复SQL失败]\n\n{e}") + raise e + + @staticmethod + async def executer( + database_type: DatabaseType, + sql: str, + connection=None, + ) -> list[dict]: + """ + 核心业务逻辑:执行 SQL + """ + try: + result = await DatabaseService.execute_sql(database_type, sql, connection) + logging.info(f"\n[执行SQL]\n\n{sql}\n\n[执行结果]\n\n{result}") + return result + except Exception as e: + logging.error(f"\n[执行失败]\n") + raise e + + @staticmethod + async def sql_handler( + database_type: DatabaseType, + goal: str, + table_info: str, + connection: Any, + max_retries: int = 3, + ) -> list[dict]: + """ + 核心业务逻辑:智能查询,支持语句异常自动修复 + """ + + llm = LLM( + model_name=config["LLM_MODEL"], + openai_api_base=config["LLM_URL"], + openai_api_key=config["LLM_KEY"], + max_tokens=config["LLM_MAX_TOKENS"], + request_timeout=60, + temperature=0.5, + ) + + # 生成 SQL 查询语句 + sql = await SqlService.generator(database_type, goal, table_info, llm) + + risk = await SqlService.risk_analysis(database_type, goal, sql, table_info, llm=llm) + + # 故意产生错误 + # sql = sql.replace("SELECT", "SELCT") + ### + + # 初次尝试执行 SQL 查询 + retries = 0 + while retries <= max_retries: + try: + execute_result = await SqlService.executer(database_type, sql, connection=connection) + return execute_result, sql, risk + except Exception as e: + if retries == max_retries: + logging.error(f"\n[重试次数已达到最大值]\n\nSQL 执行失败,最终错误:{e}") + raise e + logging.error(f"\n[执行失败 - 尝试修复 {retries + 1}/{max_retries}]\n") + repair_sql = await SqlService.repairer( + database_type=database_type, + goal=goal, + table_info=table_info, + error_sql=sql, + error_msg=str(e), + llm=llm, + ) + + sql = repair_sql + retries += 1 + + return [] + + @staticmethod + async def risk_analysis( + database_type: DatabaseType, + goal: str, + sql: str, + table_info: str, + error_sql: str | None = None, + error_msg: str | None = None, + llm: LLM | None = None, + ): + + if llm == None: + llm = LLM( + model_name=config["LLM_MODEL"], + openai_api_base=config["LLM_URL"], + openai_api_key=config["LLM_KEY"], + max_tokens=config["LLM_MAX_TOKENS"], + request_timeout=60, + temperature=0.5, + ) + + prompt = RISK_EVALUATE_SQL.format( + database_type=database_type.value, + table_info=table_info, + error_sql=error_sql, + error_msg=error_msg, + goal=goal, + sql=sql, + ) + + try: + result = await llm.nostream([], prompt, "请给出你评估的风险结果") + risk = await SqlService._extract_json(result) + logging.info(f"\n[风险分析成功]\n\n{risk}") + return risk + + except Exception as e: + logging.error(f"\n[风险分析失败]\n\n{type(e)}: {e}") + raise e + + @staticmethod + async def _extract_json(text: str): + try: + match = re.search(r"\{.*?\}\s*$", text, re.DOTALL) + if match: + return json.loads(match.group()) + except json.JSONDecodeError as e: + logging.error("\n[JSON解析失败]\n\n{e}") + raise e diff --git a/chat2db/common/.env.example b/chat2db/common/.env.example index 999e50afd1abbb79ed28eaf009874e397e337d94..a74baf1f13a38847f25ec2da45c881bbd201e64e 100644 --- a/chat2db/common/.env.example +++ b/chat2db/common/.env.example @@ -1,31 +1,4 @@ -# FastAPI -UVICORN_IP = 0.0.0.0 -UVICORN_PORT = 9015 -# SSL_CERTFILE = -# SSL_KEYFILE = -# SSL_ENABLE = - -# Postgres -DATABASE_TYPE = -DATABASE_HOST = -DATABASE_PORT = -DATABASE_USER = -DATABASE_PASSWORD = -DATABASE_DB = - -# QWEN LLM_KEY = LLM_URL = LLM_MAX_TOKENS = -LLM_MODEL = - -# Vectorize -EMBEDDING_TYPE = -EMBEDDING_API_KEY = -EMBEDDING_ENDPOINT = -EMBEDDING_MODEL_NAME = - -# security -HALF_KEY1 = R4UsZgLB -HALF_KEY2 = zRTvYV8N -HALF_KEY3 = 4eQ1wAGA \ No newline at end of file +LLM_MODEL = \ No newline at end of file diff --git a/chat2db/common/init_sql_example.py b/chat2db/common/init_sql_example.py deleted file mode 100644 index e3f09ed4bdd0afe89086c3121ad215e5e0fa7617..0000000000000000000000000000000000000000 --- a/chat2db/common/init_sql_example.py +++ /dev/null @@ -1,114 +0,0 @@ -import yaml -from fastapi import status -import requests -import uuid -import urllib.parse -from typing import Optional -from pydantic import BaseModel, Field -from chat2db.config.config import config -ip = config['UVICORN_IP'] -port = config['UVICORN_PORT'] -base_url = f'http://{ip}:{port}' -password = config['DATABASE_PASSWORD'] -encoded_password = urllib.parse.quote_plus(password) - -if config['DATABASE_TYPE'].lower() == 'opengauss': - database_url = f"opengauss+psycopg2://{config['DATABASE_USER']}:{encoded_password}@{config['DATABASE_HOST']}:{config['DATABASE_PORT']}/{config['DATABASE_DB']}" -else: - database_url = f"postgresql+psycopg2://{config['DATABASE_USER']}:{encoded_password}@{config['DATABASE_HOST']}:{config['DATABASE_PORT']}/{config['DATABASE_DB']}" - - -class DatabaseDelRequest(BaseModel): - database_id: Optional[str] = Field(default=None, description="数据库id") - database_url: Optional[str] = Field(default=None, description="数据库url") - - -def del_database_url(base_url, database_url): - server_url = f'{base_url}/database/del' - try: - request_data = DatabaseDelRequest(database_url=database_url).dict() - response = requests.post(server_url, json=request_data) - if response.json()['code'] != status.HTTP_200_OK: - print(response.json()['message']) - except Exception as e: - print(f"删除数据库配置失败: {e}") - exit(0) - return None - - -class DatabaseAddRequest(BaseModel): - database_url: str - - -def add_database_url(base_url, database_url): - server_url = f'{base_url}/database/add' - try: - request_data = DatabaseAddRequest(database_url=database_url).dict() - - response = requests.post(server_url, json=request_data) - response.raise_for_status() - if response.json()['code'] != status.HTTP_200_OK: - raise Exception(response.json()['message']) - except Exception as e: - print(f"增加数据库配置失败: {e}") - exit(0) - return response.json()['result']['database_id'] - - -class TableAddRequest(BaseModel): - database_id: str - table_name: str - - -def add_table(base_url, database_id, table_name): - server_url = f'{base_url}/table/add' - try: - request_data = TableAddRequest(database_id=database_id, table_name=table_name).dict() - response = requests.post(server_url, json=request_data) - response.raise_for_status() - if response.json()['code'] != status.HTTP_200_OK: - raise Exception(response.json()['message']) - except Exception as e: - print(f"增加表配置失败: {e}") - return - return response.json()['result']['table_id'] - - -class SqlExampleAddRequest(BaseModel): - table_id: str - question: str - sql: str - - -def add_sql_example(base_url, table_id, question, sql): - server_url = f'{base_url}/sql/example/add' - try: - request_data = SqlExampleAddRequest(table_id=table_id, question=question, sql=sql).dict() - response = requests.post(server_url, json=request_data) - if response.json()['code'] != status.HTTP_200_OK: - raise Exception(response.json()['message']) - except Exception as e: - print(f"增加sql案例失败: {e}") - return - return response.json()['result']['sql_example_id'] - - -database_id = del_database_url(base_url, database_url) -database_id = add_database_url(base_url, database_url) -with open('./chat2db/common/table_name.yaml') as f: - table_name_list = yaml.load(f, Loader=yaml.SafeLoader) -table_name_id = {} -for table_name in table_name_list: - table_id = add_table(base_url, database_id, table_name) - if table_id: - table_name_id[table_name] = table_id -with open('./chat2db/common/table_name_sql_exmple.yaml') as f: - table_name_sql_example_list = yaml.load(f, Loader=yaml.SafeLoader) -for table_name_sql_example in table_name_sql_example_list: - table_name = table_name_sql_example['table_name'] - if table_name not in table_name_id: - continue - table_id = table_name_id[table_name] - sql_example_list = table_name_sql_example['sql_example_list'] - for sql_example in sql_example_list: - add_sql_example(base_url, table_id, sql_example['question'], sql_example['sql']) diff --git a/chat2db/common/table_name.yaml b/chat2db/common/table_name.yaml deleted file mode 100644 index 553cf1b2a4a780d1c731eb87e285e3fd75b5fc04..0000000000000000000000000000000000000000 --- a/chat2db/common/table_name.yaml +++ /dev/null @@ -1,10 +0,0 @@ -- oe_community_openeuler_version -- oe_community_organization_structure -- oe_compatibility_card -- oe_compatibility_commercial_software -- oe_compatibility_cve_database -- oe_compatibility_oepkgs -- oe_compatibility_osv -- oe_compatibility_overall_unit -- oe_compatibility_security_notice -- oe_compatibility_solution diff --git a/chat2db/common/table_name_sql_exmple.yaml b/chat2db/common/table_name_sql_exmple.yaml deleted file mode 100644 index 8e87a1100ebebd05e579c244395ec6e97698f094..0000000000000000000000000000000000000000 --- a/chat2db/common/table_name_sql_exmple.yaml +++ /dev/null @@ -1,490 +0,0 @@ -- keyword_list: - - test_organization - - product_name - - company_name - sql_example_list: - - question: openEuler支持的哪些商业软件在江苏鲲鹏&欧拉生态创新中心测试通过 - sql: SELECT product_name, product_version, openeuler_version FROM public.oe_compatibility_commercial_software - WHERE test_organization ILIKE '%江苏鲲鹏&欧拉生态创新中心%'; - - question: 哪个版本的openEuler支持的商业软件最多 - sql: SELECT openeuler_version, COUNT(*) AS software_count FROM public.oe_compatibility_commercial_software GROUP - BY openeuler_version ORDER BY software_count DESC LIMIT 1; - - question: openEuler支持测试商业软件的机构有哪些? - sql: SELECT DISTINCT test_organization FROM public.oe_compatibility_commercial_software; - - question: openEuler支持的商业软件有哪些类别 - sql: SELECT DISTINCT "type" FROM public.oe_compatibility_commercial_software; - - question: openEuler有哪些虚拟化类别的商业软件 - sql: SELECT product_name FROM public.oe_compatibility_commercial_software WHERE - "type" ILIKE '%虚拟化%'; - - question: openEuler支持哪些ISV商业软件呢,请列出10个 - sql: SELECT product_name FROM public.oe_compatibility_commercial_software; - - question: openEuler支持的适配Kunpeng 920的互联网商业软件有哪些? - sql: SELECT product_name, openeuler_version,platform_type_and_server_model FROM - public.oe_compatibility_commercial_software WHERE platform_type_and_server_model - ILIKE '%Kunpeng 920%' AND "type" ILIKE '%互联网%' limit 30; - - question: openEuler-22.03版本支持哪些商业软件? - sql: SELECT product_name, openeuler_version FROM oe_compatibility_commercial_software - WHERE openeuler_version ILIKE '%22.03%'; - - question: openEuler支持的数字政府类型的商业软件有哪些 - sql: SELECT product_name, product_version FROM oe_compatibility_commercial_software - WHERE type ILIKE '%数字政府%'; - - question: 有哪些商业软件支持超过一种服务器平台 - sql: SELECT product_name FROM public.oe_compatibility_commercial_software WHERE - platform_type_and_server_model ILIKE '%Intel%' AND platform_type_and_server_model - ILIKE '%Kunpeng%'; - - question: 每个openEuler版本有多少种类型的商业软件支持 - sql: SELECT openeuler_version, COUNT(DISTINCT type) AS type_count FROM public.oe_compatibility_commercial_software GROUP - BY openeuler_version; - - question: openEuler支持的哪些商业ISV在江苏鲲鹏&欧拉生态创新中心测试通过 - sql: SELECT product_name, product_version, openeuler_version FROM public.oe_compatibility_commercial_software - WHERE test_organization ILIKE '%江苏鲲鹏&欧拉生态创新中心%'; - - question: 哪个版本的openEuler支持的商业ISV最多 - sql: SELECT openeuler_version, COUNT(*) AS software_count FROM public.oe_compatibility_commercial_software GROUP - BY openeuler_version ORDER BY software_count DESC LIMIT 1; - - question: openEuler支持测试商业ISV的机构有哪些? - sql: SELECT DISTINCT test_organization FROM public.oe_compatibility_commercial_software; - - question: openEuler支持的商业ISV有哪些类别 - sql: SELECT DISTINCT "type" FROM public.oe_compatibility_commercial_software; - - question: openEuler有哪些虚拟化类别的商业ISV - sql: SELECT product_name FROM public.oe_compatibility_commercial_software WHERE - "type" ILIKE '%虚拟化%'; - - question: openEuler支持哪些ISV商业ISV呢,请列出10个 - sql: SELECT product_name FROM public.oe_compatibility_commercial_software; - - question: openEuler支持的适配Kunpeng 920的互联网商业ISV有哪些? - sql: SELECT product_name, openeuler_version,platform_type_and_server_model FROM - public.oe_compatibility_commercial_software WHERE platform_type_and_server_model - ILIKE '%Kunpeng 920%' AND "type" ILIKE '%互联网%' limit 30; - - question: openEuler-22.03版本支持哪些商业ISV? - sql: SELECT product_name, openeuler_version FROM oe_compatibility_commercial_software - WHERE openeuler_version ILIKE '%22.03%'; - - question: openEuler支持的数字政府类型的商业ISV有哪些 - sql: SELECT product_name, product_version FROM oe_compatibility_commercial_software - WHERE type ILIKE '%数字政府%'; - - question: 有哪些商业ISV支持超过一种服务器平台 - sql: SELECT product_name FROM public.oe_compatibility_commercial_software WHERE - platform_type_and_server_model ILIKE '%Intel%' AND platform_type_and_server_model - ILIKE '%Kunpeng%'; - - question: 每个openEuler版本有多少种类型的商业ISV支持 - sql: SELECT openeuler_version, COUNT(DISTINCT type) AS type_count FROM public.oe_compatibility_commercial_software GROUP - BY openeuler_version; - - question: 卓智校园网接入门户系统基于openeuelr的什么版本? - sql: select * from oe_compatibility_commercial_software where product_name ilike - '%卓智校园网接入门户系统%'; - table_name: oe_compatibility_commercial_software -- keyword_list: - - softwareName - sql_example_list: - - question: openEuler-20.03-LTS-SP1支持哪些开源软件? - sql: SELECT DISTINCT openeuler_version,"softwareName" FROM public.oe_compatibility_open_source_software WHERE - openeuler_version ILIKE '%20.03-LTS-SP1%'; - - question: openEuler的aarch64下支持开源软件 - sql: SELECT "softwareName" FROM public.oe_compatibility_open_source_software WHERE - "arch" ILIKE '%aarch64%'; - - question: openEuler支持开源软件使用了GPLv2+许可证 - sql: SELECT "softwareName" FROM public.oe_compatibility_open_source_software WHERE - "license" ILIKE '%GPLv2+%'; - - question: tcplay支持的架构是什么 - sql: SELECT "arch" FROM public.oe_compatibility_open_source_software WHERE "softwareName" - ILIKE '%tcplay%'; - - question: openEuler支持哪些开源软件,请列出10个 - sql: SELECT "softwareName" FROM public.oe_compatibility_open_source_software LIMIT - 10; - - question: openEuler支持开源软件支持哪些结构 - sql: SELECT "arch" FROM public.oe_compatibility_open_source_software group by - "arch"; - - question: openEuler支持多少个开源软件? - sql: select tmp_table.openeuler_version,count(*) as open_source_software_cnt from - (select DISTINCT openeuler_version,"softwareName" from oe_compatibility_open_source_software) - as tmp_table group by tmp_table.openeuler_version; - - question: openEuler-20.03-LTS-SP1支持哪些开源ISV? - sql: SELECT DISTINCT openeuler_version,"softwareName" FROM public.oe_compatibility_open_source_software WHERE - openeuler_version ILIKE '%20.03-LTS-SP1%'; - - question: openEuler的aarch64下支持开源ISV - sql: SELECT "softwareName" FROM public.oe_compatibility_open_source_software WHERE - "arch" ILIKE '%aarch64%'; - - question: openEuler支持开源ISV使用了GPLv2+许可证 - sql: SELECT "softwareName" FROM public.oe_compatibility_open_source_software WHERE - "license" ILIKE '%GPLv2+%'; - - question: tcplay支持的架构是什么 - sql: SELECT "arch" FROM public.oe_compatibility_open_source_software WHERE "softwareName" - ILIKE '%tcplay%'; - - question: openEuler支持哪些开源ISV,请列出10个 - sql: SELECT "softwareName" FROM public.oe_compatibility_open_source_software LIMIT - 10; - - question: openEuler支持开源ISV支持哪些结构 - sql: SELECT "arch" FROM public.oe_compatibility_open_source_software group by - "arch"; - - question: openEuler-20.03-LTS-SP1支持多少个开源ISV? - sql: select tmp_table.openeuler_version,count(*) as open_source_software_cnt from - (select DISTINCT openeuler_version,"softwareName" from oe_compatibility_open_source_software - where openeuler_version ilike 'openEuler-20.03-LTS-SP1') as tmp_table group - by tmp_table.openeuler_version; - - question: openEuler支持多少个开源ISV? - sql: select tmp_table.openeuler_version,count(*) as open_source_software_cnt from - (select DISTINCT openeuler_version,"softwareName" from oe_compatibility_open_source_software) - as tmp_table group by tmp_table.openeuler_version; - table_name: oe_compatibility_open_source_software -- keyword_list: [] - sql_example_list: - - question: 在openEuler技术委员会担任委员的人有哪些 - sql: SELECT name FROM oe_community_organization_structure WHERE committee_name - ILIKE '%技术委员会%' AND role = '委员'; - - question: openEuler的委员会中哪些人是教授 - sql: SELECT name FROM oe_community_organization_structure WHERE personal_message - ILIKE '%教授%'; - - question: openEuler各委员会中担任主席有多少个? - sql: SELECT committee_name, COUNT(*) FROM oe_community_organization_structure - WHERE role = '主席' GROUP BY committee_name; - - question: openEuler 用户委员会中有多少位成员 - sql: SELECT count(*) FROM oe_community_organization_structure WHERE committee_name - ILIKE '%用户委员会%'; - - question: openEuler 技术委员会有多少位成员 - sql: SELECT count(*) FROM oe_community_organization_structure WHERE committee_name - ILIKE '%技术委员会%'; - - question: openEuler委员会的委员常务委员会委员有哪些人 - sql: SELECT name FROM oe_community_organization_structure WHERE committee_name - ILIKE '%委员会%' AND role ILIKE '%常务委员会委员%'; - - question: openEuler委员会有哪些人属于华为技术有限公司? - sql: SELECT DISTINCT name FROM oe_community_organization_structure WHERE personal_message - ILIKE '%华为技术有限公司%'; - - question: openEuler每个委员会有多少人? - sql: SELECT committee_name, COUNT(*) FROM oe_community_organization_structure - GROUP BY committee_name; - - question: openEuler的执行总监是谁 - sql: SELECT name FROM oe_community_organization_structure WHERE role = '执行总监'; - - question: openEuler委员会有哪些组织? - sql: SELECT DISTINCT committee_name from oe_community_organization_structure; - - question: openEuler技术委员会的主席是谁? - sql: SELECT committee_name,name FROM oe_community_organization_structure WHERE - role = '主席' and committee_name ilike '%技术委员会%'; - - question: openEuler品牌委员会的主席是谁? - sql: SELECT committee_name,name FROM oe_community_organization_structure WHERE - role = '主席' and committee_name ilike '%品牌委员会%'; - - question: openEuler委员会的主席是谁? - sql: SELECT committee_name,name FROM oe_community_organization_structure WHERE - role = '主席' and committee_name ilike '%openEuler 委员会%'; - - question: openEuler委员会的执行总监是谁? - sql: SELECT committee_name,name FROM oe_community_organization_structure WHERE - role = '执行总监' and committee_name ilike '%openEuler 委员会%'; - - question: openEuler委员会的执行秘书是谁? - sql: SELECT committee_name,name FROM oe_community_organization_structure WHERE - role = '执行秘书' and committee_name ilike '%openEuler 委员会%'; - table_name: oe_community_organization_structure -- keyword_list: - - cve_id - sql_example_list: - - question: 安全公告openEuler-SA-2024-2059的详细信息在哪里? - sql: select DISTINCT security_notice_no,details from oe_compatibility_security_notice - where security_notice_no='openEuler-SA-2024-2059'; - table_name: oe_compatibility_security_notice -- keyword_list: - - hardware_model - sql_example_list: - - question: openEuler-22.03 LTS支持哪些整机? - sql: SELECT main_board_model, cpu, ram FROM oe_compatibility_overall_unit WHERE - openeuler_version ILIKE '%openEuler-22.03-LTS%'; - - question: 查询所有支持`openEuler-22.09`,并且提供详细产品介绍链接的整机型号和它们的内存配置? - sql: SELECT hardware_model, ram FROM oe_compatibility_overall_unit WHERE openeuler_version - ILIKE '%openEuler-22.09%' AND product_information IS NOT NULL; - - question: 显示所有由新华三生产,支持`openEuler-20.03 LTS SP2`版本的整机,列出它们的型号和架构类型 - sql: SELECT hardware_model, architecture FROM oe_compatibility_overall_unit WHERE - hardware_factory = '新华三' AND openeuler_version ILIKE '%openEuler-20.03 LTS SP2%'; - - question: openEuler支持多少种整机? - sql: SELECT count(DISTINCT main_board_model) FROM oe_compatibility_overall_unit; - - question: openEuler每个版本支持多少种整机? - sql: select openeuler_version,count(*) from (SELECT DISTINCT openeuler_version,main_board_model - FROM oe_compatibility_overall_unit) as tmp_table group by openeuler_version; - - question: openEuler每个版本多少种架构的整机? - sql: select openeuler_version,architecture,count(*) from (SELECT DISTINCT openeuler_version,architecture,main_board_model - FROM oe_compatibility_overall_unit) as tmp_table group by openeuler_version,architecture; - table_name: oe_compatibility_overall_unit -- keyword_list: - - osv_name - - os_version - sql_example_list: - - question: 深圳开鸿数字产业发展有限公司基于openEuler的什么版本发行了什么商用版本? - sql: select os_version,openeuler_version,os_download_link from oe_compatibility_osv - where osv_name='深圳开鸿数字产业发展有限公司'; - - question: 统计各个openEuler版本下的商用操作系统数量 - sql: SELECT openeuler_version, COUNT(*) AS os_count FROM public.oe_compatibility_osv GROUP - BY openeuler_version; - - question: 哪个OS厂商基于openEuler发布的商用操作系统最多 - sql: SELECT osv_name, COUNT(*) AS os_count FROM public.oe_compatibility_osv GROUP - BY osv_name ORDER BY os_count DESC LIMIT 1; - - question: 不同OS厂商基于openEuler发布不同架构的商用操作系统数量是多少? - sql: SELECT arch, osv_name, COUNT(*) AS os_count FROM public.oe_compatibility_osv GROUP - BY arch, osv_name ORDER BY arch, os_count DESC; - - question: 深圳开鸿数字产业发展有限公司的商用操作系统是基于什么openEuler版本发布的 - sql: SELECT os_version, openeuler_version FROM public.oe_compatibility_osv WHERE - osv_name ILIKE '%深圳开鸿数字产业发展有限公司%'; - - question: openEuler有哪些OSV伙伴 - sql: SELECT DISTINCT osv_name FROM public.oe_compatibility_osv; - - question: 有哪些OSV友商的操作系统是x86_64架构的 - sql: SELECT osv_name, os_version FROM public.oe_compatibility_osv WHERE arch ILIKE - '%x86_64%'; - - question: 哪些OSV友商操作系统是嵌入式类型的 - sql: SELECT osv_name, os_version,openeuler_version FROM public.oe_compatibility_osv - WHERE type ILIKE '%嵌入式%'; - - question: 成都鼎桥的商用操作系统版本是基于openEuler 22.03的版本吗 - sql: SELECT osv_name, os_version,"openeuler_version" FROM public.oe_compatibility_osv WHERE - osv_name ILIKE '%成都鼎桥通信技术有限公司%' AND openeuler_version ILIKE '%22.03%'; - - question: 最近发布的基于openEuler 23.09的商用系统有哪些 - sql: SELECT osv_name, os_version,"openeuler_version" FROM public.oe_compatibility_osv WHERE - openeuler_version ILIKE '%23.09%' ORDER BY date DESC limit 10; - - question: 帮我查下成都智明达发布的所有嵌入式系统 - sql: SELECT osv_name, os_version,"openeuler_version" FROM public.oe_compatibility_osv WHERE - osv_name ILIKE '%成都智明达电子股份有限公司%' AND type = '嵌入式'; - - question: 基于openEuler发布的商用操作系统有哪些类型 - sql: SELECT DISTINCT type FROM public.oe_compatibility_osv; - - question: 江苏润和系统版本HopeOS-V22-x86_64-dvd.iso基于openEuler哪个版本 - sql: SELECT DISTINCT osv_name, os_version,"openeuler_version" FROM public.oe_compatibility_osv - WHERE "osv_name" ILIKE '%江苏润和%' AND os_version ILIKE '%HopeOS-V22-x86_64-dvd.iso%' - ; - - question: 浙江大华DH-IVSS-OSV-22.03-LTS-SP2-x86_64-dvd.iso系统版本基于openEuler哪个版本 - sql: SELECT DISTINCT osv_name, os_version,"openeuler_version" FROM public.oe_compatibility_osv - WHERE "osv_name" ILIKE '%浙江大华%' AND os_version ILIKE '%DH-IVSS-OSV-22.03-LTS-SP2-x86_64-dvd.iso%' - ; - table_name: oe_compatibility_osv -- keyword_list: - - board_model - - chip_model - - chip_vendor - - product - sql_example_list: - - question: openEuler 22.03支持哪些网络接口卡型号? - sql: SELECT board_model, chip_model,type FROM oe_compatibility_card WHERE type - ILIKE '%NIC%' AND openeuler_version ILIKE '%22.03%' limit 30; - - question: 请列出openEuler支持的所有Renesas公司的密码卡 - sql: SELECT * FROM oe_compatibility_card WHERE chip_vendor ILIKE '%Renesas%' AND - type ILIKE '%密码卡%' limit 30; - - question: openEuler各种架构支持的板卡数量是多少 - sql: SELECT architecture, COUNT(*) AS total_cards FROM oe_compatibility_card GROUP - BY architecture limit 30; - - question: 每个openEuler版本支持了多少种板卡 - sql: SELECT openeuler_version, COUNT(*) AS number_of_cards FROM oe_compatibility_card - GROUP BY openeuler_version limit 30; - - question: openEuler总共支持多少种不同的板卡型号 - sql: SELECT COUNT(DISTINCT board_model) AS board_model_cnt FROM oe_compatibility_card - limit 30; - - question: openEuler支持的GPU型号有哪些? - sql: SELECT chip_model, openeuler_version,type FROM public.oe_compatibility_card WHERE - type ILIKE '%GPU%' ORDER BY driver_date DESC limit 30; - - question: openEuler 20.03 LTS-SP4版本支持哪些类型的设备 - sql: SELECT DISTINCT openeuler_version,type FROM public.oe_compatibility_card WHERE - openeuler_version ILIKE '%20.03-LTS-SP4%' limit 30; - - question: openEuler支持的板卡驱动在2023年后发布 - sql: SELECT board_model, driver_date, driver_name FROM oe_compatibility_card WHERE - driver_date >= '2023-01-01' limit 30; - - question: 给些支持openEuler的aarch64架构下支持的的板卡的驱动下载链接 - sql: SELECT openeuler_version,board_model, download_link FROM oe_compatibility_card - WHERE architecture ILIKE '%aarch64%' AND download_link IS NOT NULL limit 30; - - question: openEuler-22.03-LTS-SP1支持的存储卡有哪些? - sql: SELECT openeuler_version,board_model, chip_model,type FROM oe_compatibility_card - WHERE type ILIKE '%SSD%' AND openeuler_version ILIKE '%openEuler-22.03-LTS-SP1%' - limit 30; - table_name: oe_compatibility_card -- keyword_list: - - cve_id - sql_example_list: - - question: CVE-2024-41053的详细信息在哪里可以看到? - sql: select DISTINCT cve_id,details from oe_compatibility_cve_database where cve_id='CVE-2024-41053'; - - question: CVE-2024-41053是个怎么样的漏洞? - sql: select DISTINCT cve_id,summary from oe_compatibility_cve_database where cve_id='CVE-2024-41053'; - - question: CVE-2024-41053影响了哪些包? - sql: select DISTINCT cve_id,package_name from oe_compatibility_cve_database where - cve_id='CVE-2024-41053'; - - question: CVE-2024-41053的cvss评分是多少? - sql: select DISTINCT cve_id,cvsss_core_nvd from oe_compatibility_cve_database - where cve_id='CVE-2024-41053'; - - question: CVE-2024-41053现在修复了么? - sql: select DISTINCT cve_id, status from oe_compatibility_cve_database where cve_id='CVE-2024-41053'; - - question: CVE-2024-41053影响了openEuler哪些版本? - sql: select DISTINCT cve_id, affected_product from oe_compatibility_cve_database - where cve_id='CVE-2024-41053'; - - question: CVE-2024-41053发布时间是? - sql: select DISTINCT cve_id, announcement_time from oe_compatibility_cve_database - where cve_id='CVE-2024-41053'; - - question: openEuler-20.03-LTS-SP4在2024年8月发布哪些漏洞? - sql: select DISTINCT affected_product,cve_id,announcement_time from oe_compatibility_cve_database - where cve_id='CVE-2024-41053' and affected_product='openEuler-20.03-LTS-SP4' - and EXTRACT(MONTH FROM announcement_time)=8; - - question: openEuler-20.03-LTS-SP4在2024年发布哪些漏洞? - sql: select DISTINCT affected_product,cve_id,announcement_time from oe_compatibility_cve_database - where cve_id='CVE-2024-41053' and affected_product='openEuler-20.03-LTS-SP4' - and EXTRACT(YEAR FROM announcement_time)=2024; - - question: CVE-2024-41053的威胁程度是怎样的? - sql: select DISTINCT affected_product,cve_id,cvsss_core_nvd,attack_complexity_nvd,attack_complexity_oe,attack_vector_nvd,attack_vector_oe - from oe_compatibility_cve_database where cve_id='CVE-2024-41053'; - table_name: oe_compatibility_cve_database -- keyword_list: - - name - sql_example_list: - - question: openEuler-20.03-LTS的非官方软件包有多少个? - sql: SELECT COUNT(*) FROM oe_compatibility_oepkgs WHERE repotype = 'openeuler_compatible' - AND openeuler_version ILIKE '%openEuler-20.03-LTS%'; - - question: openEuler支持的nginx版本有哪些? - sql: SELECT DISTINCT name,version, srcrpmpackurl FROM oe_compatibility_oepkgs - WHERE name ILIKE 'nginx'; - - question: openEuler的支持哪些架构的glibc? - sql: SELECT DISTINCT name,arch FROM oe_compatibility_oepkgs WHERE name ILIKE 'glibc'; - - question: openEuler-22.03-LTS带GPLv2许可的软件包有哪些 - sql: SELECT name,rpmlicense FROM oe_compatibility_oepkgs WHERE openeuler_version - ILIKE '%openEuler-22.03-LTS%' AND rpmlicense = 'GPLv2'; - - question: openEuler支持的python3这个软件包是用来干什么的? - sql: SELECT DISTINCT name,summary FROM oe_compatibility_oepkgs WHERE name ILIKE - 'python3'; - - question: 哪些版本的openEuler的zlib中有官方源的? - sql: SELECT DISTINCT openeuler_version,name,version FROM oe_compatibility_oepkgs - WHERE name ILIKE '%zlib%' AND repotype = 'openeuler_official'; - - question: 请以表格的形式提供openEuler-20.09的gcc软件包的下载链接 - sql: SELECT DISTINCT openeuler_version,name, rpmpackurl FROM oe_compatibility_oepkgs - WHERE openeuler_version ILIKE '%openEuler-20.09%' AND name ilike 'gcc'; - - question: 请以表格的形式提供openEuler-20.09的glibc软件包的下载链接 - sql: SELECT DISTINCT openeuler_version,name, rpmpackurl FROM oe_compatibility_oepkgs - WHERE openeuler_version ILIKE '%openEuler-20.09%' AND name ilike 'glibc'; - - question: 请以表格的形式提供openEuler-20.09的redis软件包的下载链接 - sql: SELECT DISTINCT openeuler_version,name, rpmpackurl FROM oe_compatibility_oepkgs - WHERE openeuler_version ILIKE '%openEuler-20.09%' AND name ilike 'redis'; - - question: openEuler-20.09的支持多少个软件包? - sql: select tmp_table.openeuler_version,count(*) as oepkgs_cnt from (select DISTINCT - openeuler_version,name from oe_compatibility_oepkgs WHERE openeuler_version - ILIKE '%openEuler-20.09') as tmp_table group by tmp_table.openeuler_version; - - question: openEuler支持多少个软件包? - sql: select tmp_table.openeuler_version,count(*) as oepkgs_cnt from (select DISTINCT - openeuler_version,name from oe_compatibility_oepkgs) as tmp_table group by tmp_table.openeuler_version; - - question: 请以表格的形式提供openEuler-20.09的gcc的版本 - sql: SELECT DISTINCT openeuler_version,name, version FROM oe_compatibility_oepkgs - WHERE openeuler_version ILIKE '%openEuler-20.09%' AND name ilike 'gcc'; - - question: 请以表格的形式提供openEuler-20.09的glibc的版本 - sql: SELECT DISTINCT openeuler_version,name, version FROM oe_compatibility_oepkgs - WHERE openeuler_version ILIKE '%openEuler-20.09%' AND name ilike 'glibc'; - - question: 请以表格的形式提供openEuler-20.09的redis的版本 - sql: SELECT DISTINCT openeuler_version,name, version FROM oe_compatibility_oepkgs - WHERE openeuler_version ILIKE '%openEuler-20.09%' AND name ilike 'redis'; - - question: openEuler-20.09支持哪些gcc的版本 - sql: SELECT DISTINCT openeuler_version,name, version FROM oe_compatibility_oepkgs - WHERE openeuler_version ILIKE '%openEuler-20.09%' AND name ilike 'gcc'; - - question: openEuler-20.09支持哪些glibc的版本 - sql: SELECT DISTINCT openeuler_version,name, version FROM oe_compatibility_oepkgs - WHERE openeuler_version ILIKE '%openEuler-20.09%' AND name ilike 'glibc'; - - question: openEuler-20.09支持哪些redis的版本 - sql: SELECT DISTINCT openeuler_version,name, version FROM oe_compatibility_oepkgs - WHERE openeuler_version ILIKE '%openEuler-20.09%' AND name ilike 'redis'; - - question: '' - sql: openEuler-20.09支持的gcc版本有哪些 - - question: SELECT DISTINCT openeuler_version,name, version FROM oe_compatibility_oepkgs - WHERE openeuler_version ILIKE '%openEuler-20.09%' AND name ilike 'gcc'; - sql: openEuler-20.09支持的glibc版本有哪些 - - question: SELECT DISTINCT openeuler_version,name, version FROM oe_compatibility_oepkgs - WHERE openeuler_version ILIKE '%openEuler-20.09%' AND name ilike 'glibc'; - sql: openEuler-20.09支持的redis版本有哪些 - - question: SELECT DISTINCT openeuler_version,name, version FROM oe_compatibility_oepkgs - WHERE openeuler_version ILIKE '%openEuler-20.09%' AND name ilike 'redis'; - sql: '' - - question: openEuler-20.09支持gcc 9.3.1么? - sql: SELECT DISTINCT openeuler_version,name, version FROM oe_compatibility_oepkgs - WHERE openeuler_version ILIKE '%openEuler-20.09%' AND name ilike 'gcc' AND version - ilike '9.3.1'; - table_name: oe_compatibility_oepkgs -- keyword_list: [] - sql_example_list: - - question: openEuler社区创新版本有哪些 - sql: SELECT DISTINCT openeuler_version,version_type FROM oe_community_openeuler_version - where version_type ILIKE '%社区创新版本%'; - - question: openEuler有哪些版本 - sql: SELECT openeuler_version FROM public.oe_community_openeuler_version; - - question: 查询openeuler各版本对应的内核版本 - sql: SELECT DISTINCT openeuler_version, kernel_version FROM public.oe_community_openeuler_version; - - question: openEuler有多少个长期支持版本(LTS) - sql: SELECT COUNT(*) as publish_version_count FROM public.oe_community_openeuler_version - WHERE version_type ILIKE '%长期支持版本%'; - - question: 查询openEuler-20.03的所有SP版本 - sql: SELECT openeuler_version FROM public.oe_community_openeuler_version WHERE - openeuler_version ILIKE '%openEuler-20.03-LTS-SP%'; - - question: openEuler最新的社区创新版本内核是啥 - sql: SELECT kernel_version FROM public.oe_community_openeuler_version WHERE version_type - ILIKE '%社区创新版本%' ORDER BY publish_time DESC LIMIT 1; - - question: 最早的openEuler版本是什么时候发布的 - sql: SELECT openeuler_version,publish_time FROM public.oe_community_openeuler_version - ORDER BY publish_time ASC LIMIT 1; - - question: 最新的openEuler版本是哪个 - sql: SELECT openeuler_version,publish_time FROM public.oe_community_openeuler_version - ORDER BY publish_time LIMIT 1; - - question: openEuler有哪些版本使用了Linux 5.10.0内核 - sql: SELECT openeuler_version,kernel_version FROM public.oe_community_openeuler_version - WHERE kernel_version ILIKE '5.10.0%'; - - question: 哪个openEuler版本是最近更新的长期支持版本 - sql: SELECT openeuler_version,publish_time FROM public.oe_community_openeuler_version - WHERE version_type ILIKE '%长期支持版本%' ORDER BY publish_time DESC LIMIT 1; - - question: openEuler每个年份发布了多少个版本 - sql: SELECT EXTRACT(YEAR FROM publish_time) AS year, COUNT(*) AS publish_version_count - FROM oe_community_openeuler_version group by EXTRACT(YEAR FROM publish_time); - - question: openEuler-20.03-LTS版本的linux内核是多少? - sql: SELECT openeuler_version,kernel_version FROM public.oe_community_openeuler_version - WHERE openeuler_version = 'openEuler-20.03-LTS'; - - question: openEuler-20.03-LTS版本的linux内核是多少? - sql: SELECT openeuler_version,kernel_version FROM public.oe_community_openeuler_version - WHERE openeuler_version = 'openEuler-24.09'; - table_name: oe_community_openeuler_version -- keyword_list: - - product - sql_example_list: - - question: 哪些openEuler版本支持使用至强6338N的解决方案 - sql: SELECT DISTINCT openeuler_version FROM oe_compatibility_solution WHERE cpu - ILIKE '%6338N%'; - - question: 使用intel XXV710作为网卡的解决方案对应的是哪些服务器型号 - sql: SELECT DISTINCT server_model FROM oe_compatibility_solution WHERE network_card - ILIKE '%intel XXV710%'; - - question: 哪些解决方案的硬盘驱动为SATA-SSD Skhynix - sql: SELECT DISTINCT product FROM oe_compatibility_solution WHERE hard_disk_drive - ILIKE 'SATA-SSD Skhynix'; - - question: 查询所有使用6230R系列CPU且支持磁盘阵列支持PERC H740P Adapter的解决方案的产品名 - sql: SELECT DISTINCT product FROM oe_compatibility_solution WHERE cpu ILIKE '%6230R%' - AND raid ILIKE '%PERC H740P Adapter%'; - - question: R4900-G3有哪些驱动版本 - sql: SELECT DISTINCT driver FROM oe_compatibility_solution WHERE product ILIKE - '%R4900-G3%'; - - question: DL380 Gen10支持哪些架构 - sql: SELECT DISTINCT architecture FROM oe_compatibility_solution WHERE server_model - ILIKE '%DL380 Gen10%'; - - question: 列出所有使用Intel(R) Xeon(R)系列cpu且磁盘冗余阵列为LSI SAS3408的解决方案的服务器厂家 - sql: SELECT DISTINCT server_vendor FROM oe_compatibility_solution WHERE cpu ILIKE - '%Intel(R) Xeon(R)%' AND raid ILIKE '%LSI SAS3408%'; - - question: 哪些解决方案提供了针对SEAGATE ST4000NM0025硬盘驱动的支持 - sql: SELECT * FROM oe_compatibility_solution WHERE hard_disk_drive ILIKE '%SEAGATE - ST4000NM0025%'; - - question: 查询所有使用4316系列CPU的解决方案 - sql: SELECT * FROM oe_compatibility_solution WHERE cpu ILIKE '%4316%'; - - question: 支持openEuler-22.03-LTS-SP2版本的解决方案中,哪款服务器型号出现次数最多 - sql: SELECT server_model, COUNT(*) as count FROM oe_compatibility_solution WHERE - openeuler_version ILIKE '%openEuler-22.03-LTS-SP2%' GROUP BY server_model ORDER - BY count DESC LIMIT 1; - - question: HPE提供的解决方案的介绍链接是什么 - sql: SELECT DISTINCT introduce_link FROM oe_compatibility_solution WHERE server_vendor - ILIKE '%HPE%'; - - question: 列出所有使用intel XXV710网络卡接口的解决方案的CPU型号 - sql: SELECT DISTINCT cpu FROM oe_compatibility_solution WHERE network_card ILIKE - '%intel XXV710%'; - - question: 服务器型号为2288H V5的解决方案支持哪些不同的openEuler版本 - sql: SELECT DISTINCT openeuler_version FROM oe_compatibility_solution WHERE server_model - ILIKE '%NF5180M5%'; - - question: 使用6230R系列CPU的解决方案内存最小是多少GB - sql: SELECT MIN(ram) FROM oe_compatibility_solution WHERE cpu ILIKE '%6230R%'; - - question: 哪些解决方案的磁盘驱动为MegaRAID 9560-8i - sql: SELECT * FROM oe_compatibility_solution WHERE hard_disk_drive LIKE '%MegaRAID - 9560-8i%'; - - question: 列出所有使用6330N系列CPU且服务器厂家为Dell的解决方案的产品名 - sql: SELECT DISTINCT product FROM oe_compatibility_solution WHERE cpu ILIKE '%6330N%' - AND server_vendor ILIKE '%Dell%'; - - question: R4900-G3的驱动版本是多少 - sql: SELECT driver FROM oe_compatibility_solution WHERE product ILIKE '%R4900-G3%'; - - question: 哪些解决方案的服务器型号为2288H V7 - sql: SELECT * FROM oe_compatibility_solution WHERE server_model ILIKE '%2288H - V7%'; - - question: 使用Intel i350网卡且硬盘驱动为ST4000NM0025的解决方案的服务器厂家有哪些 - sql: SELECT DISTINCT server_vendor FROM oe_compatibility_solution WHERE network_card - ILIKE '%Intel i350%' AND hard_disk_drive ILIKE '%ST4000NM0025%'; - - question: 有多少种不同的驱动版本被用于支持openEuler-22.03-LTS-SP2版本的解决方案 - sql: SELECT COUNT(DISTINCT driver) FROM oe_compatibility_solution WHERE openeuler_version - ILIKE '%openEuler-22.03-LTS-SP2%'; - table_name: oe_compatibility_solution diff --git a/chat2db/config/config.py b/chat2db/config/config.py index 2c2c1a56f977fecd46535d6a296a8920b857a052..2f664d31741763daf748df674da957807a607ec8 100644 --- a/chat2db/config/config.py +++ b/chat2db/config/config.py @@ -6,39 +6,13 @@ from pydantic import BaseModel, Field class ConfigModel(BaseModel): - # FastAPI - UVICORN_IP: str = Field(None, description="FastAPI 服务的IP地址") - UVICORN_PORT: int = Field(None, description="FastAPI 服务的端口号") - SSL_CERTFILE: str = Field(None, description="SSL证书文件的路径") - SSL_KEYFILE: str = Field(None, description="SSL密钥文件的路径") - SSL_ENABLE: str = Field(None, description="是否启用SSL连接") - - # Postgres - DATABASE_TYPE: str = Field(default="postgres", description="数据库类型") - DATABASE_HOST: str = Field(None, description="数据库地址") - DATABASE_PORT: int = Field(None, description="数据库端口") - DATABASE_USER: str = Field(None, description="数据库用户名") - DATABASE_PASSWORD: str = Field(None, description="数据库密码") - DATABASE_DB: str = Field(None, description="数据库名称") - - # QWEN + + # LLM LLM_KEY: str = Field(None, description="语言模型访问密钥") LLM_URL: str = Field(None, description="语言模型服务的基础URL") LLM_MAX_TOKENS: int = Field(None, description="单次请求中允许的最大Token数") LLM_MODEL: str = Field(None, description="使用的语言模型名称或版本") - # Vectorize - EMBEDDING_TYPE: str = Field("openai", description="embedding 服务的类型") - EMBEDDING_API_KEY: str = Field(None, description="embedding服务api key") - EMBEDDING_ENDPOINT: str = Field(None, description="embedding服务url地址") - EMBEDDING_MODEL_NAME: str = Field(None, description="embedding模型名称") - - # security - HALF_KEY1: str = Field(None, description='加密的密钥组件1') - HALF_KEY2: str = Field(None, description='加密的密钥组件2') - HALF_KEY3: str = Field(None, description='加密的密钥组件3') - - class Config: config: ConfigModel @@ -46,7 +20,7 @@ class Config: if os.getenv("CONFIG"): config_file = os.getenv("CONFIG") else: - config_file = "./chat2db/common/.env" + config_file = "chat2db/common/.env" self.config = ConfigModel(**(dotenv_values(config_file))) if os.getenv("PROD"): os.remove(config_file) diff --git a/chat2db/database/postgres.py b/chat2db/database/postgres.py deleted file mode 100644 index ea4470d49368fe289d8fbc1e07498191aa8ec6a2..0000000000000000000000000000000000000000 --- a/chat2db/database/postgres.py +++ /dev/null @@ -1,135 +0,0 @@ -import logging -from uuid import uuid4 -import urllib.parse -from pgvector.sqlalchemy import Vector -from sqlalchemy.orm import sessionmaker, declarative_base -from sqlalchemy import TIMESTAMP, UUID, Column, String, Boolean, ForeignKey, create_engine, func, Index -import sys -from chat2db.config.config import config - -logging.basicConfig(stream=sys.stdout, level=logging.INFO, - format='%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s') -Base = declarative_base() - - -class DatabaseInfo(Base): - __tablename__ = 'database_info_table' - id = Column(UUID(), default=uuid4, primary_key=True) - encrypted_database_url = Column(String()) - encrypted_config = Column(String()) - hashmac = Column(String()) - created_at = Column(TIMESTAMP(timezone=True), nullable=True, server_default=func.current_timestamp()) - - -class TableInfo(Base): - __tablename__ = 'table_info_table' - id = Column(UUID(), default=uuid4, primary_key=True) - database_id = Column(UUID(), ForeignKey('database_info_table.id', ondelete='CASCADE')) - table_name = Column(String()) - table_note = Column(String()) - table_note_vector = Column(Vector(1024)) - enable = Column(Boolean, default=False) - created_at = Column(TIMESTAMP(timezone=True), nullable=True, server_default=func.current_timestamp()) - updated_at = Column( - TIMESTAMP(timezone=True), - server_default=func.current_timestamp(), - onupdate=func.current_timestamp()) - __table_args__ = ( - Index( - 'table_note_vector_index', - table_note_vector, - postgresql_using='hnsw', - postgresql_with={'m': 16, 'ef_construction': 200}, - postgresql_ops={'table_note_vector': 'vector_cosine_ops'} - ), - ) - - -class ColumnInfo(Base): - __tablename__ = 'column_info_table' - id = Column(UUID(), default=uuid4, primary_key=True) - table_id = Column(UUID(), ForeignKey('table_info_table.id', ondelete='CASCADE')) - column_name = Column(String) - column_type = Column(String) - column_note = Column(String) - enable = Column(Boolean, default=False) - - -class SqlExample(Base): - __tablename__ = 'sql_example_table' - id = Column(UUID(), default=uuid4, primary_key=True) - table_id = Column(UUID(), ForeignKey('table_info_table.id', ondelete='CASCADE')) - question = Column(String()) - sql = Column(String()) - question_vector = Column(Vector(1024)) - created_at = Column(TIMESTAMP(timezone=True), nullable=True, server_default=func.current_timestamp()) - updated_at = Column( - TIMESTAMP(timezone=True), - server_default=func.current_timestamp(), - onupdate=func.current_timestamp()) - __table_args__ = ( - Index( - 'question_vector_index', - question_vector, - postgresql_using='hnsw', - postgresql_with={'m': 16, 'ef_construction': 200}, - postgresql_ops={'question_vector': 'vector_cosine_ops'} - ), - ) - - -class PostgresDB: - _engine = None - - @classmethod - def get_mysql_engine(cls): - if not cls._engine: - password = config['DATABASE_PASSWORD'] - encoded_password = urllib.parse.quote_plus(password) - - if config['DATABASE_TYPE'].lower() == 'opengauss': - database_url = f"opengauss+psycopg2://{config['DATABASE_USER']}:{encoded_password}@{config['DATABASE_HOST']}:{config['DATABASE_PORT']}/{config['DATABASE_DB']}" - else: - database_url = f"postgresql+psycopg2://{config['DATABASE_USER']}:{encoded_password}@{config['DATABASE_HOST']}:{config['DATABASE_PORT']}/{config['DATABASE_DB']}" - cls.engine = create_engine( - database_url, - hide_parameters=True, - echo=False, - pool_recycle=300, - pool_pre_ping=True) - - Base.metadata.create_all(cls.engine) - if config['DATABASE_TYPE'].lower() == 'opengauss': - from sqlalchemy import event - from opengauss_sqlalchemy.register_async import register_vector - - @event.listens_for(cls.engine.sync_engine, "connect") - def connect(dbapi_connection, connection_record): - dbapi_connection.run_async(register_vector) - return cls._engine - - @classmethod - def get_session(cls): - connection = None - try: - connection = sessionmaker(bind=cls.engine)() - except Exception as e: - logging.error(f"Error creating a postgres sessiondue to error: {e}") - return None - return cls._ConnectionManager(connection) - - class _ConnectionManager: - def __init__(self, connection): - self.connection = connection - - def __enter__(self): - return self.connection - - def __exit__(self, exc_type, exc_val, exc_tb): - try: - self.connection.close() - except Exception as e: - logging.error(f"Postgres connection close failed due to error: {e}") - - -PostgresDB.get_mysql_engine() diff --git "a/chat2db/docs/chat2db\345\267\245\345\205\267\350\257\246\347\273\206\350\257\264\346\230\216.md" "b/chat2db/docs/chat2db\345\267\245\345\205\267\350\257\246\347\273\206\350\257\264\346\230\216.md" deleted file mode 100644 index e4d475885fd4242a9b69f7c631e7ca5a6ee5338d..0000000000000000000000000000000000000000 --- "a/chat2db/docs/chat2db\345\267\245\345\205\267\350\257\246\347\273\206\350\257\264\346\230\216.md" +++ /dev/null @@ -1,391 +0,0 @@ -# 1. 背景说明 -工具聚焦于利用大模型能力智能生成SQL语句,查询数据库数据,为最终的模型拟合提供能力增强。工具可增强RAG多路召回能力,增强RAG对本地用户的数据适应性,同时对于服务器、硬件型号等关键字场景,在不训练模型的情况下,RAG也具备一定检索能力 - -# 2. 工具设计框架 -## 2.1 目录结构 -``` -chat2db -|-- app # 应用主入口及相关功能模块 -|-- |-- app.py # 服务请求入口,处理用户请求并返回结果 -|-- |-- __init__.py # 初始化 -|-- | -|-- |-- base # 基础功能模块 -|-- |-- |-- ac_automation.py # AC 自动机 -|-- |-- |-- mysql.py # MySQL 数据库操作封装 -|-- |-- |-- postgres.py # PostgreSQL 数据库操作封装 -|-- |-- |-- vectorize.py # 数据向量化处理模块 -|-- | -|-- |-- router # 路由模块,负责分发请求到具体服务 -|-- |-- |-- database.py # 数据库相关路由逻辑 -|-- |-- |-- sql_example.py # SQL 示例管理路由 -|-- |-- |-- sql_generate.py # SQL 生成相关路由 -|-- |-- |-- table.py # 表信息管理路由 -|-- | -|-- |-- service # 核心服务模块 -|-- |-- |-- diff_database_service.py # 不同数据库类型的服务适配 -|-- |-- |-- keyword_service.py # 关键字检索服务 -|-- |-- |-- sql_generate_service.py # SQL 生成服务逻辑 -| -|-- common # 公共资源及配置 -|-- |-- .env # 环境变量配置文件 -|-- |-- init_sql_example.py # 初始化 SQL 示例数据脚本 -|-- |-- table_name_id.yaml # 表名与 ID 映射配置 -|-- |-- table_name_sql_example.yaml # 表名与 SQL 示例映射配置 -| -|-- config # 配置模块 -|-- |-- config.py # 工具全局配置文件 -| -|-- database # 数据库相关模块 -|-- |-- postgres.py # PostgreSQL 数据库连接及操作封装 -| -|-- llm # 大模型交互模块 -|-- |-- chat_with_model.py # 与大模型交互的核心逻辑 -| -|-- manager # 数据管理模块 -|-- |-- column_info_manager.py # 列信息管理逻辑 -|-- |-- database_info_manager.py # 数据库信息管理逻辑 -|-- |-- sql_example_manager.py # SQL 示例管理逻辑 -|-- |-- table_info_manager.py # 表信息管理逻辑 -| -|-- model # 数据模型模块 -|-- |-- request.py # 请求数据模型定义 -|-- |-- response.py # 响应数据模型定义 -| -|-- scripts # 脚本工具模块 -|-- |-- chat2db_config # 工具配置相关脚本 -|-- |-- |-- config.yaml # 工具配置文件模板 -|-- |-- output_example # 输出示例相关脚本 -|-- |-- |-- output_examples.txt # 输出示例文件 -|-- |-- run_chat2db.py # 启动工具进行交互的主脚本 -| -|-- security # 安全模块 -|-- |-- security.py # 安全相关逻辑(如权限校验、加密等) -| -|-- template # 模板及提示词相关模块 -|-- |-- change_txt_to_yaml.py # 将文本提示转换为 YAML 格式的脚本 -|-- |-- prompt.yaml # 提示词模板文件,用于生成 SQL 或问题 -``` -# 3. 主要功能介绍 -## **3.1 智能生成 SQL 查询** -- **功能描述**: - - 工具的核心功能是利用大模型(如 LLM)智能生成符合用户需求的 SQL 查询语句。 - - 用户可以通过自然语言提问,工具会根据问题内容、表结构、示例数据等信息生成对应的 SQL 查询。 -- **实现模块**: - - **路由模块**:`router/sql_generate.py` 负责接收用户请求并调用相关服务。 - - **服务模块**:`service/sql_generate_service.py` 提供 SQL 生成的核心逻辑。 - - **提示词模板**:`template/prompt.yaml` 中定义了生成 SQL 的提示词模板。 - - **数据库适配**:`base/postgres.py` 和 `base/mysql.py` 提供不同数据库的操作封装。 -- **应用场景**: - - 用户无需掌握复杂的 SQL 语法,只需通过自然语言即可完成查询。 - - 支持多种数据库类型(如 PostgreSQL 和 MySQL) - ---- - -## **3.2 关键字检索与多路召回** -- **功能描述**: - - 工具支持基于关键字的检索功能,增强 RAG 的多路召回能力。 - - 对于服务器、硬件型号等特定场景,即使未训练模型,也能通过关键字匹配快速检索相关数据。 -- **实现模块**: - - **路由模块**:`router/keyword.py` 负责处理关键字检索请求。 - - **服务模块**:`service/keyword_service.py` 提供关键字检索的核心逻辑。 - - **AC 自动机**:`base/ac_automation.py` 实现高效的多模式字符串匹配。 -- **应用场景**: - - 在不依赖大模型的情况下,快速检索与关键字相关的 SQL 示例或表信息。 - - 适用于硬件型号、服务器配置等特定场景的快速查询。 - ---- - -## **3.3 数据库表与列信息管理** -- **功能描述**: - - 工具提供对数据库表和列信息的管理功能,包括元数据存储、查询和更新。 - - 用户可以通过工具查看表结构、列注释等信息,并将其用于 SQL 查询生成。 -- **实现模块**: - - **路由模块**:`router/table.py` 负责表信息相关的请求分发。 - - **管理模块**: - - `manager/table_info_manager.py`:管理表信息。 - - `manager/column_info_manager.py`:管理列信息。 - - **数据模型**:`model/request.py` 和 `model/response.py` 定义了表和列信息的数据结构。 -- **应用场景**: - - 用户可以快速了解数据库的表结构,辅助生成更准确的 SQL 查询。 - - 支持动态更新表和列信息,适应本地数据的变化。 - ---- - -## **3.4 SQL 示例管理** -- **功能描述**: - - 工具支持对 SQL 示例的增删改查操作,并结合向量相似度检索最相关的 SQL 示例。 - - 用户可以通过问题向量找到与当前问题最相似的历史 SQL 示例,从而加速查询生成。 -- **实现模块**: - - **路由模块**:`router/sql_example.py` 负责 SQL 示例相关的请求分发。 - - **管理模块**:`manager/sql_example_manager.py` 提供 SQL 示例的管理逻辑。 - - **向量化处理**:`base/vectorize.py` 将问题文本转换为向量表示。 - - **余弦距离排序**:利用 PostgreSQL 的向量计算能力,按余弦距离排序检索最相似的 SQL 示例。 -- **应用场景**: - - 在生成新 SQL 查询时,参考历史 SQL 示例,提高查询的准确性和效率。 - - 支持对 SQL 示例的灵活管理,便于维护和扩展。 - -# 4. 工具使用 - -## 4.1 服务启动与配置 - -### 服务环境配置 - -- 在common/.env文件中配置数据库连接信息,大模型API密钥等必要参数 - -### 数据库配置 - -```bash -# 进行数据库初始化,例如 -postgres=# CREATE EXTENSION zhparser; -postgres=# CREATE EXTENSIONpostgres=# CREATE EXTENSION vector; -postgres=# CREATE TEXT SEARCH CONFIGURATION zhparser (PARSER = zhparser); -postgres=# ALTER TEXT SEARCH CONFIGURATION zhparser ADD MAPPING FOR n,v,a,i,e,l WITH simple; -postgres=# exit -``` - -### 启动服务 - -```bash -# 读取.env 环境配置,app.py入口启动服务 -python3 chat2db/app/app.py -# 配置run_chat2db.py端口 -python3 chat2db/scripts/run_chat2db.py config --ip xxx --port xxx -``` - ---- - -## 4.2 命令行工具操作指南 - -### 1. 数据库操作 - -#### 添加数据库 -```bash -python3 run_chat2db.py add_db --database_url "postgresql+psycopg2://user:password@localhost:5444/mydb" - -# 成功返回示例 ->> success ->> database_id: 27fa7fd3-949b-41f9-97bc-530f498c0b57 -``` - -#### 删除数据库 - -```bash -python3 run_chat2db.py del_db --database_id mydb_database_id -``` - -#### 查询已配置数据库 - -```bash -python3 run_chat2db.py query_db - -# 返回示例 ----------------------------------------- -查询数据库配置成功 ----------------------------------------- -database_id: 27fa7fd3-949b-41f9-97bc-530f498c0b57 -database_url: postgresql+psycopg2://postgres:123456@0.0.0.0:5444/mydb -created_at: 2025-04-08T01:49:27.544521Z ----------------------------------------- -``` - -#### 查询在数据库中的表 - -```bash -python3 run_chat2db.py list_tb_in_db --database_id mydb_database_id -# 返回示例 ----------------------------------------- -{'database_id': '27fa7fd3-949b-41f9-97bc-530f498c0b57', 'table_filter': None} -查询数据库配置成功 -my_table ----------------------------------------- -# 可过滤表名 -python3 run_chat2db.py list_tb_in_db --database_id mydb_database_id --table_filter my_table -# 返回示例 ----------------------------------------- -{'database_id': '27fa7fd3-949b-41f9-97bc-530f498c0b57', 'table_filter': 'my_table'} -查询数据库配置成功 -my_table ----------------------------------------- -``` - ---- - -### 2. 表操作 - -#### 添加数据表 -```bash -python3 run_chat2db.py add_tb --database_id mydb_database_id --table_name users - -# 成功返回示例 ->> 数据表添加成功 ->> table_id: tb_0987654321 -``` - -#### 查询已添加的表 - -```bash -python3 run_chat2db.py query_tb --database_id mydb_database_id -# 返回示例 -查询表格成功 ----------------------------------------- -table_id: 984d1c82-c6d5-4d3d-93d9-8d5bc11254ba -table_name: oe_compatibility_cve_database -table_note: openEuler社区组cve漏洞信息表,存储了cve漏洞的公告时间、id、关联的软件包名称、简介、cvss评分 -created_at: 2025-03-16T12:13:51.920663Z ----------------------------------------- -``` - -#### 删除数据表 - -```bash -python3 run_chat2db.py del_tb --table_id my_table_id -# 返回示例 -删除表格成功 -``` - -#### 查询表的列信息 - -```bash -python run_chat2db.py query_col --table_id my_table_id - -# 返回示例 --------------------------------------------------------- -column_id: 5ef50ebb-310b-48cc-bbc7-cf161c779055 -column_name: id -column_note: None -column_type: bigint -enable: False --------------------------------------------------------- -column_id: 69cf3c00-8e3c-4b99-83a5-6942278a60f3 -column_name: architecture -column_note: openEuler支持的板卡信息的支持架构 -column_type: character varying -enable: False --------------------------------------------------------- -``` - -#### 启用禁用指定列 - -```bash -python3 run_chat2db.py enable_col --column_id my_column_id --enable False -# 返回示例 -列关键字功能开启/关闭成功 -``` - ---- - -### 3. SQL示例操作 - -#### 生成SQL示例 - -```bash -python3 run_chat2db.py add_sql_exp --table_id "your_table_id" --question "查询所有用户" --sql "SELECT * FROM users" -# 返回示例 -success -sql_example_id: 4282bce7-f2fd-42b0-a63b-7afd53d9e704 -``` - -#### 批量添加SQL示例 - -1. 创建Excel文件(示例格式): - - | question | sql | - |----------|----------------------------------------------| - | 查询所有用户 | SELECT * FROM users | - | 统计北京地区用户 | SELECT COUNT(*) FROM users WHERE region='北京' | - -2. 执行导入命令: - -```bash -python3 run_chat2db.py add_sql_exp --table_id "your_table_id" --dir "path/to/examples.xlsx" -# 成功返回示例 ->> 成功添加示例:查询所有用户 ->> sql_example_id: exp_556677 ->> 成功添加示例:统计北京地区用户 ->> sql_example_id: exp_778899 -``` - ---- - -#### 删除SQL示例 - -```bash -python3 run_chat2db.py del_sql_exp --sql_example_id "your_example_id" -# 返回示例 -sql案例删除成功 -``` - -#### 查询指定表的SQL示例 - -```bash -python3 run_chat2db.py query_sql_exp --table_id "your_table_id" -# 返回示例 -查询SQL案例成功 --------------------------------------------------------- -sql_example_id: 5ab552db-b122-4653-bfdc-085c0b8557d6 -question: 查询所有用户 -sql: SELECT * FROM users --------------------------------------------------------- -``` - -#### 更新SQL示例 - -```bash -python3 run_chat2db.py update_sql_exp --sql_example_id "your_example_id" --question "新问题" --sql "新SQL语句" -# 返回示例 -sql案例更新成功 -``` - -#### 生成指定数据表SQL示例 - -```bash -python run_chat2db.py generate_sql_exp --table_id "your_table_id" --generate_cnt 5 --sql_var True --dir "output.xlsx" -# --generate_cnt 参数: 生成sql对的数量 ;--sql_var: 是否验证生成的sql对,True为验证,False不验证 -# 返回示例 -sql案例生成成功 -Data written to Excel file successfully. -``` - -### 4. 智能查询 - -#### 通过自然语言生成SQL(需配合前端或API调用) - -```python -# 示例API请求 -import requests - -url = "http://localhost:8000/sql/generate" -payload = { - "question": "显示最近7天注册的用户", - "table_id": "tb_0987654321" -} - -response = requests.post(url, json=payload) -print(response.json()) - -# 返回示例 -{ - "sql": "SELECT * FROM users WHERE registration_date >= CURRENT_DATE - INTERVAL '7 days'", - "confidence": 0.92 -} -``` - ---- - -5. **执行智能查询** -```http -POST /sql/generate -Content-Type: application/json - -{ - "question": "找出过去一个月销售额超过1万元的商品", - "table_id": "tb_yyyy" -} -``` - - - - - - - diff --git a/chat2db/llm/chat_with_model.py b/chat2db/llm/chat_with_model.py deleted file mode 100644 index 9cc1ad2d60bd40e79318ef4df29fc9d47f15c250..0000000000000000000000000000000000000000 --- a/chat2db/llm/chat_with_model.py +++ /dev/null @@ -1,25 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. -from langchain_openai import ChatOpenAI -from langchain.schema import SystemMessage, HumanMessage -import re - -class LLM: - def __init__(self, model_name, openai_api_base, openai_api_key, request_timeout, max_tokens, temperature): - self.client = ChatOpenAI(model_name=model_name, - openai_api_base=openai_api_base, - openai_api_key=openai_api_key, - request_timeout=request_timeout, - max_tokens=max_tokens, - temperature=temperature) - - def assemble_chat(self, system_call, user_call): - chat = [] - chat.append(SystemMessage(content=system_call)) - chat.append(HumanMessage(content=user_call)) - return chat - - async def chat_with_model(self, system_call, user_call): - chat = self.assemble_chat(system_call, user_call) - response = await self.client.ainvoke(chat) - content = re.sub(r'.*?\n\n', '', response.content, flags=re.DOTALL) - return content diff --git a/chat2db/app/base/meta_databbase.py b/chat2db/main.py similarity index 31% rename from chat2db/app/base/meta_databbase.py rename to chat2db/main.py index b21b1f1c2ae6fa6c82495fc6775b52533d9e1e2c..d41e4d8d67e803850fcc98f622aaa87509adc8d6 100644 --- a/chat2db/app/base/meta_databbase.py +++ b/chat2db/main.py @@ -1,18 +1,21 @@ +import uvicorn +from fastapi import FastAPI import sys import logging + +from chat2db.apps.routers import sql + logging.basicConfig(stream=sys.stdout, level=logging.INFO, format='%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s') -class MetaDatabase: - @staticmethod - def result_to_json(results): - """ - 将 SQL 查询结果解析为 JSON 格式的数据结构,支持多种数据类型 - """ - try: - results = [result._asdict() for result in results] - return results - except Exception as e: - logging.error(f"数据库查询结果解析失败由于: {e}") - raise e +app = FastAPI() + +app.include_router(sql.router) + +if __name__ == "__main__": + try: + uvicorn.run(app, host="127.0.0.1", port=9015, log_level="info") + + except Exception as e: + exit(1) diff --git a/chat2db/manager/column_info_manager.py b/chat2db/manager/column_info_manager.py deleted file mode 100644 index 789f49949d70fa0c78ecf9f2c37dcd7a36dc6f04..0000000000000000000000000000000000000000 --- a/chat2db/manager/column_info_manager.py +++ /dev/null @@ -1,69 +0,0 @@ -from sqlalchemy import and_ -import sys -from chat2db.database.postgres import ColumnInfo, PostgresDB - - -class ColumnInfoManager(): - @staticmethod - async def add_column_info_with_table_id(table_id, column_name, column_type, column_note): - column_info_entry = ColumnInfo(table_id=table_id, column_name=column_name, - column_type=column_type, column_note=column_note) - with PostgresDB.get_session() as session: - session.add(column_info_entry) - session.commit() - - @staticmethod - async def del_column_info_by_column_id(column_id): - with PostgresDB.get_session() as session: - column_info_to_delete = session.query(ColumnInfo).filter(ColumnInfo.id == column_id) - session.delete(column_info_to_delete) - session.commit() - - @staticmethod - async def get_column_info_by_column_id(column_id): - tmp_dict = {} - with PostgresDB.get_session() as session: - result = session.query(ColumnInfo).filter(ColumnInfo.id == column_id).first() - session.commit() - if not result: - return None - tmp_dict = { - 'column_id': result.id, - 'table_id': result.table_id, - 'column_name': result.column_name, - 'column_type': result.column_type, - 'column_note': result.column_note, - 'enable': result.enable - } - return tmp_dict - - @staticmethod - async def update_column_info_enable(column_id, enable=True): - with PostgresDB.get_session() as session: - column_info = session.query(ColumnInfo).filter(ColumnInfo.id == column_id).first() - if column_info is not None: - column_info.enable = True - session.commit() - else: - return False - return True - - @staticmethod - async def get_column_info_by_table_id(table_id, enable=None): - column_info_list = [] - with PostgresDB.get_session() as session: - if enable is None: - results = session.query(ColumnInfo).filter(ColumnInfo.table_id == table_id).all() - else: - results = session.query(ColumnInfo).filter( - and_(ColumnInfo.table_id == table_id, ColumnInfo.enable == enable)).all() - for result in results: - tmp_dict = { - 'column_id': result.id, - 'column_name': result.column_name, - 'column_type': result.column_type, - 'column_note': result.column_note, - 'enable': result.enable - } - column_info_list.append(tmp_dict) - return column_info_list diff --git a/chat2db/manager/database_info_manager.py b/chat2db/manager/database_info_manager.py deleted file mode 100644 index cc234fb12a0c72c6459261444a7ecbf0f99ea098..0000000000000000000000000000000000000000 --- a/chat2db/manager/database_info_manager.py +++ /dev/null @@ -1,98 +0,0 @@ -import json -import hashlib -import sys -import logging -from chat2db.database.postgres import DatabaseInfo, PostgresDB -from chat2db.security.security import Security - -logging.basicConfig(stream=sys.stdout, level=logging.INFO, - format='%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s') - - -class DatabaseInfoManager(): - @staticmethod - async def add_database(database_url: str): - id = None - with PostgresDB.get_session() as session: - encrypted_database_url, encrypted_config = Security.encrypt(database_url) - hashmac = hashlib.sha256(database_url.encode('utf-8')).hexdigest() - counter = session.query(DatabaseInfo).filter(DatabaseInfo.hashmac == hashmac).first() - if counter: - return id - encrypted_config = json.dumps(encrypted_config) - database_info_entry = DatabaseInfo(encrypted_database_url=encrypted_database_url, - encrypted_config=encrypted_config, hashmac=hashmac) - session.add(database_info_entry) - session.commit() - id = database_info_entry.id - return id - - @staticmethod - async def del_database_by_id(id): - with PostgresDB.get_session() as session: - database_info_to_delete = session.query(DatabaseInfo).filter(DatabaseInfo.id == id).first() - if database_info_to_delete: - session.delete(database_info_to_delete) - else: - return False - session.commit() - return True - - @staticmethod - async def del_database_by_url(database_url): - with PostgresDB.get_session() as session: - hashmac = hashlib.sha256(database_url.encode('utf-8')).hexdigest() - database_info_entry = session.query(DatabaseInfo).filter(DatabaseInfo.hashmac == hashmac).first() - if database_info_entry: - database_info_to_delete = session.query(DatabaseInfo).filter(DatabaseInfo.id == database_info_entry.id).first() - if database_info_to_delete: - session.delete(database_info_to_delete) - else: - return False - else: - return False - session.commit() - return True - - @staticmethod - async def get_database_url_by_id(id): - with PostgresDB.get_session() as session: - result = session.query( - DatabaseInfo.encrypted_database_url, DatabaseInfo.encrypted_config).filter( - DatabaseInfo.id == id).first() - if result is None: - return None - try: - encrypted_database_url, encrypted_config = result - encrypted_config = json.loads(encrypted_config) - except Exception as e: - logging.error(f'数据库url解密失败由于{e}') - return None - if encrypted_database_url: - database_url = Security.decrypt(encrypted_database_url, encrypted_config) - else: - return None - return database_url - @staticmethod - async def get_database_id_by_url(database_url: str): - with PostgresDB.get_session() as session: - hashmac = hashlib.sha256(database_url.encode('utf-8')).hexdigest() - database_info_entry = session.query(DatabaseInfo).filter(DatabaseInfo.hashmac == hashmac).first() - if database_info_entry: - return database_info_entry.id - return None - @staticmethod - async def get_all_database_info(): - with PostgresDB.get_session() as session: - results = session.query(DatabaseInfo).order_by(DatabaseInfo.created_at).all() - database_info_list = [] - for i in range(len(results)): - database_id = results[i].id - encrypted_database_url = results[i].encrypted_database_url - encrypted_config = json.loads(results[i].encrypted_config) - created_at = results[i].created_at - if encrypted_database_url: - database_url = Security.decrypt(encrypted_database_url, encrypted_config) - tmp_dict = {'database_id': database_id, 'database_url': database_url, 'created_at': created_at} - database_info_list.append(tmp_dict) - return database_info_list diff --git a/chat2db/manager/sql_example_manager.py b/chat2db/manager/sql_example_manager.py deleted file mode 100644 index 67ccbcacc215ba538bd68c2a14f29e6d0f0d7d04..0000000000000000000000000000000000000000 --- a/chat2db/manager/sql_example_manager.py +++ /dev/null @@ -1,76 +0,0 @@ -import json -from sqlalchemy import and_ -import sys -from chat2db.database.postgres import SqlExample, PostgresDB -from chat2db.security.security import Security - - -class SqlExampleManager(): - @staticmethod - async def add_sql_example(question, sql, table_id, question_vector): - id = None - sql_example_entry = SqlExample(question=question, sql=sql, - table_id=table_id, question_vector=question_vector) - with PostgresDB.get_session() as session: - session.add(sql_example_entry) - session.commit() - id = sql_example_entry.id - return id - - @staticmethod - async def del_sql_example_by_id(id): - with PostgresDB.get_session() as session: - sql_example_to_delete = session.query(SqlExample).filter(SqlExample.id == id).first() - if sql_example_to_delete: - session.delete(sql_example_to_delete) - else: - return False - session.commit() - return True - - @staticmethod - async def update_sql_example_by_id(id, question, sql, question_vector): - with PostgresDB.get_session() as session: - sql_example_to_update = session.query(SqlExample).filter(SqlExample.id == id).first() - if sql_example_to_update: - sql_example_to_update.sql = sql - sql_example_to_update.question = question - sql_example_to_update.question_vector = question_vector - session.commit() - else: - return False - return True - - @staticmethod - async def query_sql_example_by_table_id(table_id): - with PostgresDB.get_session() as session: - results = session.query(SqlExample).filter(SqlExample.table_id == table_id).all() - sql_example_list = [] - for result in results: - tmp_dict = { - 'sql_example_id': result.id, - 'question': result.question, - 'sql': result.sql - } - sql_example_list.append(tmp_dict) - return sql_example_list - - @staticmethod - async def get_topk_sql_example_by_cos_dis(question_vector, table_id_list=None, topk=3): - with PostgresDB.get_session() as session: - if table_id_list is not None: - sql_example_list = session.query( - SqlExample - ).filter(SqlExample.table_id.in_(table_id_list)).order_by( - SqlExample.question_vector.cosine_distance(question_vector) - ).limit(topk).all() - else: - sql_example_list = session.query( - SqlExample - ).order_by( - SqlExample.question_vector.cosine_distance(question_vector) - ).limit(topk).all() - sql_example_list = [ - {'table_id': sql_example.table_id, 'question': sql_example.question, 'sql': sql_example.sql} - for sql_example in sql_example_list] - return sql_example_list diff --git a/chat2db/manager/table_info_manager.py b/chat2db/manager/table_info_manager.py deleted file mode 100644 index fcf4f6668c11e56d8e8de92bb4a34abcbe070ba8..0000000000000000000000000000000000000000 --- a/chat2db/manager/table_info_manager.py +++ /dev/null @@ -1,87 +0,0 @@ -from sqlalchemy import and_ -import sys -from chat2db.database.postgres import TableInfo, PostgresDB - - -class TableInfoManager(): - @staticmethod - async def add_table_info(database_id, table_name, table_note, table_note_vector): - id = None - with PostgresDB.get_session() as session: - counter = session.query(TableInfo).filter( - and_(TableInfo.database_id == database_id, TableInfo.table_name == table_name)).first() - if counter: - return id - table_info_entry = TableInfo(database_id=database_id, table_name=table_name, - table_note=table_note, table_note_vector=table_note_vector) - session.add(table_info_entry) - session.commit() - id = table_info_entry.id - return id - - @staticmethod - async def del_table_by_id(id): - with PostgresDB.get_session() as session: - table_info_to_delete = session.query(TableInfo).filter(TableInfo.id == id).first() - if table_info_to_delete: - session.delete(table_info_to_delete) - else: - return False - session.commit() - return True - - @staticmethod - async def get_table_info_by_table_id(table_id): - with PostgresDB.get_session() as session: - table_id, database_id, table_name, table_note = session.query( - TableInfo.id, TableInfo.database_id, TableInfo.table_name, TableInfo.table_note).filter( - TableInfo.id == table_id).first() - if table_id is None: - return None - return { - 'table_id': table_id, - 'database_id': database_id, - 'table_name': table_name, - 'table_note': table_note - } - - @staticmethod - async def get_table_id_by_database_id_and_table_name(database_id, table_name): - with PostgresDB.get_session() as session: - table_info_entry = session.query( - TableInfo).filter( - TableInfo.database_id == database_id, - TableInfo.table_name == table_name, - ).first() - if table_info_entry: - return table_info_entry.id - return None - - @staticmethod - async def get_table_info_by_database_id(database_id, enable=None): - with PostgresDB.get_session() as session: - if enable is None: - results = session.query( - TableInfo).filter(TableInfo.database_id == database_id).all() - else: - results = session.query( - TableInfo).filter( - and_(TableInfo.database_id == database_id, - TableInfo.enable == enable - )).all() - table_info_list = [] - for result in results: - table_info_list.append({'table_id': result.id, 'table_name': result.table_name, - 'table_note': result.table_note, 'created_at': result.created_at}) - return table_info_list - - @staticmethod - async def get_topk_table_by_cos_dis(database_id, tmp_vector, topk=3): - with PostgresDB.get_session() as session: - results = session.query( - TableInfo.id - ).filter(TableInfo.database_id == database_id).order_by( - TableInfo.table_note_vector.cosine_distance(tmp_vector) - ).limit(topk).all() - table_id_list = [result[0] for result in results] - return table_id_list diff --git a/chat2db/model/request.py b/chat2db/model/request.py deleted file mode 100644 index 6d8c9550d9380aca8e8edb814018a204a626b2d6..0000000000000000000000000000000000000000 --- a/chat2db/model/request.py +++ /dev/null @@ -1,89 +0,0 @@ -import uuid -from pydantic import BaseModel, Field -from typing import Optional - -class QueryRequest(BaseModel): - question: str - topk_sql: int = 5 - topk_answer: int = 15 - use_llm_enhancements: bool = False - - -class DatabaseAddRequest(BaseModel): - database_url: str - - -class DatabaseDelRequest(BaseModel): - database_id: Optional[uuid.UUID] = Field(default=None, description="数据库id") - database_url: Optional[str] = Field(default=None, description="数据库url") - -class DatabaseSqlGenerateRequest(BaseModel): - database_url: str - table_name_list: Optional[list[str]] = Field(default=[]) - question: str - topk: int = 5 - use_llm_enhancements: Optional[bool] = Field(default=False) - -class TableAddRequest(BaseModel): - database_id: uuid.UUID - table_name: str - - -class TableDelRequest(BaseModel): - table_id: uuid.UUID - - -class TableQueryRequest(BaseModel): - database_id: uuid.UUID - - -class EnableColumnRequest(BaseModel): - column_id: uuid.UUID - enable: bool - - -class SqlExampleAddRequest(BaseModel): - table_id: uuid.UUID - question: str - sql: str - - -class SqlExampleDelRequest(BaseModel): - sql_example_id: uuid.UUID - - -class SqlExampleQueryRequest(BaseModel): - table_id: uuid.UUID - - -class SqlExampleUpdateRequest(BaseModel): - sql_example_id: uuid.UUID - question: str - sql: str - - -class SqlGenerateRequest(BaseModel): - database_id: uuid.UUID - table_id_list: list[uuid.UUID] = [] - question: str - topk: int = 5 - use_llm_enhancements: bool = True - - -class SqlRepairRequest(BaseModel): - database_id: uuid.UUID - table_id: uuid.UUID - sql: str - message: str = Field(..., max_length=2048) - question: str - - -class SqlExcuteRequest(BaseModel): - database_id: uuid.UUID - sql: str - - -class SqlExampleGenerateRequest(BaseModel): - table_id: uuid.UUID - generate_cnt: int = 1 - sql_var: bool = False diff --git a/chat2db/model/response.py b/chat2db/model/response.py deleted file mode 100644 index fd7c2e7a489405410be5a5f3331915fd9c8eda0c..0000000000000000000000000000000000000000 --- a/chat2db/model/response.py +++ /dev/null @@ -1,6 +0,0 @@ -from pydantic import BaseModel -from typing import Any -class ResponseData(BaseModel): - code: int - message: str - result: Any \ No newline at end of file diff --git a/chat2db/scripts/chat2db_config/config.yaml b/chat2db/scripts/chat2db_config/config.yaml deleted file mode 100644 index 78e3719e8a65cf870cad6994a66bf4120dc113e4..0000000000000000000000000000000000000000 --- a/chat2db/scripts/chat2db_config/config.yaml +++ /dev/null @@ -1,2 +0,0 @@ -UVICORN_IP: 0.0.0.0 -UVICORN_PORT: '9015' diff --git a/chat2db/scripts/docs/output_examples.xlsx b/chat2db/scripts/docs/output_examples.xlsx deleted file mode 100644 index 599501ca6c0f1d2b88fe5235d40fb11a56fbf005..0000000000000000000000000000000000000000 Binary files a/chat2db/scripts/docs/output_examples.xlsx and /dev/null differ diff --git a/chat2db/scripts/run_chat2db.py b/chat2db/scripts/run_chat2db.py deleted file mode 100644 index 5da9e4872c6a9fff674e3eb6233ea3f5201a6eb9..0000000000000000000000000000000000000000 --- a/chat2db/scripts/run_chat2db.py +++ /dev/null @@ -1,436 +0,0 @@ -import argparse -import os -import pandas as pd -import requests -import yaml -from fastapi import FastAPI -import shutil - -terminal_width = shutil.get_terminal_size().columns -app = FastAPI() - -CHAT2DB_CONFIG_PATH = './chat2db_config' -CONFIG_YAML_PATH = './chat2db_config/config.yaml' -DEFAULT_CHAT2DB_CONFIG = { - "UVICORN_IP": "127.0.0.1", - "UVICORN_PORT": "8000" -} - - -# 修改 -def update_config(uvicorn_ip, uvicorn_port): - try: - yml = {'UVICORN_IP': uvicorn_ip, 'UVICORN_PORT': uvicorn_port} - with open(CONFIG_YAML_PATH, 'w') as file: - yaml.dump(yml, file) - return {"message": "修改成功"} - except Exception as e: - return {"message": f"修改失败,由于:{e}"} - - -# 增加数据库 -def call_add_database_info(database_url): - url = f"http://{config['UVICORN_IP']}:{config['UVICORN_PORT']}/database/add" - request_body = { - "database_url": database_url - } - response = requests.post(url, json=request_body) - return response.json() - - -# 删除数据库 -def call_del_database_info(database_id): - url = f"http://{config['UVICORN_IP']}:{config['UVICORN_PORT']}/database/del" - request_body = { - "database_id": database_id - } - response = requests.post(url, json=request_body) - return response.json() - - -# 查询数据库配置 -def call_query_database_info(): - url = f"http://{config['UVICORN_IP']}:{config['UVICORN_PORT']}/database/query" - response = requests.get(url) - return response.json() - - -# 查询数据库内表格配置 -def call_list_table_in_database(database_id, table_filter=''): - url = f"http://{config['UVICORN_IP']}:{config['UVICORN_PORT']}/database/list" - params = { - "database_id": database_id, - "table_filter": table_filter - } - print(params) - response = requests.get(url, params=params) - return response.json() - - -# 增加数据表 -def call_add_table_info(database_id, table_name): - url = f"http://{config['UVICORN_IP']}:{config['UVICORN_PORT']}/table/add" - request_body = { - "database_id": database_id, - "table_name": table_name - } - response = requests.post(url, json=request_body) - return response.json() - - -# 删除数据表 -def call_del_table_info(table_id): - url = f"http://{config['UVICORN_IP']}:{config['UVICORN_PORT']}/table/del" - request_body = { - "table_id": table_id - } - response = requests.post(url, json=request_body) - return response.json() - - -# 查询数据表配置 -def call_query_table_info(database_id): - url = f"http://{config['UVICORN_IP']}:{config['UVICORN_PORT']}/table/query" - params = { - "database_id": database_id - } - response = requests.get(url, params=params) - return response.json() - - -# 查询数据表列信息 -def call_query_column(table_id): - url = f"http://{config['UVICORN_IP']}:{config['UVICORN_PORT']}/table/column/query" - params = { - "table_id": table_id - } - response = requests.get(url, params=params) - return response.json() - - -# 启用禁用列 -def call_enable_column(column_id, enable): - url = f"http://{config['UVICORN_IP']}:{config['UVICORN_PORT']}/table/column/enable" - request_body = { - "column_id": column_id, - "enable": enable - } - response = requests.post(url, json=request_body) - return response.json() - - -# 增加sql_example案例 -def call_add_sql_example(table_id, question, sql): - url = f"http://{config['UVICORN_IP']}:{config['UVICORN_PORT']}/sql/example/add" - request_body = { - "table_id": table_id, - "question": question, - "sql": sql - } - response = requests.post(url, json=request_body) - return response.json() - - -# 删除sql_example案例 -def call_del_sql_example(sql_example_id): - url = f"http://{config['UVICORN_IP']}:{config['UVICORN_PORT']}/sql/example/del" - request_body = { - "sql_example_id": sql_example_id - } - response = requests.post(url, json=request_body) - return response.json() - - -# 查询sql_example案例 -def call_query_sql_example(table_id): - url = f"http://{config['UVICORN_IP']}:{config['UVICORN_PORT']}/sql/example/query" - params = { - "table_id": table_id - } - response = requests.get(url, params=params) - return response.json() - - -# 更新sql_example案例 -def call_update_sql_example(sql_example_id, question, sql): - url = f"http://{config['UVICORN_IP']}:{config['UVICORN_PORT']}/sql/example/update" - request_body = { - "sql_example_id": sql_example_id, - "question": question, - "sql": sql - } - response = requests.post(url, json=request_body) - return response.json() - - -# 生成sql_example案例 -def call_generate_sql_example(table_id, generate_cnt=1, sql_var=False): - url = f"http://{config['UVICORN_IP']}:{config['UVICORN_PORT']}/sql/example/generate" - response_body = { - "table_id": table_id, - "generate_cnt": generate_cnt, - "sql_var": sql_var - } - response = requests.post(url, json=response_body) - return response.json() - - -def write_sql_example_to_excel(dir, sql_example_list): - try: - if not os.path.exists(os.path.dirname(dir)): - os.makedirs(os.path.dirname(dir)) - data = { - 'question': [], - 'sql': [] - } - for sql_example in sql_example_list: - data['question'].append(sql_example['question']) - data['sql'].append(sql_example['sql']) - - df = pd.DataFrame(data) - df.to_excel(dir, index=False) - - print("Data written to Excel file successfully.") - except Exception as e: - print("Error writing data to Excel file:", str(e)) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser(description="chat2DB脚本") - subparsers = parser.add_subparsers(dest="command", help="子命令列表") - - # 修改config.yaml - parser_config = subparsers.add_parser("config", help="修改config.yaml") - parser_config.add_argument("--ip", type=str, required=True, help="uvicorn ip") - parser_config.add_argument("--port", type=str, required=True, help="uvicorn port") - - # 增加数据库 - parser_add_database = subparsers.add_parser("add_db", help="增加指定数据库") - parser_add_database.add_argument("--database_url", type=str, required=True, - help="数据库连接地址,如postgresql+psycopg2://postgres:123456@0.0.0.0:5432/postgres") - - # 删除数据库 - parser_del_database = subparsers.add_parser("del_db", help="删除指定数据库") - parser_del_database.add_argument("--database_id", type=str, required=True, help="数据库id") - - # 查询数据库配置 - parser_query_database = subparsers.add_parser("query_db", help="查询指定数据库配置") - - # 查询数据库内表格配置 - parser_list_table_in_database = subparsers.add_parser("list_tb_in_db", help="查询数据库内表格配置") - parser_list_table_in_database.add_argument("--database_id", type=str, required=True, help="数据库id") - parser_list_table_in_database.add_argument("--table_filter", type=str, required=False, help="表格名称过滤条件") - - # 增加数据表 - parser_add_table = subparsers.add_parser("add_tb", help="增加指定数据库内的数据表") - parser_add_table.add_argument("--database_id", type=str, required=True, help="数据库id") - parser_add_table.add_argument("--table_name", type=str, required=True, help="数据表名称") - - # 删除数据表 - parser_del_table = subparsers.add_parser("del_tb", help="删除指定数据表") - parser_del_table.add_argument("--table_id", type=str, required=True, help="数据表id") - - # 查询数据表配置 - parser_query_table = subparsers.add_parser("query_tb", help="查询指定数据表配置") - parser_query_table.add_argument("--database_id", type=str, required=True, help="数据库id") - - # 查询数据表列信息 - parser_query_column = subparsers.add_parser("query_col", help="查询指定数据表详细列信息") - parser_query_column.add_argument("--table_id", type=str, required=True, help="数据表id") - - # 启用禁用列 - parser_enable_column = subparsers.add_parser("enable_col", help="启用禁用指定列") - parser_enable_column.add_argument("--column_id", type=str, required=True, help="列id") - parser_enable_column.add_argument("--enable", type=bool, required=True, help="是否启用") - - # 增加sql案例 - parser_add_sql_example = subparsers.add_parser("add_sql_exp", help="增加指定数据表sql案例") - parser_add_sql_example.add_argument("--table_id", type=str, required=True, help="数据表id") - parser_add_sql_example.add_argument("--question", type=str, required=False, help="问题") - parser_add_sql_example.add_argument("--sql", type=str, required=False, help="sql") - parser_add_sql_example.add_argument("--dir", type=str, required=False, help="输入路径") - - # 删除sql_exp - parser_del_sql_example = subparsers.add_parser("del_sql_exp", help="删除指定sql案例") - parser_del_sql_example.add_argument("--sql_example_id", type=str, required=True, help="sql案例id") - - # 查询sql案例 - parser_query_sql_example = subparsers.add_parser("query_sql_exp", help="查询指定数据表sql对案例") - parser_query_sql_example.add_argument("--table_id", type=str, required=True, help="数据表id") - - # 更新sql案例 - parser_update_sql_example = subparsers.add_parser("update_sql_exp", help="更新sql对案例") - parser_update_sql_example.add_argument("--sql_example_id", type=str, required=True, help="sql案例id") - parser_update_sql_example.add_argument("--question", type=str, required=True, help="sql语句对应的问题") - parser_update_sql_example.add_argument("--sql", type=str, required=True, help="sql语句") - - # 生成sql案例 - parser_generate_sql_example = subparsers.add_parser("generate_sql_exp", help="生成指定数据表sql对案例") - parser_generate_sql_example.add_argument("--table_id", type=str, required=True, help="数据表id") - parser_generate_sql_example.add_argument("--generate_cnt", type=int, required=False, help="生成sql对数量", - default=1) - parser_generate_sql_example.add_argument("--sql_var", type=bool, required=False, - help="是否验证生成的sql对,True为验证,False不验证", - default=False) - parser_generate_sql_example.add_argument("--dir", type=str, required=False, help="生成的sql对输出路径", - default="templetes/output_examples.xlsx") - - args = parser.parse_args() - - if os.path.exists(CONFIG_YAML_PATH): - exist = True - with open(CONFIG_YAML_PATH, 'r') as file: - yml = yaml.safe_load(file) - config = { - 'UVICORN_IP': yml.get('UVICORN_IP'), - 'UVICORN_PORT': yml.get('UVICORN_PORT'), - } - else: - exist = False - - if args.command == "config": - if not exist: - os.makedirs(CHAT2DB_CONFIG_PATH, exist_ok=True) - with open(CONFIG_YAML_PATH, 'w') as file: - yaml.dump(DEFAULT_CHAT2DB_CONFIG, file, default_flow_style=False) - response = update_config(args.ip, args.port) - with open(CONFIG_YAML_PATH, 'r') as file: - yml = yaml.safe_load(file) - config = { - 'UVICORN_IP': yml.get('UVICORN_IP'), - 'UVICORN_PORT': yml.get('UVICORN_PORT'), - } - print(response.get("message")) - elif not exist: - print("please update_config first") - - elif args.command == "add_db": - response = call_add_database_info(args.database_url) - database_id = response.get("result")['database_id'] - print(response.get("message")) - if response.get("code") == 200: - print(f'database_id: ', database_id) - - elif args.command == "del_db": - response = call_del_database_info(args.database_id) - print(response.get("message")) - - elif args.command == "query_db": - response = call_query_database_info() - print(response.get("message")) - if response.get("code") == 200: - database_info = response.get("result")['database_info_list'] - for database in database_info: - print('-' * terminal_width) - print("database_id:", database["database_id"]) - print("database_url:", database["database_url"]) - print("created_at:", database["created_at"]) - print('-' * terminal_width) - - elif args.command == "list_tb_in_db": - response = call_list_table_in_database(args.database_id, args.table_filter) - print(response.get("message")) - if response.get("code") == 200: - table_name_list = response.get("result")['table_name_list'] - for table_name in table_name_list: - print(table_name) - - elif args.command == "add_tb": - response = call_add_table_info(args.database_id, args.table_name) - print(response.get("message")) - table_id = response.get("result")['table_id'] - if response.get("code") == 200: - print('table_id: ', table_id) - - elif args.command == "del_tb": - response = call_del_table_info(args.table_id) - print(response.get("message")) - - elif args.command == "query_tb": - response = call_query_table_info(args.database_id) - print(response.get("message")) - if response.get("code") == 200: - table_list = response.get("result")['table_info_list'] - for table in table_list: - print('-' * terminal_width) - print("table_id:", table['table_id']) - print("table_name:", table['table_name']) - print("table_note:", table['table_note']) - print("created_at:", table['created_at']) - print('-' * terminal_width) - - elif args.command == "query_col": - response = call_query_column(args.table_id) - print(response.get("message")) - if response.get("code") == 200: - column_list = response.get("result")['column_info_list'] - for column in column_list: - print('-' * terminal_width) - print("column_id:", column['column_id']) - print("column_name:", column['column_name']) - print("column_note:", column['column_note']) - print("column_type:", column['column_type']) - print("enable:", column['enable']) - print('-' * terminal_width) - - elif args.command == "enable_col": - response = call_enable_column(args.column_id, args.enable) - print(response.get("message")) - - elif args.command == "add_sql_exp": - def get_sql_exp(dir): - if not os.path.exists(os.path.dirname(dir)): - return None - # 读取 xlsx 文件 - df = pd.read_excel(dir) - - # 遍历每一行数据 - for index, row in df.iterrows(): - question = row['question'] - sql = row['sql'] - - # 调用 call_add_sql_example 函数 - response = call_add_sql_example(args.table_id, question, sql) - print(response.get("message")) - sql_example_id = response.get("result")['sql_example_id'] - print('sql_example_id: ', sql_example_id) - print(question, sql) - - - if args.dir: - get_sql_exp(args.dir) - else: - response = call_add_sql_example(args.table_id, args.question, args.sql) - print(response.get("message")) - sql_example_id = response.get("result")['sql_example_id'] - print('sql_example_id: ', sql_example_id) - - elif args.command == "del_sql_exp": - response = call_del_sql_example(args.sql_example_id) - print(response.get("message")) - - elif args.command == "query_sql_exp": - response = call_query_sql_example(args.table_id) - print(response.get("message")) - if response.get("code") == 200: - sql_example_list = response.get("result")['sql_example_list'] - for sql_example in sql_example_list: - print('-' * terminal_width) - print("sql_example_id:", sql_example['sql_example_id']) - print("question:", sql_example['question']) - print("sql:", sql_example['sql']) - print('-' * terminal_width) - - elif args.command == "update_sql_exp": - response = call_update_sql_example(args.sql_example_id, args.question, args.sql) - print(response.get("message")) - - elif args.command == "generate_sql_exp": - response = call_generate_sql_example(args.table_id, args.generate_cnt, args.sql_var) - print(response.get("message")) - if response.get("code") == 200: - # 输出到execl中 - sql_example_list = response.get("result")['sql_example_list'] - write_sql_example_to_excel(args.dir, sql_example_list) - else: - print("未知命令,请检查输入的命令是否正确。") diff --git a/chat2db/security/security.py b/chat2db/security/security.py deleted file mode 100644 index 0909f27bf29fa5cc8c405ff0e4d998c7f1fbf03d..0000000000000000000000000000000000000000 --- a/chat2db/security/security.py +++ /dev/null @@ -1,116 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. - -import base64 -import binascii -import hashlib -import secrets - -from cryptography.hazmat.backends import default_backend -from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes - -from chat2db.config.config import config - - -class Security: - - @staticmethod - def encrypt(plaintext: str) -> tuple[str, dict]: - """ - 加密公共方法 - :param plaintext: - :return: - """ - half_key1 = config['HALF_KEY1'] - - encrypted_work_key, encrypted_work_key_iv = Security._generate_encrypted_work_key( - half_key1) - encrypted_plaintext, encrypted_iv = Security._encrypt_plaintext(half_key1, encrypted_work_key, - encrypted_work_key_iv, plaintext) - del plaintext - secret_dict = { - "encrypted_work_key": encrypted_work_key, - "encrypted_work_key_iv": encrypted_work_key_iv, - "encrypted_iv": encrypted_iv, - "half_key1": half_key1 - } - return encrypted_plaintext, secret_dict - - @staticmethod - def decrypt(encrypted_plaintext: str, secret_dict: dict): - """ - 解密公共方法 - :param encrypted_plaintext: 待解密的字符串 - :param secret_dict: 存放工作密钥的dict - :return: - """ - plaintext = Security._decrypt_plaintext(half_key1=secret_dict.get("half_key1"), - encrypted_work_key=secret_dict.get( - "encrypted_work_key"), - encrypted_work_key_iv=secret_dict.get( - "encrypted_work_key_iv"), - encrypted_iv=secret_dict.get( - "encrypted_iv"), - encrypted_plaintext=encrypted_plaintext) - return plaintext - - @staticmethod - def _get_root_key(half_key1: str) -> bytes: - half_key2 = config['HALF_KEY2'] - key = (half_key1 + half_key2).encode("utf-8") - half_key3 = config['HALF_KEY3'].encode("utf-8") - hash_key = hashlib.pbkdf2_hmac("sha256", key, half_key3, 10000) - return binascii.hexlify(hash_key)[13:45] - - @staticmethod - def _generate_encrypted_work_key(half_key1: str) -> tuple[str, str]: - bin_root_key = Security._get_root_key(half_key1) - bin_work_key = secrets.token_bytes(32) - bin_encrypted_work_key_iv = secrets.token_bytes(16) - bin_encrypted_work_key = Security._root_encrypt(bin_root_key, bin_encrypted_work_key_iv, bin_work_key) - encrypted_work_key = base64.b64encode(bin_encrypted_work_key).decode("ascii") - encrypted_work_key_iv = base64.b64encode(bin_encrypted_work_key_iv).decode("ascii") - return encrypted_work_key, encrypted_work_key_iv - - @staticmethod - def _get_work_key(half_key1: str, encrypted_work_key: str, encrypted_work_key_iv: str) -> bytes: - bin_root_key = Security._get_root_key(half_key1) - bin_encrypted_work_key = base64.b64decode(encrypted_work_key.encode("ascii")) - bin_encrypted_work_key_iv = base64.b64decode(encrypted_work_key_iv.encode("ascii")) - return Security._root_decrypt(bin_root_key, bin_encrypted_work_key_iv, bin_encrypted_work_key) - - @staticmethod - def _root_encrypt(key: bytes, encrypted_iv: bytes, plaintext: bytes) -> bytes: - encryptor = Cipher(algorithms.AES(key), modes.GCM(encrypted_iv), default_backend()).encryptor() - encrypted = encryptor.update(plaintext) + encryptor.finalize() - return encrypted - - @staticmethod - def _root_decrypt(key: bytes, encrypted_iv: bytes, encrypted: bytes) -> bytes: - encryptor = Cipher(algorithms.AES(key), modes.GCM(encrypted_iv), default_backend()).encryptor() - plaintext = encryptor.update(encrypted) - return plaintext - - @staticmethod - def _encrypt_plaintext(half_key1: str, encrypted_work_key: str, encrypted_work_key_iv: str, - plaintext: str) -> tuple[str, str]: - bin_work_key = Security._get_work_key(half_key1, encrypted_work_key, encrypted_work_key_iv) - salt = f"{half_key1}{plaintext}" - plaintext_temp = salt.encode("utf-8") - del plaintext - del salt - bin_encrypted_iv = secrets.token_bytes(16) - bin_encrypted_plaintext = Security._root_encrypt(bin_work_key, bin_encrypted_iv, plaintext_temp) - encrypted_plaintext = base64.b64encode(bin_encrypted_plaintext).decode("ascii") - encrypted_iv = base64.b64encode(bin_encrypted_iv).decode("ascii") - return encrypted_plaintext, encrypted_iv - - @staticmethod - def _decrypt_plaintext(half_key1: str, encrypted_work_key: str, encrypted_work_key_iv: str, - encrypted_plaintext: str, encrypted_iv) -> str: - bin_work_key = Security._get_work_key(half_key1, encrypted_work_key, encrypted_work_key_iv) - bin_encrypted_plaintext = base64.b64decode(encrypted_plaintext.encode("ascii")) - bin_encrypted_iv = base64.b64decode(encrypted_iv.encode("ascii")) - plaintext_temp = Security._root_decrypt(bin_work_key, bin_encrypted_iv, bin_encrypted_plaintext) - plaintext_salt = plaintext_temp.decode("utf-8") - plaintext = plaintext_salt[len(half_key1):] - return plaintext \ No newline at end of file diff --git a/chat2db/templetes/change_txt_to_yaml.py b/chat2db/templetes/change_txt_to_yaml.py deleted file mode 100644 index 8e673d817e146c9762c7c712ab5ecf7a8689bf3b..0000000000000000000000000000000000000000 --- a/chat2db/templetes/change_txt_to_yaml.py +++ /dev/null @@ -1,92 +0,0 @@ -import yaml -text = { - 'sql_generate_base_on_example_prompt': '''你是一个数据库专家,你的任务是参考给出的表结构以及表注释和示例,基于给出的问题生成一条在{database_url}连接下可进行查询的sql语句。 -注意: -#01 sql语句中,特殊字段需要带上双引号。 -#02 sql语句中,如果要使用 as,请用双引号把别名包裹起来。 -#03 sql语句中,查询字段必须使用`distinct`关键字去重。 -#04 sql语句中,只返回生成的sql语句, 不要返回其他任何无关的内容 -#05 sql语句中,参考问题,对查询字段进行冗余。 -#06 sql语句中,需要以分号结尾。 - -以下是表结构以及表注释: -{note} -以下是{k}个示例: -{sql_example} -以下是问题: -{question} -''', - 'question_generate_base_on_data_prompt': '''你是一个postgres数据库专家,你的任务是根据给出的表结构和表内数据,输出一个用户可能针对这张表内的信息提出的问题。 -注意: -#01 问题内容和形式需要多样化,例如要用到统计、排序、模糊匹配等相关问题。 -#02 要以口语化的方式输出问题,不要机械的使用表内字段输出问题。 -#03 不要输出问题之外多余的内容! -#04 要基于用户的角度取提出问题,问题内容需要口语化、拟人化。 -#05 优先生成有注释的字段相关的sql语句。 - -以下是表结构和注释: -{note} -以下是表内数据 -{data_frame} -''', - 'sql_generate_base_on_data_prompt': '''你是一个postgres数据库专家,你的任务是参考给出的表结构以及表注释和表内数据,基于给出的问题生成一条查询{database_type}的sql语句。 -注意: -#01 sql语句中,特殊字段需要带上双引号。 -#02 sql语句中,如果要使用 as,请用双引号把别名包裹起来。 -#03 sql语句中,查询字段必须使用`distinct`关键字去重。 -#04 sql语句中,只返回生成的sql语句, 不要返回其他任何无关的内容 -#05 sql语句中,参考问题,对查询字段进行冗余。 -#06 sql语句中,需要以分号结尾。 - -以下是表结构以及表注释: -{note} -以下是表内的数据: -{data_frame} -以下是问题: -{question} -''', - 'sql_expand_prompt': '''你是一个数据库专家,你的任务是参考给出的表结构以及表注释、执行失败的sql和执行失败的报错,基于给出的问题修改执行失败的sql生成一条在{database_type}连接下可进行查询的sql语句。 - - 注意: - - #01 假设sql中有特殊字符干扰了sql的执行,请优先替换这些特殊字符保证sql可执行。 - - #02 假设sql用于检索或者过滤的字段导致了sql执行的失败,请尝试替换这些字段保证sql可执行。 - - #03 假设sql检索结果为空,请尝试将 = 的匹配方式替换为 ilike \'\%\%\' 保证sql执行给出结果。 - - #04 假设sql检索结果为空,可以使用问题中的关键字的子集作为sql的过滤条件保证sql执行给出结果。 - - 以下是表结构以及表注释: - - {note} - - 以下是执行失败的sql: - - {sql_failed} - - 以下是执行失败的报错: - - {sql_failed_message} - - 以下是问题: - - {question} -''', - 'table_choose_prompt': '''你是一个数据库专家,你的任务是参考给出的表名以及表的条目(主键,表名、表注释),输出最适配于问题回答检索的{table_cnt}张表,并返回表对应的主键。 -注意: -#01 输出的表名用python的list格式返回,下面是list的一个示例: -[\"prime_key1\",\"prime_key2\"]。 -#02 只输出包含主键的list即可不要输出其他内容!!! -#03 list重主键的顺序,按表与问题的适配程度从高到底排列。 -#04 若无任何一张表适用于问题的回答,请返回空列表。 - -以下是表的条目: -{table_entries} -以下是问题: -{question} -''' -} -print(text) -with open('./prompt.yaml', 'w', encoding='utf-8') as f: - yaml.dump(text, f, allow_unicode=True) diff --git a/chat2db/templetes/prompt.yaml b/chat2db/templetes/prompt.yaml deleted file mode 100644 index 0013b12f6df19007bbc0467d2dc8add497469a1d..0000000000000000000000000000000000000000 --- a/chat2db/templetes/prompt.yaml +++ /dev/null @@ -1,115 +0,0 @@ -question_generate_base_on_data_prompt: '你是一个postgres数据库专家,你的任务是根据给出的表结构和表内数据,输出一个用户可能针对这张表内的信息提出的问题。 - - 注意: - - #01 问题内容和形式需要多样化,例如要用到统计、排序、模糊匹配等相关问题。 - - #02 要以口语化的方式输出问题,不要机械的使用表内字段输出问题。 - - #03 不要输出问题之外多余的内容! - - #04 要基于用户的角度取提出问题,问题内容需要口语化、拟人化。 - - #05 优先生成有注释的字段相关的sql语句。 - - #06 不要对生成的sql进行解释。 - - 以下是表结构和注释: - - {note} - - 以下是表内数据 - - {data_frame} - - ' -sql_expand_prompt: "你是一个数据库专家,你的任务是参考给出的表结构以及表注释、执行失败的sql和执行失败的报错,基于给出的问题修改执行失败的sql生成一条在{database_type}连接下可进行查询的sql语句。\n\ - \n 注意:\n\n #01 假设sql中有特殊字符干扰了sql的执行,请优先替换这些特殊字符保证sql可执行。\n\n #02 假设sql用于检索或者过滤的字段导致了sql执行的失败,请尝试替换这些字段保证sql可执行。\n\ - \n #03 假设sql检索结果为空,请尝试将 = 的匹配方式替换为 ilike '\\%\\%' 保证sql执行给出结果。\n\n #04 假设sql检索结果为空,可以使用问题中的关键字的子集作为sql的过滤条件保证sql执行给出结果。\n\ - \n 以下是表结构以及表注释:\n\n {note}\n\n 以下是执行失败的sql:\n\n {sql_failed}\n\n 以下是执行失败的报错:\n\ - \n {sql_failed_message}\n\n 以下是问题:\n\n {question}\n" -sql_generate_base_on_data_prompt: '你是一个postgres数据库专家,你的任务是参考给出的表结构以及表注释和表内数据,基于给出的问题生成一条查询{database_type}的sql语句。 - - 注意: - - #01 sql语句中,特殊字段需要带上双引号。 - - #02 sql语句中,如果要使用 as,请用双引号把别名包裹起来。 - - #03 sql语句中,查询字段必须使用`distinct`关键字去重。 - - #04 sql语句中,只返回生成的sql语句, 不要返回其他任何无关的内容 - - #05 sql语句中,参考问题,对查询字段进行冗余。 - - #06 sql语句中,需要以分号结尾。 - - #07 不要对生成的sql进行解释。 - - 以下是表结构以及表注释: - - {note} - - 以下是表内的数据: - - {data_frame} - - 以下是问题: - - {question} - - ' -sql_generate_base_on_example_prompt: '你是一个数据库专家,你的任务是参考给出的表结构以及表注释和示例,基于给出的问题生成一条在{database_url}连接下可进行查询的sql语句。 - - 注意: - - #01 sql语句中,特殊字段需要带上双引号。 - - #02 sql语句中,如果要使用 as,请用双引号把别名包裹起来。 - - #03 sql语句中,查询字段必须使用`distinct`关键字去重。 - - #04 sql语句中,只返回生成的sql语句, 不要返回其他任何无关的内容 - - #05 sql语句中,参考问题,对查询字段进行冗余。 - - #06 sql语句中,需要以分号结尾。 - - - 以下是表结构以及表注释: - - {note} - - 以下是{k}个示例: - - {sql_example} - - 以下是问题: - - {question} - - ' -table_choose_prompt: '你是一个数据库专家,你的任务是参考给出的表名以及表的条目(主键,表名、表注释),输出最适配于问题回答检索的{table_cnt}张表,并返回表对应的主键。 - - 注意: - - #01 输出的表名用python的list格式返回,下面是list的一个示例: - - ["prime_key1","prime_key2"]。 - - #02 只输出包含主键的list即可不要输出其他内容!!! - - #03 list重主键的顺序,按表与问题的适配程度从高到底排列。 - - #04 若无任何一张表适用于问题的回答,请返回空列表。 - - - 以下是表的条目: - - {table_entries} - - 以下是问题: - - {question} - - ' diff --git a/data_chain/apps/app.py b/data_chain/apps/app.py index 30d7f7d7b97b0710e27ddba522efc3d1e8141235..aa852ab7518d19b2bf97fdf324670d124ade62b2 100644 --- a/data_chain/apps/app.py +++ b/data_chain/apps/app.py @@ -20,7 +20,8 @@ from data_chain.apps.router import ( other, role, usr_message, - task + task, + user ) from data_chain.apps.base.task.worker import ( base_worker, @@ -70,7 +71,7 @@ from data_chain.stores.database.database import ( from data_chain.manager.role_manager import RoleManager from data_chain.manager.knowledge_manager import KnowledgeBaseManager from data_chain.manager.document_type_manager import DocumentTypeManager -from data_chain.entities.enum import ParseMethod +from data_chain.entities.enum import ParseMethod, LanguageType from data_chain.entities.common import ( DOC_PATH_IN_OS, EXPORT_KB_PATH_IN_OS, @@ -103,7 +104,7 @@ async def add_acitons(): for action in actions: action_entity = ActionEntity( action=action['action'], - name=action['name'], + name=action['name'][LanguageType.CHINESE], type=action['type'], ) await RoleManager.add_action(action_entity) @@ -153,6 +154,7 @@ async def configure(): app.include_router(role.router) app.include_router(usr_message.router) app.include_router(task.router) + app.include_router(user.router) # 定义一个路由来获取所有路由信息 diff --git a/data_chain/apps/base/convertor.py b/data_chain/apps/base/convertor.py index 0b9692628000843fb6736286cb49fdee29d979e3..211f6504e2afcde9c42c8569b3be7f4fcd96fb4d 100644 --- a/data_chain/apps/base/convertor.py +++ b/data_chain/apps/base/convertor.py @@ -18,6 +18,10 @@ from data_chain.entities.request_data import ( from data_chain.entities.response_data import ( User, Team, + Role, + TeamUser, + TeamMsg, + UserMsg, Knowledgebase, DocumentType as DocumentTypeResponse, Document, @@ -33,6 +37,8 @@ from data_chain.entities.response_data import ( from data_chain.entities.enum import ( UserStatus, TeamStatus, + UserMessageStatus, + UserMessageType, TaskType, TaskStatus, KnowledgeBaseStatus, @@ -46,6 +52,10 @@ from data_chain.entities.common import default_roles from data_chain.stores.database.database import ( UserEntity, TeamEntity, + TeamUserEntity, + TeamMessageEntity, + UserMessageEntity, + RoleEntity, KnowledgeBaseEntity, DocumentTypeEntity, DocumentEntity, @@ -84,7 +94,8 @@ class Convertor: async def convert_user_sub_to_user_entity(user_sub: str) -> UserEntity: """将用户ID转换为用户实体""" try: - user_entity = UserEntity(id=user_sub, name=user_sub, status=UserStatus.ACTIVE) + user_entity = UserEntity( + id=user_sub, name=user_sub, status=UserStatus.ACTIVE) return user_entity except Exception as e: err = "用户ID转换为用户实体失败" @@ -95,7 +106,7 @@ class Convertor: async def convert_user_entity_to_user(user_entity: UserEntity) -> User: """将用户实体转换为用户""" try: - user = User(id=user_entity.id, name=user_entity.name) + user = User(userSub=user_entity.id, userName=user_entity.name) return user except Exception as e: err = "用户实体转换为用户失败" @@ -173,6 +184,99 @@ class Convertor: logging.exception("[Convertor] %s", err) raise e + @staticmethod + async def convert_user_entity_and_role_entity_to_team_user( + user_entity: UserEntity, role_entity: RoleEntity) -> TeamUser: + """将团队用户实体转换为团队用户""" + try: + team_user = TeamUser( + userId=user_entity.id, + userName=user_entity.name, + roleName=role_entity.name + ) + return team_user + except Exception as e: + err = "团队用户实体转换为团队用户失败" + logging.exception("[Convertor] %s", err) + raise e + + @staticmethod + async def convert_user_sub_team_id_and_message_to_team_message_entity( + user_sub: str, team_id: uuid.UUID, zh_message: str, en_message: str) -> TeamMessageEntity: + """将用户ID、团队ID和消息转换为团队消息实体""" + try: + team_message_entity = TeamMessageEntity( + team_id=team_id, + author_id=user_sub, + author_name=user_sub, + zh_message=zh_message, + en_message=en_message, + status=TeamStatus.EXISTED.value + ) + return team_message_entity + except Exception as e: + err = "用户ID、团队ID和消息转换为团队消息实体失败" + logging.exception("[Convertor] %s", err) + raise e + + @staticmethod + async def convert_team_message_entity_to_team_message( + team_message_entity: TeamMessageEntity) -> TeamMsg: + """将团队消息实体转换为团队消息""" + try: + team_msg = TeamMsg( + msgId=team_message_entity.id, + authorName=team_message_entity.author_name, + zhMsg=team_message_entity.zh_message, + enMsg=team_message_entity.en_message, + createdTime=team_message_entity.created_time.strftime( + '%Y-%m-%d %H:%M') + ) + return team_msg + except Exception as e: + err = "团队消息实体转换为团队消息失败" + logging.exception("[Convertor] %s", err) + raise e + + @staticmethod + async def convert_user_sub_and_user_message_entity_to_user_message( + user_sub: str, user_message_entity: UserMessageEntity) -> UserMsg: + """将用户消息实体转换为用户消息""" + try: + is_editable = False + if user_message_entity.status_to_receiver == UserMessageStatus.UNREAD.value: + if user_sub == user_message_entity.receiver_id: + is_editable = True + elif user_sub != user_message_entity.sender_id and user_message_entity.is_to_all: + is_editable = True + recviver_id = user_message_entity.receiver_id + recviver_name = user_message_entity.receiver_name + if user_message_entity.is_to_all: + recviver_id = "" + recviver_name = "all" + user_msg = UserMsg( + teamId=user_message_entity.team_id, + teamName=user_message_entity.team_name, + msgId=user_message_entity.id, + senderId=user_message_entity.sender_id, + senderName=user_message_entity.sender_name, + msgStatusToSender=UserMessageStatus( + user_message_entity.status_to_sender), + receiverId=recviver_id, + receiverName=recviver_name, + msgStatusToReceiver=UserMessageStatus( + user_message_entity.status_to_receiver), + msgType=UserMessageType(user_message_entity.type), + isEditable=is_editable, + createdTime=user_message_entity.created_time.strftime( + '%Y-%m-%d %H:%M') + ) + return user_msg + except Exception as e: + err = "用户消息实体转换为用户消息失败" + logging.exception("[Convertor] %s", err) + raise e + @staticmethod async def convert_default_role_dict_to_role_entity( team_id: uuid.UUID, default_role_dict: Dict[str, Any]) -> RoleEntity: @@ -209,6 +313,22 @@ class Convertor: logging.exception("[Convertor] %s", err) raise e + @staticmethod + async def convert_role_entity_to_role( + role_entity: RoleEntity) -> Role: + """将角色实体转换为角色""" + try: + role = Role( + roleId=role_entity.id, + roleName=role_entity.name, + typeActions=[] + ) + return role + except Exception as e: + err = "角色实体转换为角色失败" + logging.exception("[Convertor] %s", err) + raise e + @staticmethod async def convert_user_sub_role_id_and_team_id_to_user_role_entity( user_sub: str, role_id: uuid.UUID, team_id: uuid.UUID) -> UserRoleEntity: @@ -225,6 +345,29 @@ class Convertor: logging.exception("[Convertor] %s", err) raise e + @staticmethod + async def convert_user_sub_team_id_role_id_and_receiver_sub_to_user_message_entity( + user_sub: str, team_id: uuid.UUID, team_name: str, role_id: uuid.UUID, receiver_sub: str, is_to_all: bool, message: str, type: str) -> UserMessageEntity: + """将用户ID、团队ID和接收者ID转换为用户消息实体""" + try: + user_message_entity = UserMessageEntity( + team_id=team_id, + team_name=team_name, + role_id=role_id, + sender_id=user_sub, + sender_name=user_sub, + receiver_id=receiver_sub, + receiver_name=receiver_sub, + is_to_all=is_to_all, + message=message, + type=type, + ) + return user_message_entity + except Exception as e: + err = "用户ID、团队ID和接收者ID转换为用户消息实体失败" + logging.exception("[Convertor] %s", err) + raise e + @staticmethod async def convert_update_knowledge_base_request_to_dict( req: UpdateKnowledgeBaseRequest) -> dict: @@ -234,6 +377,8 @@ class Convertor: 'name': req.kb_name, 'description': req.description, 'tokenizer': req.tokenizer.value, + 'rerank_model': req.rerank_model, + 'spearating_characters': req.spearating_characters, 'upload_count_limit': req.upload_count_limit, 'upload_size_limit': req.upload_size_limit, 'default_parse_method': req.default_parse_method.value, @@ -256,6 +401,8 @@ class Convertor: authorName=knowledge_base_entity.author_name, tokenizer=knowledge_base_entity.tokenizer, embeddingModel=knowledge_base_entity.embedding_model, + rerankModel=knowledge_base_entity.rerank_model, + spearatingCharacters=knowledge_base_entity.spearating_characters, description=knowledge_base_entity.description, docCnt=knowledge_base_entity.doc_cnt, docSize=knowledge_base_entity.doc_size, @@ -263,7 +410,8 @@ class Convertor: uploadSizeLimit=knowledge_base_entity.upload_size_limit, defaultParseMethod=knowledge_base_entity.default_parse_method, defaultChunkSize=knowledge_base_entity.default_chunk_size, - createdTime=knowledge_base_entity.created_time.strftime('%Y-%m-%d %H:%M'), + createdTime=knowledge_base_entity.created_time.strftime( + '%Y-%m-%d %H:%M'), docTypes=[], ) return knowledge_base @@ -300,6 +448,8 @@ class Convertor: tokenizer=req.tokenizer.value, description=req.description, embedding_model=req.embedding_model, + rerank_model=req.rerank_model, + spearating_characters=req.spearating_characters, upload_count_limit=req.upload_count_limit, upload_size_limit=req.upload_size_limit, default_parse_method=req.default_parse_method.value, @@ -340,7 +490,8 @@ class Convertor: docName=document_entity.name, docType=document_type_response, chunkSize=document_entity.chunk_size, - createdTime=document_entity.created_time.strftime('%Y-%m-%d %H:%M'), + createdTime=document_entity.created_time.strftime( + '%Y-%m-%d %H:%M'), parseMethod=document_entity.parse_method, enabled=document_entity.enabled, authorName=document_entity.author_name, @@ -362,7 +513,8 @@ class Convertor: if task_report is not None: task_completed = task_report.current_stage/task_report.stage_cnt*100 if task_entity.status == TaskStatus.SUCCESS.value: - finished_time = task_report.created_time.strftime('%Y-%m-%d %H:%M') + finished_time = task_report.created_time.strftime( + '%Y-%m-%d %H:%M') task = Task( opId=task_entity.op_id, opName=task_entity.op_name, @@ -470,7 +622,8 @@ class Convertor: 'MAX_TOKENS': config['MAX_TOKENS'], 'TEMPERATURE': config['TEMPERATURE'] } - config_json = json.dumps(config_params, sort_keys=True, ensure_ascii=False).encode('utf-8') + config_json = json.dumps( + config_params, sort_keys=True, ensure_ascii=False).encode('utf-8') hash_object = hashlib.sha256(config_json) hash_hex = hash_object.hexdigest() llm = LLM( diff --git a/data_chain/apps/base/task/worker/acc_testing_worker.py b/data_chain/apps/base/task/worker/acc_testing_worker.py index 45ae92740338379506641f2db485b2834f1aa196..b6b3b0bd514b5c0dc8c09ddb63cd5bdc79ce6814 100644 --- a/data_chain/apps/base/task/worker/acc_testing_worker.py +++ b/data_chain/apps/base/task/worker/acc_testing_worker.py @@ -28,7 +28,7 @@ from data_chain.manager.testing_manager import TestingManager from data_chain.manager.testcase_manager import TestCaseManager from data_chain.manager.qa_manager import QAManager from data_chain.manager.task_queue_mamanger import TaskQueueManager -from data_chain.stores.database.database import TaskEntity, QAEntity, DataSetEntity, DataSetDocEntity, TestingEntity, TestCaseEntity +from data_chain.stores.database.database import TaskEntity, QAEntity, DataSetEntity, DataSetDocEntity, TestingEntity, TestCaseEntity, TaskQueueEntity from data_chain.stores.minio.minio import MinIO from data_chain.stores.mongodb.mongodb import Task from data_chain.config.config import config @@ -132,12 +132,16 @@ class TestingWorker(BaseWorker): return tmp_path @staticmethod - async def testing(testing_entity: TestingEntity, qa_entities: list[QAEntity], llm: LLM) -> list[TestCaseEntity]: + async def testing( + testing_entity: TestingEntity, qa_entities: list[QAEntity], + llm: LLM, language: str) -> list[TestCaseEntity]: '''测试数据集''' testcase_entities = [] with open(config['PROMPT_PATH'], 'r', encoding='utf-8') as f: prompt_dict = yaml.load(f, Loader=yaml.SafeLoader) - prompt_template = prompt_dict.get('LLM_PROMPT_TEMPLATE', '') + prompt_template = prompt_dict.get('LLM_PROMPT_TEMPLATE', {}) + prompt_template = prompt_template.get( + language, '') for qa_entity in qa_entities: question = qa_entity.question answer = qa_entity.answer @@ -145,7 +149,8 @@ class TestingWorker(BaseWorker): chunk_entities = await BaseSearcher.search(testing_entity.search_method, testing_entity.kb_id, question, top_k=testing_entity.top_k, doc_ids=None, banned_ids=[]) related_chunk_entities = [] banned_ids = [chunk_entity.id for chunk_entity in chunk_entities] - divide_tokens = llm.max_tokens // len(chunk_entities) if chunk_entities else llm.max_tokens + divide_tokens = llm.max_tokens // len( + chunk_entities) if chunk_entities else llm.max_tokens leave_tokens = 0 token_sum = 0 for chunk_entity in chunk_entities: @@ -175,36 +180,37 @@ class TestingWorker(BaseWorker): for chunk_entity in chunk_entities: sub_bac_info += chunk_entity.text bac_info += sub_bac_info+'\n' - bac_info = TokenTool.get_k_tokens_words_from_content(bac_info, llm.max_tokens//8*7) + bac_info = TokenTool.get_k_tokens_words_from_content( + bac_info, llm.max_tokens//8*7) prompt = prompt_template.format( bac_info=bac_info ) llm_answer = await llm.nostream([], prompt, question) - sub_socres = [] - pre = await TokenTool.cal_precision(question, answer, llm) + sub_scores = [] + pre = await TokenTool.cal_precision(question, llm_answer, llm, language) if pre != -1: - sub_socres.append(pre) - rec = await TokenTool.cal_recall(answer, llm_answer, llm) + sub_scores.append(pre) + rec = await TokenTool.cal_recall(answer, bac_info, llm, language) if rec != -1: - sub_socres.append(rec) - fai = await TokenTool.cal_faithfulness(question, llm_answer, bac_info, llm) + sub_scores.append(rec) + fai = await TokenTool.cal_faithfulness(question, llm_answer, bac_info, llm, language) if fai != -1: - sub_socres.append(fai) - rel = await TokenTool.cal_relevance(question, llm_answer, llm) + sub_scores.append(fai) + rel = await TokenTool.cal_relevance(question, llm_answer, llm, language) if rel != -1: - sub_socres.append(rel) + sub_scores.append(rel) lcs = TokenTool.cal_lcs(answer, llm_answer) if lcs != -1: - sub_socres.append(lcs) + sub_scores.append(lcs) leve = TokenTool.cal_leve(answer, llm_answer) if leve != -1: - sub_socres.append(leve) + sub_scores.append(leve) jac = TokenTool.cal_jac(answer, llm_answer) if jac != -1: - sub_socres.append(jac) + sub_scores.append(jac) score = -1 - if sub_socres: - score = sum(sub_socres) / len(sub_socres) + if sub_scores: + score = sum(sub_scores) / len(sub_scores) test_case_entity = TestCaseEntity( testing_id=testing_entity.id, question=question, @@ -325,7 +331,8 @@ class TestingWorker(BaseWorker): cleaned_value = invalid_chars.sub('', value) # 额外处理常见问题字符(如替换冒号、斜杠等) - problematic_chars = {'\\': '', '/': '', '*': '', '?': '', '"': "'", '<': '', '>': '', ':': ''} + problematic_chars = {'\\': '', '/': '', '*': '', + '?': '', '"': "'", '<': '', '>': '', ':': ''} for char, replacement in problematic_chars.items(): cleaned_value = cleaned_value.replace(char, replacement) @@ -369,12 +376,16 @@ class TestingWorker(BaseWorker): 'jac(杰卡德相似度)': [] } for test_case_entity in testcase_entities: - test_case_dict['question'].append(clean_value(test_case_entity.question)) - test_case_dict['answer'].append(clean_value(test_case_entity.answer)) + test_case_dict['question'].append( + clean_value(test_case_entity.question)) + test_case_dict['answer'].append( + clean_value(test_case_entity.answer)) test_case_dict['chunk'].append(clean_value(test_case_entity.chunk)) - test_case_dict['doc_name'].append(clean_value(test_case_entity.doc_name)) + test_case_dict['doc_name'].append( + clean_value(test_case_entity.doc_name)) test_case_dict['llm_answer'].append(test_case_entity.llm_answer) - test_case_dict['related_chunk'].append(test_case_entity.related_chunk) + test_case_dict['related_chunk'].append( + test_case_entity.related_chunk) test_case_dict['score(综合得分)'].append(test_case_entity.score) test_case_dict['pre(准确率)'].append(test_case_entity.pre) test_case_dict['rec(召回率)'].append(test_case_entity.rec) @@ -385,9 +396,12 @@ class TestingWorker(BaseWorker): test_case_dict['jac(杰卡德相似度)'].append(test_case_entity.jac) test_case_df = pd.DataFrame(test_case_dict) with pd.ExcelWriter(xlsx_path, engine='xlsxwriter') as writer: - model_config_df.to_excel(writer, sheet_name='config(配置)', index=False) - ave_result_df.to_excel(writer, sheet_name='ave_result(平均结果)', index=False) - test_case_df.to_excel(writer, sheet_name='test_case(测试结果)', index=False) + model_config_df.to_excel( + writer, sheet_name='config(配置)', index=False) + ave_result_df.to_excel( + writer, sheet_name='ave_result(平均结果)', index=False) + test_case_df.to_excel( + writer, sheet_name='test_case(测试结果)', index=False) await MinIO.put_object( TESTING_REPORT_PATH_IN_MINIO, str(testing_entity.id), @@ -421,7 +435,8 @@ class TestingWorker(BaseWorker): current_stage += 1 await TestingWorker.report(task_id, "初始化路径", current_stage, stage_cnt) qa_entities = await QAManager.list_all_qa_by_dataset_id(testing_entity.dataset_id) - testcase_entities = await TestingWorker.testing(testing_entity, qa_entities, llm) + knowledge_entity = await KnowledgeBaseManager.get_knowledge_base_by_kb_id(testing_entity.kb_id) + testcase_entities = await TestingWorker.testing(testing_entity, qa_entities, llm, knowledge_entity.tokenizer) current_stage += 1 await TestingWorker.report(task_id, "测试完成", current_stage, stage_cnt) testing_entity = await TestingWorker.update_testing_score(testing_entity.id, testcase_entities) @@ -431,11 +446,11 @@ class TestingWorker(BaseWorker): await TestingWorker.generate_report_and_upload_to_minio(dataset_entity, testing_entity, testcase_entities, tmp_path) current_stage += 1 await TestingWorker.report(task_id, "生成报告并上传到minio", current_stage, stage_cnt) - await TaskQueueManager.add_task(Task(_id=task_id, status=TaskStatus.SUCCESS.value)) + await TaskQueueManager.add_task(TaskQueueEntity(id=task_id, status=TaskStatus.SUCCESS.value)) except Exception as e: err = f"[TestingWorker] 任务失败,task_id: {task_id}, 错误信息: {e}" logging.exception(err) - await TaskQueueManager.add_task(Task(_id=task_id, status=TaskStatus.FAILED.value)) + await TaskQueueManager.add_task(TaskQueueEntity(id=task_id, status=TaskStatus.FAILED.value)) await TestingWorker.report(task_id, "任务失败", 0, 1) @staticmethod diff --git a/data_chain/apps/base/task/worker/export_dataset_worker.py b/data_chain/apps/base/task/worker/export_dataset_worker.py index 8d9b0812b11a3911629573f2fef7c34f21bd8465..082326a86ae6216814c7d38ca5e237e66748d948 100644 --- a/data_chain/apps/base/task/worker/export_dataset_worker.py +++ b/data_chain/apps/base/task/worker/export_dataset_worker.py @@ -23,7 +23,7 @@ from data_chain.manager.chunk_manager import ChunkManager from data_chain.manager.dataset_manager import DatasetManager from data_chain.manager.qa_manager import QAManager from data_chain.manager.task_queue_mamanger import TaskQueueManager -from data_chain.stores.database.database import TaskEntity, DocumentEntity, DocumentTypeEntity, QAEntity, DataSetEntity, DataSetDocEntity +from data_chain.stores.database.database import TaskEntity, DocumentEntity, DocumentTypeEntity, QAEntity, DataSetEntity, DataSetDocEntity, TaskQueueEntity from data_chain.stores.minio.minio import MinIO from data_chain.stores.mongodb.mongodb import Task @@ -190,11 +190,11 @@ class ExportDataSetWorker(BaseWorker): await ExportDataSetWorker.upload_file_to_minio(task_id, zip_path) current_stage += 1 await ExportDataSetWorker.report(task_id, "上传文件到minio", current_stage, stage_cnt) - await TaskQueueManager.add_task(Task(_id=task_id, status=TaskStatus.SUCCESS.value)) + await TaskQueueManager.add_task(TaskQueueEntity(id=task_id, status=TaskStatus.SUCCESS.value)) except Exception as e: err = f"[ExportDataSetWorker] 任务失败,task_id: {task_id}, 错误信息: {e}" logging.exception(err) - await TaskQueueManager.add_task(Task(_id=task_id, status=TaskStatus.FAILED.value)) + await TaskQueueManager.add_task(TaskQueueEntity(id=task_id, status=TaskStatus.FAILED.value)) await ExportDataSetWorker.report(task_id, "任务失败", 0, 1) @staticmethod diff --git a/data_chain/apps/base/task/worker/export_knowledge_base_worker.py b/data_chain/apps/base/task/worker/export_knowledge_base_worker.py index 5995debc839b13b94bca7c52ff13007013bce1de..2e50c55582b98ba7803f4d919193e76830feab51 100644 --- a/data_chain/apps/base/task/worker/export_knowledge_base_worker.py +++ b/data_chain/apps/base/task/worker/export_knowledge_base_worker.py @@ -13,7 +13,7 @@ from data_chain.manager.task_manager import TaskManager from data_chain.manager.knowledge_manager import KnowledgeBaseManager from data_chain.manager.document_manager import DocumentManager from data_chain.manager.task_queue_mamanger import TaskQueueManager -from data_chain.stores.database.database import TaskEntity, DocumentEntity +from data_chain.stores.database.database import TaskEntity, DocumentEntity, TaskQueueEntity from data_chain.stores.minio.minio import MinIO from data_chain.stores.mongodb.mongodb import Task @@ -197,11 +197,11 @@ class ExportKnowledgeBaseWorker(BaseWorker): await ExportKnowledgeBaseWorker.upload_zip_to_minio(zip_path, task_id) current_stage += 1 await ExportKnowledgeBaseWorker.report(task_id, "上传压缩包到minio", current_stage, stage_cnt) - await TaskQueueManager.add_task(Task(_id=task_id, status=TaskStatus.SUCCESS.value)) + await TaskQueueManager.add_task(TaskQueueEntity(id=task_id, status=TaskStatus.SUCCESS.value)) except Exception as e: err = f"[ExportKnowledgeBaseWorker] 运行任务失败,task_id: {task_id},错误信息: {e}" logging.exception(err) - await TaskQueueManager.add_task(Task(_id=task_id, status=TaskStatus.FAILED.value)) + await TaskQueueManager.add_task(TaskQueueEntity(id=task_id, status=TaskStatus.FAILED.value)) await ExportKnowledgeBaseWorker.report(task_id, err, 0, 1) @staticmethod diff --git a/data_chain/apps/base/task/worker/generate_dataset_worker.py b/data_chain/apps/base/task/worker/generate_dataset_worker.py index eb01850191b85f987584b370898b704d9e56f832..552cdac8b4ca8947e0f0e53dcbe9edafa40cceb1 100644 --- a/data_chain/apps/base/task/worker/generate_dataset_worker.py +++ b/data_chain/apps/base/task/worker/generate_dataset_worker.py @@ -15,12 +15,13 @@ from data_chain.entities.enum import TaskType, TaskStatus, KnowledgeBaseStatus, from data_chain.entities.common import DEFAULT_DOC_TYPE_ID from data_chain.parser.tools.token_tool import TokenTool from data_chain.manager.task_manager import TaskManager +from data_chain.manager.knowledge_manager import KnowledgeBaseManager from data_chain.manager.document_manager import DocumentManager from data_chain.manager.chunk_manager import ChunkManager from data_chain.manager.dataset_manager import DatasetManager from data_chain.manager.qa_manager import QAManager from data_chain.manager.task_queue_mamanger import TaskQueueManager -from data_chain.stores.database.database import TaskEntity, DocumentEntity, DocumentTypeEntity, QAEntity, DataSetEntity, DataSetDocEntity +from data_chain.stores.database.database import TaskEntity, DocumentEntity, DocumentTypeEntity, QAEntity, DataSetEntity, DataSetDocEntity, TaskQueueEntity from data_chain.stores.minio.minio import MinIO from data_chain.stores.mongodb.mongodb import Task @@ -116,7 +117,9 @@ class GenerateDataSetWorker(BaseWorker): return doc_chunks @staticmethod - async def generate_qa(dataset_entity: DataSetEntity, doc_chunks: list[DocChunk], llm: LLM) -> list[QAEntity]: + async def generate_qa( + dataset_entity: DataSetEntity, doc_chunks: list[DocChunk], + llm: LLM, language: str) -> list[QAEntity]: chunk_cnt = 0 for doc_chunk in doc_chunks: chunk_cnt += len(doc_chunk.chunks) @@ -138,9 +141,12 @@ class GenerateDataSetWorker(BaseWorker): random.shuffle(doc_chunks) with open(config['PROMPT_PATH'], 'r', encoding='utf-8') as f: prompt_dict = yaml.load(f, Loader=yaml.SafeLoader) - q_generate_prompt_template = prompt_dict.get('GENREATE_QUESTION_FROM_CONTENT_PROMPT', '') - answer_generate_prompt_template = prompt_dict.get('GENERATE_ANSWER_FROM_QUESTION_AND_CONTENT_PROMPT', '') - cal_qa_score_prompt_template = prompt_dict.get('CAL_QA_SCORE_PROMPT', '') + q_generate_prompt_template = prompt_dict.get('GENERATE_QUESTION_FROM_CONTENT_PROMPT', {}) + q_generate_prompt_template = q_generate_prompt_template.get(language, '') + answer_generate_prompt_template = prompt_dict.get('GENERATE_ANSWER_FROM_QUESTION_AND_CONTENT_PROMPT', {}) + answer_generate_prompt_template = answer_generate_prompt_template.get(language, '') + cal_qa_score_prompt_template = prompt_dict.get('CAL_QA_SCORE_PROMPT', {}) + cal_qa_score_prompt_template = cal_qa_score_prompt_template.get(language, '') dataset_score = 0 logging.error(f"{chunk_index_list}") exist_q_set = set() @@ -301,18 +307,19 @@ class GenerateDataSetWorker(BaseWorker): doc_chunks = await GenerateDataSetWorker.get_chunks(dataset_entity) current_stage += 1 await GenerateDataSetWorker.report(task_id, "获取文档分块信息", current_stage, stage_cnt) + knowlege_base_entity = await KnowledgeBaseManager.get_knowledge_base_by_kb_id(dataset_entity.kb_id) qa_entities = await GenerateDataSetWorker.generate_qa( - dataset_entity, doc_chunks, llm) + dataset_entity, doc_chunks, llm, knowlege_base_entity.tokenizer) current_stage += 1 await GenerateDataSetWorker.report(task_id, "生成QA", current_stage, stage_cnt) await GenerateDataSetWorker.add_qa_to_db(qa_entities) current_stage += 1 await GenerateDataSetWorker.report(task_id, "添加QA到数据库", current_stage, stage_cnt) - await TaskQueueManager.add_task(Task(_id=task_id, status=TaskStatus.SUCCESS.value)) + await TaskQueueManager.add_task(TaskQueueEntity(id=task_id, status=TaskStatus.SUCCESS.value)) except Exception as e: err = f"[GenerateDataSetWorker] 任务失败,task_id: {task_id},错误信息: {e}" logging.exception(err) - await TaskQueueManager.add_task(Task(_id=task_id, status=TaskStatus.FAILED.value)) + await TaskQueueManager.add_task(TaskQueueEntity(id=task_id, status=TaskStatus.FAILED.value)) await GenerateDataSetWorker.report(task_id, err, 0, 1) @staticmethod diff --git a/data_chain/apps/base/task/worker/import_dataset_worker.py b/data_chain/apps/base/task/worker/import_dataset_worker.py index 02451f56bc546f0ab8430bb53e04539e956c303f..86b9809295dadb7346e55cf741697e8904ba0c63 100644 --- a/data_chain/apps/base/task/worker/import_dataset_worker.py +++ b/data_chain/apps/base/task/worker/import_dataset_worker.py @@ -21,8 +21,9 @@ from data_chain.manager.task_manager import TaskManager from data_chain.manager.chunk_manager import ChunkManager from data_chain.manager.dataset_manager import DatasetManager from data_chain.manager.qa_manager import QAManager +from data_chain.manager.knowledge_manager import KnowledgeBaseManager from data_chain.manager.task_queue_mamanger import TaskQueueManager -from data_chain.stores.database.database import TaskEntity, DocumentEntity, DocumentTypeEntity, QAEntity, DataSetEntity, DataSetDocEntity +from data_chain.stores.database.database import TaskEntity, DocumentEntity, DocumentTypeEntity, QAEntity, DataSetEntity, DataSetDocEntity, TaskQueueEntity from data_chain.stores.minio.minio import MinIO from data_chain.stores.mongodb.mongodb import Task @@ -185,14 +186,15 @@ class ImportDataSetWorker(BaseWorker): return qa_entities @staticmethod - async def update_dataset_score(dataset_id: uuid.UUID, qa_entities: list[QAEntity], llm: LLM) -> None: + async def update_dataset_score(dataset_id: uuid.UUID, qa_entities: list[QAEntity], llm: LLM, language: str) -> None: '''更新数据集分数''' if not qa_entities: return databse_score = 0 with open(config['PROMPT_PATH'], 'r', encoding='utf-8') as f: prompt_dict = yaml.load(f, Loader=yaml.SafeLoader) - cal_qa_score_prompt_template = prompt_dict.get('CAL_QA_SCORE_PROMPT', '') + cal_qa_score_prompt_template = prompt_dict.get('CAL_QA_SCORE_PROMPT', {}) + cal_qa_score_prompt_template = cal_qa_score_prompt_template.get(language, '') for qa_entity in qa_entities: chunk = qa_entity.chunk question = qa_entity.question @@ -234,6 +236,7 @@ class ImportDataSetWorker(BaseWorker): await DatasetManager.update_dataset_by_dataset_id(dataset_entity.id, {"status": DataSetStatus.IMPORTING.value}) current_stage = 0 stage_cnt = 4 + knowlege_base_entity = await KnowledgeBaseManager.get_knowledge_base_by_kb_id(dataset_entity.kb_id) tmp_path = await ImportDataSetWorker.init_path(task_id) current_stage += 1 await ImportDataSetWorker.report(task_id, "初始化路径", current_stage, stage_cnt) @@ -243,14 +246,14 @@ class ImportDataSetWorker(BaseWorker): qa_entities = await ImportDataSetWorker.load_qa_entity_from_file(dataset_entity.id, file_path) current_stage += 1 await ImportDataSetWorker.report(task_id, "加载qa实体", current_stage, stage_cnt) - await ImportDataSetWorker.update_dataset_score(dataset_entity.id, qa_entities, llm) + await ImportDataSetWorker.update_dataset_score(dataset_entity.id, qa_entities, llm, knowlege_base_entity.tokenizer) current_stage += 1 await ImportDataSetWorker.report(task_id, "更新数据集分数", current_stage, stage_cnt) - await TaskQueueManager.add_task(Task(_id=task_id, status=TaskStatus.SUCCESS.value)) + await TaskQueueManager.add_task(TaskQueueEntity(id=task_id, status=TaskStatus.SUCCESS.value)) except Exception as e: err = f"[ImportDataSetWorker] 任务失败,task_id: {task_id},错误信息: {e}" logging.exception(err) - await TaskQueueManager.add_task(Task(_id=task_id, status=TaskStatus.FAILED.value)) + await TaskQueueManager.add_task(TaskQueueEntity(id=task_id, status=TaskStatus.FAILED.value)) await ImportDataSetWorker.report(task_id, "任务失败", 0, 1) @staticmethod diff --git a/data_chain/apps/base/task/worker/import_knowledge_base_worker.py b/data_chain/apps/base/task/worker/import_knowledge_base_worker.py index b7dcc5eaa49aa00de91761d8008c61d70b86651e..c85f0c98f9ef89659747019aad0c40a8e97583d7 100644 --- a/data_chain/apps/base/task/worker/import_knowledge_base_worker.py +++ b/data_chain/apps/base/task/worker/import_knowledge_base_worker.py @@ -15,7 +15,7 @@ from data_chain.manager.knowledge_manager import KnowledgeBaseManager from data_chain.manager.document_type_manager import DocumentTypeManager from data_chain.manager.document_manager import DocumentManager from data_chain.manager.task_queue_mamanger import TaskQueueManager -from data_chain.stores.database.database import TaskEntity, DocumentEntity, DocumentTypeEntity +from data_chain.stores.database.database import TaskEntity, DocumentEntity, DocumentTypeEntity, TaskQueueEntity from data_chain.stores.minio.minio import MinIO from data_chain.stores.mongodb.mongodb import Task @@ -223,11 +223,11 @@ class ImportKnowledgeBaseWorker(BaseWorker): await ImportKnowledgeBaseWorker.init_doc_parse_tasks(kb_id) current_stage += 1 await ImportKnowledgeBaseWorker.report(task_id, "初始化文档解析任务", current_stage, stage_cnt) - await TaskQueueManager.add_task(Task(_id=task_id, status=TaskStatus.SUCCESS.value)) + await TaskQueueManager.add_task(TaskQueueEntity(id=task_id, status=TaskStatus.SUCCESS.value)) except Exception as e: err = f"[ImportKnowledgeBaseWorker] 任务失败,task_id: {task_id},错误信息: {e}" logging.exception(err) - await TaskQueueManager.add_task(Task(_id=task_id, status=TaskStatus.FAILED.value)) + await TaskQueueManager.add_task(TaskQueueEntity(id=task_id, status=TaskStatus.FAILED.value)) await ImportKnowledgeBaseWorker.report(task_id, err, 0, 1) @staticmethod @@ -252,7 +252,7 @@ class ImportKnowledgeBaseWorker(BaseWorker): err = f"[ExportKnowledgeBaseWorker] 任务不存在,task_id: {task_id}" logging.exception(err) return None - if task_entity.status == TaskStatus.CANCLED or TaskStatus.FAILED.value: + if task_entity.status == TaskStatus.CANCLED.value or task_entity.status == TaskStatus.FAILED.value: await KnowledgeBaseManager.update_knowledge_base_by_kb_id(task_entity.op_id, {"status": KnowledgeBaseStatus.DELETED.value}) await MinIO.delete_object(IMPORT_KB_PATH_IN_MINIO, str(task_entity.op_id)) return task_id diff --git a/data_chain/apps/base/task/worker/parse_document_worker.py b/data_chain/apps/base/task/worker/parse_document_worker.py index 50a5996477379dce6c034bae893cd6e3f04c088e..38be5c804068390f6f860e0929a4e940fa3e94f3 100644 --- a/data_chain/apps/base/task/worker/parse_document_worker.py +++ b/data_chain/apps/base/task/worker/parse_document_worker.py @@ -8,6 +8,7 @@ import random import io import numpy as np from PIL import Image +import asyncio from data_chain.parser.tools.ocr_tool import OcrTool from data_chain.parser.tools.token_tool import TokenTool from data_chain.parser.tools.image_tool import ImageTool @@ -29,7 +30,7 @@ from data_chain.manager.chunk_manager import ChunkManager from data_chain.manager.image_manager import ImageManager from data_chain.manager.task_report_manager import TaskReportManager from data_chain.manager.task_queue_mamanger import TaskQueueManager -from data_chain.stores.database.database import TaskEntity, DocumentEntity, DocumentTypeEntity, ChunkEntity, ImageEntity +from data_chain.stores.database.database import TaskEntity, DocumentEntity, DocumentTypeEntity, ChunkEntity, ImageEntity, TaskQueueEntity from data_chain.stores.minio.minio import MinIO from data_chain.stores.mongodb.mongodb import Task @@ -114,7 +115,8 @@ class ParseDocumentWorker(BaseWorker): err = f"[ParseDocumentWorker] 文档不存在,doc_id: {doc_id}" logging.exception(err) raise Exception(err) - file_path = os.path.join(tmp_path, str(doc_id)+'.'+doc_entity.extension) + file_path = os.path.join(tmp_path, str( + doc_id)+'.'+doc_entity.extension) await MinIO.download_object( DOC_PATH_IN_MINIO, str(doc_entity.id), @@ -153,7 +155,8 @@ class ParseDocumentWorker(BaseWorker): return str(js) @staticmethod - async def handle_parse_result(parse_result: ParseResult, doc_entity: DocumentEntity, llm: LLM = None) -> None: + async def handle_parse_result( + parse_result: ParseResult, doc_entity: DocumentEntity, llm: LLM = None, language: str = "中文") -> None: '''处理解析结果''' if doc_entity.parse_method == ParseMethod.GENERAL.value or doc_entity.parse_method == ParseMethod.QA.value: nodes = [] @@ -186,7 +189,7 @@ class ParseDocumentWorker(BaseWorker): question = qa.get('question') answer = qa.get('answer') if question is None or answer is None: - warning = f"[ParseDocumentWorker] 解析问题和答案失败,doc_id: {doc_entity.id}, error: {e}" + warning = f"[ParseDocumentWorker] 解析问题和答案失败,doc_id: {doc_entity.id}, qa: {qa}" logging.warning(warning) continue node = ParseNode( @@ -219,9 +222,10 @@ class ParseDocumentWorker(BaseWorker): node.text_feature = node.content elif node.type == ChunkType.CODE: if llm is not None: - node.text_feature = await TokenTool.get_abstract_by_llm(node.content, llm) + node.text_feature = await TokenTool.get_abstract_by_llm(node.content, llm, language) if node.text_feature is None: - node.text_feature = TokenTool.get_top_k_keywords(node.content) + node.text_feature = TokenTool.get_top_k_keywords( + node.content) elif node.type == ChunkType.TABLE: content = node.content[:] for i in range(len(content)): @@ -248,7 +252,8 @@ class ParseDocumentWorker(BaseWorker): ) image_entities.append(image_entity) image_blob = node.content - image_file_path = os.path.join(image_path, str(node.id) + '.' + extension) + image_file_path = os.path.join( + image_path, str(node.id) + '.' + extension) with open(image_file_path, 'wb') as f: f.write(image_blob) await MinIO.put_object( @@ -271,30 +276,45 @@ class ParseDocumentWorker(BaseWorker): index += 1024 @staticmethod - async def ocr_from_parse_image(parse_result: ParseResult, llm: LLM = None) -> list: + async def ocr_from_parse_image( + parse_result: ParseResult, image_path: str, llm: LLM = None, language: str = '中文') -> None: '''从解析图片中获取ocr''' - for node in parse_result.nodes: + async def _ocr(node: ParseNode, language: str) -> None: + try: + image_related_text = '' + for related_node in node.link_nodes: + if related_node.type != ChunkType.IMAGE: + image_related_text += related_node.content + '\n' + extension = ImageTool.get_image_type(node.content) + image_file_path = os.path.join( + image_path, str(node.id) + '.' + extension) + ocr_result = (await OcrTool.image_to_text(image_file_path, image_related_text, llm, language)) + node.text_feature = ocr_result + node.content = ocr_result + except Exception as e: + err = f"[OCRTool] OCR识别失败: {e}" + logging.exception(err) + return None + + image_node_ids = [] + for i, node in enumerate(parse_result.nodes): if node.type == ChunkType.IMAGE: - try: - image_blob = node.content - image = Image.open(io.BytesIO(image_blob)) - img_np = np.array(image) - image_related_text = '' - for related_node in node.link_nodes: - if related_node.type != ChunkType.IMAGE: - image_related_text += related_node.content - node.content = await OcrTool.image_to_text(img_np, image_related_text, llm) - node.text_feature = node.content - except Exception as e: - err = f"[ParseDocumentWorker] OCR失败 error: {e}" - logging.exception(err) - continue + image_node_ids.append(i) + group_size = 5 + index = 0 + while index < len(image_node_ids): + sub_image_node_ids = image_node_ids[index:index + group_size] + task_list = [] + for node_id in sub_image_node_ids: + # 通过asyncio.create_task来异步执行OCR + node = parse_result.nodes[node_id] + task_list.append(asyncio.create_task(_ocr(node, language))) + await asyncio.gather(*task_list) + index += group_size @staticmethod - async def merge_and_split_text(parse_result: ParseResult, doc_entity: DocumentEntity) -> None: - '''合并和拆分内容''' - if doc_entity.parse_method == ParseMethod.QA or parse_result.parse_topology_type == DocParseRelutTopology.TREE: - return + async def merge_and_split_text_list(parse_result: ParseResult, doc_entity: DocumentEntity) -> None: + '''线性列表合并和拆分内容''' nodes = [] for node in parse_result.nodes: if node.type == ChunkType.TEXT: @@ -320,7 +340,8 @@ class ParseDocumentWorker(BaseWorker): if TokenTool.get_tokens(sentence) > doc_entity.chunk_size: tmp = sentence[:] while len(tmp) > 0: - sub_sentence = TokenTool.get_k_tokens_words_from_content(tmp, doc_entity.chunk_size) + sub_sentence = TokenTool.get_k_tokens_words_from_content( + tmp, doc_entity.chunk_size) new_sentences.append(sub_sentence) tmp = tmp[len(sub_sentence):] else: @@ -369,13 +390,128 @@ class ParseDocumentWorker(BaseWorker): parse_result.nodes = nodes @staticmethod - async def push_up_words_feature(parse_result: ParseResult, llm: LLM = None) -> None: + async def merge_and_split_text_tree(parse_result: ParseResult, doc_entity: DocumentEntity) -> None: + '''树形结构合并和拆分内容''' + async def dfs(node: ParseNode, doc_entity: DocumentEntity) -> None: + for cnode in node.link_nodes: + if cnode.parse_topology_type != ChunkParseTopology.TREELEAF: + await dfs(cnode, doc_entity) + new_nodes = [] + index = 0 + while index < len(node.link_nodes): + cnode = node.link_nodes[index] + if cnode.parse_topology_type != ChunkParseTopology.TREELEAF or cnode.type != ChunkType.TEXT: + new_nodes.append(cnode) + index += 1 + else: + content = '' + tmp_nodes = [] + while index < len(node.link_nodes) and node.link_nodes[index].parse_topology_type == ChunkParseTopology.TREELEAF and node.link_nodes[index].type == ChunkType.TEXT: + if node.link_nodes[index].is_need_newline: + content += '\n' + elif node.link_nodes[index].is_need_space: + content += ' ' + content += node.link_nodes[index].content + index += 1 + sentences = TokenTool.content_to_sentences(content) + new_sentences = [] + for sentence in sentences: + if TokenTool.get_tokens(sentence) > doc_entity.chunk_size: + tmp = sentence[:] + while len(tmp) > 0: + sub_sentence = TokenTool.get_k_tokens_words_from_content( + tmp, doc_entity.chunk_size) + new_sentences.append(sub_sentence) + tmp = tmp[len(sub_sentence):] + else: + new_sentences.append(sentence) + sentences = new_sentences + tmp = '' + for sentence in sentences: + if TokenTool.get_tokens(tmp+sentence) > doc_entity.chunk_size: + tmp_node = ParseNode( + id=uuid.uuid4(), + pre_id=cnode.pre_id, + lv=cnode.lv, + parse_topology_type=ChunkParseTopology.TREELEAF, + text_feature=tmp, + content=tmp, + type=ChunkType.TEXT, + link_nodes=[] + ) + tmp_nodes.append(tmp_node) + tmp = sentence + else: + tmp += sentence + if len(tmp) > 0: + tmp_node = ParseNode( + id=uuid.uuid4(), + pre_id=cnode.pre_id, + lv=cnode.lv, + parse_topology_type=ChunkParseTopology.TREELEAF, + text_feature=tmp, + content=tmp, + type=ChunkType.TEXT, + link_nodes=[] + ) + tmp_nodes.append(tmp_node) + new_nodes.extend(tmp_nodes) + node.link_nodes = new_nodes + + async def flatten(node: ParseNode, nodes: list) -> None: + nodes.append(node) + for cnode in node.link_nodes: + await flatten(cnode, nodes) + await dfs(parse_result.nodes[0], doc_entity) + nodes = [] + await flatten(parse_result.nodes[0], nodes) + parse_result.nodes = nodes + @staticmethod + async def merge_and_split_text_by_spearating_characters(parse_result: ParseResult, characters: str) -> None: + '''通过分隔符合并并拆分内容''' + content='' + for node in parse_result.nodes: + content+=node.content + parts = content.split(characters) + nodes = [] + for part in parts: + tmp_node = ParseNode( + id=uuid.uuid4(), + lv=0, + parse_topology_type=ChunkParseTopology.GERNERAL, + text_feature=part, + content=part, + type=ChunkType.TEXT, + link_nodes=[] + ) + nodes.append(tmp_node) + parse_result.parse_topology_type=DocParseRelutTopology.LIST + parse_result.nodes = nodes + async def merge_and_split_text(parse_result: ParseResult, doc_entity: DocumentEntity) -> None: + '''合并和拆分内容''' + kb_entity = await KnowledgeBaseManager.get_knowledge_base_by_kb_id(doc_entity.kb_id) + if kb_entity.spearating_characters is not None: + await ParseDocumentWorker.merge_and_split_text_by_spearating_characters(parse_result, kb_entity.spearating_characters) + return + if doc_entity.parse_method == ParseMethod.QA: + return + logging.error( + f"parse_result_parse_topology_type: {parse_result.parse_topology_type}") + if parse_result.parse_topology_type == DocParseRelutTopology.TREE: + await ParseDocumentWorker.merge_and_split_text_tree( + parse_result, doc_entity) + elif parse_result.parse_topology_type == DocParseRelutTopology.LIST or parse_result.parse_topology_type == DocParseRelutTopology.GRAPH: + await ParseDocumentWorker.merge_and_split_text_list( + parse_result, doc_entity) + + @staticmethod + async def push_up_words_feature(parse_result: ParseResult, llm: LLM = None, language: str = '中文') -> None: '''推送上层词特征''' - async def dfs(node: ParseNode, parent_node: ParseNode, llm: LLM = None) -> None: + async def dfs(node: ParseNode, parent_node: ParseNode, llm: LLM = None, language: str = '中文') -> None: if parent_node is not None: node.pre_id = parent_node.id for child_node in node.link_nodes: - await dfs(child_node, node, llm) + await dfs(child_node, node, llm, language) if node.title is not None: if len(node.title) == 0: if llm is not None: @@ -384,17 +520,19 @@ class ParseDocumentWorker(BaseWorker): if cnode.title: content += cnode.title + '\n' else: - sentences = TokenTool.get_top_k_keysentence(cnode.content, 1) + sentences = TokenTool.get_top_k_keysentence( + cnode.content, 1) if sentences: content += sentences[0] + '\n' if content: - title = await TokenTool.get_title_by_llm(content, llm) + title = await TokenTool.get_title_by_llm(content, llm, language) if "无法生成标题" in title: title = '' else: title = '' if not title: - sentences = TokenTool.get_top_k_keysentence(content, 1) + sentences = TokenTool.get_top_k_keysentence( + content, 1) if sentences: title = sentences[0] node.text_feature = title @@ -406,34 +544,49 @@ class ParseDocumentWorker(BaseWorker): node.text_feature = node.title node.content = node.text_feature if parse_result.parse_topology_type == DocParseRelutTopology.TREE: - await dfs(parse_result.nodes[0], None, llm) + await dfs(parse_result.nodes[0], None, llm, language) @staticmethod - async def update_doc_abstract(doc_id: uuid.UUID, parse_result: ParseResult, llm: LLM = None) -> str: - '''获取文档摘要''' - abstract = "" + async def update_doc_abstract_and_full_text( + doc_id: uuid.UUID, parse_result: ParseResult, llm: LLM = None, language: str = "中文") -> str: + '''获取文档摘要和全文''' + full_text = "" for node in parse_result.nodes: - abstract += node.content + full_text += node.content if llm is not None: - abstract = await TokenTool.get_abstract_by_llm(abstract, llm) + abstract = await TokenTool.get_abstract_by_llm(full_text, llm, language) else: - abstract = abstract[:128] + abstract = full_text[:128] abstract_vector = await Embedding.vectorize_embedding(abstract) await DocumentManager.update_document_by_doc_id( doc_id, { + "full_text": full_text, "abstract": abstract, "abstract_vector": abstract_vector } ) + await DocumentManager.update_document_abstract_ts_vector_by_doc_ids([doc_id]) return abstract @staticmethod async def embedding_chunk(parse_result: ParseResult) -> None: '''嵌入chunk''' - for node in parse_result.nodes: + async def _embedding(node: ParseNode) -> None: node.vector = await Embedding.vectorize_embedding(node.text_feature) + group_size = 32 + index = 0 + while index < len(parse_result.nodes): + sub_nodes = parse_result.nodes[index:index + group_size] + task_list = [] + for node in sub_nodes: + # 与OCR代码风格保持一致 + task_list.append(asyncio.create_task(_embedding(node))) + # 直接await任务集合 + await asyncio.gather(*task_list) + index += group_size + @staticmethod async def add_parse_result_to_db(parse_result: ParseResult, doc_entity: DocumentEntity) -> None: '''添加解析结果到数据库''' @@ -475,6 +628,9 @@ class ParseDocumentWorker(BaseWorker): while index < len(chunk_entities): try: await ChunkManager.add_chunks(chunk_entities[index:index+1024]) + sub_chunk_ids = [ + chunk.id for chunk in chunk_entities[index:index+1024]] + await ChunkManager.update_chunk_text_ts_vector_by_chunk_ids(sub_chunk_ids) except Exception as e: err = f"[ParseDocumentWorker] 添加解析结果到数据库失败,doc_id: {doc_entity.id}, error: {e}" logging.exception(err) @@ -507,38 +663,40 @@ class ParseDocumentWorker(BaseWorker): tmp_path, image_path = await ParseDocumentWorker.init_path(task_id) current_stage = 0 stage_cnt = 10 + knowledge_base_entity = await KnowledgeBaseManager.get_knowledge_base_by_kb_id(doc_entity.kb_id) await ParseDocumentWorker.download_doc_from_minio(task_entity.op_id, tmp_path) current_stage += 1 await ParseDocumentWorker.report(task_id, '下载文档', current_stage, stage_cnt) - file_path = os.path.join(tmp_path, str(task_entity.op_id)+'.'+doc_entity.extension) + file_path = os.path.join(tmp_path, str( + task_entity.op_id)+'.'+doc_entity.extension) parse_result = await ParseDocumentWorker.parse_doc(doc_entity, file_path) current_stage += 1 await ParseDocumentWorker.report(task_id, '解析文档', current_stage, stage_cnt) - await ParseDocumentWorker.handle_parse_result(parse_result, doc_entity, llm) + await ParseDocumentWorker.handle_parse_result(parse_result, doc_entity, llm, knowledge_base_entity.tokenizer) current_stage += 1 await ParseDocumentWorker.report(task_id, '处理解析结果', current_stage, stage_cnt) await ParseDocumentWorker.upload_parse_image_to_minio_and_postgres(parse_result, doc_entity, image_path) current_stage += 1 await ParseDocumentWorker.report(task_id, '上传解析图片', current_stage, stage_cnt) - await ParseDocumentWorker.ocr_from_parse_image(parse_result, llm) + await ParseDocumentWorker.ocr_from_parse_image(parse_result, image_path, llm, knowledge_base_entity.tokenizer) current_stage += 1 await ParseDocumentWorker.report(task_id, 'OCR图片', current_stage, stage_cnt) await ParseDocumentWorker.merge_and_split_text(parse_result, doc_entity) current_stage += 1 await ParseDocumentWorker.report(task_id, '合并和拆分文本', current_stage, stage_cnt) - await ParseDocumentWorker.push_up_words_feature(parse_result, llm) + await ParseDocumentWorker.push_up_words_feature(parse_result, llm, knowledge_base_entity.tokenizer) current_stage += 1 await ParseDocumentWorker.report(task_id, '推送上层词特征', current_stage, stage_cnt) await ParseDocumentWorker.embedding_chunk(parse_result) current_stage += 1 await ParseDocumentWorker.report(task_id, '嵌入chunk', current_stage, stage_cnt) - await ParseDocumentWorker.update_doc_abstract(doc_entity.id, parse_result, llm) + await ParseDocumentWorker.update_doc_abstract_and_full_text(doc_entity.id, parse_result, llm, knowledge_base_entity.tokenizer) current_stage += 1 - await ParseDocumentWorker.report(task_id, '更新文档摘要', current_stage, stage_cnt) + await ParseDocumentWorker.report(task_id, '更新文档摘要和全文', current_stage, stage_cnt) await ParseDocumentWorker.add_parse_result_to_db(parse_result, doc_entity) current_stage += 1 await ParseDocumentWorker.report(task_id, '添加解析结果到数据库', current_stage, stage_cnt) - await TaskQueueManager.add_task(Task(_id=task_id, status=TaskStatus.SUCCESS.value)) + await TaskQueueManager.add_task(TaskQueueEntity(id=task_id, status=TaskStatus.SUCCESS.value)) task_report = await ParseDocumentWorker.assemble_task_report(task_id) report_path = os.path.join(tmp_path, 'task_report.txt') with open(report_path, 'w') as f: @@ -551,7 +709,7 @@ class ParseDocumentWorker(BaseWorker): except Exception as e: err = f"[DocParseWorker] 任务失败,task_id: {task_id},错误信息: {e}" logging.exception(err) - await TaskQueueManager.add_task(Task(_id=task_id, status=TaskStatus.FAILED.value)) + await TaskQueueManager.add_task(TaskQueueEntity(id=task_id, status=TaskStatus.FAILED.value)) await ParseDocumentWorker.report(task_id, err, 0, 1) task_report = await ParseDocumentWorker.assemble_task_report(task_id) report_path = os.path.join(tmp_path, 'task_report.txt') diff --git a/data_chain/apps/router/acc_testing.py b/data_chain/apps/router/acc_testing.py index e5dc7b953f0856e5bf4ad0b4c9320ca441c66f48..1381bb7a20b0a3fc3c4b84c36df955452d1a706e 100644 --- a/data_chain/apps/router/acc_testing.py +++ b/data_chain/apps/router/acc_testing.py @@ -6,6 +6,7 @@ import urllib from uuid import UUID from httpx import AsyncClient from typing import Annotated +from data_chain.entities.enum import IdType from data_chain.entities.request_data import ( ListTestingRequest, ListTestCaseRequest, @@ -20,6 +21,7 @@ from data_chain.entities.response_data import ( UpdateTestingResponse, DeleteTestingResponse ) +from data_chain.apps.service.team_service import TeamService from data_chain.apps.service.knwoledge_base_service import KnowledgeBaseService from data_chain.apps.service.dataset_service import DataSetService from data_chain.apps.service.acc_testing_service import TestingService @@ -37,6 +39,7 @@ async def list_testing_by_kb_id( if not (await KnowledgeBaseService.validate_user_action_to_knowledge_base(user_sub, req.kb_id, action)): raise Exception("用户没有权限访问该知识库的测试") list_testing_msg = await TestingService.list_testing_by_kb_id(req) + await TeamService.add_team_msg(user_sub, req.kb_id, IdType.KNOWLEDGE_BASE, '查看了知识库{kbName}的测试集列表', 'knowledge base {kbName} Testing list viewed') return ListTestingResponse(result=list_testing_msg) @@ -49,6 +52,7 @@ async def list_testcase_by_testing_id( if not (await TestingService.validate_user_action_to_testing(user_sub, req.testing_id, action)): raise Exception("用户没有权限访问该测试的测试用例") testing_testcase = await TestingService.list_testcase_by_testing_id(req) + await TeamService.add_team_msg(user_sub, req.testing_id, IdType.TESTING, '知识库{kbName}的测试集{testingName}的测试用例', 'knowledge base {kbName} Testing {testingName} test case viewed') return ListTestCaseResponse(result=testing_testcase) @@ -76,6 +80,7 @@ async def download_testing_report_by_testing_id( }, media_type="application/" + extension) else: raise Exception(f"下载测试报告失败,状态码: {response.status_code}") + await TeamService.add_team_msg(user_sub, testing_id, IdType.TESTING, '下载了知识库{kbName}的测试{testingName}的测试报告', 'knowledge base {kbName} Testing {testingName} report downloaded') @router.post( @@ -87,6 +92,7 @@ async def create_testing( if not (await DataSetService.validate_user_action_to_dataset(user_sub, req.dataset_id, action)): raise Exception("用户没有权限访问该数据集的测试") task_id = await TestingService.create_testing(user_sub, req) + await TeamService.add_team_msg(user_sub, req.kb_id, IdType.KNOWLEDGE_BASE, '创建了知识库{kbName}的测试', 'knowledge base {kbName} Testing created') return CreateTestingResponsing(result=task_id) @@ -100,6 +106,7 @@ async def run_testing_by_testing_id( if not (await TestingService.validate_user_action_to_testing(user_sub, testing_id, action)): raise Exception("用户没有权限访问该测试的测试用例") task_id = await TestingService.run_testing_by_testing_id(testing_id, run) + await TeamService.add_team_msg(user_sub, testing_id, IdType.TESTING, '运行了知识库{kbName}的测试{testingName}', 'knowledge base {kbName} Testing {testingName} run') return RunTestingResponse(result=task_id) @@ -113,6 +120,7 @@ async def update_testing_by_testing_id( if not (await TestingService.validate_user_action_to_testing(user_sub, testing_id, action)): raise Exception("用户没有权限访问该测试的测试用例") testing_id = await TestingService.update_testing_by_testing_id(testing_id, req) + await TeamService.add_team_msg(user_sub, testing_id, IdType.TESTING, '更新了{kbName}的测试{testingName}', 'knowledge base {kbName} Testing {testingName} updated') return UpdateTestingResponse(result=testing_id) @@ -125,5 +133,6 @@ async def delete_testing_by_testing_ids( for testing_id in testing_ids: if not (await TestingService.validate_user_action_to_testing(user_sub, testing_id, action)): raise Exception("用户没有权限访问该测试的测试用例") + await TeamService.add_team_msg(user_sub, testing_ids[0], IdType.TESTING, '删除了知识库{kbName}的测试{testingName}', 'knowledge base {kbName} Testing {testingName} deleted') testing_ids = await TestingService.delete_testing_by_testing_ids(testing_ids) return DeleteTestingResponse(result=testing_ids) diff --git a/data_chain/apps/router/chunk.py b/data_chain/apps/router/chunk.py index 8b0970c801f2120c01c2888937249b3f112133cb..4c7eb60e778d7157174ea72f571bbd06d115a369 100644 --- a/data_chain/apps/router/chunk.py +++ b/data_chain/apps/router/chunk.py @@ -16,7 +16,9 @@ from data_chain.entities.response_data import ( UpdateChunkResponse, UpdateChunkEnabledResponse ) +from data_chain.entities.enum import IdType from data_chain.apps.service.router_service import get_route_info +from data_chain.apps.service.team_service import TeamService from data_chain.apps.service.document_service import DocumentService from data_chain.apps.service.chunk_service import ChunkService router = APIRouter(prefix='/chunk', tags=['Chunk']) @@ -31,6 +33,7 @@ async def list_chunks_by_document_id( if not (await DocumentService.validate_user_action_to_document(user_sub, req.doc_id, action)): raise Exception("用户没有权限访问该文档的分片") list_chunk_msg = await ChunkService.list_chunks_by_document_id(req) + await TeamService.add_team_msg(user_sub, req.doc_id, IdType.DOCUMENT, '查看了知识库{kbName}的文档《{docName}》的分片列表', 'knowledge base {kbName} Document {docName} chunk list viewed') return ListChunkResponse(result=list_chunk_msg) @@ -52,6 +55,7 @@ async def update_chunk_by_id(user_sub: Annotated[str, Depends(get_user_sub)], if not (await ChunkService.validate_user_action_to_chunk(user_sub, chunk_id, action)): raise Exception("用户没有权限访问该文档的分片") chunk_id = await ChunkService.update_chunk_by_id(chunk_id, req) + await TeamService.add_team_msg(user_sub, chunk_id, IdType.DOCUMENT, '更新了知识库{kbName}的文档《{docName}》的分片', 'knowledge base {kbName} Document {docName} chunk updated') return UpdateChunkResponse(result=chunk_id) @@ -64,4 +68,6 @@ async def update_chunk_enabled_by_id(user_sub: Annotated[str, Depends(get_user_s if not (await ChunkService.validate_user_action_to_chunk(user_sub, chunk_id, action)): raise Exception("用户没有权限访问该文档的分片") chunk_ids = await ChunkService.update_chunks_enabled_by_id(chunk_ids, enabled) + for chunk_id in chunk_ids: + await TeamService.add_team_msg(user_sub, chunk_id, IdType.DOCUMENT, '更新了知识库{kbName}的文档《{docName}》的分片启用状态', 'knowledge base {kbName} Document {docName} chunk enabled status updated') return UpdateChunkEnabledResponse(result=chunk_ids) diff --git a/data_chain/apps/router/dataset.py b/data_chain/apps/router/dataset.py index a16b678a30f3c78e63e097f657de0e24a193893f..e10a35264630a7493c3a882f5e69160fbc6983d0 100644 --- a/data_chain/apps/router/dataset.py +++ b/data_chain/apps/router/dataset.py @@ -26,6 +26,8 @@ from data_chain.entities.response_data import ( DeleteDatasetResponse, DeleteDataResponse ) +from data_chain.entities.enum import IdType +from data_chain.apps.service.team_service import TeamService from data_chain.apps.service.knwoledge_base_service import KnowledgeBaseService from data_chain.apps.service.dataset_service import DataSetService from data_chain.apps.service.task_service import TaskService @@ -43,6 +45,7 @@ async def list_dataset_by_kb_id( if not (await KnowledgeBaseService.validate_user_action_to_knowledge_base(user_sub, req.kb_id, action)): raise Exception("用户没有权限访问该知识库的数据集") list_dataset_msg = await DataSetService.list_dataset_by_kb_id(req) + await TeamService.add_team_msg(user_sub, req.kb_id, IdType.KNOWLEDGE_BASE, '查看了知识库{kbName}的数据集列表', 'knowledge base {kbName} Dataset list viewed') return ListDatasetResponse(result=list_dataset_msg) @@ -54,6 +57,7 @@ async def list_data_in_dataset( if not (await DataSetService.validate_user_action_to_dataset(user_sub, req.dataset_id, action)): raise Exception("用户没有权限访问该数据集的数据") list_data_in_dataset_msg = await DataSetService.list_data_in_dataset(req) + await TeamService.add_team_msg(user_sub, req.dataset_id, IdType.DATASET, '查看了知识库{kbName}的数据集{datasetName}的数据列表', 'knowledge base {kbName} Dataset {datasetName} data list viewed') return ListDataInDatasetResponse(result=list_data_in_dataset_msg) @@ -93,6 +97,7 @@ async def download_dataset_by_task_id( }, media_type="application/" + extension) else: raise Exception(f"下载数据集失败,状态码: {response.status_code}") + await TeamService.add_team_msg(user_sub, task_id, IdType.TASK, '下载了知识库{kbName}的数据集{datasetName}', 'knowledge base {kbName} Dataset {datasetName} downloaded') @router.post('', response_model=CreateDatasetResponse, dependencies=[Depends(verify_user)]) @@ -104,6 +109,7 @@ async def create_dataset( if not (await KnowledgeBaseService.validate_user_action_to_knowledge_base(user_sub, req.kb_id, action)): raise Exception("用户没有权限访问该知识库的数据集") task_id = await DataSetService.create_dataset(user_sub, req) + await TeamService.add_team_msg(user_sub, req.kb_id, IdType.KNOWLEDGE_BASE, '创建了知识库{kbName}的数据集{datasetName}', 'knowledge base {kbName} Dataset {datasetName} created') return CreateDatasetResponse(result=task_id) @@ -115,6 +121,7 @@ async def import_dataset(user_sub: Annotated[str, Depends(get_user_sub)], if not (await KnowledgeBaseService.validate_user_action_to_knowledge_base(user_sub, kb_id, action)): raise Exception("用户没有权限在该知识库导入数据集") dataset_import_task_ids = await DataSetService.import_dataset(user_sub, kb_id, dataset_packages) + await TeamService.add_team_msg(user_sub, kb_id, IdType.KNOWLEDGE_BASE, '导入了知识库{kbName}的数据集', 'knowledge base {kbName} Dataset imported') return ImportDatasetResponse(result=dataset_import_task_ids) @@ -127,6 +134,8 @@ async def export_dataset_by_dataset_ids( if not (await DataSetService.validate_user_action_to_dataset(user_sub, dataset_id, action)): raise Exception("用户没有权限访问该数据集的数据") dataset_export_task_ids = await DataSetService.export_dataset(dataset_ids) + for dataset_id in dataset_ids: + await TeamService.add_team_msg(user_sub, dataset_id, IdType.DATASET, '导出了知识库{kbName}的数据集{datasetName}', 'knowledge base {kbName} Dataset {datasetName} exported') return ExportDatasetResponse(result=dataset_export_task_ids) @@ -139,6 +148,7 @@ async def generate_dataset_by_id( if not (await DataSetService.validate_user_action_to_dataset(user_sub, dataset_id, action)): raise Exception("用户没有权限访问该数据集") dataset_id = await DataSetService.generate_dataset_by_id(dataset_id, generate) + await TeamService.add_team_msg(user_sub, dataset_id, IdType.DATASET, '生成了知识库{kbName}的数据集{datasetName}', 'knowledge base {kbName} Dataset {datasetName} generated') return GenerateDatasetResponse(result=dataset_id) @@ -151,6 +161,7 @@ async def update_dataset_by_dataset_id( if not (await DataSetService.validate_user_action_to_dataset(user_sub, database_id, action)): raise Exception("用户没有权限访问该数据集") database_id = await DataSetService.update_dataset_by_dataset_id(database_id, req) + await TeamService.add_team_msg(user_sub, database_id, IdType.DATASET, '更新了知识库{kbName}的数据集{datasetName}', 'knowledge base {kbName} Dataset {datasetName} updated') return UpdateDatasetResponse(result=database_id) @@ -163,6 +174,7 @@ async def update_data_by_dataset_id( if not (await DataSetService.validate_user_action_to_data(user_sub, data_id, action)): raise Exception("用户没有权限访问该数据集的数据") data_id = await DataSetService.update_data(data_id, req) + await TeamService.add_team_msg(user_sub, data_id, IdType.DATASET, '更新了知识库{kbName}的数据集{datasetName}的数据', 'knowledge base {kbName} Dataset {datasetName} data updated') return UpdateDataResponse() @@ -175,6 +187,8 @@ async def delete_dataset_by_dataset_ids( if not (await DataSetService.validate_user_action_to_dataset(user_sub, database_id, action)): raise Exception("用户没有权限访问该数据集") dataset_ids = await DataSetService.delete_dataset_by_dataset_ids(database_ids) + for dataset_id in dataset_ids: + await TeamService.add_team_msg(user_sub, dataset_id, IdType.DATASET, '删除了知识库{kbName}的数据集{datasetName}', 'knowledge base {kbName} Dataset {datasetName} deleted') return DeleteDatasetResponse(result=dataset_ids) @@ -186,5 +200,7 @@ async def delete_data_by_data_ids( for data_id in data_ids: if not (await DataSetService.validate_user_action_to_data(user_sub, data_id, action)): raise Exception("用户没有权限访问该数据集的数据") + for data_id in data_ids: + await TeamService.add_team_msg(user_sub, data_id, IdType.DATASET_DATA, '删除了知识库{kbName}的数据集{datasetName}的数据', 'knowledge base {kbName} Dataset {datasetName} data deleted') await DataSetService.delete_data_by_data_ids(data_ids) return DeleteDataResponse() diff --git a/data_chain/apps/router/document.py b/data_chain/apps/router/document.py index ee5921f8f9e2f9b5380ffe9c4512b86da338b1b0..e25c28a536b5b873abe3d72e3f8b8ded84826d52 100644 --- a/data_chain/apps/router/document.py +++ b/data_chain/apps/router/document.py @@ -8,6 +8,7 @@ from uuid import UUID from httpx import AsyncClient from typing import Annotated from uuid import UUID +from data_chain.entities.enum import IdType from data_chain.entities.request_data import ( ListDocumentRequest, UpdateDocumentRequest, @@ -22,14 +23,17 @@ from data_chain.entities.response_data import ( GetDocumentReportResponse, UploadDocumentResponse, ParseDocumentResponse, + ParseDocumentRealTimeResponse, UpdateDocumentResponse, DeleteDocumentResponse, GetTemporaryDocumentStatusResponse, UploadTemporaryDocumentResponse, + GetTemporaryDocumentTextResponse, DeleteTemporaryDocumentResponse ) from data_chain.apps.service.session_service import get_user_sub, verify_user from data_chain.apps.service.router_service import get_route_info +from data_chain.apps.service.team_service import TeamService from data_chain.apps.service.knwoledge_base_service import KnowledgeBaseService from data_chain.apps.service.document_service import DocumentService router = APIRouter(prefix='/doc', tags=['Document']) @@ -44,6 +48,7 @@ async def list_doc( if not (await KnowledgeBaseService.validate_user_action_to_knowledge_base(user_sub, req.kb_id, action)): raise Exception("用户没有权限访问该知识库的文档") list_document_msg = await DocumentService.list_doc(req) + await TeamService.add_team_msg(user_sub, req.kb_id, IdType.KNOWLEDGE_BASE, '查看了知识库{kbName}的文档列表', 'knowledge base {kbName} Document list viewed') return ListDocumentResponse(result=list_document_msg) @@ -71,6 +76,7 @@ async def download_doc_by_id( }, media_type="application/" + extension) else: raise Exception(f"下载文档失败,状态码: {response.status_code}") + await TeamService.add_team_msg(user_sub, doc_id, IdType.DOCUMENT, '下载了知识库{kbName}的文档《{documentName}》', 'knowledge base {kbName} Document <{documentName}> downloaded') @router.get('/report', response_model=GetDocumentReportResponse, dependencies=[Depends(verify_user)]) @@ -81,6 +87,7 @@ async def get_doc_report( if not (await DocumentService.validate_user_action_to_document(user_sub, doc_id, action)): raise Exception("用户没有权限访问该文档") task_report = await DocumentService.get_doc_report(doc_id) + await TeamService.add_team_msg(user_sub, doc_id, IdType.DOCUMENT, '查看了知识库{kbName}的文档《{documentName}》的解析报告', 'knowledge base {kbName} Document <{documentName}> report viewed') return GetDocumentReportResponse(result=task_report) @@ -109,6 +116,7 @@ async def download_doc_report( }, media_type="application/" + extension) else: raise Exception(f"下载文档报告失败,状态码: {response.status_code}") + await TeamService.add_team_msg(user_sub, doc_id, IdType.DOCUMENT, '下载了知识库{kbName}的文档《{documentName}》的解析报告', 'knowledge base {kbName} Document <{documentName}> report downloaded') @router.post('', response_model=UploadDocumentResponse, dependencies=[Depends(verify_user)]) @@ -120,6 +128,7 @@ async def upload_docs( if not (await KnowledgeBaseService.validate_user_action_to_knowledge_base(user_sub, kb_id, action)): raise Exception("用户没有权限上传文档到该知识库") doc_ids = await DocumentService.upload_docs(user_sub, kb_id, docs) + await TeamService.add_team_msg(user_sub, kb_id, IdType.KNOWLEDGE_BASE, '往{kbName}上传了文档') return UploadDocumentResponse(result=doc_ids) @@ -133,9 +142,20 @@ async def parse_docuement_by_doc_ids( if not (await DocumentService.validate_user_action_to_document(user_sub, doc_id, action)): raise Exception("用户没有权限解析该文档") doc_ids = await DocumentService.parse_docs(doc_ids, parse) + for doc_id in doc_ids: + await TeamService.add_team_msg(user_sub, doc_id, IdType.DOCUMENT, '解析了知识库{kbName}的文档《{documentName}》', 'knowledge base {kbName} Document <{documentName}> parsed') return ParseDocumentResponse(result=doc_ids) +@router.post('/metadata', response_model=ParseDocumentRealTimeResponse, dependencies=[Depends(verify_user)]) +async def parse_docuement_realtime( + user_sub: Annotated[str, Depends(get_user_sub)], + docs: list[UploadFile] = File(...) +): + doc_contents = await DocumentService.parse_docs_realtime(docs) + return ParseDocumentRealTimeResponse(result=doc_contents) + + @router.put('', response_model=UpdateDocumentResponse, dependencies=[Depends(verify_user)]) async def update_doc_by_doc_id( user_sub: Annotated[str, Depends(get_user_sub)], @@ -145,6 +165,7 @@ async def update_doc_by_doc_id( if not (await DocumentService.validate_user_action_to_document(user_sub, doc_id, action)): raise Exception("用户没有权限更新该文档") doc_id = await DocumentService.update_doc(doc_id, req) + await TeamService.add_team_msg(user_sub, doc_id, IdType.DOCUMENT, '更新了知识库{kbName}的文档《{documentName}》', 'knowledge base {kbName} Document <{documentName}> updated') return UpdateDocumentResponse(result=doc_id) @@ -156,6 +177,8 @@ async def delete_docs_by_ids( for doc_id in doc_ids: if not (await DocumentService.validate_user_action_to_document(user_sub, doc_id, action)): raise Exception("用户没有权限删除该文档") + for doc_id in doc_ids: + await TeamService.add_team_msg(user_sub, doc_id, IdType.DOCUMENT, '删除了{kbName}的文档{documentName}') await DocumentService.delete_docs_by_ids(doc_ids) return DeleteDocumentResponse(result=doc_ids) @@ -177,6 +200,15 @@ async def upload_temporary_docs( return UploadTemporaryDocumentResponse(result=doc_ids) +@router.get('/temporary/text', response_model=GetTemporaryDocumentTextResponse, + dependencies=[Depends(verify_user)]) +async def get_temporary_docs_text( + user_sub: Annotated[str, Depends(get_user_sub)], + id: Annotated[UUID, Query()]): + doc_text = await DocumentService.get_temporary_doc_text(user_sub, id) + return GetTemporaryDocumentTextResponse(result=doc_text) + + @router.post('/temporary/delete', response_model=DeleteTemporaryDocumentResponse, dependencies=[Depends(verify_user)]) async def delete_temporary_docs( user_sub: Annotated[str, Depends(get_user_sub)], diff --git a/data_chain/apps/router/knowledge_base.py b/data_chain/apps/router/knowledge_base.py index 65476b1801ca6a09730ee5344f497e7ea9079d71..91ab9c483b414231618abee448623b9a9d3e3817 100644 --- a/data_chain/apps/router/knowledge_base.py +++ b/data_chain/apps/router/knowledge_base.py @@ -11,7 +11,7 @@ from data_chain.entities.request_data import ( CreateKnowledgeBaseRequest, UpdateKnowledgeBaseRequest, ) - +from data_chain.entities.enum import IdType from data_chain.entities.response_data import ( ListAllKnowledgeBaseMsg, ListAllKnowledgeBaseResponse, @@ -51,6 +51,7 @@ async def list_kb_by_team_id( if not await TeamService.validate_user_action_in_team(user_sub, req.team_id, action): raise Exception("用户没有权限访问该团队的知识库") list_kb_msg = await KnowledgeBaseService.list_kb_by_team_id(req) + await TeamService.add_team_msg(user_sub, req.team_id, IdType.TEAM, '查看了知识库列表', 'knowledge base list viewed') return ListKnowledgeBaseResponse(result=list_kb_msg) @@ -63,6 +64,7 @@ async def list_doc_types_by_kb_id( if not await KnowledgeBaseService.validate_user_action_to_knowledge_base(user_sub, kb_id, action): raise Exception("用户没有权限访问该知识库的文档类型") doc_types = await KnowledgeBaseService.list_doc_types_by_kb_id(kb_id) + await TeamService.add_team_msg(user_sub, kb_id, IdType.KNOWLEDGE_BASE, '查看了知识库{kbName}的文档类型', 'knowledge base {kbName} Document types viewed') return ListDocumentTypesResponse(result=doc_types) @@ -108,6 +110,7 @@ async def create_kb(user_sub: Annotated[str, Depends(get_user_sub)], if not await TeamService.validate_user_action_in_team(user_sub, team_id, action): raise Exception("用户没有权限在该团队创建知识库") kb_id = await KnowledgeBaseService.create_kb(user_sub, team_id, req) + await TeamService.add_team_msg(user_sub, kb_id, IdType.KNOWLEDGE_BASE, '创建了知识库{kbName}', 'knowledge base {kbName} created') return CreateKnowledgeBaseResponse(result=kb_id) @@ -119,6 +122,7 @@ async def import_kbs(user_sub: Annotated[str, Depends(get_user_sub)], if not await TeamService.validate_user_action_in_team(user_sub, team_id, action): raise Exception("用户没有权限在该团队导入知识库") kb_import_task_ids = await KnowledgeBaseService.import_kbs(user_sub, team_id, kb_packages) + await TeamService.add_team_msg(user_sub, team_id, IdType.TEAM, '导入了知识库', 'knowledge base imported') return ImportKnowledgeBaseResponse(result=kb_import_task_ids) @@ -131,6 +135,8 @@ async def export_kb_by_kb_ids( if not await KnowledgeBaseService.validate_user_action_to_knowledge_base(user_sub, kb_id, action): raise Exception("用户没有权限在该知识库导出知识库") kb_export_task_ids = await KnowledgeBaseService.export_kb_by_kb_ids(kb_ids) + for kb_id in kb_ids: + await TeamService.add_team_msg(user_sub, kb_id, IdType.KNOWLEDGE_BASE, '导出了知识库{kbName}', 'knowledge base {kbName} exported') return ExportKnowledgeBaseResponse(result=kb_export_task_ids) @@ -143,6 +149,7 @@ async def update_kb_by_kb_id( if not await KnowledgeBaseService.validate_user_action_to_knowledge_base(user_sub, kb_id, action): raise Exception("用户没有权限在该知识库更新知识库") kb_id = await KnowledgeBaseService.update_kb_by_kb_id(kb_id, req) + await TeamService.add_team_msg(user_sub, kb_id, IdType.KNOWLEDGE_BASE, '更新了知识库{kbName}', 'knowledge base {kbName} updated') return UpdateKnowledgeBaseResponse(result=kb_id) @@ -154,5 +161,7 @@ async def delete_kb_by_kb_ids( for kb_id in kb_ids: if not await KnowledgeBaseService.validate_user_action_to_knowledge_base(user_sub, kb_id, action): raise Exception("用户没有权限在该知识库删除知识库") + for kb_id in kb_ids: + await TeamService.add_team_msg(user_sub, kb_id, IdType.KNOWLEDGE_BASE, '删除了知识库{kbName}', 'knowledge base {kbName} deleted') kb_ids_deleted = await KnowledgeBaseService.delete_kb_by_kb_ids(kb_ids) return DeleteKnowledgeBaseResponse(result=kb_ids_deleted) diff --git a/data_chain/apps/router/other.py b/data_chain/apps/router/other.py index 24f9d001a069d4bc8fe4059392c7a5f78e21d1bd..2bcf71ad240ef69df96161f1d48c31a6917394b9 100644 --- a/data_chain/apps/router/other.py +++ b/data_chain/apps/router/other.py @@ -12,6 +12,7 @@ from data_chain.entities.response_data import ( ListLLMMsg, ListLLMResponse, ListEmbeddingResponse, + ListRerankResponse, ListTokenizerResponse, ListParseMethodResponse, ListSearchMethodResponse @@ -55,6 +56,10 @@ async def list_embeddings(): embeddings = [config['EMBEDDING_MODEL_NAME']] return ListEmbeddingResponse(result=embeddings) +@router.get('/rerank', response_model=ListRerankResponse, dependencies=[Depends(verify_user)]) +async def list_reranks(): + reranks = [config['RERANK_MODEL_NAME']] + return ListRerankResponse(result=reranks) @router.get('/tokenizer', response_model=ListTokenizerResponse, dependencies=[Depends(verify_user)]) async def list_tokenizers(): diff --git a/data_chain/apps/router/role.py b/data_chain/apps/router/role.py index 52711563253241e98f42cd5445a01ae86495ccca..1b3f31ced64cf2853ed1976424d5a10f44879085 100644 --- a/data_chain/apps/router/role.py +++ b/data_chain/apps/router/role.py @@ -3,6 +3,7 @@ from fastapi import APIRouter, Depends, Query, Body from typing import Annotated from uuid import UUID +from data_chain.entities.enum import LanguageType, IdType from data_chain.entities.request_data import ( ListRoleRequest, CreateRoleRequest, @@ -11,11 +12,14 @@ from data_chain.entities.request_data import ( from data_chain.entities.response_data import ( ListActionResponse, + GetUserRoleResponse, ListRoleResponse, CreateRoleResponse, UpdateRoleResponse, DeleteRoleResponse ) +from data_chain.apps.service.role_service import RoleService +from data_chain.apps.service.team_service import TeamService from data_chain.apps.service.session_service import get_user_sub, verify_user from data_chain.apps.service.router_service import get_route_info router = APIRouter(prefix='/role', tags=['Role']) @@ -24,35 +28,67 @@ router = APIRouter(prefix='/role', tags=['Role']) @router.get('/action', response_model=ListActionResponse, dependencies=[Depends(verify_user)]) async def list_actions( user_sub: Annotated[str, Depends(get_user_sub)], + language: Annotated[LanguageType, Query( + alias="language")] = LanguageType.CHINESE ): - return ListActionResponse() + list_action_msg = await RoleService.list_actions(language) + return ListActionResponse(message='操作列表获取成功', result=list_action_msg) + + +@router.get('', response_model=GetUserRoleResponse, dependencies=[Depends(verify_user)]) +async def get_user_role( + user_sub: Annotated[str, Depends(get_user_sub)], + team_id: Annotated[UUID, Query(alias="teamId")] +): + user_role_msg = await RoleService.get_user_role_in_team(user_sub, team_id) + return GetUserRoleResponse(message='用户角色获取成功', result=user_role_msg) @router.post('/list', response_model=ListRoleResponse, dependencies=[Depends(verify_user)]) -async def list_role_by_team_id( +async def list_roles( user_sub: Annotated[str, Depends(get_user_sub)], + action: Annotated[str, Depends(get_route_info)], req: Annotated[ListRoleRequest, Body()], ): - return ListRoleResponse() + if not (await TeamService.validate_user_action_in_team(user_sub, req.team_id, action)): + raise Exception('用户没有权限查看该团队角色') + list_role_msg = await RoleService.list_roles(req) + await TeamService.add_team_msg(user_sub, req.team_id, IdType.TEAM, '查看了角色列表') + return ListRoleResponse(message='角色列表获取成功', result=list_role_msg) @router.post('', response_model=CreateRoleResponse, dependencies=[Depends(verify_user)]) async def create_role(user_sub: Annotated[str, Depends(get_user_sub)], - team_id: Annotated[UUID, Query(alias="TeamId")], + action: Annotated[str, Depends(get_route_info)], + team_id: Annotated[UUID, Query(alias="teamId")], req: Annotated[CreateRoleRequest, Body()]): - return CreateRoleResponse() + if not (await TeamService.validate_user_action_in_team(user_sub, team_id, action)): + raise Exception('用户没有权限创建该团队角色') + role_id = await RoleService.create_role(team_id, req) + await TeamService.add_team_msg(user_sub, role_id, IdType.ROLE, '创建了{teamName}的角色{roleName}') + return CreateRoleResponse(message='角色创建成功', result=role_id) @router.put('', response_model=UpdateRoleResponse, dependencies=[Depends(verify_user)]) async def update_role_by_role_id( user_sub: Annotated[str, Depends(get_user_sub)], + action: Annotated[str, Depends(get_route_info)], role_id: Annotated[UUID, Query(alias="roleId")], req: Annotated[UpdateRoleRequest, Body()]): - return UpdateRoleResponse() + if not (await RoleService.validate_user_action_to_role(user_sub, role_id, action)): + raise Exception('用户没有权限修改该团队角色') + role_id = await RoleService.update_role(role_id, req) + await TeamService.add_team_msg(user_sub, role_id, IdType.ROLE, '更新了{teamName}的角色{roleName}') + return UpdateRoleResponse(message='角色更新成功', result=role_id) @router.delete('', response_model=DeleteRoleResponse, dependencies=[Depends(verify_user)]) async def delete_role_by_role_ids( user_sub: Annotated[str, Depends(get_user_sub)], - role_ids: Annotated[list[UUID], Body(alias="roleId")]): - return DeleteRoleResponse() + action: Annotated[str, Depends(get_route_info)], + role_id: Annotated[UUID, Query(alias="roleId")]): + if not (await RoleService.validate_user_action_to_role(user_sub, role_id, action)): + raise Exception('用户没有权限删除该团队角色') + role_id = await RoleService.delete_role(role_id) + await TeamService.add_team_msg(user_sub, role_id, IdType.ROLE, '删除了{teamName}的角色{roleName}') + return DeleteRoleResponse(message='角色删除成功', result=role_id) diff --git a/data_chain/apps/router/task.py b/data_chain/apps/router/task.py index f24d9595dc780539a0f355901276c2b88a73d1f4..181bde3bf3940dd6ad913c7428cfba3997ee899d 100644 --- a/data_chain/apps/router/task.py +++ b/data_chain/apps/router/task.py @@ -40,8 +40,8 @@ async def delete_task_by_task_id( ): if not (await TaskService.validate_user_action_to_task(user_sub, task_id, action)): raise Exception("用户没有权限访问该团队的任务") - task_ids = await TaskService.delete_task_by_task_id(task_id) - return DeleteTaskByIdResponse() + task_id = await TaskService.delete_task_by_task_id(task_id) + return DeleteTaskByIdResponse(message='任务删除成功', result=task_id) @router.delete('/all', response_model=DeleteTaskByTypeResponse, dependencies=[Depends(verify_user)]) @@ -54,4 +54,4 @@ async def delete_task_by_task_type( if not (await TeamService.validate_user_action_in_team(user_sub, team_id, action)): raise Exception("用户没有权限访问该团队的任务") task_ids = await TaskService.delete_task_by_type(user_sub, team_id, task_type) - return DeleteTaskByTypeResponse() + return DeleteTaskByTypeResponse(message='任务删除成功', result=task_ids) diff --git a/data_chain/apps/router/team.py b/data_chain/apps/router/team.py index 9a171d7a0237febaffd1520b9a4f61d3cc4c7586..ab9546e6e65853dd24ed9e96942ff741a3173d10 100644 --- a/data_chain/apps/router/team.py +++ b/data_chain/apps/router/team.py @@ -9,8 +9,9 @@ from data_chain.entities.request_data import ( ListTeamUserRequest, CreateTeamRequest, UpdateTeamRequest, + DetleteTeamUserRequest ) - +from data_chain.entities.enum import IdType from data_chain.entities.response_data import ( ListTeamMsg, ListTeamResponse, @@ -46,7 +47,10 @@ async def list_team_user_by_team_id( user_sub: Annotated[str, Depends(get_user_sub)], action: Annotated[str, Depends(get_route_info)], req: Annotated[ListTeamUserRequest, Body()]): - return ListTeamUserResponse() + if not (await TeamService.validate_user_action_in_team(user_sub, req.team_id, action)): + raise Exception('用户没有权限查看该团队成员') + list_team_user_msg = await TeamService.list_team_users(req) + return ListTeamUserResponse(message='团队成员列表获取成功', result=list_team_user_msg) @router.post('/msg', response_model=ListTeamMsgResponse, dependencies=[Depends(verify_user)]) @@ -54,7 +58,10 @@ async def list_team_msg_by_team_id( user_sub: Annotated[str, Depends(get_user_sub)], action: Annotated[str, Depends(get_route_info)], req: Annotated[ListTeamMsgRequest, Body()]): - return ListTeamMsgResponse() + if not (await TeamService.validate_user_action_in_team(user_sub, req.team_id, action)): + raise Exception('用户没有权限查看该团队消息') + list_team_msg = await TeamService.list_team_msg_by_team_id(req) + return ListTeamMsgResponse(message='团队消息列表获取成功', result=list_team_msg) @router.post('', response_model=CreateTeamResponse, dependencies=[Depends(verify_user)]) @@ -70,16 +77,20 @@ async def invite_team_user_by_user_sub( user_sub: Annotated[str, Depends(get_user_sub)], action: Annotated[str, Depends(get_route_info)], team_id: Annotated[UUID, Query(alias="teamId")], + role_id: Annotated[UUID, Query(alias="roleId")], user_sub_invite: Annotated[str, Query(alias="userSubInvite")]): - return InviteTeamUserResponse() + if not (await TeamService.validate_user_action_in_team(user_sub, team_id, action)): + raise Exception('用户没有权限邀请该团队成员') + user_sub_invite = await TeamService.invite_team_user(user_sub, team_id, role_id, user_sub_invite) + return InviteTeamUserResponse(message='团队成员邀请成功', result=user_sub_invite) @router.post('/application', response_model=JoinTeamResponse, dependencies=[Depends(verify_user)]) -async def join_team( +async def apply_to_join_team( user_sub: Annotated[str, Depends(get_user_sub)], - action: Annotated[str, Depends(get_route_info)], team_id: Annotated[UUID, Query(alias="teamId")]): - return JoinTeamResponse() + user_sub = await TeamService.apply_to_join_team(user_sub, team_id) + return JoinTeamResponse(message='团队加入申请发送成功', result=user_sub) @router.put('', response_model=UpdateTeamResponse, dependencies=[Depends(verify_user)]) @@ -90,26 +101,37 @@ async def update_team_by_team_id( req: Annotated[UpdateTeamRequest, Body()]): if not (await TeamService.validate_user_action_in_team(user_sub, team_id, action)): raise Exception('用户没有权限修改该团队') - team_id = await TeamService.update_team_by_team_id(user_sub, team_id, req) + team_id = await TeamService.update_team_by_team_id(team_id, req) + await TeamService.add_team_msg(user_sub, team_id, IdType.TEAM, '更新了团队信息', 'team info updated') return UpdateTeamResponse(message='团队更新成功', result=team_id) @router.put('/usr', response_model=UpdateTeamUserRoleResponse, dependencies=[Depends(verify_user)]) -async def update_team_by_team_id( - user_sub: Annotated[str, Depends(get_user_sub)], - action: Annotated[str, Depends(get_route_info)], - team_id: Annotated[UUID, Query(alias="teamId")], - role_id: Annotated[UUID, Query(alias="roleId")]): - return UpdateTeamUserRoleResponse() +async def update_usr_role_by_team_id_and_user_sub( + user_sub: Annotated[str, Depends(get_user_sub)], + action: Annotated[str, Depends(get_route_info)], + team_id: Annotated[UUID, Query(alias="teamId")], + target_user_sub: Annotated[str, Query(alias="targetUserSub")], + role_id: Annotated[UUID, Query(alias="roleId")], +): + if not (await TeamService.validate_user_action_in_team(user_sub, team_id, action)): + raise Exception('用户没有权限修改该团队成员角色') + target_user_sub = await TeamService.update_team_user_role_by_team_id_and_user_sub(user_sub, team_id, target_user_sub, role_id) + await TeamService.add_team_msg(user_sub, team_id, IdType.USER, '更新了成员{targetUserName}的角色', 'user {targetUserName} role updated', targetUserName=target_user_sub) + return UpdateTeamUserRoleResponse(message='团队成员角色更新成功', result=target_user_sub) @router.put('/author', response_model=UpdateTeamAuthorResponse, dependencies=[Depends(verify_user)]) async def update_team_author_by_team_id( user_sub: Annotated[str, Depends(get_user_sub)], action: Annotated[str, Depends(get_route_info)], - recriver_sub: Annotated[str, Query(alias="recriverSub")], + target_user_sub: Annotated[str, Query(alias="targetUserSub")], team_id: Annotated[UUID, Query(alias="teamId")]): - return UpdateTeamAuthorResponse() + if not (await TeamService.validate_user_action_in_team(user_sub, team_id, action)): + raise Exception('用户没有权限转让该团队') + team_id = await TeamService.update_team_author_by_team_id(user_sub, team_id, target_user_sub) + await TeamService.add_team_msg(user_sub, team_id, IdType.USER, '将团队转让给了{targetUserName}', 'team transferred to {targetUserName}', targetUserName=target_user_sub) + return UpdateTeamAuthorResponse(message='团队转让成功', result=team_id) @router.delete('', response_model=DeleteTeamResponse, dependencies=[Depends(verify_user)]) @@ -127,6 +149,16 @@ async def delete_team_by_team_id( async def delete_team_user_by_team_id_and_user_subs( user_sub: Annotated[str, Depends(get_user_sub)], action: Annotated[str, Depends(get_route_info)], - team_id: Annotated[UUID, Query(alias="teamId")], - user_subs: Annotated[list[str], Body(alias="userSub")]): - return DeleteTeamUserResponse() + req: Annotated[DetleteTeamUserRequest, Body()]): + flag = await TeamService.validate_user_action_in_team(user_sub, req.team_id, action) + if len(req.user_subs) == 1 and req.user_subs[0] == user_sub: + flag = True + if not flag: + raise Exception('用户没有权限删除该团队成员') + user_subs = await TeamService.delete_team_user_by_team_id_and_user_subs(req.team_id, req.user_subs) + if len(req.user_subs) == 1 and req.user_subs[0] == user_sub: + await TeamService.add_team_msg(user_sub, req.team_id, IdType.TEAM, '退出了团队', 'left the team') + else: + for target_user_sub in req.user_subs: + await TeamService.add_team_msg(user_sub, req.team_id, IdType.USER, '将成员{targetUserName}移出了团队', 'user {targetUserName} removed from team', targetUserName=target_user_sub) + return DeleteTeamUserResponse(message='团队成员删除成功', result=user_subs) diff --git a/data_chain/apps/router/user.py b/data_chain/apps/router/user.py index 96cb6023b928fed87b9b503b4988544880dfcaf7..e754aba5367e13f3d61e196f0f2029a57599da43 100644 --- a/data_chain/apps/router/user.py +++ b/data_chain/apps/router/user.py @@ -11,7 +11,7 @@ from data_chain.entities.response_data import ( ) from data_chain.apps.service.session_service import get_user_sub, verify_user from data_chain.apps.service.router_service import get_route_info - +from data_chain.apps.service.user_service import UserService router = APIRouter( prefix="/user", tags=["User"] @@ -20,7 +20,8 @@ router = APIRouter( @router.post("/list", response_model=ListUserResponse, dependencies=[Depends(verify_user)]) async def list_users( - user_sub: Annotated[str, Query(default=None, alias="userSub")], + user_sub: Annotated[str, Depends(get_user_sub)], req: Annotated[ListUserRequest, Body()] ): - return ListUserResponse() + list_user_msg = await UserService.list_users(req) + return ListUserResponse(message="用户列表获取成功", result=list_user_msg) diff --git a/data_chain/apps/router/usr_message.py b/data_chain/apps/router/usr_message.py index 5018442dc2bac5d951d46966cb75a4829b8aa452..46ad20995dce4db7810b00ae565eb48484d96731 100644 --- a/data_chain/apps/router/usr_message.py +++ b/data_chain/apps/router/usr_message.py @@ -3,12 +3,16 @@ from fastapi import APIRouter, Depends, Query, Body from typing import Annotated from uuid import UUID -from data_chain.entities.enum import UserMessageType, UserStatus +from data_chain.entities.enum import UserMessageType, UserStatus, UserMessageStatus +from data_chain.entities.request_data import ( + ListUserMessageRequest +) from data_chain.entities.response_data import ( ListUserMessageResponse, UpdateUserMessageResponse, DeleteUserMessageResponse ) +from data_chain.apps.service.user_message_service import UserMessageService from data_chain.apps.service.session_service import get_user_sub, verify_user from data_chain.apps.service.router_service import get_route_info router = APIRouter(prefix='/usr_msg', tags=['User Message']) @@ -17,21 +21,25 @@ router = APIRouter(prefix='/usr_msg', tags=['User Message']) @router.post('/list', response_model=ListUserMessageResponse, dependencies=[Depends(verify_user)]) async def list_user_msgs_by_user_sub( user_sub: Annotated[str, Depends(get_user_sub)], - msg_type: Annotated[UserMessageType, Query(alias="msgType")], + req: Annotated[ListUserMessageRequest, Body()] ): - return ListUserMessageResponse() + list_user_message = await UserMessageService.list_user_messages(user_sub, req) + return ListUserMessageResponse(message='用户消息列表获取成功', result=list_user_message) @router.put('', response_model=UpdateUserMessageResponse, dependencies=[Depends(verify_user)]) async def update_user_msg_by_msg_id( user_sub: Annotated[str, Depends(get_user_sub)], + action: Annotated[str, Depends(get_route_info)], msg_id: Annotated[UUID, Query(alias="msgId")], - msg_status: Annotated[UserStatus, Query(alias="msgStatus")]): - return UpdateUserMessageResponse() + msg_status: Annotated[UserMessageStatus, Query(alias="msgStatus")]): + msg_id = await UserMessageService.update_user_message(user_sub, msg_id, msg_status) + return UpdateUserMessageResponse(message='用户消息更新成功', result=msg_id) @router.delete('', response_model=DeleteUserMessageResponse, dependencies=[Depends(verify_user)]) async def delete_user_msg_by_msg_ids( user_sub: Annotated[str, Depends(get_user_sub)], - msg_ids: Annotated[list[UUID], Body(alias="msgIds")]): - return DeleteUserMessageResponse() + msg_id: Annotated[UUID, Query(alias="msgId")]): + msg_id = await UserMessageService.delete_user_messages(user_sub, msg_id) + return DeleteUserMessageResponse(message='用户消息删除成功', result=msg_id) diff --git a/data_chain/apps/service/acc_testing_service.py b/data_chain/apps/service/acc_testing_service.py index 184f25fda8d595b9a37a0379962a34cf40711313..a92149678c6ca1588b4e1008a9547d1db4a5d9f9 100644 --- a/data_chain/apps/service/acc_testing_service.py +++ b/data_chain/apps/service/acc_testing_service.py @@ -26,6 +26,7 @@ from data_chain.manager.testing_manager import TestingManager from data_chain.manager.testcase_manager import TestCaseManager from data_chain.manager.team_manager import TeamManager from data_chain.manager.role_manager import RoleManager +from data_chain.manager.team_message_manager import TeamMessageManager from data_chain.stores.minio.minio import MinIO from data_chain.entities.enum import TestingStatus, TaskType, TaskStatus from data_chain.entities.common import TESTING_REPORT_PATH_IN_MINIO @@ -56,9 +57,11 @@ class TestingService: try: total, dataset_entities = await TestingManager.list_testing_unique_datasets(req) dataset_entities.sort(key=lambda x: x.created_at, reverse=True) - dataset_ids = [dataset_entity.id for dataset_entity in dataset_entities] + dataset_ids = [ + dataset_entity.id for dataset_entity in dataset_entities] dataset_entities = await DatasetManager.list_datasets_by_dataset_ids(dataset_ids) - dataset_dict = {dataset_entity.id: dataset_entity for dataset_entity in dataset_entities} + dataset_dict = { + dataset_entity.id: dataset_entity for dataset_entity in dataset_entities} dataset_testings = [] llm = await Convertor.convert_llm_config_to_llm() testing_ids = [] @@ -67,7 +70,8 @@ class TestingService: for testing_entity in testing_entities: if testing_entity.dataset_id not in dataset_testing_dict: dataset_testing_dict[testing_entity.dataset_id] = [] - dataset_testing_dict[testing_entity.dataset_id].append(testing_entity) + dataset_testing_dict[testing_entity.dataset_id].append( + testing_entity) for dataset_id in dataset_ids: dataset_entity = dataset_dict.get(dataset_id) testing_entities = await TestingManager.list_testing_by_dataset_id(dataset_id) @@ -87,12 +91,14 @@ class TestingService: task_report_entities = await TaskReportManager.list_current_task_report_by_task_ids( [task.id for task in task_entities] ) - task_report_dict = {task_report.task_id: task_report for task_report in task_report_entities} + task_report_dict = { + task_report.task_id: task_report for task_report in task_report_entities} for dataset_testing in dataset_testings: for testing in dataset_testing.testings: task_entity = task_dict.get(testing.testing_id, None) if task_entity: - task_report_entity = task_report_dict.get(task_entity.id, None) + task_report_entity = task_report_dict.get( + task_entity.id, None) task = await Convertor.convert_task_entity_to_task(task_entity, task_report_entity) testing.testing_task = task @@ -204,7 +210,8 @@ class TestingService: for task_entity in task_entities: await TaskQueueService.stop_task(task_entity.id) testing_entities = await TestingManager.update_testing_by_testing_ids(testing_ids, {"status": TestingStatus.DELETED.value}) - testing_ids = [testing_entity.id for testing_entity in testing_entities] + testing_ids = [ + testing_entity.id for testing_entity in testing_entities] return testing_ids except Exception as e: err = "删除测试失败" diff --git a/data_chain/apps/service/chunk_service.py b/data_chain/apps/service/chunk_service.py index d1b706461a3e191cdfdb9b9403ef0fdc7ba5cecd..a533a03cd8c9d0784ffc35ace9a26b7039b6bc82 100644 --- a/data_chain/apps/service/chunk_service.py +++ b/data_chain/apps/service/chunk_service.py @@ -81,7 +81,24 @@ class ChunkService: chunk_entities = [] for kb_id in req.kb_ids: try: - chunk_entities += await BaseSearcher.search(req.search_method.value, kb_id, req.query, req.top_k, req.doc_ids, req.banned_ids) + kb_entity = await KnowledgeBaseManager.get_knowledge_base_by_kb_id(kb_id) + if kb_entity is None: + err = f"知识库不存在,知识库ID: {kb_id}" + logging.warning("[ChunkService] %s", err) + continue + if kb_id!=DEFAULT_KNOWLEDGE_BASE_ID and not await KnowledgeBaseService.validate_user_action_to_knowledge_base(user_sub, kb_id, action): + err = f"用户没有权限访问该知识库,知识库ID: {kb_id}" + logging.warning("[ChunkService] %s", err) + continue + top_k=req.top_k + if req.is_rerank: + top_k=req.top_k*3 + sub_chunk_entities= await BaseSearcher.search(req.search_method.value, kb_id, req.query, top_k, req.doc_ids, req.banned_ids) + if req.is_rerank: + sub_chunk_indexs = await BaseSearcher.rerank(sub_chunk_entities,kb_entity.rerank_model, req.query) + sub_chunk_entities = [sub_chunk_entities[i] for i in sub_chunk_indexs] + sub_chunk_entities = sub_chunk_entities[:req.top_k] + chunk_entities += sub_chunk_entities except Exception as e: err = f"[ChunkService] 搜索分片失败,error: {e}" logging.exception(err) @@ -89,14 +106,16 @@ class ChunkService: if len(chunk_entities) == 0: return SearchChunkMsg(docChunks=[]) if req.is_rerank: - chunk_entities = await BaseSearcher.rerank(chunk_entities, req.query) + chunk_indexs = await BaseSearcher.rerank(chunk_entities,None, req.query) + chunk_entities = [chunk_entities[i] for i in chunk_indexs] chunk_entities = chunk_entities[:req.top_k] chunk_ids = [chunk_entity.id for chunk_entity in chunk_entities] logging.error("[ChunkService] 搜索分片,查询结果数量: %s", len(chunk_entities)) if req.is_related_surrounding: # 关联上下文 tokens_limit = req.tokens_limit - tokens_limit_every_chunk = tokens_limit // len(chunk_entities) if len(chunk_entities) > 0 else tokens_limit + tokens_limit_every_chunk = tokens_limit // len( + chunk_entities) if len(chunk_entities) > 0 else tokens_limit leave_tokens = 0 related_chunk_entities = [] token_sum = 0 @@ -121,7 +140,6 @@ class ChunkService: if token_sum >= tokens_limit: break chunk_entities += related_chunk_entities - logging.error(len(chunk_entities)) search_chunk_msg = SearchChunkMsg(docChunks=[]) if req.is_classify_by_doc: doc_chunks = await BaseSearcher.classify_by_doc_id(chunk_entities) @@ -135,13 +153,17 @@ class ChunkService: chunk = await Convertor.convert_chunk_entity_to_chunk(chunk_entity) if req.is_compress: chunk.text = TokenTool.compress_tokens(chunk.text) - dc = DocChunk(docId=chunk_entity.doc_id, docName=chunk_entity.doc_name, chunks=[chunk]) + dc = DocChunk(docId=chunk_entity.doc_id, + docName=chunk_entity.doc_name, chunks=[chunk]) search_chunk_msg.doc_chunks.append(dc) doc_entities = await DocumentManager.list_document_by_doc_ids( [doc_chunk.doc_id for doc_chunk in search_chunk_msg.doc_chunks]) doc_map = {doc_entity.id: doc_entity for doc_entity in doc_entities} for doc_chunk in search_chunk_msg.doc_chunks: doc_entity = doc_map.get(doc_chunk.doc_id) + doc_chunk.doc_author = doc_entity.author_name if doc_entity else "" + doc_chunk.doc_created_at = doc_entity.created_time.strftime( + '%Y-%m-%d %H:%M') if doc_entity else "" doc_chunk.doc_abstract = doc_entity.abstract if doc_entity else "" doc_chunk.doc_extension = doc_entity.extension if doc_entity else "" doc_chunk.doc_size = doc_entity.size if doc_entity else 0 @@ -153,6 +175,7 @@ class ChunkService: if req.text: vector = await Embedding.vectorize_embedding(req.text) chunk_dict["text_vector"] = vector + await ChunkManager.update_chunk_text_ts_vector_by_chunk_ids([chunk_id]) chunk_entity = await ChunkManager.update_chunk_by_chunk_id(chunk_id, chunk_dict) return chunk_entity.id except Exception as e: diff --git a/data_chain/apps/service/document_service.py b/data_chain/apps/service/document_service.py index dfe65b9d1ad5a972f7da38ec682e2ff86f70534d..7ec05556bb5381615813ff6bc6829f8dc35ea9ec 100644 --- a/data_chain/apps/service/document_service.py +++ b/data_chain/apps/service/document_service.py @@ -4,7 +4,9 @@ from fastapi import APIRouter, Depends, Query, Body, File, UploadFile import uuid import traceback import shutil +from typing import Union import os +import hashlib from data_chain.entities.request_data import ( ListDocumentRequest, UploadTemporaryRequest, @@ -30,6 +32,8 @@ from data_chain.stores.minio.minio import MinIO from data_chain.entities.enum import ParseMethod, DataSetStatus, DocumentStatus, TaskType, TaskStatus from data_chain.entities.common import DOC_PATH_IN_OS, DOC_PATH_IN_MINIO, REPORT_PATH_IN_MINIO, DEFAULT_KNOWLEDGE_BASE_ID, DEFAULT_DOC_TYPE_ID from data_chain.logger.logger import logger as logging +from data_chain.parser.parse_result import ParseResult +from data_chain.parser.handler.base_parser import BaseParser class DocumentService: @@ -218,7 +222,7 @@ class DocumentService: name=file_name, extension=extension, size=os.path.getsize(document_file_path), - parse_method=ParseMethod.OCR.value, + parse_method=doc.parse_method.value, parse_relut_topology=None, chunk_size=1024, type_id=DEFAULT_DOC_TYPE_ID, @@ -255,6 +259,20 @@ class DocumentService: await KnowledgeBaseManager.update_doc_cnt_and_doc_size(kb_id=DEFAULT_KNOWLEDGE_BASE_ID) return doc_ids + @staticmethod + async def get_temporary_doc_text(user_sub: str, doc_id: uuid.UUID): + """获取临时文档解析结果文本""" + doc_entity = await DocumentManager.get_document_by_doc_id(doc_id) + if doc_entity is None: + err = f"获取临时文档失败, 文档ID: {doc_id}" + logging.error("[DocumentService] %s", err) + raise ValueError(err) + if doc_entity.author_id != user_sub: + err = f"用户没有权限访问临时文档, 文档ID: {doc_entity.id}, 用户ID: {user_sub}" + logging.error("[DocumentService] %s", err) + raise PermissionError(err) + return doc_entity.full_text + @staticmethod async def delete_temporary_docs(user_sub: str, doc_ids: list[uuid.UUID]) -> list[uuid.UUID]: """删除临时文档""" @@ -379,6 +397,42 @@ class DocumentService: logging.exception("[DocumentService] %s", err) raise e + @staticmethod + async def parse_docs_realtime(docs: list[UploadFile]) -> list[Union[ParseResult, None]]: + """实时解析文档""" + parse_results = [] + tmp_path = os.path.join(DOC_PATH_IN_OS, str(uuid.uuid4())) + for doc in docs: + try: + if os.path.exists(tmp_path): + shutil.rmtree(tmp_path) + os.makedirs(tmp_path) + doc_path = os.path.join(tmp_path, doc.filename) + doc_hash = None + async with aiofiles.open(doc_path, "wb") as f: + content = await doc.read() + doc_hash = await hashlib.sha256(content).hexdigest() + await f.write(content) + # 获取文件扩展名 + extension = doc.filename.split('.')[-1] + if not extension: + parse_results.append(None) + if os.path.exists(tmp_path): + shutil.rmtree(tmp_path) + continue + parse_result = await BaseParser.parser(extension, tmp_path) + parse_result.doc_hash = doc_hash[64:] + parse_results.append(parse_result) + if os.path.exists(tmp_path): + shutil.rmtree(tmp_path) + except Exception as e: + err = f"实时解析文档失败, 文档名: {doc.filename}, 错误信息: {e}" + logging.error("[DocumentService] %s", err) + parse_results.append(None) + if os.path.exists(tmp_path): + shutil.rmtree(tmp_path) + return parse_results + @staticmethod async def update_doc(doc_id: uuid.UUID, req: UpdateDocumentRequest) -> uuid.UUID: """更新文档""" diff --git a/data_chain/apps/service/knwoledge_base_service.py b/data_chain/apps/service/knwoledge_base_service.py index 17b445dc2a603bdeb287ee48e370a57da9a66a0e..25236a08603bffa2a5bee26e551ab04a8aadc099 100644 --- a/data_chain/apps/service/knwoledge_base_service.py +++ b/data_chain/apps/service/knwoledge_base_service.py @@ -78,10 +78,12 @@ class KnowledgeBaseService: team_id = knowledge_base_entity.team_id if team_id not in team_knowledge_bases_dict: team_knowledge_bases_dict[team_id] = [] - team_knowledge_bases_dict[team_id].append(knowledge_base_entity) + team_knowledge_bases_dict[team_id].append( + knowledge_base_entity) team_knowledge_bases = [] for team_entity in team_entities: - knowledge_base_entities = team_knowledge_bases_dict.get(team_entity.id, []) + knowledge_base_entities = team_knowledge_bases_dict.get( + team_entity.id, []) team_knowledge_base = TeamKnowledgebase( teamId=team_entity.id, teamName=team_entity.name, @@ -174,7 +176,7 @@ class KnowledgeBaseService: if knowledge_base_entity is None: err = "创建知识库失败" logging.exception("[KnowledgeBaseService] %s", err) - raise e + raise Exception(err) doc_types = req.doc_types doc_type_entities = [] for doc_type in doc_types: @@ -205,7 +207,8 @@ class KnowledgeBaseService: doc_size=0, upload_count_limit=kb_config.get("upload_count_limit", 128), upload_size_limit=kb_config.get("upload_size_limit", 512), - default_parse_method=kb_config.get("default_parse_method", ParseMethod.GENERAL.value), + default_parse_method=kb_config.get( + "default_parse_method", ParseMethod.GENERAL.value), default_chunk_size=kb_config.get("default_chunk_size", 1024), status=kb_config.get("status", KnowledgeBaseStatus.IDLE.value), ) @@ -295,10 +298,12 @@ class KnowledgeBaseService: @staticmethod async def update_doc_types(kb_id: uuid.UUID, doc_types: list[DocumentTypeRequest]) -> None: - new_doc_type_map = {doc_type.doc_type_id: doc_type.doc_type_name for doc_type in doc_types} + new_doc_type_map = { + doc_type.doc_type_id: doc_type.doc_type_name for doc_type in doc_types} new_doc_type_ids = {doc_type.doc_type_id for doc_type in doc_types} old_doc_type_entities = await KnowledgeBaseManager.list_doc_types_by_kb_id(kb_id) - old_doc_type_ids = {doc_type_entity.id for doc_type_entity in old_doc_type_entities} + old_doc_type_ids = { + doc_type_entity.id for doc_type_entity in old_doc_type_entities} delete_doc_type_ids = old_doc_type_ids - new_doc_type_ids add_doc_type_ids = new_doc_type_ids - old_doc_type_ids update_doc_type_ids = old_doc_type_ids & new_doc_type_ids @@ -329,7 +334,7 @@ class KnowledgeBaseService: if knowledge_base_entity is None: err = "更新知识库失败" logging.exception("[KnowledgeBaseService] %s", err) - raise e + raise Exception(err) await KnowledgeBaseService.update_doc_types(kb_id, req.doc_types) return knowledge_base_entity.id except Exception as e: @@ -348,9 +353,11 @@ class KnowledgeBaseService: testing_entities = await TestingManager.list_testing_by_kb_id(kb_id) doc_ids = [doc_entity.id for doc_entity in document_entities] await DocumentService.delete_docs_by_ids(doc_ids) - dataset_ids = [dataset_entity.id for dataset_entity in dataset_entities] + dataset_ids = [ + dataset_entity.id for dataset_entity in dataset_entities] await DataSetService.delete_data_by_data_ids(dataset_ids) - testing_ids = [testing_entity.id for testing_entity in testing_entities] + testing_ids = [ + testing_entity.id for testing_entity in testing_entities] await TestingService.delete_testing_by_testing_ids(testing_ids) task_entity = await TaskManager.get_current_task_by_op_id(kb_id) if task_entity is not None: diff --git a/data_chain/apps/service/llm_service.py b/data_chain/apps/service/llm_service.py deleted file mode 100644 index 9aee1781fa6751733994e7fa62dce554ab0d9cb7..0000000000000000000000000000000000000000 --- a/data_chain/apps/service/llm_service.py +++ /dev/null @@ -1,145 +0,0 @@ -from typing import List -import time -import yaml -import json -import jieba -from data_chain.models.service import ModelDTO -from data_chain.logger.logger import logger as logging -from data_chain.config.config import config -from data_chain.apps.base.model.llm import LLM -from data_chain.parser.tools.split import split_tools -from data_chain.apps.base.security.security import Security -def load_stopwords(file_path): - with open(file_path, 'r', encoding='utf-8') as f: - stopwords = set(line.strip() for line in f) - return stopwords - - -def filter_stopwords(text): - words = jieba.lcut(text) - stop_words = load_stopwords(config['STOP_WORDS_PATH']) - filtered_words = [word for word in words if word not in stop_words] - return filtered_words - - -async def question_rewrite(history: List[dict], question: str,model_dto:ModelDTO=None) -> str: - if not history: - return question - try: - st = time.time() - with open(config['PROMPT_PATH'], 'r', encoding='utf-8') as f: - prompt_template_dict = yaml.load(f, Loader=yaml.SafeLoader) - prompt = prompt_template_dict['INTENT_DETECT_PROMPT_TEMPLATE'] - history_prompt = "" - q_cnt = 0 - a_cnt = 0 - history_abstract_list = [] - sum_tokens = 0 - for item in history: - history_abstract_list.append(item['content']) - sum_tokens += split_tools.get_tokens(item['content']) - used_tokens = split_tools.get_tokens(prompt + question) - maxtokens=config['MODELS'][0]['MAX_TOKENS'] - if model_dto is not None: - maxtokens=model_dto.max_tokens - # 计算 history_prompt 的长度 - if sum_tokens > maxtokens - used_tokens: - filtered_history = [] - # 使用 jieba 分词并去除停用词 - for item in history_abstract_list: - filtered_words = filter_stopwords(item) - filtered_history_prompt = ''.join(filtered_words) - filtered_history.append(filtered_history_prompt) - history_abstract_list = filtered_history - - character = 'user' - for item in history_abstract_list: - if character == 'user': - history_prompt += "用户历史问题" + str(q_cnt) + ':' + item + "\n" - character = 'assistant' - q_cnt += 1 - elif character == 'assistant': - history_prompt += "模型历史回答" + str(a_cnt) + ':' + item + "\n" - a_cnt += 1 - character = 'user' - if split_tools.get_tokens(history_prompt) > maxtokens - used_tokens: - splited_prompt = split_tools.split_words(history_prompt) - splited_prompt = splited_prompt[-(maxtokens - used_tokens):] - history_prompt = ''.join(splited_prompt) - prompt = prompt.format(history=history_prompt, question=question) - user_call = "请输出改写后的问题" - default_llm = LLM(model_name=config['MODELS'][0]['MODEL_NAME'], - openai_api_base=config['MODELS'][0]['OPENAI_API_BASE'], - openai_api_key=config['MODELS'][0]['OPENAI_API_KEY'], - max_tokens=config['MODELS'][0]['MAX_TOKENS'], - request_timeout=60, - temperature=0.35) - if model_dto is not None: - default_llm = LLM(model_name=model_dto.model_name, - openai_api_base=model_dto.openai_api_base, - openai_api_key=model_dto.openai_api_key, - max_tokens=model_dto.max_tokens, - request_timeout=60, - temperature=0.35) - rewrite_question = await default_llm.nostream([], prompt, user_call) - logging.info(f'改写后的问题为:{rewrite_question}') - logging.info(f'问题改写耗时:{time.time() - st}') - return rewrite_question - except Exception as e: - logging.error(f"Rewrite question failed due to: {e}") - return question - - -async def question_split(question: str) -> List[str]: - # TODO: 问题拆分 - return [question] - - -async def get_llm_answer(history, bac_info, question, is_stream=True,model_dto:ModelDTO=None): - try: - with open(config['PROMPT_PATH'], 'r', encoding='utf-8') as f: - prompt_dict = yaml.load(f, Loader=yaml.SafeLoader) - prompt = prompt_dict['LLM_PROMPT_TEMPLATE'] - prompt = prompt.format(bac_info=bac_info) - except Exception as e: - logging.error(f'Get prompt failed : {e}') - raise e - llm = LLM( - openai_api_key=config['MODELS'][0]['OPENAI_API_KEY'], - openai_api_base=config['MODELS'][0]['OPENAI_API_BASE'], - model_name=config['MODELS'][0]['MODEL_NAME'], - max_tokens=config['MODELS'][0]['MAX_TOKENS']) - if model_dto is not None: - llm = LLM(model_name=model_dto.model_name, - openai_api_base=model_dto.openai_api_base, - openai_api_key=model_dto.openai_api_key, - max_tokens=model_dto.max_tokens - ) - if is_stream: - return llm.stream(history, prompt, question) - res = await llm.nostream(history, prompt, question) - return res - - -async def get_question_chunk_relation(question, chunk,model_dto:ModelDTO=None): - with open(config['PROMPT_PATH'], 'r', encoding='utf-8') as f: - prompt_template_dict = yaml.load(f, Loader=yaml.SafeLoader) - - prompt = prompt_template_dict['DETERMINE_ANSWER_AND_QUESTION'] - prompt = prompt.format(chunk=chunk, question=question) - user_call = "判断,并输出关联性编号" - default_llm = LLM(model_name=config['MODELS'][0]['MODEL_NAME'], - openai_api_base=config['MODELS'][0]['OPENAI_API_BASE'], - openai_api_key=config['MODELS'][0]['OPENAI_API_KEY'], - max_tokens=config['MODELS'][0]['MAX_TOKENS'], - request_timeout=60, - temperature=0.35) - if model_dto is not None: - default_llm = LLM(model_name=model_dto.model_name, - openai_api_base=model_dto.openai_api_base, - openai_api_key=model_dto.openai_api_key, - max_tokens=model_dto.max_tokens, - request_timeout=60, - temperature=0.35) - ans = await default_llm.nostream([], prompt, user_call) - return ans diff --git a/data_chain/apps/service/role_service.py b/data_chain/apps/service/role_service.py new file mode 100644 index 0000000000000000000000000000000000000000..4441c37ef9892cd8f7272aac91a477d3468788fc --- /dev/null +++ b/data_chain/apps/service/role_service.py @@ -0,0 +1,214 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. +import copy +import uuid +from typing import Union +from data_chain.logger.logger import logger as logging +from data_chain.entities.request_data import ( + ListRoleRequest, + CreateRoleRequest, + UpdateRoleRequest +) +from data_chain.entities.response_data import ( + Action, + TypeAction, + ListActionMsg, + GetUserRoleMsg, + Role, + ListRoleMsg +) +from data_chain.entities.enum import ( + ActionType, + RoleActionStatus, + DeafaultRole, + LanguageType +) +from data_chain.entities.common import actions +from data_chain.stores.database.database import RoleEntity, RoleActionEntity +from data_chain.apps.base.convertor import Convertor +from data_chain.manager.team_manager import TeamManager +from data_chain.manager.role_manager import RoleManager + + +class RoleService: + """团队服务""" + @staticmethod + async def validate_user_action_to_role( + user_sub: str, role_id: uuid.UUID, action: str) -> bool: + """验证用户对角色的操作权限""" + try: + role_entity = await RoleManager.get_role_by_id(role_id) + if not role_entity: + raise Exception('角色不存在') + + action_entity = await RoleManager.get_action_by_team_id_user_sub_and_action( + user_sub, role_entity.team_id, action) + if action_entity is None: + return False + return True + except Exception as e: + err = "验证用户对角色的操作权限失败" + logging.exception("[RoleService] %s", err) + raise e + + @staticmethod + async def get_all_actions() -> list[str]: + """获取所有操作列表""" + tmp_actions = [] + for action in actions: + tmp_actions.append(action['action']) + return tmp_actions + + @staticmethod + async def get_type_actions(language: LanguageType) -> list[TypeAction]: + """获取所有操作列表""" + action_dict = {} + for action in actions: + if action['type'] not in action_dict: + action_dict[action['type']] = [] + action_dict[action['type']].append( + Action(action=action['action'], + actionDescription=action['name'][language]) + ) + action_strings = [member.value for member in ActionType] + type_actions = [] + for action_string in action_strings: + if action_string in action_dict: + type_actions.append( + TypeAction(actionType=ActionType(action_string), + actions=action_dict[action_string]) + ) + return type_actions + + @staticmethod + async def list_actions(language: LanguageType) -> ListActionMsg: + """获取所有操作列表""" + try: + type_actions = await RoleService.get_type_actions(language) + return ListActionMsg(TypeActions=type_actions) + except Exception as e: + err = "获取所有操作列表失败" + logging.exception("[RoleService] %s", err) + raise e + + @staticmethod + async def get_user_role_in_team(user_sub: str, team_id: uuid.UUID) -> GetUserRoleMsg: + """获取用户在团队中的角色""" + try: + team_user_entity = await TeamManager.get_team_user_by_user_sub_and_team_id( + user_sub, team_id) + if not team_user_entity: + raise Exception('用户不在该团队中') + user_role_entity = await RoleManager.get_user_role_by_user_sub_and_team_id( + user_sub, team_id) + role_entity = await RoleManager.get_role_by_id(user_role_entity.role_id) + return GetUserRoleMsg( + roleId=role_entity.id, + roleName=role_entity.name, + isOwner=role_entity.name == DeafaultRole.CREATOR.value + ) + except Exception as e: + err = "获取用户在团队中的角色失败" + logging.exception("[RoleService] %s", err) + raise e + + @staticmethod + async def list_roles(req: ListRoleRequest) -> ListRoleMsg: + """根据团队标识获取角色列表""" + try: + team_entity = await TeamManager.get_team_by_id(req.team_id) + if not team_entity: + raise Exception('团队不存在') + total, role_entities = await RoleManager.list_roles(req) + roles = [] + type_actions = await RoleService.get_type_actions(req.language) + role_action_entities = await RoleManager.list_role_actions_by_role_ids( + [role_entity.id for role_entity in role_entities]) + role_action_dict = {} + for role_action_entity in role_action_entities: + if role_action_entity.role_id not in role_action_dict: + role_action_dict[role_action_entity.role_id] = set() + role_action_dict[role_action_entity.role_id].add( + role_action_entity.action) + for role_entity in role_entities: + if role_entity.name == DeafaultRole.CREATOR.value: + continue + role = await Convertor.convert_role_entity_to_role(role_entity) + if req.is_editable: + type_actions_cp = copy.deepcopy(type_actions) + for type_action in type_actions_cp: + for action in type_action.actions: + if (role_entity.id in role_action_dict and + action.action in role_action_dict[role_entity.id]): + action.is_used = True + role.type_actions = type_actions_cp + roles.append(role) + return ListRoleMsg(total=total, roles=roles) + except Exception as e: + err = "根据团队标识获取角色列表失败" + logging.exception("[RoleService] %s", err) + raise e + + @staticmethod + async def create_role(team_id: uuid.UUID, req: CreateRoleRequest) -> uuid.UUID: + """创建角色""" + try: + existing_role_entity = await RoleManager.get_role_by_role_name_and_team_id( + req.role_name, team_id) + if existing_role_entity: + req.role_name = f"{req.role_name}_{str(uuid.uuid4())[:16]}" + role_entity = RoleEntity( + team_id=team_id, + name=req.role_name, + ) + role_entity = await RoleManager.add_role(role_entity) + role_id = role_entity.id + role_action_entities = [] + actions = await RoleService.get_all_actions() + actions_set = set(actions) + for action in req.actions: + if action not in actions_set: + continue + role_action_entities.append( + RoleActionEntity(role_id=role_id, action=action) + ) + await RoleManager.add_role_actions(role_action_entities) + return role_id + except Exception as e: + err = "创建角色失败" + logging.exception("[RoleService] %s", err) + raise e + + async def update_role(role_id: uuid.UUID, req: UpdateRoleRequest) -> uuid.UUID: + role_entity = await RoleManager.get_role_by_id(role_id) + logging.error(req) + if req.role_name is not None: + existing_role_entity = await RoleManager.get_role_by_role_name_and_team_id( + req.role_name, role_entity.team_id) + if existing_role_entity and existing_role_entity.id != role_id: + raise Exception('角色名称已存在') + await RoleManager.update_role_by_id(role_id, {'name': req.role_name}) + if req.actions is not None: + await RoleManager.update_role_actions_by_role_id( + role_id, {'status': RoleActionStatus.DELETED.value}) + role_action_entities = [] + actions = await RoleService.get_all_actions() + actions_set = set(actions) + for action in req.actions: + if action not in actions_set: + continue + role_action_entities.append( + RoleActionEntity(role_id=role_id, action=action) + ) + await RoleManager.add_role_actions(role_action_entities) + return role_id + + async def delete_role(role_id: uuid.UUID) -> uuid.UUID: + try: + await RoleManager.update_role_by_id(role_id, {'status': RoleActionStatus.DELETED.value}) + await RoleManager.update_role_actions_by_role_id( + role_id, {'status': RoleActionStatus.DELETED.value}) + return role_id + except Exception as e: + err = "删除角色失败" + logging.exception("[RoleService] %s", err) + raise e diff --git a/data_chain/apps/service/session_service.py b/data_chain/apps/service/session_service.py index 1f62060aaebb9cb457533c2996bfc0450df87788..dec209e77eecd92fd31e879f0f426390cbd195f3 100644 --- a/data_chain/apps/service/session_service.py +++ b/data_chain/apps/service/session_service.py @@ -22,8 +22,11 @@ class UserHTTPException(HTTPException): async def verify_user(request: HTTPConnection): """验证用户是否在Session中""" + import os if config["DEBUG"]: - return + user_sub = config["DEBUG_USER"] + return user_sub + try: session_id = None auth_header = request.headers.get("Authorization") @@ -45,8 +48,9 @@ async def verify_user(request: HTTPConnection): async def get_user_sub(request: HTTPConnection) -> uuid: """从Session中获取用户""" if config["DEBUG"]: - await UserManager.add_user((await Convertor.convert_user_sub_to_user_entity('admin'))) - return "admin" + user_sub = config["DEBUG_USER"] + await UserManager.add_user((await Convertor.convert_user_sub_to_user_entity(user_sub))) + return user_sub else: try: session_id = None diff --git a/data_chain/apps/service/task_queue_service.py b/data_chain/apps/service/task_queue_service.py index 2a16ab14b6083780dd204089f1e7dc12960de750..1a3179b396f19e134f7cafc73bb38ad35b75b6fd 100644 --- a/data_chain/apps/service/task_queue_service.py +++ b/data_chain/apps/service/task_queue_service.py @@ -4,7 +4,8 @@ import uuid from typing import Optional from data_chain.entities.enum import TaskType, TaskStatus from data_chain.apps.base.task.worker.base_worker import BaseWorker -from data_chain.stores.mongodb.mongodb import MongoDB, Task +# from data_chain.stores.mongodb.mongodb import MongoDB, Task +from data_chain.stores.database.database import TaskQueueEntity from data_chain.manager.task_manager import TaskManager from data_chain.manager.task_queue_mamanger import TaskQueueManager from data_chain.logger.logger import logger as logging @@ -22,7 +23,7 @@ class TaskQueueService: if task_entity.status == TaskStatus.RUNNING.value: flag = await BaseWorker.reinit(task_entity.id) if flag: - task = Task(_id=task_entity.id, status=TaskStatus.PENDING.value) + task = TaskQueueEntity(id=task_entity.id, status=TaskStatus.PENDING.value) await TaskQueueManager.update_task_by_id(task_entity.id, task) else: await BaseWorker.stop(task_entity.id) @@ -30,11 +31,11 @@ class TaskQueueService: else: task = await TaskQueueManager.get_task_by_id(task_entity.id) if task is None: - task = Task(_id=task_entity.id, status=TaskStatus.PENDING.value) + task = TaskQueueEntity(id=task_entity.id, status=TaskStatus.PENDING.value) await TaskQueueManager.add_task(task) except Exception as e: - warining = f"[TaskQueueService] 初始化任务失败 {e}" - logging.warning(warining) + warning = f"[TaskQueueService] 初始化任务失败 {e}" + logging.warning(warning) @staticmethod async def init_task(task_type: str, op_id: uuid.UUID) -> uuid.UUID: @@ -42,7 +43,7 @@ class TaskQueueService: try: task_id = await BaseWorker.init(task_type, op_id) if task_id: - await TaskQueueManager.add_task(Task(_id=task_id, status=TaskStatus.PENDING.value)) + await TaskQueueManager.add_task(TaskQueueEntity(id=task_id, status=TaskStatus.PENDING.value)) return task_id except Exception as e: err = f"[TaskQueueService] 初始化任务失败 {e}" @@ -75,53 +76,53 @@ class TaskQueueService: async def handle_successed_tasks(): handle_successed_task_limit = 1024 for i in range(handle_successed_task_limit): - task = await TaskQueueManager.get_oldest_tasks_by_status(TaskStatus.SUCCESS.value) + task = await TaskQueueManager.get_oldest_tasks_by_status(TaskStatus.SUCCESS) if task is None: break try: - await BaseWorker.deinit(task.task_id) + await BaseWorker.deinit(task.id) except Exception as e: err = f"[TaskQueueService] 处理成功任务失败 {e}" logging.error(err) - await TaskQueueManager.delete_task_by_id(task.task_id) + await TaskQueueManager.delete_task_by_id(task.id) @staticmethod async def handle_failed_tasks(): handle_failed_task_limit = 1024 for i in range(handle_failed_task_limit): - task = await TaskQueueManager.get_oldest_tasks_by_status(TaskStatus.FAILED.value) + task = await TaskQueueManager.get_oldest_tasks_by_status(TaskStatus.FAILED) if task is None: break try: - flag = await BaseWorker.reinit(task.task_id) + flag = await BaseWorker.reinit(task.id) except Exception as e: err = f"[TaskQueueService] 处理失败任务失败 {e}" logging.error(err) - await TaskQueueManager.delete_task_by_id(task.task_id) + await TaskQueueManager.delete_task_by_id(task.id) continue if flag: - task = Task(_id=task.task_id, status=TaskStatus.PENDING.value) - await TaskQueueManager.update_task_by_id(task.task_id, task) + task.status = TaskStatus.PENDING.value + await TaskQueueManager.update_task_by_id(task.id, task) else: - await TaskQueueManager.delete_task_by_id(task.task_id) + await TaskQueueManager.delete_task_by_id(task.id) @staticmethod async def handle_pending_tasks(): handle_pending_task_limit = 128 for i in range(handle_pending_task_limit): - task = await TaskQueueManager.get_oldest_tasks_by_status(TaskStatus.PENDING.value) + task = await TaskQueueManager.get_oldest_tasks_by_status(TaskStatus.PENDING) if task is None: break try: - flag = await BaseWorker.run(task.task_id) + flag = await BaseWorker.run(task.id) except Exception as e: err = f"[TaskQueueService] 处理待处理任务失败 {e}" logging.error(err) - await TaskQueueManager.delete_task_by_id(task.task_id) + await TaskQueueManager.delete_task_by_id(task.id) continue if not flag: break - await TaskQueueManager.delete_task_by_id(task.task_id) + await TaskQueueManager.delete_task_by_id(task.id) @staticmethod async def handle_tasks(): diff --git a/data_chain/apps/service/team_service.py b/data_chain/apps/service/team_service.py index dd36ebaaa53c2345ab267e4f114d1916c6c145da..e540f051fd9056fc350084ba1cae2783f4c9ba84 100644 --- a/data_chain/apps/service/team_service.py +++ b/data_chain/apps/service/team_service.py @@ -1,15 +1,45 @@ # Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. import uuid +from typing import Union, Any from data_chain.logger.logger import logger as logging -from data_chain.entities.request_data import ListTeamRequest, CreateTeamRequest -from data_chain.entities.response_data import ListTeamMsg -from data_chain.entities.enum import TeamType, TeamStatus +from data_chain.entities.request_data import ( + ListTeamUserRequest, + ListTeamMsgRequest, + ListTeamRequest, + CreateTeamRequest +) +from data_chain.entities.response_data import ( + TeamUser, + ListTeamMsg, + ListTeamMsgMsg, + ListTeamUserMsg +) +from data_chain.entities.enum import ( + TeamType, + TeamStatus, + TeamUserStaus, + UserRoleStatus, + UserMessageStatus, + UserMessageType, + DeafaultRole, + IdType +) +from data_chain.entities.enum import IdType from data_chain.entities.common import default_roles from data_chain.stores.database.database import TeamEntity from data_chain.apps.base.convertor import Convertor from data_chain.manager.team_manager import TeamManager +from data_chain.manager.team_message_manager import TeamMessageManager +from data_chain.manager.user_message_manager import UserMessageManager from data_chain.manager.role_manager import RoleManager from data_chain.manager.knowledge_manager import KnowledgeBaseManager +from data_chain.manager.document_manager import DocumentManager +from data_chain.manager.role_manager import RoleManager +from data_chain.manager.chunk_manager import ChunkManager +from data_chain.manager.dataset_manager import DatasetManager +from data_chain.manager.testing_manager import TestingManager +from data_chain.manager.testcase_manager import TestCaseManager +from data_chain.manager.task_manager import TaskManager from data_chain.apps.service.knwoledge_base_service import KnowledgeBaseService @@ -50,6 +80,214 @@ class TeamService: teams.append(team) return ListTeamMsg(total=total, teams=teams) + @staticmethod + async def list_team_users(req: ListTeamUserRequest) -> ListTeamUserMsg: + """列出团队成员""" + total, user_entities = await TeamManager.list_team_user_by_team_id(req) + user_subs = [ + user_entity.id for user_entity in user_entities] + user_role_entities = await RoleManager.list_user_roles_by_team_id_and_user_subs( + req.team_id, user_subs) + role_ids = list( + set([user_role_entity.role_id for user_role_entity in user_role_entities])) + role_entities = await RoleManager.list_roles_by_role_ids(role_ids) + user_role_dict = { + user_role_entity.user_id: user_role_entity for user_role_entity in user_role_entities} + role_dict = { + role_entity.id: role_entity for role_entity in role_entities} + team_users = [] + # 先加入创建者 + for user_entity in user_entities: + user_role = user_role_dict.get(user_entity.id) + role = role_dict.get(user_role.role_id) if user_role else None + logging.error(role.name) + if role and role.name == DeafaultRole.CREATOR.value: + team_user = await Convertor.convert_user_entity_and_role_entity_to_team_user(user_entity, role) + team_user.is_editable = False + team_users.append(team_user) + break + # 再加入管理员 + admin_team_users = [] + for user_entity in user_entities: + user_role = user_role_dict.get(user_entity.id) + role = role_dict.get(user_role.role_id) if user_role else None + if role and role.name == DeafaultRole.ADMINISTRATOR.value: + team_user = await Convertor.convert_user_entity_and_role_entity_to_team_user(user_entity, role) + admin_team_users.append((team_user, user_role.created_time)) + admin_team_users.sort(key=lambda x: x[1], reverse=False) + team_users.extend([item[0] for item in admin_team_users]) + # 最后加入普通成员 + normal_team_users = [] + for user_entity in user_entities: + user_role = user_role_dict.get(user_entity.id) + role = role_dict.get(user_role.role_id) if user_role else None + if role and role.name != DeafaultRole.CREATOR.value and role.name != DeafaultRole.ADMINISTRATOR.value: + team_user = await Convertor.convert_user_entity_and_role_entity_to_team_user(user_entity, role) + normal_team_users.append((team_user, user_role.created_time)) + normal_team_users.sort(key=lambda x: x[1], reverse=False) + team_users.extend([item[0] for item in normal_team_users]) + return ListTeamUserMsg(total=total, teamUsers=team_users) + + @staticmethod + async def add_team_msg(user_sub: str, id: uuid.UUID, id_type: IdType, zh_message: str, en_message: str, ** kwargs: dict[str, Any]) -> uuid.UUID: + """添加团队消息""" + try: + if id_type == IdType.TEAM: + team_entity = await TeamManager.get_team_by_id(id) + if team_entity is None: + err = f"团队不存在,团队ID: {id}" + logging.warning("[TeamService] %s", err) + return None + team_id = team_entity.id + elif id_type == IdType.ROLE: + role_entity = await RoleManager.get_role_by_id(id) + if role_entity is None: + err = f"角色不存在,角色ID: {id}" + logging.warning("[TeamService] %s", err) + return None + team_entity = await TeamManager.get_team_by_id(role_entity.team_id) + team_id = team_entity.id + zh_message = zh_message.format( + teamName=team_entity.name, roleName=role_entity.name) + en_message = en_message.format( + teamName=team_entity.name, roleName=role_entity.name) + elif id_type == IdType.USER: + team_entity = await TeamManager.get_team_by_id(id) + if team_entity is None: + err = f"团队不存在,团队ID: {id}" + logging.warning("[TeamService] %s", err) + return None + team_id = team_entity.id + zh_message = zh_message.format( + targetUserName=kwargs.get('targetUserName', '')) + en_message = en_message.format( + targetUserName=kwargs.get('targetUserName', '')) + elif id_type == IdType.MSG: + team_msg_entity = await UserMessageManager.get_user_message_by_msg_id(id) + if team_msg_entity is None: + err = f"消息不存在,消息ID: {id}" + logging.warning("[TeamService] %s", err) + return None + team_entity = await TeamManager.get_team_by_id(team_msg_entity.team_id) + team_id = team_entity.id + zh_message = zh_message.format( + targetUserName=team_msg_entity.receiver_id) + en_message = en_message.format( + targetUserName=team_msg_entity.receiver_id) + elif id_type == IdType.KNOWLEDGE_BASE: + kb_entity = await KnowledgeBaseManager.get_knowledge_base_by_kb_id(id) + if kb_entity is None: + err = f"知识库不存在,知识库ID: {id}" + logging.warning("[TeamService] %s", err) + return None + team_id = kb_entity.team_id + zh_message = zh_message.format(kbName=kb_entity.name) + en_message = en_message.format(kbName=kb_entity.name) + elif id_type == IdType.DOCUMENT: + doc_entity = await DocumentManager.get_document_by_doc_id(id) + if doc_entity is None: + err = f"文档不存在,文档ID: {id}" + logging.warning("[TeamService] %s", err) + return None + kb_entity = await KnowledgeBaseManager.get_knowledge_base_by_kb_id(doc_entity.kb_id) + team_id = kb_entity.team_id + zh_message = zh_message.format( + kbName=kb_entity.name, docName=doc_entity.name) + en_message = en_message.format( + kbName=kb_entity.name, docName=doc_entity.name) + elif id_type == IdType.CHUNK: + chunk_entity = await ChunkManager.get_chunk_by_chunk_id(id) + if chunk_entity is None: + err = f"分片不存在,分片ID: {id}" + logging.warning("[TeamService] %s", err) + return None + doc_entity = await DocumentManager.get_document_by_doc_id(chunk_entity.doc_id) + if doc_entity is None: + err = f"文档不存在,文档ID: {chunk_entity.doc_id}" + logging.warning("[TeamService] %s", err) + return None + kb_entity = await KnowledgeBaseManager.get_knowledge_base_by_kb_id(doc_entity.kb_id) + if kb_entity is None: + err = f"知识库不存在,知识库ID: {doc_entity.kb_id}" + logging.warning("[TeamService] %s", err) + return None + team_id = kb_entity.team_id + zh_message = zh_message.format( + kbName=kb_entity.name, docName=doc_entity.name) + en_message = en_message.format( + kbName=kb_entity.name, docName=doc_entity.name) + elif id_type == IdType.DATASET: + dataset_entity = await DatasetManager.get_dataset_by_dataset_id(id) + if dataset_entity is None: + err = f"数据集不存在,数据集ID: {id}" + logging.warning("[TeamService] %s", err) + return None + kb_entity = await KnowledgeBaseManager.get_knowledge_base_by_kb_id(dataset_entity.kb_id) + team_id = kb_entity.team_id + zh_message = zh_message.format( + kbName=kb_entity.name, datasetName=dataset_entity.name) + en_message = en_message.format( + kbName=kb_entity.name, datasetName=dataset_entity.name) + elif id_type == IdType.DATASET_DATA: + dataset_entity = await DatasetManager.get_dataset_by_data_id(id) + if dataset_entity is None: + err = f"数据集不存在,数据集ID: {id}" + logging.warning("[TeamService] %s", err) + return None + kb_entity = await KnowledgeBaseManager.get_knowledge_base_by_kb_id(dataset_entity.kb_id) + team_id = kb_entity.team_id + zh_message = zh_message.format( + kbName=kb_entity.name, datasetName=dataset_entity.name) + en_message = en_message.format( + kbName=kb_entity.name, datasetName=dataset_entity.name) + elif id_type == IdType.TESTING: + testing_entity = await TestingManager.get_testing_by_testing_id(id) + if testing_entity is None: + err = f"测试不存在,测试ID: {id}" + logging.warning("[TeamService] %s", err) + return None + kb_entity = await KnowledgeBaseManager.get_knowledge_base_by_kb_id(testing_entity.kb_id) + team_id = kb_entity.team_id + zh_message = zh_message.format( + kbName=kb_entity.name, testingName=testing_entity.name) + en_message = en_message.format( + kbName=kb_entity.name, testingName=testing_entity.name) + elif id_type == IdType.TEST_CASE: + testcase_entity = await TestCaseManager.get_test_case_by_id(id) + if testcase_entity is None: + err = f"测试用例不存在,测试用例ID: {id}" + logging.warning("[TeamService] %s", err) + return None + testing_entity = await TestingManager.get_testing_by_testing_id(testcase_entity.testing_id) + if testing_entity is None: + err = f"测试不存在,测试ID: {testcase_entity.testing_id}" + logging.warning("[TeamService] %s", err) + return None + kb_entity = await KnowledgeBaseManager.get_knowledge_base_by_kb_id(testing_entity.kb_id) + team_id = kb_entity.team_id + zh_message = zh_message.format( + kbName=kb_entity.name, testingName=testing_entity.name) + en_message = en_message.format( + kbName=kb_entity.name, testingName=testing_entity.name) + team_msg_entity = await Convertor.convert_user_sub_team_id_and_message_to_team_message_entity( + user_sub, team_id, zh_message, en_message) + team_msg_entity = await TeamMessageManager.add_team_msg(team_msg_entity) + return team_msg_entity.id + except Exception as e: + err = "添加团队消息失败" + logging.exception("[TeamService] %s", err) + raise e + + @staticmethod + async def list_team_msg_by_team_id(req: ListTeamMsgRequest) -> ListTeamMsgMsg: + """列出团队消息""" + total, team_msg_entities = await TeamMessageManager.list_team_msg_by_team_id(req) + team_msgs = [] + for team_msg_entity in team_msg_entities: + team_msg = await Convertor.convert_team_message_entity_to_team_message(team_msg_entity) + team_msgs.append(team_msg) + return ListTeamMsgMsg(total=total, teamMsgs=team_msgs) + @staticmethod async def create_team(user_sub: str, req: CreateTeamRequest) -> uuid.UUID: """创建团队""" @@ -58,17 +296,17 @@ class TeamService: team_entity = await TeamManager.add_team(team_entity) team_user_entity = await Convertor.convert_user_sub_and_team_id_to_team_user_entity(user_sub, team_entity.id) await TeamManager.add_team_user(team_user_entity) - become_creator_flag = False + creator_role_id = '' for role_dict in default_roles: role_entity = await Convertor.convert_default_role_dict_to_role_entity(team_entity.id, role_dict) role_entity = await RoleManager.add_role(role_entity) - if not become_creator_flag: - user_role_entity = await Convertor.convert_user_sub_role_id_and_team_id_to_user_role_entity( - user_sub, role_entity.id, team_entity.id) - await RoleManager.add_user_role(user_role_entity) - become_creator_flag = True + if role_entity.name == DeafaultRole.CREATOR.value: + creator_role_id = role_entity.id role_action_entities = await Convertor.convert_default_role_action_dicts_to_role_action_entities(role_entity.id, role_dict['actions']) await RoleManager.add_role_actions(role_action_entities) + user_role_entity = await Convertor.convert_user_sub_role_id_and_team_id_to_user_role_entity( + user_sub, creator_role_id, team_entity.id) + await RoleManager.add_user_role(user_role_entity) return team_entity.id except Exception as e: err = "创建团队失败" @@ -76,8 +314,87 @@ class TeamService: raise e @staticmethod - async def update_team_by_team_id( - user_sub: str, team_id: uuid.UUID, req: CreateTeamRequest) -> bool: + async def invite_team_user(user_sub: str, team_id: uuid.UUID, role_id: uuid.UUID, user_sub_invite: str) -> Union[None, uuid.UUID]: + """增加一条用户邀请的信息""" + team_entity = await TeamManager.get_team_by_id(team_id) + if team_entity is None or not team_entity.is_public: + err = "邀请团队成员失败, 团队不存在或不是公开团队" + logging.exception("[TeamService] %s", err) + raise "邀请团队成员失败, 团队不存在或不是公开团队" + team_user_entity = await TeamManager.get_team_user_by_user_sub_and_team_id(user_sub_invite, team_id) + if team_user_entity: + return None + try: + user_message_entity = await Convertor.convert_user_sub_team_id_role_id_and_receiver_sub_to_user_message_entity( + user_sub, team_id, team_entity.name, role_id, user_sub_invite, False, "", UserMessageType.INVITATION.value) + user_message_entity = await UserMessageManager.add_user_message(user_message_entity) + if not user_message_entity: + err = "邀请团队成员失败" + logging.exception("[TeamService] %s", err) + raise "邀请团队成员失败" + return user_sub_invite + except Exception as e: + err = "邀请团队成员失败" + logging.exception("[TeamService] %s", err) + raise e + + @staticmethod + async def apply_to_join_team(user_sub: str, team_id: uuid.UUID) -> str: + """用户申请加入团队""" + team_entity = await TeamManager.get_team_by_id(team_id) + if team_entity is None or not team_entity.is_public: + err = "用户申请加入团队失败, 团队不存在或不是公开团队" + logging.exception("[TeamService] %s", err) + raise "用户申请加入团队失败, 团队不存在或不是公开团队" + try: + member_role_entity = await RoleManager.get_role_by_role_name_and_team_id( + DeafaultRole.MEMBER.value, team_id) + if member_role_entity is None: + err = "用户申请加入团队失败, 角色不存在" + logging.exception("[TeamService] %s", err) + raise "用户申请加入团队失败, 角色不存在" + user_message_entity = await Convertor.convert_user_sub_team_id_role_id_and_receiver_sub_to_user_message_entity( + user_sub, team_id, team_entity.name, member_role_entity.id, "", True, "", UserMessageType.APPLICATION.value) + user_message_entity = await UserMessageManager.add_user_message(user_message_entity) + if not user_message_entity: + err = "用户申请加入团队失败" + logging.exception("[TeamService] %s", err) + raise "用户申请加入团队失败" + return user_sub + except Exception as e: + err = "用户申请加入团队失败" + logging.exception("[TeamService] %s", err) + raise e + + @staticmethod + async def add_team_user(team_id: uuid.UUID, role_id: uuid.UUID, user_sub_invite: str) -> Union[None, uuid.UUID]: + # 判断用户是否已经是团队成员 + team_user_entity = await TeamManager.get_team_user_by_user_sub_and_team_id(user_sub_invite, team_id) + if team_user_entity: + return None + try: + team_user_entity = await Convertor.convert_user_sub_and_team_id_to_team_user_entity(user_sub_invite, team_id) + team_user_entity = await TeamManager.add_team_user(team_user_entity) + role_entity = await RoleManager.get_role_by_id(role_id) + if role_entity is None or role_entity.team_id != team_id or role_entity.is_unique: + member_role_entity = await RoleManager.get_role_by_role_name_and_team_id( + DeafaultRole.MEMBER.value, team_id) + if member_role_entity is None: + err = "邀请团队成员失败, 角色不存在" + logging.exception("[TeamService] %s", err) + raise "邀请团队成员失败, 角色不存在" + role_id = member_role_entity.id + user_role_entity = await Convertor.convert_user_sub_role_id_and_team_id_to_user_role_entity( + user_sub_invite, role_id, team_id) + user_role_entity = await RoleManager.add_user_role(user_role_entity) + return user_sub_invite + except Exception as e: + err = "邀请团队成员失败" + logging.exception("[TeamService] %s", err) + raise e + + @staticmethod + async def update_team_by_team_id(team_id: uuid.UUID, req: CreateTeamRequest) -> bool: """更新团队""" try: team_dict = await Convertor.convert_update_team_request_to_dict(req) @@ -92,6 +409,107 @@ class TeamService: logging.exception("[TeamService] %s", err) raise e + @staticmethod + async def update_team_user_role_by_team_id_and_user_sub( + user_sub: uuid.UUID, team_id: uuid.UUID, target_user_sub: str, role_id: uuid.UUID) -> uuid.UUID: + """更新团队成员角色""" + if user_sub == target_user_sub: + err = "更新团队成员角色失败, 不能修改自己的角色" + logging.exception("[TeamService] %s", err) + raise "更新团队成员角色失败, 不能修改自己的角色" + team_entity = await TeamManager.get_team_by_id(team_id) + if team_entity is None: + err = "更新团队成员角色失败, 团队不存在" + logging.exception("[TeamService] %s", err) + raise "更新团队成员角色失败, 团队不存在" + team_user_entity = await TeamManager.get_team_user_by_user_sub_and_team_id(target_user_sub, team_id) + if team_user_entity is None: + err = "更新团队成员角色失败, 团队成员不存在" + logging.exception("[TeamService] %s", err) + raise "更新团队成员角色失败, 团队成员不存在" + current_user_role_entity = await RoleManager.get_user_role_by_user_sub_and_team_id(target_user_sub, team_id) + current_role_entity = await RoleManager.get_role_by_id(current_user_role_entity.role_id) + if current_role_entity.name == DeafaultRole.CREATOR.value: + err = "更新团队成员角色失败, 不能修改创建者的角色" + logging.exception("[TeamService] %s", err) + raise "更新团队成员角色失败, 不能修改创建者的角色" + role_entity = await RoleManager.get_role_by_id(role_id) + if role_entity is None or role_entity.team_id != team_id: + err = "更新团队成员角色失败, 角色不存在" + logging.exception("[TeamService] %s", err) + raise "更新团队成员角色失败, 角色不存在" + if role_entity.is_unique: + err = "更新团队成员角色失败, 该角色为唯一角色" + logging.exception("[TeamService] %s", err) + raise "更新团队成员角色失败, 该角色为唯一角色" + try: + user_role_entity = await RoleManager.get_user_role_by_user_sub_and_team_id(target_user_sub, team_id) + if user_role_entity is None: + err = "更新团队成员角色失败, 团队成员角色不存在" + logging.exception("[TeamService] %s", err) + raise "更新团队成员角色失败, 团队成员角色不存在" + user_role_entity = await RoleManager.update_user_role_by_id( + user_role_entity.id, {"role_id": role_id}) + if user_role_entity is None: + err = "更新团队成员角色失败" + logging.exception("[TeamService] %s", err) + raise "更新团队成员角色失败" + except Exception as e: + err = "更新团队成员角色失败" + logging.exception("[TeamService] %s", err) + raise e + return target_user_sub + + @staticmethod + async def update_team_author_by_team_id(user_sub: str, team_id: uuid.UUID, target_user_sub: str) -> uuid.UUID: + """转让团队""" + team_entity = await TeamManager.get_team_by_id(team_id) + if team_entity is None: + err = "转让团队失败, 团队不存在" + logging.exception("[TeamService] %s", err) + raise "转让团队失败, 团队不存在" + creator_role_entity = await RoleManager.get_role_by_role_name_and_team_id( + DeafaultRole.CREATOR.value, team_id) + if creator_role_entity is None: + err = "转让团队失败, 创建者角色不存在" + logging.exception("[TeamService] %s", err) + raise "转让团队失败, 创建者角色不存在" + admin_role_entity = await RoleManager.get_role_by_role_name_and_team_id( + DeafaultRole.ADMINISTRATOR.value, team_id) + if admin_role_entity is None: + err = "转让团队失败, 管理员角色不存在" + logging.exception("[TeamService] %s", err) + raise "转让团队失败, 管理员角色不存在" + team_user_entity = await TeamManager.get_team_user_by_user_sub_and_team_id(target_user_sub, team_id) + if team_user_entity is None: + err = "转让团队失败, 团队成员不存在" + logging.exception("[TeamService] %s", err) + raise "转让团队失败, 团队成员不存在" + try: + # 将当前创建者角色转为管理员角色 + current_creator_user_role_entity = await RoleManager.get_user_role_by_user_sub_and_team_id(user_sub, team_id) + if current_creator_user_role_entity is None: + err = "转让团队失败, 当前创建者角色不存在" + logging.exception("[TeamService] %s", err) + raise "转让团队失败, 当前创建者角色不存在" + # 将目标成员角色转为创建者角色 + target_user_role_entity = await RoleManager.get_user_role_by_user_sub_and_team_id(target_user_sub, team_id) + if target_user_role_entity is None: + err = "转让团队失败, 目标成员角色不存在" + logging.exception("[TeamService] %s", err) + raise "转让团队失败, 目标成员角色不存在" + await RoleManager.update_user_role_by_id( + current_creator_user_role_entity.id, {"role_id": admin_role_entity.id}) + await RoleManager.update_user_role_by_id( + target_user_role_entity.id, {"role_id": creator_role_entity.id}) + # 更新团队的创建者 + team_entity = await TeamManager.update_team_by_id(team_id, {"author_id": target_user_sub, "author_name": target_user_sub}) + return team_entity.id + except Exception as e: + err = "转让团队失败" + logging.exception("[TeamService] %s", err) + raise e + @staticmethod async def soft_delete_team_by_team_id( team_id: uuid.UUID) -> bool: @@ -111,3 +529,44 @@ class TeamService: err = "软删除团队失败" logging.exception("[TeamService] %s", err) raise e + + @staticmethod + async def delete_team_user_by_team_id_and_user_subs( + team_id: uuid.UUID, user_subs: list[str]) -> list[uuid.UUID]: + """删除团队成员""" + team_entity = await TeamManager.get_team_by_id(team_id) + if team_entity is None: + err = "删除团队成员失败, 团队不存在" + logging.exception("[TeamService] %s", err) + raise "删除团队成员失败, 团队不存在" + team_user_entities = await TeamManager.list_team_user_by_team_id_and_user_subs(team_id, user_subs) + user_subs = [team_user.user_id for team_user in team_user_entities] + try: + user_role_entities = await RoleManager.list_user_roles_by_team_id_and_user_subs( + team_id, user_subs) + unique_role_ids = set( + [user_role.role_id for user_role in user_role_entities]) + role_entities = await RoleManager.list_roles_by_role_ids( + list(unique_role_ids)) + usr_role_dict = { + user_role.user_id: user_role for user_role in user_role_entities} + role_dict = { + role_entity.id: role_entity for role_entity in role_entities} + user_subs_deleted = [] + for user_sub in user_subs: + user_role = usr_role_dict.get(user_sub) + role = role_dict.get(user_role.role_id) if user_role else None + if role and role.name == DeafaultRole.CREATOR.value: + warning = f"删除团队成员失败, 不能删除创建者 {user_sub}" + logging.warning("[TeamService] %s", warning) + continue + user_subs_deleted.append(user_sub) + await TeamManager.update_team_users_by_team_id_and_user_subs( + team_id, user_subs, {"status": TeamUserStaus.DELETED.value}) + await RoleManager.update_user_roles_by_team_id_and_user_subs( + team_id, user_subs, {"status": UserRoleStatus.DELETED.value}) + return user_subs_deleted + except Exception as e: + err = "删除团队成员失败" + logging.exception("[TeamService] %s", err) + raise e diff --git a/data_chain/apps/service/user_message_service.py b/data_chain/apps/service/user_message_service.py new file mode 100644 index 0000000000000000000000000000000000000000..625f41f3d539747039620616b4c5bdb7bc85b879 --- /dev/null +++ b/data_chain/apps/service/user_message_service.py @@ -0,0 +1,121 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. +import uuid +from typing import Union +from data_chain.logger.logger import logger as logging +from data_chain.entities.request_data import ( + ListUserMessageRequest +) +from data_chain.entities.response_data import ( + ListUserMessageMsg +) +from data_chain.entities.enum import ( + IdType, + TeamType, + TeamStatus, + TeamUserStaus, + UserRoleStatus, + UserMessageStatus, + UserMessageType, + DeafaultRole +) +from data_chain.entities.common import default_roles +from data_chain.stores.database.database import TeamEntity +from data_chain.apps.base.convertor import Convertor +from data_chain.apps.service.team_service import TeamService +from data_chain.manager.team_manager import TeamManager +from data_chain.manager.team_message_manager import TeamMessageManager +from data_chain.manager.user_message_manager import UserMessageManager +from data_chain.manager.role_manager import RoleManager +from data_chain.manager.knowledge_manager import KnowledgeBaseManager +from data_chain.apps.service.knwoledge_base_service import KnowledgeBaseService + + +class UserMessageService: + """用户消息服务""" + @staticmethod + async def list_user_messages( + user_sub: str, req: ListUserMessageRequest) -> ListUserMessageMsg: + """根据用户标识和消息类型获取用户消息列表""" + try: + total, user_message_entities = await UserMessageManager.list_user_messages( + user_sub, req) + user_messages = [] + for user_message_entity in user_message_entities: + user_message = await Convertor.convert_user_sub_and_user_message_entity_to_user_message(user_sub, user_message_entity) + user_messages.append(user_message) + return ListUserMessageMsg(total=total, userMessages=user_messages) + except Exception as e: + err = "根据用户标识和消息类型获取用户消息列表失败" + logging.exception("[UserMessageService] %s", err) + raise e + + @staticmethod + async def update_user_message( + user_sub: str, msg_id: uuid.UUID, msg_status: UserMessageStatus) -> Union[None, str]: + """根据消息标识和消息状态更新用户消息""" + user_message_entity = await UserMessageManager.get_user_message_by_msg_id(msg_id) + if not user_message_entity: + raise Exception('用户消息不存在') + if user_message_entity.status_to_receiver != UserMessageStatus.UNREAD.value: + raise Exception('用户消息状态只能从未读更新为已读') + if msg_status == UserMessageStatus.UNREAD: + raise Exception('用户消息状态只能从未读更新为已读') + if user_sub == user_message_entity.sender_id: + raise Exception('用户不能修改自己发送的消息') + can_access = False + if user_sub == user_message_entity.receiver_id: + can_access = True + if user_sub != user_message_entity.sender_id and user_message_entity.is_to_all: + action_entity = await RoleManager.get_action_by_team_id_user_sub_and_action( + user_sub, user_message_entity.team_id, 'PUT /usr_msg') + if action_entity: + can_access = True + if not can_access: + raise Exception('用户没有权限修改该消息') + await UserMessageManager.update_user_message_by_msg_id( + msg_id, {'status_to_receiver': msg_status.value}) + if user_message_entity.type == UserMessageType.INVITATION: + if msg_status == UserMessageStatus.REJECTED: + await TeamService.add_team_msg(user_message_entity.sender_id, user_message_entity.team_id, IdType.USER, '{targetUserName}拒绝了你的邀请', 'user {targetUserName} rejected your invitation', targetUserName=user_message_entity.receiver_id) + return None + elif msg_status == UserMessageStatus.ACCEPTED: + await TeamService.add_team_user(user_message_entity.team_id, user_message_entity.role_id, user_message_entity.receiver_id) + await TeamService.add_team_msg(user_message_entity.receiver_id, user_message_entity.team_id, IdType.USER, '{targetUserName}加入了团队', 'user {targetUserName} joined the team', targetUserName=user_message_entity.receiver_id) + return msg_id + elif user_message_entity.type == UserMessageType.APPLICATION: + if msg_status == UserMessageStatus.REJECTED: + await TeamService.add_team_msg(user_message_entity.sender_id, user_message_entity.team_id, IdType.USER, '{targetUserName}拒绝了你的申请', 'user {targetUserName} rejected your application', targetUserName=user_message_entity.sender_id) + return None + elif msg_status == UserMessageStatus.ACCEPTED: + await TeamService.add_team_user(user_message_entity.team_id, user_message_entity.role_id, user_message_entity.sender_id) + await TeamService.add_team_msg(user_message_entity.receiver_id, user_message_entity.team_id, IdType.USER, '{targetUserName}加入了团队', 'user {targetUserName} joined the team', targetUserName=user_message_entity.sender_id) + return msg_id + return None + + @staticmethod + async def delete_user_messages( + user_sub: str, msg_id: uuid.UUID) -> Union[None, str]: + """根据消息标识删除用户消息""" + try: + user_message_entity = await UserMessageManager.get_user_message_by_msg_id(msg_id) + if not user_message_entity: + raise Exception('用户消息不存在') + msg_dict = {} + if user_sub == user_message_entity.sender_id: + msg_dict['status_to_sender'] = UserMessageStatus.DELETED.value + if user_sub == user_message_entity.receiver_id: + msg_dict['status_to_receiver'] = UserMessageStatus.DELETED.value + if user_sub != user_message_entity.sender_id and user_message_entity.is_to_all: + action_entity = await RoleManager.get_action_by_team_id_user_sub_and_action( + user_message_entity.team_id, user_sub, 'PUT /usr_msg') + if action_entity: + msg_dict['status_to_receiver'] = UserMessageStatus.DELETED.value + if not msg_dict: + raise Exception('用户没有权限删除该消息') + await UserMessageManager.update_user_message_by_msg_id( + msg_id, msg_dict) + return msg_id + except Exception as e: + err = "根据消息标识删除用户消息失败" + logging.exception("[UserMessageService] %s", err) + raise e diff --git a/data_chain/apps/service/user_service.py b/data_chain/apps/service/user_service.py new file mode 100644 index 0000000000000000000000000000000000000000..2107d3e34ffabd5ea666ee87cb0d38d03b606cd8 --- /dev/null +++ b/data_chain/apps/service/user_service.py @@ -0,0 +1,40 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. +import uuid +from typing import Union +from data_chain.logger.logger import logger as logging +from data_chain.entities.request_data import ( + ListUserRequest +) +from data_chain.entities.response_data import ( + User, + ListUserMsg +) +from data_chain.entities.enum import ( + IdType, + TeamType, + TeamStatus, + TeamUserStaus, + UserRoleStatus, + UserMessageStatus, + UserMessageType, + DeafaultRole +) +from data_chain.entities.common import default_roles +from data_chain.apps.base.convertor import Convertor +from data_chain.manager.user_manager import UserManager + + +class UserService: + @staticmethod + async def list_users(req: ListUserRequest) -> ListUserMsg: + try: + total, user_entities = await UserManager.list_user(req) + user_list = [] + for user_entity in user_entities: + user = await Convertor.convert_user_entity_to_user(user_entity) + user_list.append(user) + return ListUserMsg(total=total, users=user_list) + except Exception as e: + err = "用户列表获取失败" + logging.warning("[UserService] %s", err) + raise e diff --git a/data_chain/common/.env.example b/data_chain/common/.env.example index 063eddddebb8439b62ee5770471dadab15bd0564..f5a134658072e5b508a968a66d5d3674ae5da4ae 100644 --- a/data_chain/common/.env.example +++ b/data_chain/common/.env.example @@ -41,6 +41,11 @@ EMBEDDING_TYPE = EMBEDDING_API_KEY = EMBEDDING_ENDPOINT = EMBEDDING_MODEL_NAME = +# Rerank +RERANK_TYPE = +RERANK_API_KEY = +RERANK_ENDPOINT = +RERANK_MODEL_NAME = # Token SESSION_TTL = CSRF_KEY = diff --git a/data_chain/common/pp.py b/data_chain/common/pp.py index bfcfb50c37a24f3760f9dc7f317bd34e60532d58..dc35adf4a7426a0324ce7e682302ee583598c961 100644 --- a/data_chain/common/pp.py +++ b/data_chain/common/pp.py @@ -26,13 +26,8 @@ def save_yaml_file(yaml_data, file_path): # 示例:加载YAML文件 file_path = './data_chain/common/prompt.yaml' yaml_data = load_yaml_file(file_path) -if yaml_data: - print(yaml_data) -# yaml_data['LLM_PROMPT_TEMPLATE']='' -# yaml_data['INTENT_DETECT_PROMPT_TEMPLATE']='' -# yaml_data['OCR_ENHANCED_PROMPT']='' -# yaml_data['DETERMINE_ANSWER_AND_QUESTION']='' -# save_yaml_file(yaml_data,file_path) +print(yaml_data) +# print(config.__dict__) # llm = LLM( # model_name=config['MODEL_NAME'], # openai_api_base=config['OPENAI_API_BASE'], @@ -41,34 +36,21 @@ if yaml_data: # max_tokens=config['MAX_TOKENS'], # temperature=config['TEMPERATURE'], # ) -# prompt_template = yaml_data['CONTENT_TO_ABSTRACT_PROMPT'] -# content = '''在那遥远的山谷之中,有一片神秘而又美丽的森林。阳光透过茂密的枝叶,洒下一片片金色的光斑,仿佛是大自然精心编织的梦幻画卷。森林里,鸟儿欢快地歌唱,那清脆的歌声在林间回荡,传递着生机与活力。松鼠们在树枝间跳跃,敏捷的身影如同灵动的音符,谱写着森林的乐章。 -# 沿着蜿蜒的小径前行,脚下的落叶发出沙沙的声响,仿佛在诉说着岁月的故事。路边的野花竞相开放,红的、黄的、紫的,五彩斑斓,散发着阵阵芬芳。蝴蝶在花丛中翩翩起舞,它们那绚丽的翅膀,如同绚丽的丝绸,在微风中轻轻摇曳。 -# 不远处,一条清澈的小溪潺潺流淌。溪水从山间缓缓流下,清澈见底,能看到鱼儿在水中自由自在地游弋。溪水撞击着石头,发出叮叮咚咚的声音,宛如一首美妙的乐曲。溪边的石头上,长满了青苔,仿佛是大自然赋予的绿色绒毯。 -# 在森林的深处,隐藏着一座古老的城堡。城堡的墙壁上爬满了藤蔓,仿佛是岁月留下的痕迹。城堡的大门紧闭,似乎隐藏着无数的秘密。传说中,这座城堡里住着一位美丽的公主,她被邪恶的巫师困在了这里,等待着勇敢的骑士前来解救。 -# 有一天,一位年轻的骑士听闻了这个传说,决定踏上寻找公主的冒险之旅。他骑着一匹矫健的白马,手持长剑,穿过茂密的森林,越过湍急的河流,历经千辛万苦,终于来到了城堡的门前。 -# 骑士用力地敲打着城堡的大门,然而,大门却纹丝不动。就在他感到绝望的时候,一只小精灵出现在他的面前。小精灵告诉他,要打开城堡的大门,必须找到三把神奇的钥匙。这三把钥匙分别隐藏在森林的三个不同的地方,只有集齐了这三把钥匙,才能打开城堡的大门。 -# 骑士听了小精灵的话,毫不犹豫地踏上了寻找钥匙的旅程。他在森林里四处寻找,遇到了各种各样的困难和挑战。有时候,他会迷失在森林的深处,找不到方向;有时候,他会遇到凶猛的野兽,不得不与之搏斗。但是,骑士始终没有放弃,他坚信自己一定能够找到钥匙,救出公主。 -# 终于,经过一番艰苦的努力,骑士找到了三把神奇的钥匙。他拿着钥匙,来到城堡的门前,将钥匙插入锁孔。随着一阵清脆的响声,城堡的大门缓缓打开。骑士走进城堡,沿着昏暗的走廊前行,终于在一间房间里找到了公主。 -# 公主看到骑士,眼中闪烁着希望的光芒。她告诉骑士,自己被巫师困在这里已经很久了,一直在等待着有人来救她。骑士将公主带出城堡,骑着白马,离开了这片神秘的森林。 -# 从此以后,骑士和公主过上了幸福的生活。他们的故事在这片土地上流传开来,成为了人们心中的一段佳话。 -# 在这个世界上,还有许多未知的领域等待着我们去探索。也许,在那遥远的地方,还有更多神秘的故事等待着我们去发现。无论是茂密的森林,还是古老的城堡,都充满了无限的魅力。它们吸引着我们不断地前行,去追寻那未知的美好。 -# 当夜幕降临,天空中繁星闪烁。那璀璨的星光,仿佛是大自然赋予我们的最美的礼物。在这宁静的夜晚,我们可以静静地聆听大自然的声音,感受它的神奇与美妙。 -# 有时候,我们会在生活中遇到各种各样的困难和挫折。但是,只要我们像那位勇敢的骑士一样,坚持不懈,勇往直前,就一定能够克服困难,实现自己的梦想。生活就像一场冒险,充满了未知和挑战。我们要勇敢地面对生活中的一切,用自己的智慧和勇气去创造美好的未来。 -# 在这个充满变化的世界里,我们要学会珍惜身边的一切。无论是亲人、朋友,还是那美丽的大自然,都是我们生活中不可或缺的一部分。我们要用心去感受他们的存在,用爱去呵护他们。 -# 随着时间的推移,那片神秘的森林依然静静地矗立在那里。它见证了无数的故事,承载了无数的回忆。而那座古老的城堡,也依然默默地守护着那些神秘的传说。它们就像历史的见证者,诉说着过去的辉煌与沧桑。 -# 我们生活在一个充满希望和梦想的时代。每一个人都有自己的追求和目标,都在为了实现自己的梦想而努力奋斗。无论是科学家、艺术家,还是普通的劳动者,都在各自的岗位上发光发热,为社会的发展做出自己的贡献。 -# 在科技飞速发展的今天,我们的生活发生了翻天覆地的变化。互联网的普及,让我们的信息传播更加迅速和便捷。我们可以通过网络了解到世界各地的新闻和文化,与远方的朋友进行交流和沟通。科技的进步,也让我们的生活更加舒适和便利。我们有了更加先进的交通工具、更加便捷的通讯设备,以及更加高效的生活方式。 -# 然而,科技的发展也带来了一些问题。比如,环境污染、能源危机等。这些问题不仅影响着我们的生活质量,也威胁着我们的未来。因此,我们在享受科技带来的便利的同时,也要关注环境保护和可持续发展。我们要努力寻找更加绿色、环保的生活方式,减少对自然资源的消耗和对环境的破坏。 -# 除了科技的发展,文化的传承和创新也是我们生活中重要的一部分。每一个国家和民族都有自己独特的文化传统,这些文化传统是我们的精神财富,也是我们民族的灵魂。我们要传承和弘扬自己的文化传统,让它们在新的时代焕发出新的活力。同时,我们也要积极吸收和借鉴其他国家和民族的优秀文化成果,促进文化的交流和融合。 -# 在教育方面,我们要注重培养学生的创新精神和实践能力。我们要让学生在学习知识的同时,学会思考、学会创新、学会实践。只有这样,我们才能培养出适应时代发展需要的高素质人才。 -# 在人际交往中,我们要学会尊重他人、理解他人、关心他人。我们要建立良好的人际关系,与他人和谐相处。只有这样,我们才能在生活中感受到温暖和快乐。 -# 总之,我们的生活是丰富多彩的,充满了无限的可能。我们要珍惜生活中的每一个瞬间,用积极的态度去面对生活中的一切。无论是成功还是失败,无论是欢笑还是泪水,都是我们生活中的宝贵财富。让我们一起努力,创造一个更加美好的未来!''' -# abstract = '' -# for i in range(10): -# part = TokenTool.get_k_tokens_words_from_content(content, 100) -# content = content[len(part):] -# sys_call = prompt_template.format(content=part, abstract=abstract) -# user_call = '请详细输出内容的摘要,不要输出其他内容' -# abstract = asyncio.run(llm.nostream([], sys_call, user_call)) -# print(abstract) +# print(prompt_dict) +# for key in prompt_dict: +# prompt = prompt_dict[key]['zh'] +# systemcall = f""" +# 你是一个翻译专家, 你需要将用户输入的中文内容翻译成地道的英文, 只需要返回翻译后的英文内容, 不需要任何多余的解释和说明. +# 你需要严格遵守以下规则: +# 1. 你只能翻译用户输入的内容, 不能添加任何额外的信息. +# 2. 你需要确保翻译后的内容符合英文的语法和表达习惯. +# 3. 你需要确保翻译后的内容准确传达用户输入的中文内容的意思. + +# 标签中的内容是用户输入的中文内容, 你需要将这些内容翻译成英文. +# {prompt} +# """ +# user_call = f"请将上面的内容翻译为英文" +# result = asyncio.run(llm.nostream([], systemcall, user_call)) +# print(result) +# prompt_dict[key]['en'] = result +# print(prompt_dict) diff --git a/data_chain/common/prompt.yaml b/data_chain/common/prompt.yaml index 70ba9e2b047707aac037375fef93b7d51956a42c..20429db73f5f8899feac5ad7517d490cc7a905ed 100644 --- a/data_chain/common/prompt.yaml +++ b/data_chain/common/prompt.yaml @@ -1,381 +1,735 @@ -INTENT_DETECT_PROMPT_TEMPLATE: "\n\n \n \ - \ 根据历史对话,推断用户的实际意图并补全用户的提问内容。\n 用户的提问内容将在中给出,历史对话将在中给出。\n\ - \ 要求:\n 1. 参考下面给出的样例,请直接输出补全后的提问内容;输出不要包含XML标签,不要包含任何解释说明;\n\ - \ 2. 若用户当前提问内容与对话上文不相关,或你认为用户的提问内容已足够完整,请直接输出用户的提问内容。\n \ - \ 3. 补全内容必须精准、恰当,不要编造任何内容。\n \n\n \n\ - \ openEuler是什么 \n 有什么特点\n\ - \ openEuler有什么特点?\n \n \n\ - \n \n {history}\n \n\ - \ \n {question}\n \n" -LLM_PROMPT_TEMPLATE: "\n \n 你是EulerCopilot,openEuler社区的智能助手。请结合给出的背景信息,\ - \ 回答用户的提问。\n 上下文背景信息将在中给出。\n 注意:输出不要包含任何XML标签,不要编造任何信息。若你认为用户提问与背景信息无关,请忽略背景信息直接作答。\n\ - \ \n\n \n {bac_info}\n \ - \ \n" -OCR_ENHANCED_PROMPT: '你是一个图片ocr内容总结专家,你的任务是根据我提供的上下文、相邻图片组描述、当前图片上一次的ocr内容总结、当前图片部分ocr的结果(包含文字和文字的相对坐标)给出图片描述. - - 注意: - - #01 必须使用大于200字小于500字详细详细描述这个图片的内容,可以详细列出数据. - - #02 如果这个图是流程图,请按照流程图顺序描述内容。 - - #03 如果这张图是表格,请用markdown形式输出表格内容 . - - #04 如果这张图是架构图,请按照架构图层次结构描述内容。 - - #05 总结的图片描述必须包含图片中的主要信息,不能只描述图片位置。 - - #6 图片识别结果中相邻的文字可能是同一段落的内容,请合并后总结 - - #7 文字可能存在错位,请修正顺序后进行总结 - - #8 请仅输出图片的总结即可,不要输出其他内容 - - #9 不要输出坐标等信息,输出每个部分相对位置的描述即可 - - #10 如果图片内容为空,请输出“图片内容为空” - - #11 如果图片本身就是一段文字,请直接输出文字内容 - - 上下文:{image_related_text} - - 当前图片上一部分的ocr内容总结:{pre_part_description} - - 当前图片部分ocr的结果:{part}' -QA_TO_STATEMENTS_PROMPT: '你是一个文本分解专家,你的任务是根据我给出的问题和答案,将答案提取为多个陈诉,陈诉使用列表形式返回 - - 注意: - #01 陈诉必须来源于答案中的重点内容 - #02 陈诉按相对顺序排列 - #03 输出的单个陈诉长度不超过50个字 - #04 输出的陈诉总数不超过20个 - #05 请仅输出陈诉列表,不要输出其他内容 - 例子: - - 输入: - 问题:openEuler是什么操作系统? - 答案:openEuler是一个开源的操作系统,旨在为云计算和边缘计算提供支持。它具有高性能、高安全性和高可靠性等特点。 - 输出: - [ - \"openEuler是一个开源的操作系统\", - \"openEuler旨在为云计算和边缘计算提供支持\", - \"openEuler具有高性能、高安全性和高可靠性等特点\" - ] - - 下面是给出的问题和答案: - 问题:{question} - 答案:{answer} -' -ANSWER_TO_ANSWER_PROMPT: '你是一个文本分析专家,你的任务对比两个文本之间的相似度,并输出一个0-100之间的分数且保留两位小数: -注意: -#01 请根据文本在语义、语序和关键字上的相似度进行打分 -#02 如果两个文本在核心表达上一致,那么分数也相对高 -#03 一个文本包含另一个文本的核心内容,那么分数也相对高 -#04 两个文本间内容有重合,那么按照重合内容的比例打分 -#05 请仅输出分数,不要输出其他内容 -例子: -输入1: - 文本1:openEuler是一个开源的操作系统,旨在为云计算和边缘计算提供支持。它具有高性能、高安全性和高可靠性等特点。 - 文本2:openEuler是一个开源的操作系统,旨在为云计算和边缘计算提供支持。它具有高性能、高安全性和高可靠性等特点。 - 输出1:100.00 -输入2: - 文本1:openEuler是一个开源的操作系统,旨在为云计算和边缘计算提供支持。它具有高性能、高安全性和高可靠性等特点。 - 文本2:openEuler是一个开源的操作系统,旨在为云计算和边缘计算提供支持。它具有高性能和高安全性等特点。 - 输出2:90.00 -输入3: - 文本1:openEuler是一个开源的操作系统,旨在为云计算和边缘计算提供支持。它具有高性能、高安全性和高可靠性等特点。 - 文本2:白马非马 - 输出3:00.00 -下面是给出的文本: - 文本1:{text_1} - 文本2:{text_2} -' -CONTENT_TO_STATEMENTS_PROMPT: '你是一个文本分解专家,你的任务是根据我给出的文本,将文本提取为多个陈诉,陈诉使用列表形式返回 - - 注意: - #01 陈诉必须来源于文本中的重点内容 - #02 陈诉按相对顺序排列 - #03 输出的单个陈诉长度不少于20个字,不超过50个字 - #04 输出的陈诉总数不超过3个 - #05 请仅输出陈诉列表,不要输出其他内容 - 例子: - - 输入:openEuler是一个开源的操作系统,旨在为云计算和边缘计算提供支持。它具有高性能、高安全性和高可靠性等特点。 - 输出: - [ - \"openEuler是一个开源的操作系统\", - \"openEuler旨在为云计算和边缘计算提供支持\", - \"openEuler具有高性能、高安全性和高可靠性等特点\" - ] - - 下面是给出的文本: - {content} - ' -STATEMENTS_TO_FRAGMENT_PROMPT: '你是一个文本专家,你的任务是根据给出的陈诉是否与片段强相关 - 注意: - #01 如果陈诉与片段强相关或者来自于片段,请输出YES - #02 如果陈诉中的内容与片段无关,请输出NO - #03 如果陈诉是片段中某部分的提炼,请输出YES - #05 请仅输出YES或NO,不要输出其他内容 - 例子: - 输入1: - - 陈诉:openEuler是一个开源的操作系统。 - 片段:openEuler是一个开源的操作系统,旨在为云计算和边缘计算提供支持。它具有高性能、高安全性和高可靠性等特点。 - 输出1:YES - - 输入2: - 陈诉:白马非马 - 片段:openEuler是一个开源的操作系统,旨在为云计算和边缘计算提供支持。它具有高性能、高安全性和高可靠性等特点。 - 输出2:NO - - 下面是给出的陈诉和片段: - 陈诉:{statement} - 片段:{fragment} - ' -STATEMENTS_TO_QUESTION_PROMPT: '你是一个文本分析专家,你的任务是根据给出的陈诉和问题判断,陈诉是否与问题相关 - 注意: - #01 如果陈诉是否与问题相关,请输出YES - #02 如果陈诉与问题不相关,请输出NO - #03 请仅输出YES或NO,不要输出其他内容 - #04 陈诉与问题相关是指,陈诉中的内容可以回答问题或者与问题在内容上有交集 - 例子: - 输入1: - 陈诉:openEuler是一个开源的操作系统。 +ACC_ANALYSIS_RESULT_MERGE_PROMPT: + en: | + You are a text analysis expert. Your task is to combine two analysis results and output a new one. Note: + #01 Please combine the content of the two analysis results to produce a new analysis result. + #02 Please analyze using the four metrics of recall, precision, faithfulness, and interpretability. + #03 The new analysis result must be no longer than 500 characters. + #04 Please output only the new analysis result; do not output any other content. + Example: + Input 1: + Analysis Result 1: + Recall: Currently, the recall is 95.00, with room for improvement. We will optimize the vectorized search algorithm to further mine information in the original fragment that is relevant to the question but not retrieved, such as some specific practical cases in the openEuler ecosystem. The embedding model bge-m3 will be adjusted to more comprehensively and accurately capture semantics, expand the search scope, improve recall, and make the generated answers closer to the standard answer. + Accuracy: The accuracy is 99.00, which is quite high. However, further optimization is possible, including deeper semantic analysis of the retrieved snippets. By combining the features of the large model qwen2.5-32b, this can precisely match the question semantics and avoid subtle semantic deviations. For example, this can more precisely illustrate the specific manifestations of OpenEuler's high performance in cloud computing and edge computing. + Fidelity: The fidelity value is 90.00, indicating that some answers are not fully derived from the retrieved snippets. Optimizing the rag retrieval algorithm, improving the recall rate of the embedding model, and adjusting the text chunk size to 512 may be inappropriate and require re-evaluation based on the content. This ensures that the retrieved snippets contain sufficient context to support the answer, ensuring that the generated answer content is fully derived from the retrieved snippets. For example, regarding the development of the openEuler ecosystem, relevant technical details should be obtained from the retrieved snippets. + Interpretability: The interpretability is 85.00, which is relatively low. Improve the compliance of the large model qwen2.5-32b and optimize the recall of the rag retrieval algorithm and the embedding model bge-m3. This ensures that retrieved snippets better support answer generation and clearly answer questions. For example, when answering questions related to OpenEuler, this makes the answer logic clearer and more targeted, improving overall interpretability. + + Analysis Result 2: + The recall rate is currently 95.00. Further optimization of the rag retrieval algorithm and embedding model can be used to increase the semantic similarity between the generated answers and the standard answers, approaching or achieving a higher recall rate. For example, the algorithm can be continuously optimized to better match relevant snippets. + The precision is 99.00, close to the maximum score, indicating that the generated answers are semantically similar to the questions. However, further improvement is possible. This can be achieved by refining the embedding model to better understand the question semantics, optimizing the contextual completeness of the retrieved snippets, and reducing fluctuations in precision caused by insufficient context. + The faithfulness score is currently 90.00, indicating that some content in the generated answer is not fully derived from the retrieved snippet. The rag retrieval algorithm can be optimized to improve its recall rate. The text chunk size can also be adjusted appropriately to ensure that the retrieved snippet fully answers the question, thereby improving the faithfulness score. + Regarding interpretability, it is currently 85.00, indicating that the generated answer has room for improvement in terms of answering questions. On the one hand, the large model used can be optimized to improve its compliance, making the generated answer more accurate. On the other hand, the recall rates of the rag retrieval algorithm and embedding model can be further optimized to ensure that the retrieved snippet fully supports the answer and improve interpretability. + + Output: + Recall: Currently at 95.00, there is room for improvement. The vectorized retrieval algorithm can be optimized to further uncover information in the original snippet that is relevant to the question but not retrieved, as demonstrated in some specific practical cases within the openEuler ecosystem. Adjustments were made to the embedding model bge-m3 to enable it to more comprehensively and accurately capture semantics, expand the search scope, improve recall, and bring the generated answers closer to the standard answer. + Accuracy: The accuracy reached 99.00, which is already high. However, further optimization is needed to conduct deeper semantic analysis of the retrieved snippets. By combining the features of the large model qwen2.5-32b, this can precisely match the question semantics and avoid subtle semantic deviations. For example, this could more accurately demonstrate the specific characteristics of OpenEuler's high performance in cloud computing and edge computing. + Fidelity: The fidelity value was 90.00, indicating that some answer content was not fully derived from the retrieved snippet. The rag retrieval algorithm was optimized to improve the recall of the embedding model. Adjusting the text chunk size to 512 may be unreasonable and requires re-evaluation based on the content to ensure that the retrieved snippets contain sufficient context to support the answer, ensuring that the generated answer content is fully derived from the retrieved snippet. For example, relevant technical details regarding the development of the OpenEuler ecosystem should be obtained from the retrieved snippet. + Interpretability: The interpretability value was 85.00, which is relatively low. Improve the compliance of the large qwen2.5-32b model and optimize the recall of the rag retrieval algorithm and the embedding model bge-m3. This ensures that retrieval fragments can better support answer generation and clearly answer questions. For example, when answering questions related to OpenEuler, this improves answer logic, makes it more targeted, and improves overall interpretability. + + The following two analysis results: + Analysis Result 1: {analysis_result_1} + Analysis Result 2: {analysis_result_2} + + 中文: | + 你是一个文本分析专家,你的任务融合两条分析结果输出一份新的分析结果。注意: + #01 请根据两条分析结果中的内容融合出一条新的分析结果 + #02 请结合召回率、精确度、忠实值和可解释性四个指标进行分析 + #03 新的分析结果长度不超过500字 + #04 请仅输出新的分析结果,不要输出其他内容 + 例子: + 输入1: + 分析结果1: + 召回率:目前召回率为 95.00,有提升空间。优化向量化检索算法,进一步挖掘原始片段中与问题相关但未被检索到的信息,如 openEuler 生态中一些具体实践案例等。调整 embedding 模型 bge-m3,使其能更全面准确地捕捉语义,扩大检索范围,提高召回率,使生成答案更接近标准答案。 + 精确度:精确度达 99.00,已较高。但可进一步优化,对检索到的片段进行更深入的语义分析,结合大模型 qwen2.5-32b 的特点,精准匹配问题语义,避免细微语义偏差,例如更精确阐述 openEuler 在云计算和边缘计算中高性能等特性的具体表现。 + 忠实值:忠实值为 90.00,说明部分答案内容未完全源于检索片段。优化 rag 检索算法,提高 embedding 模型召回率,调整文本分块大小为 512 可能存在不合理,需根据内容重新评估,确保检索片段包含足够上下文以支撑答案,使生成答案内容均来自检索片段,如关于 openEuler 生态建设中相关技术细节应从检索片段获取。 + 可解释性:可解释性为 85.00,相对较低。提升大模型 qwen2.5-32b 的遵从度,优化 rag 检索算法和 embedding 模型 bge-m3 的召回率,使检索片段能更好支撑生成答案,保证答案能清晰回答问题,例如在回答 openEuler 相关问题时,使答案逻辑更清晰、针对性更强,提高整体可解释性。 + + 分析结果2: + 从召回率来看,目前为 95.00,可进一步优化 rag 检索算法和 embedding 模型,以提高生成答案与标准回答之间的语义相似程度,接近或达到更高的召回率,例如可以持续优化算法来更好地匹配相关片段。 + 从精确度来看,为 99.00,接近满分,说明生成的答案与问题语义相似程度较高,但仍可进一步提升,可通过完善 embedding 模型来更好地理解问题语义,优化检索到的片段的上下文完整性,减少因上下文不足导致的精确度波动。 + 对于忠实值,目前为 90.00,说明生成的答案中部分内容未完全来自检索到的片段。可优化 rag 检索算法,提高其召回率,同时合理调整文本分块大小,确保检索到的片段能充分回答问题,从而提高忠实值。 + 关于可解释性,当前为 85.00,说明生成的答案在用于回答问题方面有一定提升空间。一方面可以优化使用的大模型,提高其遵从度,使其生成的答案更准确地回答问题;另一方面,继续优化 rag 检索算法和 embedding 模型的召回率,保证检索到的片段能全面支撑问题的回答,提高可解释性。 + + 输出: + 召回率:目前召回率为 95.00,有提升空间。优化向量化检索算法,进一步挖掘原始片段中与问题相关但未被检索到的信息,如 openEuler 生态中一些具体实践案例等。调整 embedding 模型 bge-m3,使其能更全面准确地捕捉语义,扩大检索范围,提高召回率,使生成答案更接近标准答案。 + 精确度:精确度达 99.00,已较高。但可进一步优化,对检索到的片段进行更深入的语义分析,结合大模型 qwen2.5-32b 的特点,精准匹配问题语义,避免细微语义偏差,例如更精确阐述 openEuler 在云计算和边缘计算中高性能等特性的具体表现。 + 忠实值:忠实值为 90.00,说明部分答案内容未完全源于检索片段。优化 rag 检索算法,提高 embedding 模型召回率,调整文本分块大小为 512 可能存在不合理,需根据内容重新评估,确保检索片段包含足够上下文以支撑答案,使生成答案内容均来自检索片段,如关于 openEuler 生态建设中相关技术细节应从检索片段获取。 + 可解释性:可解释性为 85.00,相对较低。提升大模型 qwen2.5-32b 的遵从度,优化 rag 检索算法和 embedding 模型 bge-m3 的召回率,使检索片段能更好支撑生成答案,保证答案能清晰回答问题,例如在回答 openEuler 相关问题时,使答案逻辑更清晰、针对性更强,提高整体可解释性。 + + 下面两条分析结果: + 分析结果1:{analysis_result_1} + 分析结果2:{analysis_result_2} + +ACC_RESULT_ANALYSIS_PROMPT: + en: | + You are a text analysis expert. Your task is to: analyze the large model used in the test, the embedding model used in the test, the parsing method and chunk size of related documents, the snippets matched by the RAG algorithm for a single test result, and propose methods to improve the accuracy of question-answering in the current knowledge base. + + The test results include the following information: + - Question: The question used in the test + - Standard answer: The standard answer used in the test + - Generated answer: The answer output by the large model in the test results + - Original snippet: The original snippet provided in the test results + - Retrieved snippet: The snippet retrieved by the RAG algorithm in the test results + + The four evaluation metrics are defined as follows: + - Precision: Evaluates the semantic similarity between the generated answer and the question. A lower score indicates lower compliance of the large model; additionally, it may mean the snippets retrieved by the RAG algorithm lack context and are insufficient to support the answer. + - Recall: Evaluates the semantic similarity between the generated answer and the standard answer. A lower score indicates lower compliance of the large model. + - Fidelity: Evaluates whether the content of the generated answer is derived from the retrieved snippet. A lower score indicates lower recall of the RAG retrieval algorithm and embedding model (resulting in retrieved snippets insufficient to answer the question); additionally, it may mean the text chunk size is inappropriate. + - Interpretability: Evaluates whether the generated answer is useful for answering the question. A lower score indicates lower recall of the RAG retrieval algorithm and embedding model (resulting in retrieved snippets insufficient to answer the question); additionally, it may mean lower compliance of the used large model. + + Notes: + #01 Analyze methods to improve the accuracy of current knowledge base question-answering based on the test results. + #02 Conduct the analysis using the four metrics: Recall, Precision, Fidelity, and Interpretability. + #03 The analysis result must not exceed 500 words. + #04 Output only the analysis result; do not include any other content. + + Example: + Input: + Model name: qwen2.5-32b + Embedding model: bge-m3 + Text chunk size: 512 + Used RAG algorithm: Vectorized retrieval + Question: What is OpenEuler? + Standard answer: OpenEuler is an open source operating system designed to support cloud computing and edge computing. It features high performance, high security, and high reliability. + Generated answer: OpenEuler is an open source operating system designed to support cloud computing and edge computing. It features high performance, high security, and high reliability. + Original snippet: openEuler is an open source operating system incubated and operated by the Open Atom Open Source Foundation. Its mission is to build an open source operating system ecosystem for digital infrastructure and provide solid underlying support for cutting-edge fields such as cloud computing and edge computing. In cloud computing scenarios, openEuler can fully optimize resource scheduling and allocation mechanisms. Through a lightweight kernel design and efficient virtualization technology, it significantly improves the responsiveness and throughput of cloud services. In edge computing, its exceptional low resource consumption and real-time processing capabilities ensure the timeliness and accuracy of data processing at edge nodes in complex environments. openEuler boasts a series of exceptional features: In terms of performance, its independently developed intelligent scheduling algorithm dynamically adapts to different load scenarios, and combined with deep optimization of hardware resources, significantly improves system efficiency. Regarding security, its built-in multi-layered security system, including mandatory access control, vulnerability scanning, and remediation mechanisms, provides a solid defense for system data and applications. Regarding reliability, its distributed storage, automatic fault detection, and rapid recovery technologies ensure stable system operation in the face of unexpected situations such as network fluctuations and hardware failures, minimizing the risk of service interruptions. These features make openEuler a crucial technological cornerstone for promoting high-quality development of the digital economy, helping enterprises and developers seize the initiative in digital transformation. + Retrieved snippet: As a pioneer in the open source operating system field, openEuler deeply integrates the wisdom of community developers and continuously iterates and upgrades to adapt to the rapidly changing technological environment. In the current era of prevalent microservices architectures, openEuler Through deep optimization of containerization technology and support for mainstream orchestration tools such as Kubernetes, it makes application deployment and management more convenient and efficient, significantly enhancing the flexibility of enterprise business deployments. At the same time, it actively embraces the AI era. By adapting and optimizing machine learning frameworks, it provides powerful computing power for AI model training and inference, effectively reducing the development and operating costs of AI applications. Regarding ecosystem development, openEuler boasts a large and active open source community, bringing together technology enthusiasts and industry experts from around the world, forming a complete ecosystem from kernel development and driver adaptation to application optimization. The community regularly hosts technical exchanges and developer conferences to promote knowledge sharing and technological innovation, providing developers with a wealth of learning resources and practical opportunities. Numerous hardware and software manufacturers have joined the openEuler ecosystem, launching solutions and products based on the system across key industries such as finance, telecommunications, and energy. These efforts, validated through real-world application scenarios and feeding back into openEuler's technological development, have fostered a virtuous cycle of innovation, making openEuler not just an operating system but a powerful engine driving collaborative industry development. + Recall: 95.00 + Precision: 99.00 + Fidelity: 90.00 + Interpretability: 85.00 + + Output: + Based on the test results, methods for improving the accuracy of current knowledge base question-answering can be analyzed from the following aspects: Recall: The current recall is 95.00, with room for improvement. Optimize the vectorized retrieval algorithm to further mine question-related but unretrieved information in the original snippets, such as some specific practical cases in the openEuler ecosystem. Adjust the embedding model bge-m3 to more comprehensively and accurately capture semantics, expand the search scope, improve recall, and make the generated answers closer to the standard answer. Precision: The accuracy reached 99.00, which is already high. However, further optimization is possible, including deeper semantic analysis of retrieved snippets. By combining the features of the large model qwen2.5-32b, this can accurately match the question semantics and avoid subtle semantic deviations. For example, more precise demonstration of openEuler's high performance in cloud computing and edge computing can be achieved. Fidelity: A fidelity score of 90.00 indicates that some answers are not fully derived from the search snippet. We optimized the rag retrieval algorithm, improved the recall of the embedding model, and adjusted the text chunk size to 512. This may be inappropriate and requires reassessment based on the content. We need to ensure that the search snippet contains sufficient context to support the answer, ensuring that the generated answer content is derived from the search snippet. For example, relevant technical details regarding the development of the openEuler ecosystem should be obtained from the search snippet. Interpretability: The interpretability score is 85.00, which is relatively low. We improved the compliance of the large model qwen2.5-32b and optimized the recall of the rag retrieval algorithm and the embedding model bge-m3. This ensures that the search snippet better supports answer generation and clearly answers the question. For example, when answering openEuler-related questions, the answer logic is made clearer and more targeted, improving overall interpretability. + + The following is the test result content: + Used large model: {model_name} + Embedding model: {embedding_model} + Text chunk size: {chunk_size} + Used RAG parsing algorithm: {rag_algorithm} + Question: {question} + Standard answer: {standard_answer} + Generated answer: {generated_answer} + Original fragment: {original_fragment} + Retrieved fragment: {retrieved_fragment} + Recall: {recall} + Precision: {precision} + Faithfulness: {faithfulness} + Interpretability: {relevance} + + 中文: | + 你是一个文本分析专家,你的任务是:根据给出的测试使用的大模型、embedding模型、测试相关文档的解析方法和分块大小、单条测试结果分析RAG算法匹配到的片段,并分析当前知识库问答准确率的提升方法。 + + 测试结果包含以下内容: + - 问题:测试使用的问题 + - 标准答案:测试使用的标准答案 + - 生成的答案:测试结果中大模型输出的答案 + - 原始片段:测试结果中的原始片段 + - 检索的片段:测试结果中RAG算法检索到的片段 + + 四个评估指标定义如下: + - 精确率:评估生成的答案与问题之间的语义相似程度。评分越低,说明使用的大模型遵从度越低;其次可能是RAG检索到的片段缺少上下文,不足以支撑问题的回答。 + - 召回率:评估生成的答案与标准回答之间的语义相似程度。评分越低,说明使用的大模型遵从度越低。 + - 忠实值:评估生成的答案中的内容是否来自于检索到的片段。评分越低,说明RAG检索算法和embedding模型的召回率越低(导致检索到的片段不足以回答问题);其次可能是文本分块大小不合理。 + - 可解释性:评估生成的答案是否能用于回答问题。评分越低,说明RAG检索算法和embedding模型的召回率越低(导致检索到的片段不足以回答问题);其次可能是使用的大模型遵从度越低。 + + 注意: + #01 请根据测试结果中的内容分析当前知识库问答准确率的提升方法。 + #02 请结合召回率、精确率、忠实值和可解释性四个指标进行分析。 + #03 分析结果长度不超过500字。 + #04 请仅输出分析结果,不要输出其他内容。 + + 例子: + 输入: + 模型名称:qwen2.5-32b + embedding模型:bge-m3 + 文本的分块大小:512 + 使用解析的RAG算法:向量化检索 问题:openEuler是什么操作系统? - 输出1:YES - - 输入2: - 陈诉:白马非马 + 标准答案:openEuler是一个开源的操作系统,旨在为云计算和边缘计算提供支持。它具有高性能、高安全性和高可靠性等特点。 + 生成的答案:openEuler是一个开源的操作系统,旨在为云计算和边缘计算提供支持。它具有高性能、高安全性和高可靠性等特点。 + 原始片段:openEuler是由开放原子开源基金会孵化及运营的开源操作系统,以构建面向数字基础设施的开源操作系统生态为使命,致力于为云计算、边缘计算等前沿领域提供坚实的底层支持。在云计算场景中,openEuler能够充分优化资源调度与分配机制,通过轻量化的内核设计和高效的虚拟化技术,显著提升云服务的响应速度与吞吐量;在边缘计算领域,它凭借出色的低资源消耗特性与实时处理能力,保障了边缘节点在复杂环境下数据处理的及时性与准确性。openEuler具备一系列卓越特性:在性能方面,其自主研发的智能调度算法能够动态适配不同负载场景,结合对硬件资源的深度优化利用,大幅提升系统运行效率;安全性上,通过内置的多层次安全防护体系,包括强制访问控制、漏洞扫描与修复机制,为系统数据与应用程序构筑起坚实的安全防线;可靠性层面,基于分布式存储、故障自动检测与快速恢复技术,确保系统在面对网络波动、硬件故障等突发状况时,依然能够稳定运行,最大限度降低服务中断风险。这些特性使openEuler成为推动数字经济高质量发展的重要技术基石,助力企业与开发者在数字化转型进程中抢占先机。 + 检索的片段:openEuler作为开源操作系统领域的先锋力量,深度融合了社区开发者的智慧结晶,不断迭代升级以适应快速变化的技术环境。在微服务架构盛行的当下,openEuler通过对容器化技术的深度优化,支持Kubernetes等主流编排工具,让应用部署与管理变得更加便捷高效,极大提升了企业的业务部署灵活性。同时,它积极拥抱AI时代,通过对机器学习框架的适配与优化,为AI模型训练和推理提供强大的算力支撑,有效降低了AI应用的开发与运行成本。在生态建设方面,openEuler拥有庞大且活跃的开源社区,汇聚了来自全球的技术爱好者与行业专家,形成了从内核开发、驱动适配到应用优化的完整生态链。社区定期举办技术交流与开发者大会,推动知识共享与技术创新,为开发者提供了丰富的学习资源与实践机会。众多硬件厂商和软件企业纷纷加入openEuler生态,推出基于该系统的解决方案和产品,涵盖金融、电信、能源等关键行业,以实际应用场景验证并反哺openEuler的技术发展,形成了良性循环的创新生态,让openEuler不仅是一个操作系统,更成为推动产业协同发展的强大引擎。 + 召回率:95.00 + 精确率:99.00 + 忠实值:90.00 + 可解释性:85.00 + + 输出: + 根据测试结果中的内容,当前知识库问答准确率提升的方法可以从以下几个方面进行分析:召回率:目前召回率为95.00,有提升空间。优化向量化检索算法,进一步挖掘原始片段中与问题相关但未被检索到的信息,如openEuler生态中一些具体实践案例等。调整embedding模型bge-m3,使其能更全面准确地捕捉语义,扩大检索范围,提高召回率,使生成答案更接近标准答案。精确率:精确率达99.00,已较高。但可进一步优化,对检索到的片段进行更深入的语义分析,结合大模型qwen2.5-32b的特点,精准匹配问题语义,避免细微语义偏差,例如更精确阐述openEuler在云计算和边缘计算中高性能等特性的具体表现。忠实值:忠实值为90.00,说明部分答案内容未完全源于检索片段。优化RAG检索算法,提高embedding模型召回率,文本分块大小为512可能存在不合理,需根据内容重新评估,确保检索片段包含足够上下文以支撑答案,使生成答案内容均来自检索片段,如关于openEuler生态建设中相关技术细节应从检索片段获取。可解释性:可解释性为85.00,相对较低。提升大模型qwen2.5-32b的遵从度,优化RAG检索算法和embedding模型bge-m3的召回率,使检索片段能更好支撑生成答案,保证答案能清晰回答问题,例如在回答openEuler相关问题时,使答案逻辑更清晰、针对性更强,提高整体可解释性。 + + 下面是测试结果中的内容: + 使用的大模型:{model_name} + embedding模型:{embedding_model} + 文本的分块大小:{chunk_size} + 使用解析的RAG算法:{rag_algorithm} + 问题:{question} + 标准答案:{standard_answer} + 生成的答案:{generated_answer} + 原始片段:{original_fragment} + 检索的片段:{retrieved_fragment} + 召回率:{recall} + 精确率:{precision} + 忠实值:{faithfulness} + 可解释性:{relevance} + +ANSWER_TO_ANSWER_PROMPT: + # 英文文本相似度评分提示词 + en: | + You are a text analysis expert. Your task is to compare the similarity between two documents and output a score between 0 and 100 with two decimal places. + + Note: + #01 Score based on text similarity in three dimensions: semantics, word order, and keywords. + #02 If the core expressions of the two documents are consistent, the score will be relatively high. + #03 If one document contains the core content of the other, the score will also be relatively high. + #04 If there is content overlap between the two documents, the score will be determined by the proportion of the overlap. + #05 Output only the score (no other content). + + Example 1: + Input - Text 1: openEuler is an open source operating system designed to support cloud computing and edge computing. It features high performance, high security, and high reliability. + Text 2: openEuler is an open source operating system designed to support cloud computing and edge computing. It features high performance, high security, and high reliability. + Output: 100.00 + + Example 2: + Input - Text 1: openEuler is an open source operating system designed to support cloud computing and edge computing. It features high performance, high security, and high reliability. + Text 2: openEuler is an open-source operating system designed to support cloud computing and edge computing. It features high performance and high security. + Output: 90.00 + + Example 3: + Input - Text 1: openEuler is an open-source operating system designed to support cloud computing and edge computing. It features high performance, high security, and high reliability. + Text 2: A white horse is not a horse + Output: 00.00 + + The following are the given texts: + Text 1: {text_1} + Text 2: {text_2} + + # 中文文本相似度评分提示词 + 中文: | + 你是一个文本分析专家,你的任务是对比两个文本之间的相似度,并输出一个 0-100 之间的分数(保留两位小数)。 + + 注意: + #01 请根据文本在语义、语序和关键字三个维度的相似度进行打分。 + #02 如果两个文本在核心表达上一致,那么分数将相对较高。 + #03 如果一个文本包含另一个文本的核心内容,那么分数也将相对较高。 + #04 如果两个文本间存在内容重合,那么将按照重合内容的比例确定分数。 + #05 仅输出分数,不要输出其他任何内容。 + + 例子 1: + 输入 - 文本 1:openEuler是一个开源的操作系统,旨在为云计算和边缘计算提供支持。它具有高性能、高安全性和高可靠性等特点。 + 文本 2:openEuler是一个开源的操作系统,旨在为云计算和边缘计算提供支持。它具有高性能、高安全性和高可靠性等特点。 + 输出:100.00 + + 例子 2: + 输入 - 文本 1:openEuler是一个开源的操作系统,旨在为云计算和边缘计算提供支持。它具有高性能、高安全性和高可靠性等特点。 + 文本 2:openEuler是一个开源的操作系统,旨在为云计算和边缘计算提供支持。它具有高性能和高安全性等特点。 + 输出:90.00 + + 例子 3: + 输入 - 文本 1:openEuler是一个开源的操作系统,旨在为云计算和边缘计算提供支持。它具有高性能、高安全性和高可靠性等特点。 + 文本 2:白马非马 + 输出:00.00 + + 下面是给出的文本: + 文本 1:{text_1} + 文本 2:{text_2} + +CAL_QA_SCORE_PROMPT: + en: >- + You are a text analysis expert. Your task is to evaluate the questions and answers generated from a given fragment, and assign a score between 0 and 100 (retaining two decimal places). Please evaluate based on the following criteria: + + ### 1. Question Evaluation + - **Relevance**: Is the question closely related to the topic of the given fragment? Is it accurately based on the fragment content? Does it deviate from or distort the core message of the fragment? + - **Plausibility**: Is the question formulated clearly and logically coherently? Does it conform to normal language and thinking habits? Is it free of semantic ambiguity, vagueness, or self-contradiction? + - **Variety**: If there are multiple questions, are their angles and types sufficiently varied to avoid being overly monotonous or repetitive? Can they explore the fragment content from different perspectives? + - **Difficulty**: Is the question difficulty appropriate? Not too easy (where answers can be directly copied from the fragment), nor too difficult (where respondents cannot find clues or evidence from the fragment)? + + ### 2. Answer Evaluation + - **Accuracy**: Does the answer accurately address the question? Is it consistent with the information in the fragment? Does it contain errors or omit key points? + - **Completeness**: Is the answer comprehensive, covering all aspects of the question? For questions requiring elaboration, does it provide sufficient details and explanations? + - **Succinctness**: On the premise of ensuring completeness and accuracy, is the answer concise and clear? Does it avoid lengthy or redundant expressions, and convey key information in concise language? + - **Coherence**: Is the answer logically clear? Are transitions between content sections natural and smooth? Are there any jumps or confusion? + + ### 3. Overall Assessment + - **Consistency**: Do the question and answer match each other? Does the answer address the raised question? Are they consistent in content and logic? + - **Integration**: Does the answer effectively integrate information from the fragment? Is it not just a simple excerpt, but rather an integrated, refined presentation in a logical manner? + - **Innovation**: In some cases, evaluate whether the answer demonstrates innovation or unique insights? Does it appropriately expand or deepen the information in the fragment? + + ### Note + #01 Please output only the score (without any other content). + + ### Example + Input 1: + Question: What operating system is openEuler? + Answer: openEuler is an open source operating system designed to support cloud computing and edge computing. It features high performance, high security, and high reliability. + Snippet: openEuler is an open source operating system designed to support cloud and edge computing. It features high performance, high security, and high reliability. + Output 1: 100.00 + + Below is the given question, answer, and snippet: + Question: {question} + Answer: {answer} + Snippet: {fragment} + 中文: >- + 你是文本分析专家,任务是评估由给定片段生成的问题与答案,输出 0-100 之间的分数(保留两位小数)。请根据以下标准进行评估: + + ### 1. 问题评估 + - **相关性**:问题是否与给定片段的主题紧密相关?是否准确基于片段内容提出?有无偏离或曲解片段的核心信息? + - **合理性**:问题表述是否清晰、逻辑连贯?是否符合正常的语言表达和思维习惯?不存在语义模糊、歧义或自相矛盾的情况? + - **多样性**:若存在多个问题,问题之间的角度和类型是否具有足够多样性(避免过于单一或重复)?能否从不同方面挖掘片段内容? + - **难度**:问题难度是否适中?既不过于简单(答案可直接从片段中照搬),也不过于困难(回答者难以从片段中找到线索或依据)? + + ### 2. 答案评估 + - **准确性**:答案是否准确无误地回答了问题?与片段中的信息是否一致?有无错误或遗漏关键要点? + - **完整性**:答案是否完整,涵盖问题涉及的各个方面?对于需要详细阐述的问题,是否提供了足够的细节和解释? + - **简洁性**:在保证回答完整、准确的前提下,答案是否简洁明了?是否避免冗长、啰嗦的表述,能否以简洁语言传达关键信息? + - **连贯性**:答案逻辑是否清晰?各部分内容之间的衔接是否自然流畅?有无跳跃或混乱的情况? + + ### 3. 整体评估 + - **一致性**:问题与答案之间是否相互匹配?答案是否针对所提出的问题进行回答?两者在内容和逻辑上是否保持一致? + - **融合性**:答案是否能很好地融合片段中的信息?是否并非简单摘抄,而是经过整合、提炼后以合理方式呈现? + - **创新性**:在某些情况下,评估答案是否具有一定创新性或独特见解?是否能在片段信息基础上进行适当拓展或深入思考? + + ### 注意事项 + #01 请仅输出分数,不要输出其他内容。 + + ### 示例 + 输入 1: + 问题:openEuler 是什么操作系统? + 答案:openEuler 是一个开源的操作系统,旨在为云计算和边缘计算提供支持。它具有高性能、高安全性和高可靠性等特点。 + 片段:openEuler 是一个开源的操作系统,旨在为云计算和边缘计算提供支持。它具有高性能、高安全性和高可靠性等特点。 + 输出 1:100.00 + + 下面是给出的问题、答案和片段: + 问题:{question} + 答案:{answer} + 片段:{fragment} + +CHUNK_QUERY_MATCH_PROMPT: + en: | + You are a text analysis expert. Your task is to determine whether a given fragment is relevant to a question. + Note: + #01 If the fragment is relevant, output YES. + #02 If the fragment is not relevant, output NO. + #03 Only output YES or NO, and do not output anything else. + + Example: + Input 1: + Fragment: openEuler is an open source operating system. + Question: What kind of operating system is openEuler? + Output 1: YES + + Input 2: + Fragment: A white horse is not a horse. + Question: What kind of operating system is openEuler? + Output 2: NO + + Here are the given fragment and question: + Fragment: {chunk} + Question: {question} + 中文: | + 你是一个文本分析专家,你的任务是根据给出的片段和问题,判断片段是否与问题相关。 + 注意: + #01 如果片段与问题相关,请输出YES; + #02 如果片段与问题不相关,请输出NO; + #03 请仅输出YES或NO,不要输出其他内容。 + + 例子: + 输入1: + 片段:openEuler是一个开源的操作系统。 问题:openEuler是什么操作系统? - 输出2:NO - - 下面是给出的陈诉和问题: - 陈诉:{statement} - 问题:{question} - ' -GENREATE_QUESTION_FROM_CONTENT_PROMPT: '你是一个文本分析专家,你的任务是根据给出的文本生成{k}个问题并用列表返回 - 注意: - #01 问题必须来源于文本中的内容 - #02 单个问题长度不超过50个字 - #03 不要输出重复的问题 - #04 输出的问题要多样,覆盖文本中的不同方面 - #05 请仅输出问题列表,不要输出其他内容 - 例子: - 输入:openEuler是一个开源的操作系统,旨在为云计算和边缘计算提供支持。它具有高性能、高安全性和高可靠性等特点。 - 输出: - [\"openEuler是什么操作系统?\",\"openEuler旨在为哪个领域提供支持?\",\"openEuler具有哪些特点?\",\"openEuler的安全性如何?\",\"openEuler的可靠性如何?\"] - 下面是给出的文本: - {content} -' -GENERATE_ANSWER_FROM_QUESTION_AND_CONTENT_PROMPT: '你是一个文本分析专家,你的任务是根据给出的问题和文本 - 生成答案 - 注意: - #01 答案必须来源于文本中的内容 - #02 答案长度不少于50字且不超过500个字 - #03 请仅输出答案,不要输出其他内容 - 例子: - 输入1: + 输出1:YES + + 输入2: + 片段:白马非马 问题:openEuler是什么操作系统? - 文本:openEuler是一个开源的操作系统,旨在为云计算和边缘计算提供支持。它具有高性能、高安全性和高可靠性等特点。 - 输出1:openEuler是一个开源的操作系统,旨在为云计算和边缘计算提供支持。 - - 输入2: - 问题:openEuler的安全性如何? - 文本:openEuler是一个开源的操作系统,旨在为云计算和边缘计算提供支持。它具有高性能、高安全性和高可靠性等特点。 - 输出2:openEuler具有高安全性。 - - 下面是给出的问题和文本: - 问题:{question} - 文本:{content} -' -CAL_QA_SCORE_PROMPT: '你是一个文本分析专家,你的任务是给出的问题 答案 片段 判断由片段生成的问题和答案的分数,输出一个0-100之间的数,保留两位小数 -请根据下面规则评估: -问题评估 -相关性:问题是否与给定片段的主题紧密相关,是否准确地基于片段内容提出,有无偏离或曲解片段的核心信息。 -合理性:问题的表述是否清晰、逻辑连贯,是否符合正常的语言表达和思维习惯,不存在语义模糊、歧义或自相矛盾的情况。 -多样性:如果有多个问题,问题之间的角度和类型是否具有一定的多样性,避免过于单一或重复,能否从不同方面挖掘片段的内容。 -难度:问题的难度是否适中,既不过于简单,使答案可以直接从片段中照搬,也不过于困难,让回答者难以从片段中找到线索或依据。 -答案评估 -准确性:答案是否准确无误地回答了问题,与片段中的信息是否一致,有无错误或遗漏关键要点。 -完整性:答案是否完整,涵盖了问题所涉及的各个方面,对于需要详细阐述的问题,是否提供了足够的细节和解释。 -简洁性:在保证回答完整准确的前提下,答案是否简洁明了,避免冗长、啰嗦的表述,能否以简洁的语言传达关键信息。 -连贯性:答案的逻辑是否清晰,各部分内容之间的衔接是否自然流畅,有无跳跃或混乱的情况。 -整体评估 -一致性:问题和答案之间是否相互匹配,答案是否是针对所提出的问题进行的回答,两者在内容和逻辑上是否保持一致。 -融合性:答案是否能够很好地融合片段中的信息,不仅仅是简单的摘抄,而是经过整合和提炼,以合理的方式呈现出来。 -创新性:在某些情况下,评估答案是否具有一定的创新性或独特见解,是否能够在片段信息的基础上进行适当的拓展或深入思考。 - -注意: -#01 请仅输出分数,不要输出其他内容 - -例子: -输入1: - 问题:openEuler是什么操作系统? - 答案:openEuler是一个开源的操作系统,旨在为云计算和边缘计算提供支持。它具有高性能、高安全性和高可靠性等特点。 - 片段:openEuler是一个开源的操作系统,旨在为云计算和边缘计算提供支持。它具有高性能、高安全性和高可靠性等特点。 - 输出1:100.00 - -下面是给出的问题、答案和片段: - 问题:{question} - 答案:{answer} - 片段:{fragment} -' -CONTENT_TO_ABSTRACT_PROMPT: '你是一个文本摘要专家,你的任务是根据给出的文本和摘要生成一个新的摘要 - 注意: - #01 请结合文本和摘要中最重要的内容生成新的摘要 - #02 新的摘要的长度必须大于200字小于500字 - #03 请仅输出新的摘要,不要输出其他内容 - 例子: - 输入1: + 输出2:NO + + 下面是给出的片段和问题: + 片段:{chunk} + 问题:{question} + +CONTENT_TO_ABSTRACT_PROMPT: + en: | + You are a text summarization expert. Your task is to generate a new English summary based on a given text and an existing summary. + Note: + #01 Please combine the most important content from the text and the existing summary to generate the new summary. + #02 The length of the new summary must be greater than 200 words and less than 500 words. + #03 Please only output the new English summary; do not output any other content. + + Example: + Input 1: + Text: openEuler features high performance, high security, and high reliability. + Abstract: openEuler is an open source operating system designed to support cloud computing and edge computing. + Output 1: openEuler is an open source operating system designed to support cloud computing and edge computing. openEuler features high performance, high security, and high reliability. + + Below is the given text and summary: + Text: {content} + Abstract: {abstract} + 中文: | + 你是一个文本摘要专家,你的任务是根据给出的文本和已有摘要,生成一个新的中文摘要。 + 注意: + #01 请结合文本和已有摘要中最重要的内容,生成新的摘要; + #02 新的摘要长度必须大于200字且小于500字; + #03 请仅输出新的中文摘要,不要输出其他内容。 + + 例子: + 输入1: 文本:openEuler具有高性能、高安全性和高可靠性等特点。 摘要:openEuler是一个开源的操作系统,旨在为云计算和边缘计算提供支持。 - 输出1:openEuler是一个开源的操作系统,旨在为云计算和边缘计算提供支持。openEuler具有高性能、高安全性和高可靠性等特点。 - - 下面是给出的文本和摘要: - 文本:{content} - 摘要:{abstract} -' - -CONTENT_TO_TITLE_PROMPT: '你是一个标题提取专家,你的任务是根据给出的文本生成一个标题 - 注意: - #01 标题必须来源于文本中的内容 - #02 标题长度不超过20个字 - #03 请仅输出标题,不要输出其他内容 - #04 如果给出的文本不够生成标题,请输出“无法生成标题” - 例子: - 输入:openEuler是一个开源的操作系统,旨在为云计算和边缘计算提供支持。它具有高性能、高安全性和高可靠性等特点。 - 输出:openEuler操作系统概述 - 下面是给出的文本: - {content} -' - -ACC_RESULT_ANALYSIS_PROMPT: '你是一个文本分析专家,你的任务根据给出的 测试使用的大模型 embdding模型 测试相关文档的解析方法和分块大小 单条测试结果分析rag算法匹配到的片段分析当前知识库问答准确率提升的方法 -测试结果包含下面内容: -问题:测试使用的问题 -标准答案:测试使用的标准答案 -生成的答案:测试结果中大模型的答案 -原始片段:测试结果中原始片段 -检索的片段:测试结果中rag算法检索到的片段 -精确率:评估生成的答案与问题之间的语义相似程度,这个评分月越低说明使用的大模型遵从度越低,其次是rag检索到的片段缺少上下文,不足以支撑问题的回答 -召回率度:评估生成的答案与标准回答之间的语义相似程度,这个评分月越低说明使用的大模型遵从度越低 -忠实值:评估生成的答案中的内容是否来自于检索到的片段,这个评分越低说明rag检索算法和embedding模型的召回率越低,检索到的片段不足以回答问题,其次是文本分块大小不合理 -可解释性:评估生成的答案是否用于回答问题,这个评分越低说明rag检索算法和embedding模型的召回率越低,检索到的片段不足以回答问题,其次是使用的大模型遵从度越低 - -注意: -#01 请根据测试结果中的内容分析当前知识库问答准确率提升的方法 -#02 请结合召回率、精确度、忠实值和可解释性四个指标进行分析 -#03 分析结果长度不超过500字 -#04 请仅输出分析结果,不要输出其他内容 -例子: -输入: -模型名称:qwen2.5-32b -embedding模型:bge-m3 -文本的分块大小:512 -使用解析的rag算法:向量化检索 -问题:openEuler是什么操作系统? -标准答案:openEuler是一个开源的操作系统,旨在为云计算和边缘计算提供支持。它具有高性能、高安全性和高可靠性等特点。 -生成的答案:openEuler是一个开源的操作系统,旨在为云计算和边缘计算提供支持。它具有高性能、高安全性和高可靠性等特点。 -原始片段:openEuler 是由开放原子开源基金会孵化及运营的开源操作系统,以构建面向数字基础设施的开源操作系统生态为使命,致力于为云计算、边缘计算等前沿领域提供坚实的底层支持。在云计算场景中,openEuler 能够充分优化资源调度与分配机制,通过轻量化的内核设计和高效的虚拟化技术,显著提升云服务的响应速度与吞吐量;在边缘计算领域,它凭借出色的低资源消耗特性与实时处理能力,保障了边缘节点在复杂环境下数据处理的及时性与准确性。 -openEuler 具备一系列卓越特性:在性能方面,其自主研发的智能调度算法能够动态适配不同负载场景,结合对硬件资源的深度优化利用,大幅提升系统运行效率;安全性上,通过内置的多层次安全防护体系,包括强制访问控制、漏洞扫描与修复机制,为系统数据与应用程序构筑起坚实的安全防线;可靠性层面,基于分布式存储、故障自动检测与快速恢复技术,确保系统在面对网络波动、硬件故障等突发状况时,依然能够稳定运行,最大限度降低服务中断风险。这些特性使 openEuler 成为推动数字经济高质量发展的重要技术基石,助力企业与开发者在数字化转型进程中抢占先机。 -检索的片段:openEuler 作为开源操作系统领域的先锋力量,深度融合了社区开发者的智慧结晶,不断迭代升级以适应快速变化的技术环境。在微服务架构盛行的当下,openEuler 通过对容器化技术的深度优化,支持 Kubernetes 等主流编排工具,让应用部署与管理变得更加便捷高效,极大提升了企业的业务部署灵活性。同时,它积极拥抱 AI 时代,通过对机器学习框架的适配与优化,为 AI 模型训练和推理提供强大的算力支撑,有效降低了 AI 应用的开发与运行成本。 -在生态建设方面,openEuler 拥有庞大且活跃的开源社区,汇聚了来自全球的技术爱好者与行业专家,形成了从内核开发、驱动适配到应用优化的完整生态链。社区定期举办技术交流与开发者大会,推动知识共享与技术创新,为开发者提供了丰富的学习资源与实践机会。众多硬件厂商和软件企业纷纷加入 openEuler 生态,推出基于该系统的解决方案和产品,涵盖金融、电信、能源等关键行业,以实际应用场景验证并反哺 openEuler 的技术发展,形成了良性循环的创新生态,让 openEuler 不仅是一个操作系统,更成为推动产业协同发展的强大引擎 。 - -召回率:95.00 -精确度:99.00 -忠实值:90.00 -可解释性:85.00 - -输出: -根据测试结果中的内容,当前知识库问答准确率提升的方法可以从以下几个方面进行分析: -召回率:目前召回率为 95.00,有提升空间。优化向量化检索算法,进一步挖掘原始片段中与问题相关但未被检索到的信息,如 openEuler 生态中一些具体实践案例等。调整 embedding 模型 bge-m3,使其能更全面准确地捕捉语义,扩大检索范围,提高召回率,使生成答案更接近标准答案。 -精确度:精确度达 99.00,已较高。但可进一步优化,对检索到的片段进行更深入的语义分析,结合大模型 qwen2.5-32b 的特点,精准匹配问题语义,避免细微语义偏差,例如更精确阐述 openEuler 在云计算和边缘计算中高性能等特性的具体表现。 -忠实值:忠实值为 90.00,说明部分答案内容未完全源于检索片段。优化 rag 检索算法,提高 embedding 模型召回率,调整文本分块大小为 512 可能存在不合理,需根据内容重新评估,确保检索片段包含足够上下文以支撑答案,使生成答案内容均来自检索片段,如关于 openEuler 生态建设中相关技术细节应从检索片段获取。 -可解释性:可解释性为 85.00,相对较低。提升大模型 qwen2.5-32b 的遵从度,优化 rag 检索算法和 embedding 模型 bge-m3 的召回率,使检索片段能更好支撑生成答案,保证答案能清晰回答问题,例如在回答 openEuler 相关问题时,使答案逻辑更清晰、针对性更强,提高整体可解释性。 - - -下面是测试结果中的内容: -使用的大模型:{model_name} -embedding模型:{embedding_model} -文本的分块大小:{chunk_size} -使用解析的rag算法:{rag_algorithm} -问题:{question} -标准答案:{standard_answer} -生成的答案:{generated_answer} -原始片段:{original_fragment} -检索的片段:{retrieved_fragment} -召回率:{recall} -精确度:{precision} -忠实值:{faithfulness} -可解释性:{relevance} -' -ACC_ANALYSIS_RESULT_MERGE_PROMPT: '你是一个文本分析专家,你的任务融合两条分析结果输出一份新的分析结果 -注意: -#01 请根据两条分析结果中的内容融合出一条新的分析结果 -#02 请结合召回率、精确度、忠实值和可解释性四个指标进行分析 -#03 新的分析结果长度不超过500字 -#04 请仅输出新的分析结果,不要输出其他内容 -例子: -输入1: -分析结果1: - -召回率:目前召回率为 95.00,有提升空间。优化向量化检索算法,进一步挖掘原始片段中与问题相关但未被检索到的信息,如 openEuler 生态中一些具体实践案例等。调整 embedding 模型 bge-m3,使其能更全面准确地捕捉语义,扩大检索范围,提高召回率,使生成答案更接近标准答案。 -精确度:精确度达 99.00,已较高。但可进一步优化,对检索到的片段进行更深入的语义分析,结合大模型 qwen2.5-32b 的特点,精准匹配问题语义,避免细微语义偏差,例如更精确阐述 openEuler 在云计算和边缘计算中高性能等特性的具体表现。 -忠实值:忠实值为 90.00,说明部分答案内容未完全源于检索片段。优化 rag 检索算法,提高 embedding 模型召回率,调整文本分块大小为 512 可能存在不合理,需根据内容重新评估,确保检索片段包含足够上下文以支撑答案,使生成答案内容均来自检索片段,如关于 openEuler 生态建设中相关技术细节应从检索片段获取。 -可解释性:可解释性为 85.00,相对较低。提升大模型 qwen2.5-32b 的遵从度,优化 rag 检索算法和 embedding 模型 bge-m3 的召回率,使检索片段能更好支撑生成答案,保证答案能清晰回答问题,例如在回答 openEuler 相关问题时,使答案逻辑更清晰、针对性更强,提高整体可解释性。 - -分析结果2: - -从召回率来看,目前为 95.00,可进一步优化 rag 检索算法和 embedding 模型,以提高生成答案与标准回答之间的语义相似程度,接近或达到更高的召回率,例如可以持续优化算法来更好地匹配相关片段。 -从精确度来看,为 99.00,接近满分,说明生成的答案与问题语义相似程度较高,但仍可进一步提升,可通过完善 embedding 模型来更好地理解问题语义,优化检索到的片段的上下文完整性,减少因上下文不足导致的精确度波动。 -对于忠实值,目前为 90.00,说明生成的答案中部分内容未完全来自检索到的片段。可优化 rag 检索算法,提高其召回率,同时合理调整文本分块大小,确保检索到的片段能充分回答问题,从而提高忠实值。 -关于可解释性,当前为 85.00,说明生成的答案在用于回答问题方面有一定提升空间。一方面可以优化使用的大模型,提高其遵从度,使其生成的答案更准确地回答问题;另一方面,继续优化 rag 检索算法和 embedding 模型的召回率,保证检索到的片段能全面支撑问题的回答,提高可解释性。 - -输出: -召回率:目前召回率为 95.00,有提升空间。优化向量化检索算法,进一步挖掘原始片段中与问题相关但未被检索到的信息,如 openEuler 生态中一些具体实践案例等。调整 embedding 模型 bge-m3,使其能更全面准确地捕捉语义,扩大检索范围,提高召回率,使生成答案更接近标准答案。 -精确度:精确度达 99.00,已较高。但可进一步优化,对检索到的片段进行更深入的语义分析,结合大模型 qwen2.5-32b 的特点,精准匹配问题语义,避免细微语义偏差,例如更精确阐述 openEuler 在云计算和边缘计算中高性能等特性的具体表现。 -忠实值:忠实值为 90.00,说明部分答案内容未完全源于检索片段。优化 rag 检索算法,提高 embedding 模型召回率,调整文本分块大小为 512 可能存在不合理,需根据内容重新评估,确保检索片段包含足够上下文以支撑答案,使生成答案内容均来自检索片段,如关于 openEuler 生态建设中相关技术细节应从检索片段获取。 -可解释性:可解释性为 85.00,相对较低。提升大模型 qwen2.5-32b 的遵从度,优化 rag 检索算法和 embedding 模型 bge-m3 的召回率,使检索片段能更好支撑生成答案,保证答案能清晰回答问题,例如在回答 openEuler 相关问题时,使答案逻辑更清晰、针对性更强,提高整体可解释性。 - -下面两条分析结果: -分析结果1:{analysis_result_1} -分析结果2:{analysis_result_2} -' -CHUNK_QUERY_MATCH_PROMPT: '你是一个文本分析专家,你的任务是根据给出的片段和问题判断,片段是否与问题相关 - 注意: - #01 如果片段与问题相关,请输出YES - #02 如果片段与问题不相关,请输出NO - #03 请仅输出YES或NO,不要输出其他内容 - 例子: - 输入1: - 片段:openEuler是一个开源的操作系统。 + 输出1:openEuler是一个开源的操作系统,旨在为云计算和边缘计算提供支持。openEuler具有高性能、高安全性和高可靠性等特点。 + + 下面是给出的文本和摘要: + 文本:{content} + 摘要:{abstract} + +CONTENT_TO_STATEMENTS_PROMPT: + en: | + You are a text parsing expert. Your task is to extract multiple English statements from a given text and return them as a list. + + Note: + #01 Statements must be derived from key points in the text. + #02 Statements must be arranged in relative order. + #03 Each statement must be at least 20 characters long and no more than 50 characters long. + #04 The total number of statements output must not exceed three. + #05 Please output only the list of statements, not any other content. Each statement must be in English. + Example: + + Input: openEuler is an open source operating system designed to support cloud computing and edge computing. It features high performance, high security, and high reliability. + Output: [ "openEuler is an open source operating system", "openEuler is designed to support cloud computing and edge computing", "openEuler features high performance, high security, and high reliability" ] + + The following is the given text: {content} + 中文: | + 你是一个文本分解专家,你的任务是根据我给出的文本,将文本提取为多个中文陈述,陈述使用列表形式返回 + + 注意: + #01 陈述必须来源于文本中的重点内容 + #02 陈述按相对顺序排列 + #03 输出的单个陈述长度不少于20个字,不超过50个字 + #04 输出的陈述总数不超过3个 + #05 请仅输出陈述列表,不要输出其他内容,且每一条陈述都是中文。 + 例子: + + 输入:openEuler是一个开源的操作系统,旨在为云计算和边缘计算提供支持。它具有高性能、高安全性和高可靠性等特点。 + 输出:[ "openEuler是一个开源的操作系统", "openEuler旨在为云计算和边缘计算提供支持", "openEuler具有高性能、高安全性和高可靠性等特点" ] + + 下面是给出的文本: {content} + +CONTENT_TO_TITLE_PROMPT: + en: >- + You are a title extraction expert. Your task is to generate an English title based on the given text. + Note: + #01 The title must be derived from the content of the text. + #02 The title must be no longer than 20 characters. + #03 Please output only the English title, and do not output any other content. + #04 If the given text is insufficient to generate a title, output "Unable to generate title." + Example: + Input: openEuler is an open source operating system designed to support cloud computing and edge computing. It features high performance, high security, and high reliability. + Output: Overview of the openEuler operating system. + Below is the given text: {content} + 中文: >- + 你是一个标题提取专家,你的任务是根据给出的文本生成一个中文标题。 + 注意: + #01 标题必须来源于文本中的内容 + #02 标题长度不超过20个字 + #03 请仅输出中文标题,不要输出其他内容 + #04 如果给出的文本不够生成标题,请输出“无法生成标题” + 例子: + 输入:openEuler是一个开源的操作系统,旨在为云计算和边缘计算提供支持。它具有高性能、高安全性和高可靠性等特点。 + 输出:openEuler操作系统概述 + 下面是给出的文本:{content} + +GENERATE_ANSWER_FROM_QUESTION_AND_CONTENT_PROMPT: + en: | + You are a text analysis expert. Your task is to generate an English answer based on a given question and text. + Note: + #01 The answer must be derived from the content in the text. + #02 The answer must be at least 50 words and no more than 500 words. + #03 Please only output the English answer; do not output any other content. + Example: + Input 1: + Question: What kind of operating system is openEuler? + Text: openEuler is an open source operating system designed to support cloud computing and edge computing. It features high performance, high security, and high reliability. + Output 1: openEuler is an open source operating system designed to support cloud computing and edge computing. + + Input 2: + Question: How secure is openEuler? + Text: openEuler is an open source operating system designed to support cloud computing and edge computing. It features high performance, high security, and high reliability. + Output 2: openEuler is highly secure. + + Below is the given question and text: + Question: {question} + Text: {content} + 中文: | + 你是一个文本分析专家,你的任务是根据给出的问题和文本生成中文答案。 + 注意: + #01 答案必须来源于文本中的内容; + #02 答案长度不少于50字且不超过500个字; + #03 请仅输出中文答案,不要输出其他内容。 + 例子: + 输入1: + 问题:openEuler是什么操作系统? + 文本:openEuler是一个开源的操作系统,旨在为云计算和边缘计算提供支持。它具有高性能、高安全性和高可靠性等特点。 + 输出1:openEuler是一个开源的操作系统,旨在为云计算和边缘计算提供支持。 + + 输入2: + 问题:openEuler的安全性如何? + 文本:openEuler是一个开源的操作系统,旨在为云计算和边缘计算提供支持。它具有高性能、高安全性和高可靠性等特点。 + 输出2:openEuler具有高安全性。 + + 下面是给出的问题和文本: + 问题:{question} + 文本:{content} + +GENERATE_QUESTION_FROM_CONTENT_PROMPT: + en: | + You are a text analysis expert. Your task is to generate {k} English questions based on the given text and return them as a list. + Note: + #01 Questions must be derived from the content of the text. + #02 A single question must not exceed 50 characters. + #03 Do not output duplicate questions. + #04 The output questions should be diverse, covering different aspects of the text. + #05 Please only output a list of English questions, not other content. + Example: + Input: openEuler is an open source operating system designed to support cloud computing and edge computing. It features high performance, high security, and high reliability. + Output: ["What is openEuler?","What fields does openEuler support?","What are the characteristics of openEuler?","How secure is openEuler?","How reliable is openEuler?"] + The following is the given text: {content} + 中文: | + 你是一个文本分析专家,你的任务是根据给出的文本生成{k}个中文问题并用列表返回。 + 注意: + #01 问题必须来源于文本中的内容; + #02 单个问题长度不超过50个字; + #03 不要输出重复的问题; + #04 输出的问题要多样,覆盖文本中的不同方面; + #05 请仅输出中文问题列表,不要输出其他内容。 + 例子: + 输入:openEuler是一个开源的操作系统,旨在为云计算和边缘计算提供支持。它具有高性能、高安全性和高可靠性等特点。 + 输出:["openEuler是什么操作系统?","openEuler旨在为哪个领域提供支持?","openEuler具有哪些特点?","openEuler的安全性如何?","openEuler的可靠性如何?"] + 下面是给出的文本:{content} + +OCR_ENHANCED_PROMPT: + en: | + You are an expert in image OCR content summarization. Your task is to describe the image based on the context I provide, descriptions of adjacent images, a summary of the previous OCR result for the current image, and the partial OCR results (including text and relative coordinates). + + Note: + #01 The image content must be described in detail, using at least 200 and no more than 500 words. Detailed data listing is acceptable. + #02 If this diagram is a flowchart, please describe the content in the order of the flowchart. + #03 If this diagram is a table, please output the table content in Markdown format. + #04 If this diagram is an architecture diagram, please describe the content according to the hierarchy of the architecture diagram. + #05 The summarized image description must include the key information in the image; it cannot simply describe the image's location. + #06 Adjacent text in the image recognition results may be part of the same paragraph. Please merge them before summarizing. + #07 The text may be misplaced. Please correct the order before summarizing. + #08 Please only output the image summary; do not output any other content. + #09 Do not output coordinates or other information; only output a description of the relative position of each part. + #10 If the image content is empty, output "Image content is empty." + #11 If the image itself is a paragraph of text, output the text content directly. + #12 Please use English for the output. + Context: {image_related_text} + Summary of the OCR content of the previous part of the current image: {pre_part_description} + Result of the OCR of the current part of the image: {part} + 中文: | + 你是一个图片OCR内容总结专家,你的任务是根据我提供的上下文、相邻图片组描述、当前图片上一次的OCR内容总结、当前图片部分OCR的结果(包含文字和文字的相对坐标)给出图片描述。 + + 注意: + #01 必须使用大于200字小于500字详细描述这个图片的内容,可以详细列出数据。 + #02 如果这个图是流程图,请按照流程图顺序描述内容。 + #03 如果这张图是表格,请用Markdown形式输出表格内容。 + #04 如果这张图是架构图,请按照架构图层次结构描述内容。 + #05 总结的图片描述必须包含图片中的主要信息,不能只描述图片位置。 + #06 图片识别结果中相邻的文字可能是同一段落的内容,请合并后总结。 + #07 文字可能存在错位,请修正顺序后进行总结。 + #08 请仅输出图片的总结即可,不要输出其他内容。 + #09 不要输出坐标等信息,输出每个部分相对位置的描述即可。 + #10 如果图片内容为空,请输出“图片内容为空”。 + #11 如果图片本身就是一段文字,请直接输出文字内容。 + #12 请使用中文输出。 + 上下文:{image_related_text} + 当前图片上一部分的OCR内容总结:{pre_part_description} + 当前图片部分OCR的结果:{part} + +QA_TO_STATEMENTS_PROMPT: + en: | + You are a text parsing expert. Your task is to extract the answers from the questions and answers I provide into multiple English statements, returning them as a list. + + Note: + #01 The statements must be derived from the key points of the answers. + #02 The statements must be arranged in relative order. + #03 The length of each statement output must not exceed 50 characters. + #04 The total number of statements output must not exceed 20. + #05 Please only output the list of English statements; do not output any other content. + + Example: + Input: Question: What is openEuler? Answer: openEuler is an open source operating system designed to support cloud computing and edge computing. It features high performance, high security, and high reliability. + Output: [ "openEuler is an open source operating system", "openEuler is designed to support cloud computing and edge computing", "openEuler features high performance, high security, and high reliability" ] + + Below are the given questions and answers: + Question: {question} + Answer: {answer} + 中文: | + 你是一个文本分解专家,你的任务是根据我给出的问题和答案,将答案提取为多个中文陈述,陈述使用列表形式返回。 + + 注意: + #01 陈述必须来源于答案中的重点内容 + #02 陈述按相对顺序排列 + #03 输出的单个陈述长度不超过50个字 + #04 输出的陈述总数不超过20个 + #05 请仅输出中文陈述列表,不要输出其他内容 + + 例子: + 输入:问题:openEuler是什么操作系统? 答案:openEuler是一个开源的操作系统,旨在为云计算和边缘计算提供支持。它具有高性能、高安全性和高可靠性等特点。 + 输出:[ "openEuler是一个开源的操作系统", "openEuler旨在为云计算和边缘计算提供支持", "openEuler具有高性能、高安全性和高可靠性等特点" ] + + 下面是给出的问题和答案: + 问题:{question} + 答案:{answer} + +QUERY_EXTEND_PROMPT: + en: | + You are a question expansion expert. Your task is to expand {k} questions based on the given question. + + Note: + #01 The content of the expanded question must be derived from the content of the original question. + #02 The expanded question length must not exceed 50 characters. + #03 Questions can be rewritten by replacing synonyms, swapping word order within the question, changing English capitalization, etc. + #04 Please only output the expanded question list, do not output other content. + + Example: + Input: What operating system is openEuler? + Output: [ "What kind of operating system is openEuler?", "What are the characteristics of the openEuler operating system?", "What are the functions of the openEuler operating system?", "What are the advantages of the openEuler operating system?" ] + + The following is the given question: {question} + 中文: | + 你是一个问题扩写专家,你的任务是根据给出的问题扩写{k}个问题。 + + 注意: + #01 扩写的问题的内容必须来源于原问题中的内容 + #02 扩写的问题长度不超过50个字 + #03 可以通过近义词替换、问题内词序交换、修改英文大小写等方式来改写问题 + #04 请仅输出扩写的问题列表,不要输出其他内容 + + 例子: + 输入:openEuler是什么操作系统? + 输出:[ "openEuler是一个什么样的操作系统?", "openEuler操作系统的特点是什么?", "openEuler操作系统有哪些功能?", "openEuler操作系统的优势是什么?" ] + + 下面是给出的问题:{question} + +STATEMENTS_TO_FRAGMENT_PROMPT: + en: | + You are a text expert. Your task is to determine whether a given statement is strongly related to the fragment. + + Note: + #01 If the statement is strongly related to the fragment or is derived from the fragment, output YES. + #02 If the content in the statement is unrelated to the fragment, output NO. + #03 If the statement is a refinement of a portion of the fragment, output YES. + #05 Only output YES or NO, and do not output anything else. + + Example: + Input 1: + Statement: openEuler is an open source operating system. + Fragment: openEuler is an open source operating system designed to support cloud computing and edge computing. It features high performance, high security, and high reliability. + Output 1: YES + + Input 2: + Statement: A white horse is not a horse. + Fragment: openEuler is an open source operating system designed to support cloud computing and edge computing. It features high performance, high security, and high reliability. + Output 2: NO + + Below is a given statement and fragment: + Statement: {statement} + Fragment: {fragment} + 中文: | + 你是一个文本专家,你的任务是判断给出的陈述是否与片段强相关。 + + 注意: + #01 如果陈述与片段强相关或者来自于片段,请输出YES + #02 如果陈述中的内容与片段无关,请输出NO + #03 如果陈述是片段中某部分的提炼,请输出YES + #05 请仅输出YES或NO,不要输出其他内容 + + 例子: + 输入1: + 陈述:openEuler是一个开源的操作系统。 + 片段:openEuler是一个开源的操作系统,旨在为云计算和边缘计算提供支持。它具有高性能、高安全性和高可靠性等特点。 + 输出1:YES + + 输入2: + 陈述:白马非马 + 片段:openEuler是一个开源的操作系统,旨在为云计算和边缘计算提供支持。它具有高性能、高安全性和高可靠性等特点。 + 输出2:NO + + 下面是给出的陈述和片段: + 陈述:{statement} + 片段:{fragment} + +STATEMENTS_TO_QUESTION_PROMPT: + en: | + You are a text analysis expert. Your task is to determine whether a given statement is relevant to a question. + + Note: + #01 If the statement is relevant to the question, output YES. + #02 If the statement is not relevant to the question, output NO. + #03 Only output YES or NO, and do not output anything else. + #04 A statement's relevance to the question means that the content in the statement can answer the question or overlaps with the question in terms of content. + + Example: + Input 1: + Statement: openEuler is an open source operating system. + Question: What kind of operating system is openEuler? + Output 1: YES + + Input 2: + Statement: A white horse is not a horse. + Question: What kind of operating system is openEuler? + Output 2: NO + + Below is the given statement and question: + Statement: {statement} + Question: {question} + 中文: | + 你是一个文本分析专家,你的任务是判断给出的陈述是否与问题相关。 + + 注意: + #01 如果陈述与问题相关,请输出YES + #02 如果陈述与问题不相关,请输出NO + #03 请仅输出YES或NO,不要输出其他内容 + #04 陈述与问题相关是指,陈述中的内容可以回答问题或者与问题在内容上有交集 + + 例子: + 输入1: + 陈述:openEuler是一个开源的操作系统。 问题:openEuler是什么操作系统? - 输出1:YES + 输出1:YES - 输入2: - 片段:白马非马 + 输入2: + 陈述:白马非马 问题:openEuler是什么操作系统? - 输出2:NO - - 下面是给出的片段和问题: - 片段:{chunk} - 问题:{question} - ' -QUERY_EXTEND_PROMPT: '你是一个问题扩写专家,你的任务是根据给出的问题扩写{k}个问题 - 注意: - #01 扩写的问题的内容必须来源于原问题中的内容 - #02 扩写的问题长度不超过50个字 - #03 可以通过近义词替换 问题内词序交换 修改英文大小写等方式来改写问题 - #04 请仅输出扩写的问题列表,不要输出其他内容 - 例子: - 输入:openEuler是什么操作系统? - 输出: - [ - \"openEuler是一个什么样的操作系统?\", - \"openEuler操作系统的特点是什么?\", - \"openEuler操作系统有哪些功能?\", - \"openEuler操作系统的优势是什么?\" - ] - 下面是给出的问题: - {question} - ' \ No newline at end of file + 输出2:NO + + 下面是给出的陈述和问题: + 陈述:{statement} + 问题:{question} +LLM_PROMPT_TEMPLATE: + en: | + You are an intelligent assistant. Your task is to answer the user's question based on the background information. + Note: + #01 If the answer can be derived from the valid content in the background information, please provide the answer. + #02 Please answer the user's question in English. + + Example: + Background information: openEuler is an open source operating system designed to support cloud computing and edge computing. It features high performance, high security, and high reliability. + User's question: What kind of operating system is openEuler? + Output: openEuler is an open source operating system designed to support cloud computing and edge computing. + + Background information: + {bac_info} + 中文: | + 你是一个智能助手你的任务是根据背景信息,回答用户的问题。 + 注意: + #01 如果结合背景信息中的有效内容回答用户的问题,请给出答案。 + #02 请使用中文回答用户的问题 + + 例子: + 背景信息:openEuler是一个开源的操作系统,旨在为云计算和边缘计算提供支持。它具有高性能、高安全性和高可靠性等特点。 + 用户的问题:openEuler是什么操作系统? + 输出:openEuler是一个开源的操作系统,旨在为云计算和边缘计算提供支持。 + + 背景信息: + {bac_info} \ No newline at end of file diff --git a/data_chain/config/config.py b/data_chain/config/config.py index 4d9678d64259384ccbabace6d8d700fabf555481..1caf9db39888fcc37ebcd4bed9fac77614b9a8f9 100644 --- a/data_chain/config/config.py +++ b/data_chain/config/config.py @@ -4,6 +4,7 @@ import uuid from dotenv import dotenv_values from pydantic import BaseModel, Field from typing import List +from data_chain.entities.enum import EmbeddingType, RerankType class DictBaseModel(BaseModel): @@ -52,10 +53,17 @@ class ConfigModel(DictBaseModel): MAX_TOKENS: int = Field(None, description="最大token数") TEMPERATURE: float = Field(default=0.7, description="温度系数") # Embedding - EMBEDDING_TYPE: str = Field(default="openai", description="embedding 服务的类型") + EMBEDDING_TYPE: EmbeddingType = Field( + default=EmbeddingType.OPENAI, description="embedding 服务的类型") EMBEDDING_API_KEY: str = Field(None, description="embedding服务api key") EMBEDDING_ENDPOINT: str = Field(None, description="embedding服务url地址") EMBEDDING_MODEL_NAME: str = Field(None, description="embedding模型名称") + # Rerank + RERANK_TYPE: RerankType = Field( + default=RerankType.BAILIAN, description="rerank 服务的类型") + RERANK_API_KEY: str = Field(None, description="rerank服务api key") + RERANK_ENDPOINT: str = Field(None, description="rerank服务url地址") + RERANK_MODEL_NAME: str = Field(None, description="rerank模型名称") # Token SESSION_TTL: int = Field(None, description="用户session过期时间") CSRF_KEY: str = Field(None, description="csrf的密钥") @@ -64,13 +72,22 @@ class ConfigModel(DictBaseModel): HALF_KEY2: str = Field(None, description="两层密钥管理组件2") HALF_KEY3: str = Field(None, description="两层密钥管理组件3") # Prompt file - PROMPT_PATH: str = Field(default="./data_chain/common/prompt.yaml", description="prompt路径") + PROMPT_PATH: str = Field( + default="./data_chain/common/prompt.yaml", description="prompt路径") # Stop Words PATH - STOP_WORDS_PATH: str = Field(default="./data_chain/common/stopwords.txt", description="停用词表存放位置") + STOP_WORDS_PATH: str = Field( + default="./data_chain/common/stopwords.txt", description="停用词表存放位置") # CPU Limit USE_CPU_LIMIT: int = Field(default=64, description="文档解析器使用CPU核数") # Task Retry Time limit TASK_RETRY_TIME_LIMIT: int = Field(default=3, description="任务重试次数限制") + # Ocr Method + OCR_METHOD: str = Field( + default="offline", description="ocr识别方式,online or offline") + OCR_API_URL: str = Field( + default="", description="ocr在线识别接口地址", pattern=r'^https?://.+') + # default user sub + DEBUG_USER: str = Field(default="openEuler", description="默认用户标识") class Config: diff --git a/data_chain/embedding/embedding.py b/data_chain/embedding/embedding.py index a0c0e2b27ec91b84872ab9cc345c6e1e8aa1a4ed..6166bce03266df61feeaeb8e2d6aae50d72557c3 100644 --- a/data_chain/embedding/embedding.py +++ b/data_chain/embedding/embedding.py @@ -4,7 +4,7 @@ import json import urllib3 from data_chain.config.config import config from data_chain.logger.logger import logger as logging - +from data_chain.entities.enum import EmbeddingType urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) @@ -12,7 +12,7 @@ class Embedding(): @staticmethod async def vectorize_embedding(text): vector = None - if config['EMBEDDING_TYPE'] == 'openai': + if config['EMBEDDING_TYPE'] == EmbeddingType.OPENAI: headers = { "Authorization": f"Bearer {config['EMBEDDING_API_KEY']}" } @@ -22,7 +22,8 @@ class Embedding(): "encoding_format": "float" } try: - res = requests.post(url=config["EMBEDDING_ENDPOINT"], headers=headers, json=data, verify=False) + res = requests.post( + url=config["EMBEDDING_ENDPOINT"], headers=headers, json=data, verify=False) if res.status_code != 200: return None vector = res.json()['data'][0]['embedding'] @@ -30,12 +31,13 @@ class Embedding(): err = f"[Embedding] 向量化失败 ,error: {e}" logging.exception(err) return None - elif config['EMBEDDING_TYPE'] == 'mindie': + elif config['EMBEDDING_TYPE'] == EmbeddingType.MINDIE: try: data = { "inputs": text, } - res = requests.post(url=config["EMBEDDING_ENDPOINT"], json=data, verify=False) + res = requests.post( + url=config["EMBEDDING_ENDPOINT"], json=data, verify=False) if res.status_code != 200: return None vector = json.loads(res.text)[0] diff --git a/data_chain/entities/common.py b/data_chain/entities/common.py index 751184017ae08e4c6cdb3049538c155ecf106912..a9f9097370d30a56a87a7234e62024e9b1424174 100644 --- a/data_chain/entities/common.py +++ b/data_chain/entities/common.py @@ -1,379 +1,371 @@ import uuid +from data_chain.entities.enum import DeafaultRole +from data_chain.entities.enum import LanguageType DEFAULT_DOC_TYPE_ID = uuid.UUID("00000000-0000-0000-0000-000000000000") DEFAULT_KNOWLEDGE_BASE_ID = uuid.UUID("00000000-0000-0000-0000-000000000000") DEFAULt_DOC_TYPE_NAME = "default" actions = [ - {'type': 'team', - 'name': '获取团队用户列表', 'action': 'POST /team/usr'}, - {'type': 'team', - 'name': '获取团队消息列表', 'action': 'POST /team/msg'}, - {'type': 'team', - 'name': '发送团队邀请', 'action': 'POST /team/invitation'}, - {'type': 'team', - 'name': '处理用户申请', 'action': 'PUT /usr_msg'}, - {'type': 'team', - 'name': '更新团队信息', 'action': 'PUT /team'}, - {'type': 'team', - 'name': '更新团队用户角色', 'action': 'PUT /team/usr'}, - {'type': 'team', - 'name': '移交团队', 'action': 'PUT /team/author'}, - {'type': 'team', - 'name': '解散团队', 'action': 'DELETE /team'}, - {'type': 'team', - 'name': '剔除团队用户', 'action': 'DELETE /team/usr'}, - {'type': 'knowledge_base', - 'name': '获取团队下知识库列表', 'action': 'POST /kb/team'}, - {'type': 'knowledge_base', - 'name': '获取知识库文档类型', 'action': 'GET /kb/doc_type'}, - {'type': 'knowledge_base', - 'name': '下载知识库文件', 'action': 'GET /kb/download'}, - {'type': 'knowledge_base', - 'name': '创建知识库', 'action': 'POST /kb'}, - {'type': 'knowledge_base', - 'name': '导入知识库', 'action': 'POST /kb/import'}, - {'type': 'knowledge_base', - 'name': '导出知识库', 'action': 'POST /kb/export'}, - {'type': 'knowledge_base', - 'name': '更新知识库信息', 'action': 'PUT /kb'}, - {'type': 'knowledge_base', - 'name': '删除知识库', 'action': 'DELETE /kb'}, - {'type': 'chunk', - 'name': '获取文档解析结果列表', 'action': 'POST /chunk/list'}, - {'type': 'chunk', - 'name': '检索文档解析结果', 'action': 'POST /chunk/search'}, - {'type': 'chunk', - 'name': '更新文档解析结果', 'action': 'PUT /chunk'}, - {'type': 'document', - 'name': '获取文档列表', 'action': 'POST /doc/list'}, - {'type': 'document', - 'name': '下载文档', 'action': 'GET /doc/download'}, - {'type': 'document', - 'name': '获取文档报告', 'action': 'GET /doc/report'}, - {'type': 'document', - 'name': '下载文档报告', 'action': 'GET /doc/report/download'}, - {'type': 'document', - 'name': '创建文档', 'action': 'POST /doc'}, - {'type': 'document', - 'name': '解析文档', 'action': 'POST /doc/parse'}, - {'type': 'document', - 'name': '更新文档信息', 'action': 'PUT /doc'}, - {'type': 'document', - 'name': '删除文档', 'action': 'DELETE /doc'}, - {'type': 'dataset_data', - 'name': '获取数据集列表', 'action': 'POST /dataset/list'}, - {'type': 'dataset_data', - 'name': '获取测试数据列表', 'action': 'POST /dataset/data'}, - {'type': 'dataset_data', - 'name': '获取测试数据下是否有测试任务', 'action': 'GET /dataset/testing/exist'}, - {'type': 'dataset_data', - 'name': '下载数据集', 'action': 'GET /dataset/download'}, - {'type': 'dataset_data', - 'name': '新建数据集', 'action': 'POST /dataset'}, - {'type': 'dataset_data', - 'name': '导入数据集', 'action': 'POST /dataset/import'}, - {'type': 'dataset_data', - 'name': '导出数据集', 'action': 'POST /dataset/export'}, - {'type': 'dataset_data', - 'name': '生成数据集', 'action': 'POST /dataset/generate'}, - {'type': 'dataset_data', - 'name': '修改数据集信息', 'action': 'PUT /dataset'}, - {'type': 'dataset_data', - 'name': '修改测试样例', 'action': 'PUT /dataset/data'}, - {'type': 'dataset_data', - 'name': '删除数据集', 'action': 'DELETE /dataset'}, - {'type': 'dataset_data', - 'name': '删除测试样例', 'action': 'DELETE /dataset/data'}, - {'type': 'testing', - 'name': '获取测试列表', 'action': 'POST /testing/list'}, - {'type': 'testing', - 'name': '获取测试用例列表', 'action': 'POST /testing/testcase'}, - {'type': 'testing', - 'name': '下载测试结果', 'action': 'GET /testing/download'}, - {'type': 'testing', - 'name': '创建测试', 'action': 'POST /testing'}, - {'type': 'testing', - 'name': '运行测试', 'action': 'POST /testing/run'}, - {'type': 'testing', - 'name': '更新测试信息', 'action': 'PUT /testing'}, - {'type': 'testing', - 'name': '删除测试', 'action': 'DELETE /testing'}, - {'type': 'role', - 'name': '获取角色操作列表', 'action': 'GET /role/action'}, - {'type': 'role', - 'name': '获取角色列表', 'action': 'POST /role/list'}, - {'type': 'role', - 'name': '创建角色', 'action': 'POST /role'}, - {'type': 'role', - 'name': '更新角色信息', 'action': 'PUT /role'}, - {'type': 'role', - 'name': '删除角色', 'action': 'DELETE /role'}, - {'type': 'task', - 'name': '获取任务列表', 'action': 'POST /task'}, - {'type': 'task', - 'name': '获取任务报告', 'action': 'GET /task/report'}, - {'type': 'task', - 'name': '删除单个任务', 'action': 'DELETE /task/one'}, - {'type': 'task', - 'name': '删除单个任务', 'action': 'DELETE /task/all'}] + {'type': 'team', 'name': {LanguageType.CHINESE: '获取团队用户列表', + LanguageType.ENGLISH: 'Get team user list'}, 'action': 'POST /team/usr'}, + {'type': 'team', 'name': {LanguageType.CHINESE: '获取团队消息列表', + LanguageType.ENGLISH: 'Get team message list'}, 'action': 'POST /team/msg'}, + {'type': 'team', 'name': {LanguageType.CHINESE: '发送团队邀请', + LanguageType.ENGLISH: 'Send team invitation'}, 'action': 'POST /team/invitation'}, + {"type": "team", "name": {LanguageType.CHINESE: '获取用户消息列表', + LanguageType.ENGLISH: 'Get user message list'}, "action": "POST /usr_msg/list"}, + {'type': 'team', 'name': {LanguageType.CHINESE: '更新用户消息', + LanguageType.ENGLISH: 'Update User message'}, 'action': 'PUT /usr_msg'}, + {'type': 'team', 'name': {LanguageType.CHINESE: '更新团队信息', + LanguageType.ENGLISH: 'Update team information'}, 'action': 'PUT /team'}, + {'type': 'team', 'name': {LanguageType.CHINESE: '更新团队用户角色', + LanguageType.ENGLISH: 'Update team user role'}, 'action': 'PUT /team/usr'}, + {'type': 'team', 'name': {LanguageType.CHINESE: '转让团队所有权', + LanguageType.ENGLISH: 'Transfer team ownership'}, 'action': 'PUT /team/author'}, + {'type': 'team', 'name': {LanguageType.CHINESE: '解散团队', + LanguageType.ENGLISH: 'Disband team'}, 'action': 'DELETE /team'}, + {'type': 'team', 'name': {LanguageType.CHINESE: '移除团队用户', + LanguageType.ENGLISH: 'Remove team user'}, 'action': 'DELETE /team/usr'}, + {'type': 'role', 'name': {LanguageType.CHINESE: '获取角色操作列表', + LanguageType.ENGLISH: 'Get role operation list'}, 'action': 'GET /role/action'}, + {'type': 'role', 'name': {LanguageType.CHINESE: '获取角色列表', + LanguageType.ENGLISH: 'Get role list'}, 'action': 'POST /role/list'}, + {'type': 'role', 'name': {LanguageType.CHINESE: '创建角色', + LanguageType.ENGLISH: 'Create role'}, 'action': 'POST /role'}, + {'type': 'role', 'name': {LanguageType.CHINESE: '更新角色信息', + LanguageType.ENGLISH: 'Update role information'}, 'action': 'PUT /role'}, + {'type': 'role', 'name': {LanguageType.CHINESE: '删除角色', + LanguageType.ENGLISH: 'Delete role'}, 'action': 'DELETE /role'}, + {'type': 'knowledge_base', 'name': {LanguageType.CHINESE: '获取团队下的知识库列表', + LanguageType.ENGLISH: 'Get knowledge base list under team'}, 'action': 'POST /kb/team'}, + {'type': 'knowledge_base', 'name': {LanguageType.CHINESE: '获取知识库文档类型', + LanguageType.ENGLISH: 'Get knowledge base document types'}, 'action': 'GET /kb/doc_type'}, + {'type': 'knowledge_base', 'name': {LanguageType.CHINESE: '下载知识库文件', + LanguageType.ENGLISH: 'Download knowledge base file'}, 'action': 'GET /kb/download'}, + {'type': 'knowledge_base', 'name': {LanguageType.CHINESE: '创建知识库', + LanguageType.ENGLISH: 'Create knowledge base'}, 'action': 'POST /kb'}, + {'type': 'knowledge_base', 'name': {LanguageType.CHINESE: '导入知识库', + LanguageType.ENGLISH: 'Import knowledge base'}, 'action': 'POST /kb/import'}, + {'type': 'knowledge_base', 'name': {LanguageType.CHINESE: '导出知识库', + LanguageType.ENGLISH: 'Export knowledge base'}, 'action': 'POST /kb/export'}, + {'type': 'knowledge_base', 'name': {LanguageType.CHINESE: '更新知识库信息', + LanguageType.ENGLISH: 'Update knowledge base information'}, 'action': 'PUT /kb'}, + {'type': 'knowledge_base', 'name': {LanguageType.CHINESE: '删除知识库', + LanguageType.ENGLISH: 'Delete knowledge base'}, 'action': 'DELETE /kb'}, + {'type': 'document', 'name': {LanguageType.CHINESE: '获取文档列表', + LanguageType.ENGLISH: 'Get document list'}, 'action': 'POST /doc/list'}, + {'type': 'document', 'name': {LanguageType.CHINESE: '下载文档', + LanguageType.ENGLISH: 'Download document'}, 'action': 'GET /doc/download'}, + {'type': 'document', 'name': {LanguageType.CHINESE: '获取文档报告', + LanguageType.ENGLISH: 'Get document report'}, 'action': 'GET /doc/report'}, + {'type': 'document', 'name': {LanguageType.CHINESE: '下载文档报告', + LanguageType.ENGLISH: 'Download document report'}, 'action': 'GET /doc/report/download'}, + {'type': 'document', 'name': {LanguageType.CHINESE: '创建文档', + LanguageType.ENGLISH: 'Create document'}, 'action': 'POST /doc'}, + {'type': 'document', 'name': {LanguageType.CHINESE: '解析文档', + LanguageType.ENGLISH: 'Parse document'}, 'action': 'POST /doc/parse'}, + {'type': 'document', 'name': {LanguageType.CHINESE: '更新文档信息', + LanguageType.ENGLISH: 'Update document information'}, 'action': 'PUT /doc'}, + {'type': 'document', 'name': {LanguageType.CHINESE: '删除文档', + LanguageType.ENGLISH: 'Delete document'}, 'action': 'DELETE /doc'}, + {'type': 'chunk', 'name': {LanguageType.CHINESE: '获取文档解析结果列表', + LanguageType.ENGLISH: 'Get document parsing result list'}, 'action': 'POST /chunk/list'}, + {'type': 'chunk', 'name': {LanguageType.CHINESE: '检索文档解析结果', + LanguageType.ENGLISH: 'Retrieve document parsing results'}, 'action': 'POST /chunk/search'}, + {'type': 'chunk', 'name': {LanguageType.CHINESE: '更新文档解析结果', + LanguageType.ENGLISH: 'Update document parsing results'}, 'action': 'PUT /chunk'}, + {'type': 'dataset', 'name': {LanguageType.CHINESE: '获取数据集列表', + LanguageType.ENGLISH: 'Get dataset list'}, 'action': 'POST /dataset/list'}, + {'type': 'dataset', 'name': {LanguageType.CHINESE: '获取测试数据列表', + LanguageType.ENGLISH: 'Get test data list'}, 'action': 'POST /dataset/data'}, + {'type': 'dataset', 'name': {LanguageType.CHINESE: '检查测试数据下是否有测试任务', LanguageType.ENGLISH: + 'Check if there are test tasks under test data'}, 'action': 'GET /dataset/testing/exist'}, + {'type': 'dataset', 'name': {LanguageType.CHINESE: '下载数据集', + LanguageType.ENGLISH: 'Download dataset'}, 'action': 'GET /dataset/download'}, + {'type': 'dataset', 'name': {LanguageType.CHINESE: '创建新数据集', + LanguageType.ENGLISH: 'Create new dataset'}, 'action': 'POST /dataset'}, + {'type': 'dataset', 'name': {LanguageType.CHINESE: '导入数据集', + LanguageType.ENGLISH: 'Import dataset'}, 'action': 'POST /dataset/import'}, + {'type': 'dataset', 'name': {LanguageType.CHINESE: '导出数据集', + LanguageType.ENGLISH: 'Export dataset'}, 'action': 'POST /dataset/export'}, + {'type': 'dataset', 'name': {LanguageType.CHINESE: '生成数据集', + LanguageType.ENGLISH: 'Generate dataset'}, 'action': 'POST /dataset/generate'}, + {'type': 'dataset', 'name': {LanguageType.CHINESE: '修改数据集信息', + LanguageType.ENGLISH: 'Modify dataset information'}, 'action': 'PUT /dataset'}, + {'type': 'dataset', 'name': {LanguageType.CHINESE: '修改测试用例', + LanguageType.ENGLISH: 'Modify test case'}, 'action': 'PUT /dataset/data'}, + {'type': 'dataset', 'name': {LanguageType.CHINESE: '删除数据集', + LanguageType.ENGLISH: 'Delete dataset'}, 'action': 'DELETE /dataset'}, + {'type': 'dataset', 'name': {LanguageType.CHINESE: '删除测试用例', + LanguageType.ENGLISH: 'Delete test case'}, 'action': 'DELETE /dataset/data'}, + {'type': 'testing', 'name': {LanguageType.CHINESE: '获取测试列表', + LanguageType.ENGLISH: 'Get test list'}, 'action': 'POST /testing/list'}, + {'type': 'testing', 'name': {LanguageType.CHINESE: '获取测试用例列表', + LanguageType.ENGLISH: 'Get test case list'}, 'action': 'POST /testing/testcase'}, + {'type': 'testing', 'name': {LanguageType.CHINESE: '下载测试结果', + LanguageType.ENGLISH: 'Download test results'}, 'action': 'GET /testing/download'}, + {'type': 'testing', 'name': {LanguageType.CHINESE: '创建测试', + LanguageType.ENGLISH: 'Create test'}, 'action': 'POST /testing'}, + {'type': 'testing', 'name': {LanguageType.CHINESE: '运行测试', + LanguageType.ENGLISH: 'Run test'}, 'action': 'POST /testing/run'}, + {'type': 'testing', 'name': {LanguageType.CHINESE: '更新测试信息', + LanguageType.ENGLISH: 'Update test information'}, 'action': 'PUT /testing'}, + {'type': 'testing', 'name': {LanguageType.CHINESE: '删除测试', + LanguageType.ENGLISH: 'Delete test'}, 'action': 'DELETE /testing'}, + {'type': 'task', 'name': {LanguageType.CHINESE: '获取任务列表', + LanguageType.ENGLISH: 'Get task list'}, 'action': 'POST /task'}, + {'type': 'task', 'name': {LanguageType.CHINESE: '获取任务报告', + LanguageType.ENGLISH: 'Get task report'}, 'action': 'GET /task/report'}, + {'type': 'task', 'name': {LanguageType.CHINESE: '删除单个任务', + LanguageType.ENGLISH: 'Delete single task'}, 'action': 'DELETE /task/one'}, + {'type': 'task', 'name': {LanguageType.CHINESE: '删除所有任务', + LanguageType.ENGLISH: 'Delete all tasks'}, 'action': 'DELETE /task/all'} +] default_roles = [ { "id": uuid.UUID("00000000-0000-0000-0000-000000000001"), - "name": "创建者", + "name": DeafaultRole.CREATOR.value, # 角色名称:创建者 → Creator(更符合行业通用表述) "is_unique": True, "actions": [ - {'type': 'team', - 'name': '获取团队用户列表', 'action': 'POST /team/usr'}, - {'type': 'team', - 'name': '获取团队消息列表', 'action': 'POST /team/msg'}, - {'type': 'team', - 'name': '发送团队邀请', 'action': 'POST /team/invitation'}, - {'type': 'team', - 'name': '处理用户申请', 'action': 'PUT /usr_msg'}, - {'type': 'team', - 'name': '更新团队信息', 'action': 'PUT /team'}, - {'type': 'team', - 'name': '更新团队用户角色', 'action': 'PUT /team/usr'}, - {'type': 'team', - 'name': '移交团队', 'action': 'PUT /team/author'}, - {'type': 'team', - 'name': '解散团队', 'action': 'DELETE /team'}, - {'type': 'team', - 'name': '剔除团队用户', 'action': 'DELETE /team/usr'}, - {'type': 'knowledge_base', - 'name': '获取团队下知识库列表', 'action': 'POST /kb/team'}, - {'type': 'knowledge_base', - 'name': '获取知识库文档类型', 'action': 'GET /kb/doc_type'}, - {'type': 'knowledge_base', - 'name': '下载知识库文件', 'action': 'GET /kb/download'}, - {'type': 'knowledge_base', - 'name': '创建知识库', 'action': 'POST /kb'}, + {'type': 'team', 'name': 'Get Team User List', + 'action': 'POST /team/usr'}, + {'type': 'team', 'name': 'Get Team Message List', + 'action': 'POST /team/msg'}, + {'type': 'team', 'name': 'Send Team Invitation', + 'action': 'POST /team/invitation'}, + {"type": "team", "name": "Get user message list", + "action": "POST /usr_msg/list"}, + {'type': 'team', 'name': 'Update User message', + 'action': 'PUT /usr_msg'}, + {'type': 'team', 'name': 'Update Team Information', 'action': 'PUT /team'}, + {'type': 'team', 'name': 'Update Team User Role', + 'action': 'PUT /team/usr'}, + {'type': 'team', 'name': 'Transfer Team Ownership', + 'action': 'PUT /team/author'}, # 移交团队 → 移交团队所有权(更符合权限场景) + {'type': 'team', 'name': 'Disband Team', 'action': 'DELETE /team'}, + {'type': 'team', 'name': 'Remove Team User', + 'action': 'DELETE /team/usr'}, + {'type': 'knowledge_base', 'name': 'Get Knowledge Base List Under Team', + 'action': 'POST /kb/team'}, + {'type': 'knowledge_base', 'name': 'Get Knowledge Base Document Types', + 'action': 'GET /kb/doc_type'}, + {'type': 'knowledge_base', 'name': 'Download Knowledge Base File', + 'action': 'GET /kb/download'}, {'type': 'knowledge_base', - 'name': '导入知识库', 'action': 'POST /kb/import'}, + 'name': 'Create Knowledge Base', 'action': 'POST /kb'}, + {'type': 'knowledge_base', 'name': 'Import Knowledge Base', + 'action': 'POST /kb/import'}, + {'type': 'knowledge_base', 'name': 'Export Knowledge Base', + 'action': 'POST /kb/export'}, {'type': 'knowledge_base', - 'name': '导出知识库', 'action': 'POST /kb/export'}, - {'type': 'knowledge_base', - 'name': '更新知识库信息', 'action': 'PUT /kb'}, - {'type': 'knowledge_base', - 'name': '删除知识库', 'action': 'DELETE /kb'}, - {'type': 'chunk', - 'name': '获取文档解析结果列表', 'action': 'POST /chunk/list'}, - {'type': 'chunk', - 'name': '检索文档解析结果', 'action': 'POST /chunk/search'}, - {'type': 'chunk', - 'name': '更新文档解析结果', 'action': 'PUT /chunk'}, - {'type': 'document', - 'name': '获取文档列表', 'action': 'POST /doc/list'}, - {'type': 'document', - 'name': '下载文档', 'action': 'GET /doc/download'}, - {'type': 'document', - 'name': '获取文档报告', 'action': 'GET /doc/report'}, - {'type': 'document', - 'name': '下载文档报告', 'action': 'GET /doc/report/download'}, - {'type': 'document', - 'name': '创建文档', 'action': 'POST /doc'}, - {'type': 'document', - 'name': '解析文档', 'action': 'POST /doc/parse'}, - {'type': 'document', - 'name': '更新文档信息', 'action': 'PUT /doc'}, - {'type': 'document', - 'name': '删除文档', 'action': 'DELETE /doc'}, - {'type': 'dataset_data', - 'name': '获取数据集列表', 'action': 'POST /dataset/list'}, - {'type': 'dataset_data', - 'name': '获取测试数据列表', 'action': 'POST /dataset/data'}, - {'type': 'dataset_data', - 'name': '获取测试数据下是否有测试任务', 'action': 'GET /dataset/testing/exist'}, - {'type': 'dataset_data', - 'name': '下载数据集', 'action': 'GET /dataset/download'}, - {'type': 'dataset_data', - 'name': '新建数据集', 'action': 'POST /dataset'}, - {'type': 'dataset_data', - 'name': '导入数据集', 'action': 'POST /dataset/import'}, - {'type': 'dataset_data', - 'name': '导出数据集', 'action': 'POST /dataset/export'}, - {'type': 'dataset_data', - 'name': '生成数据集', 'action': 'POST /dataset/generate'}, - {'type': 'dataset_data', - 'name': '修改数据集信息', 'action': 'PUT /dataset'}, - {'type': 'dataset_data', - 'name': '修改测试样例', 'action': 'PUT /dataset/data'}, - {'type': 'dataset_data', - 'name': '删除数据集', 'action': 'DELETE /dataset'}, - {'type': 'dataset_data', - 'name': '删除测试样例', 'action': 'DELETE /dataset/data'}, - {'type': 'testing', - 'name': '获取测试列表', 'action': 'POST /testing/list'}, - {'type': 'testing', - 'name': '获取测试用例列表', 'action': 'POST /testing/testcase'}, - {'type': 'testing', - 'name': '下载测试结果', 'action': 'GET /testing/download'}, - {'type': 'testing', - 'name': '创建测试', 'action': 'POST /testing'}, - {'type': 'testing', - 'name': '运行测试', 'action': 'POST /testing/run'}, - {'type': 'testing', - 'name': '更新测试信息', 'action': 'PUT /testing'}, - {'type': 'testing', - 'name': '删除测试', 'action': 'DELETE /testing'}, - {'type': 'role', - 'name': '获取角色操作列表', 'action': 'GET /role/action'}, - {'type': 'role', - 'name': '获取角色列表', 'action': 'POST /role/list'}, - {'type': 'role', - 'name': '创建角色', 'action': 'POST /role'}, - {'type': 'role', - 'name': '更新角色信息', 'action': 'PUT /role'}, - {'type': 'role', - 'name': '删除角色', 'action': 'DELETE /role'}, - {'type': 'task', - 'name': '获取任务列表', 'action': 'POST /task'}, - {'type': 'task', - 'name': '获取任务报告', 'action': 'GET /task/report'}, - {'type': 'task', - 'name': '删除单个任务', 'action': 'DELETE /task/one'}, - {'type': 'task', - 'name': '删除单个任务', 'action': 'DELETE /task/all'}], - "editable": False, - }, { + 'name': 'Update Knowledge Base Information', 'action': 'PUT /kb'}, + {'type': 'knowledge_base', 'name': 'Delete Knowledge Base', + 'action': 'DELETE /kb'}, + {'type': 'chunk', 'name': 'Get Document Parsing Result List', + 'action': 'POST /chunk/list'}, + {'type': 'chunk', 'name': 'Retrieve Document Parsing Results', + 'action': 'POST /chunk/search'}, # 检索 → Retrieve(技术场景常用) + {'type': 'chunk', 'name': 'Update Document Parsing Results', + 'action': 'PUT /chunk'}, + {'type': 'document', 'name': 'Get Document List', + 'action': 'POST /doc/list'}, + {'type': 'document', 'name': 'Download Document', + 'action': 'GET /doc/download'}, + {'type': 'document', 'name': 'Get Document Report', + 'action': 'GET /doc/report'}, + {'type': 'document', 'name': 'Download Document Report', + 'action': 'GET /doc/report/download'}, + {'type': 'document', 'name': 'Create Document', 'action': 'POST /doc'}, + {'type': 'document', 'name': 'Parse Document', + 'action': 'POST /doc/parse'}, + {'type': 'document', 'name': 'Update Document Information', + 'action': 'PUT /doc'}, + {'type': 'document', 'name': 'Delete Document', 'action': 'DELETE /doc'}, + {'type': 'dataset', 'name': 'Get Dataset List', + 'action': 'POST /dataset/list'}, + {'type': 'dataset', 'name': 'Get Test Data List', + 'action': 'POST /dataset/data'}, + {'type': 'dataset', 'name': 'Check for Test Tasks Under Test Data', + 'action': 'GET /dataset/testing/exist'}, # 更简洁的"检查是否存在"表述 + {'type': 'dataset', 'name': 'Download Dataset', + 'action': 'GET /dataset/download'}, + {'type': 'dataset', 'name': 'Create New Dataset', + 'action': 'POST /dataset'}, + {'type': 'dataset', 'name': 'Import Dataset', + 'action': 'POST /dataset/import'}, + {'type': 'dataset', 'name': 'Export Dataset', + 'action': 'POST /dataset/export'}, + {'type': 'dataset', 'name': 'Generate Dataset', + 'action': 'POST /dataset/generate'}, + {'type': 'dataset', 'name': 'Modify Dataset Information', + 'action': 'PUT /dataset'}, + {'type': 'dataset', 'name': 'Modify Test Case', + 'action': 'PUT /dataset/data'}, # 测试样例 → Test Case(行业通用术语) + {'type': 'dataset', 'name': 'Delete Dataset', + 'action': 'DELETE /dataset'}, + {'type': 'dataset', 'name': 'Delete Test Case', + 'action': 'DELETE /dataset/data'}, + {'type': 'testing', 'name': 'Get Test List', + 'action': 'POST /testing/list'}, + {'type': 'testing', 'name': 'Get Test Case List', + 'action': 'POST /testing/testcase'}, + {'type': 'testing', 'name': 'Download Test Results', + 'action': 'GET /testing/download'}, + {'type': 'testing', 'name': 'Create Test', 'action': 'POST /testing'}, + {'type': 'testing', 'name': 'Run Test', 'action': 'POST /testing/run'}, + {'type': 'testing', 'name': 'Update Test Information', + 'action': 'PUT /testing'}, + {'type': 'testing', 'name': 'Delete Test', 'action': 'DELETE /testing'}, + {'type': 'role', 'name': 'Get Role Operation List', + 'action': 'GET /role/action'}, + {'type': 'role', 'name': 'Get Role List', 'action': 'POST /role/list'}, + {'type': 'role', 'name': 'Create Role', 'action': 'POST /role'}, + {'type': 'role', 'name': 'Update Role Information', 'action': 'PUT /role'}, + {'type': 'role', 'name': 'Delete Role', 'action': 'DELETE /role'}, + {'type': 'task', 'name': 'Get Task List', 'action': 'POST /task'}, + {'type': 'task', 'name': 'Get Task Report', + 'action': 'GET /task/report'}, + {'type': 'task', 'name': 'Delete Single Task', + 'action': 'DELETE /task/one'}, + {'type': 'task', 'name': 'Delete All Tasks', + 'action': 'DELETE /task/all'} # 修正原中文表述,匹配action路径 + ], + "editable": False + }, + { "id": uuid.UUID("00000000-0000-0000-0000-000000000002"), - "name": "管理员", + "name": DeafaultRole.ADMINISTRATOR.value, # 角色名称:管理员 → Administrator(更符合行业通用表述) "is_unique": False, "actions": [ - {'type': 'team', - 'name': '获取团队用户列表', 'action': 'POST /team/usr'}, - {'type': 'team', - 'name': '获取团队消息列表', 'action': 'POST /team/msg'}, - {'type': 'team', - 'name': '发送团队邀请', 'action': 'POST /team/invitation'}, - {'type': 'team', - 'name': '处理用户申请', 'action': 'PUT /usr_msg'}, - {'type': 'team', - 'name': '更新团队用户角色', 'action': 'PUT /team/usr'}, - {'type': 'team', - 'name': '剔除团队用户', 'action': 'DELETE /team/usr'}, + {'type': 'team', 'name': 'Get Team User List', + 'action': 'POST /team/usr'}, + {'type': 'team', 'name': 'Get Team Message List', + 'action': 'POST /team/msg'}, + {'type': 'team', 'name': 'Send Team Invitation', + 'action': 'POST /team/invitation'}, + {"type": "team", "name": "Get user message list", + "action": "POST /usr_msg/list"}, + {'type': 'team', 'name': 'Update User message', + 'action': 'PUT /usr_msg'}, + {'type': 'team', 'name': 'Update Team User Role', + 'action': 'PUT /team/usr'}, + {'type': 'team', 'name': 'Remove Team User', + 'action': 'DELETE /team/usr'}, + {'type': 'knowledge_base', 'name': 'Get Knowledge Base List Under Team', + 'action': 'POST /kb/team'}, + {'type': 'knowledge_base', 'name': 'Get Knowledge Base Document Types', + 'action': 'GET /kb/doc_type'}, + {'type': 'knowledge_base', 'name': 'Download Knowledge Base File', + 'action': 'GET /kb/download'}, {'type': 'knowledge_base', - 'name': '获取团队下知识库列表', 'action': 'POST /kb/team'}, + 'name': 'Create Knowledge Base', 'action': 'POST /kb'}, + {'type': 'knowledge_base', 'name': 'Import Knowledge Base', + 'action': 'POST /kb/import'}, + {'type': 'knowledge_base', 'name': 'Export Knowledge Base', + 'action': 'POST /kb/export'}, {'type': 'knowledge_base', - 'name': '获取知识库文档类型', 'action': 'GET /kb/doc_type'}, - {'type': 'knowledge_base', - 'name': '下载知识库文件', 'action': 'GET /kb/download'}, - {'type': 'knowledge_base', - 'name': '创建知识库', 'action': 'POST /kb'}, - {'type': 'knowledge_base', - 'name': '导入知识库', 'action': 'POST /kb/import'}, - {'type': 'knowledge_base', - 'name': '导出知识库', 'action': 'POST /kb/export'}, - {'type': 'knowledge_base', - 'name': '更新知识库信息', 'action': 'PUT /kb'}, - {'type': 'knowledge_base', - 'name': '删除知识库', 'action': 'DELETE /kb'}, - {'type': 'chunk', - 'name': '获取文档解析结果列表', 'action': 'POST /chunk/list'}, - {'type': 'chunk', - 'name': '检索文档解析结果', 'action': 'POST /chunk/search'}, - {'type': 'chunk', - 'name': '更新文档解析结果', 'action': 'PUT /chunk'}, - {'type': 'document', - 'name': '获取文档列表', 'action': 'POST /doc/list'}, - {'type': 'document', - 'name': '下载文档', 'action': 'GET /doc/download'}, - {'type': 'document', - 'name': '获取文档报告', 'action': 'GET /doc/report'}, - {'type': 'document', - 'name': '下载文档报告', 'action': 'GET /doc/report/download'}, - {'type': 'document', - 'name': '创建文档', 'action': 'POST /doc'}, - {'type': 'document', - 'name': '解析文档', 'action': 'POST /doc/parse'}, - {'type': 'document', - 'name': '更新文档信息', 'action': 'PUT /doc'}, - {'type': 'document', - 'name': '删除文档', 'action': 'DELETE /doc'}, - {'type': 'dataset_data', - 'name': '获取数据集列表', 'action': 'POST /dataset/list'}, - {'type': 'dataset_data', - 'name': '获取测试数据列表', 'action': 'POST /dataset/data'}, - {'type': 'dataset_data', - 'name': '获取测试数据下是否有测试任务', 'action': 'GET /dataset/testing/exist'}, - {'type': 'dataset_data', - 'name': '下载数据集', 'action': 'GET /dataset/download'}, - {'type': 'dataset_data', - 'name': '新建数据集', 'action': 'POST /dataset'}, - {'type': 'dataset_data', - 'name': '导入数据集', 'action': 'POST /dataset/import'}, - {'type': 'dataset_data', - 'name': '导出数据集', 'action': 'POST /dataset/export'}, - {'type': 'dataset_data', - 'name': '生成数据集', 'action': 'POST /dataset/generate'}, - {'type': 'dataset_data', - 'name': '修改数据集信息', 'action': 'PUT /dataset'}, - {'type': 'dataset_data', - 'name': '修改测试样例', 'action': 'PUT /dataset/data'}, - {'type': 'dataset_data', - 'name': '删除数据集', 'action': 'DELETE /dataset'}, - {'type': 'dataset_data', - 'name': '删除测试样例', 'action': 'DELETE /dataset/data'}, - {'type': 'testing', - 'name': '获取测试列表', 'action': 'POST /testing/list'}, - {'type': 'testing', - 'name': '获取测试用例列表', 'action': 'POST /testing/testcase'}, - {'type': 'testing', - 'name': '下载测试结果', 'action': 'GET /testing/download'}, - {'type': 'testing', - 'name': '创建测试', 'action': 'POST /testing'}, - {'type': 'testing', - 'name': '运行测试', 'action': 'POST /testing/run'}, - {'type': 'testing', - 'name': '更新测试信息', 'action': 'PUT /testing'}, - {'type': 'testing', - 'name': '删除测试', 'action': 'DELETE /testing'}, - {'type': 'role', - 'name': '获取角色操作列表', 'action': 'GET /role/action'}, - {'type': 'role', - 'name': '获取角色列表', 'action': 'POST /role/list'}, - {'type': 'task', - 'name': '获取任务列表', 'action': 'POST /task'}, - {'type': 'task', - 'name': '获取任务报告', 'action': 'GET /task/report'}, - {'type': 'task', - 'name': '删除单个任务', 'action': 'DELETE /task/one'}, - {'type': 'task', - 'name': '删除单个任务', 'action': 'DELETE /task/all'}], - "editable": False, - }, { + 'name': 'Update Knowledge Base Information', 'action': 'PUT /kb'}, + {'type': 'knowledge_base', 'name': 'Delete Knowledge Base', + 'action': 'DELETE /kb'}, + {'type': 'chunk', 'name': 'Get Document Parsing Result List', + 'action': 'POST /chunk/list'}, + {'type': 'chunk', 'name': 'Retrieve Document Parsing Results', + 'action': 'POST /chunk/search'}, + {'type': 'chunk', 'name': 'Update Document Parsing Results', + 'action': 'PUT /chunk'}, + {'type': 'document', 'name': 'Get Document List', + 'action': 'POST /doc/list'}, + {'type': 'document', 'name': 'Download Document', + 'action': 'GET /doc/download'}, + {'type': 'document', 'name': 'Get Document Report', + 'action': 'GET /doc/report'}, + {'type': 'document', 'name': 'Download Document Report', + 'action': 'GET /doc/report/download'}, + {'type': 'document', 'name': 'Create Document', 'action': 'POST /doc'}, + {'type': 'document', 'name': 'Parse Document', + 'action': 'POST /doc/parse'}, + {'type': 'document', 'name': 'Update Document Information', + 'action': 'PUT /doc'}, + {'type': 'document', 'name': 'Delete Document', 'action': 'DELETE /doc'}, + {'type': 'dataset', 'name': 'Get Dataset List', + 'action': 'POST /dataset/list'}, + {'type': 'dataset', 'name': 'Get Test Data List', + 'action': 'POST /dataset/data'}, + {'type': 'dataset', 'name': 'Check for Test Tasks Under Test Data', + 'action': 'GET /dataset/testing/exist'}, + {'type': 'dataset', 'name': 'Download Dataset', + 'action': 'GET /dataset/download'}, + {'type': 'dataset', 'name': 'Create New Dataset', + 'action': 'POST /dataset'}, + {'type': 'dataset', 'name': 'Import Dataset', + 'action': 'POST /dataset/import'}, + {'type': 'dataset', 'name': 'Export Dataset', + 'action': 'POST /dataset/export'}, + {'type': 'dataset', 'name': 'Generate Dataset', + 'action': 'POST /dataset/generate'}, + {'type': 'dataset', 'name': 'Modify Dataset Information', + 'action': 'PUT /dataset'}, + {'type': 'dataset', 'name': 'Modify Test Case', + 'action': 'PUT /dataset/data'}, + {'type': 'dataset', 'name': 'Delete Dataset', + 'action': 'DELETE /dataset'}, + {'type': 'dataset', 'name': 'Delete Test Case', + 'action': 'DELETE /dataset/data'}, + {'type': 'testing', 'name': 'Get Test List', + 'action': 'POST /testing/list'}, + {'type': 'testing', 'name': 'Get Test Case List', + 'action': 'POST /testing/testcase'}, + {'type': 'testing', 'name': 'Download Test Results', + 'action': 'GET /testing/download'}, + {'type': 'testing', 'name': 'Create Test', 'action': 'POST /testing'}, + {'type': 'testing', 'name': 'Run Test', 'action': 'POST /testing/run'}, + {'type': 'testing', 'name': 'Update Test Information', + 'action': 'PUT /testing'}, + {'type': 'testing', 'name': 'Delete Test', 'action': 'DELETE /testing'}, + {'type': 'role', 'name': 'Get Role Operation List', + 'action': 'GET /role/action'}, + {'type': 'role', 'name': 'Get Role List', 'action': 'POST /role/list'}, + {'type': 'task', 'name': 'Get Task List', 'action': 'POST /task'}, + {'type': 'task', 'name': 'Get Task Report', + 'action': 'GET /task/report'}, + {'type': 'task', 'name': 'Delete Single Task', + 'action': 'DELETE /task/one'}, + {'type': 'task', 'name': 'Delete All Tasks', + 'action': 'DELETE /task/all'} + ], + "editable": False + }, + { "id": uuid.UUID("00000000-0000-0000-0000-000000000003"), - "name": "成员", + "name": DeafaultRole.MEMBER.value, # 角色名称:成员 → Member(更符合行业通用表述) "is_unique": False, - "actions": - [ - {'type': 'team', - 'name': '获取团队用户列表', 'action': 'POST /team/usr'}, - {'type': 'team', - 'name': '获取团队消息列表', 'action': 'POST /team/msg'}, - {'type': 'knowledge_base', - 'name': '获取团队下知识库列表', 'action': 'POST /kb/team'}, - {'type': 'knowledge_base', - 'name': '获取知识库文档类型', 'action': 'GET /kb/doc_type'}, - {'type': 'chunk', - 'name': '获取文档解析结果列表', 'action': 'POST /chunk/list'}, - {'type': 'document', - 'name': '获取文档列表', 'action': 'POST /doc/list'}, - {'type': 'document', - 'name': '下载文档', 'action': 'GET /doc/download'}, - {'type': 'dataset_data', - 'name': '获取数据集列表', 'action': 'POST /dataset/list'}, - {'type': 'dataset_data', - 'name': '获取测试数据列表', 'action': 'POST /dataset/data'}, - {'type': 'testing', - 'name': '获取测试列表', 'action': 'POST /testing/list'}, - {'type': 'testing', - 'name': '获取测试用例列表', 'action': 'POST /testing/testcase'}, - {'type': 'role', - 'name': '获取角色操作列表', 'action': 'GET /role/action'}, - {'type': 'role', - 'name': '获取角色列表', 'action': 'POST /role/list'} + "actions": [ + {'type': 'team', 'name': 'Get Team User List', + 'action': 'POST /team/usr'}, + {'type': 'team', 'name': 'Get Team Message List', + 'action': 'POST /team/msg'}, + {'type': 'knowledge_base', 'name': 'Get Knowledge Base List Under Team', + 'action': 'POST /kb/team'}, + {'type': 'knowledge_base', 'name': 'Get Knowledge Base Document Types', + 'action': 'GET /kb/doc_type'}, + {'type': 'chunk', 'name': 'Get Document Parsing Result List', + 'action': 'POST /chunk/list'}, + {'type': 'document', 'name': 'Get Document List', + 'action': 'POST /doc/list'}, + {'type': 'document', 'name': 'Download Document', + 'action': 'GET /doc/download'}, + {'type': 'dataset', 'name': 'Get Dataset List', + 'action': 'POST /dataset/list'}, + {'type': 'dataset', 'name': 'Get Test Data List', + 'action': 'POST /dataset/data'}, + {'type': 'testing', 'name': 'Get Test List', + 'action': 'POST /testing/list'}, + {'type': 'testing', 'name': 'Get Test Case List', + 'action': 'POST /testing/testcase'}, + {'type': 'role', 'name': 'Get Role Operation List', + 'action': 'GET /role/action'}, + {'type': 'role', 'name': 'Get Role List', 'action': 'POST /role/list'} ], - "editable": False, + "editable": False } ] diff --git a/data_chain/entities/enum.py b/data_chain/entities/enum.py index 90d7c702f7be21caca42e5e8f2b18c2b2bb2441b..3b2e3d0d1b990b6a46ce4761f5dedd5818b47251 100644 --- a/data_chain/entities/enum.py +++ b/data_chain/entities/enum.py @@ -8,6 +8,20 @@ Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. from enum import Enum +class EmbeddingType(str, Enum): + """embedding 服务的类型""" + OPENAI = "openai" + MINDIE = "mindie" + + +class RerankType(str, Enum): + """rerank 服务的类型""" + BAILIAN = "bailian" + GUIJILIUDONG = "guijiliudong" + VLLM = "vllm" + ASSECEND = "assecend" + + class TeamType(str, Enum): """团队类型""" MYCREATED = "mycreated" @@ -21,12 +35,49 @@ class TeamStatus(str, Enum): DELETED = "deleted" +class TeamMessageStatus(str, Enum): + """团队消息状态""" + EXISTED = "existed" + DELETED = "deleted" + + +class TeamUserStaus(str, Enum): + """团队用户状态""" + EXISTED = "existed" + DELETED = "deleted" + + +class RoleStatus(str, Enum): + """角色状态""" + EXISTED = "existed" + DELETED = "deleted" + + +class RoleActionStatus(str, Enum): + """角色操作状态""" + EXISTED = "existed" + DELETED = "deleted" + + +class UserRoleStatus(str, Enum): + """用户角色状态""" + EXISTED = "existed" + DELETED = "deleted" + + +class DeafaultRole(str, Enum): + """默认角色""" + CREATOR = "creator" + ADMINISTRATOR = "administrator" + MEMBER = "member" + + class Tokenizer(str, Enum): """分词器""" ZH = "中文" EN = "en" - MIX = "mix" + # MIX = "mix" class Embedding(str, Enum): @@ -53,8 +104,8 @@ class UserStatus(str, Enum): class UserMessageType(str, Enum): """用户消息类型""" - Invitation = "invitation" - Application = "application" + INVITATION = "invitation" + APPLICATION = "application" class UserMessageStatus(str, Enum): @@ -62,6 +113,7 @@ class UserMessageStatus(str, Enum): UNREAD = "unread" ACCEPTED = "accepted" REJECTED = "rejected" + DELETED = "deleted" class KnowledgeBaseStatus(str, Enum): @@ -203,6 +255,29 @@ class ActionType(str, Enum): DOCUMENT = "document" CHUNK = "chunk" DATASET = "dataset" + DATASET_DATA = "dataset_data" TESTING = "testing" + TASK = "task" + + +class IdType(str, Enum): + """ID类型""" + TEAM = "team" + ROLE = "role" + MSG = "msg" + USER = "user" + KNOWLEDGE_BASE = "knowledge_base" + DOCUMENT = "document" + CHUNK = "chunk" + DATASET = "dataset" DATASET_DATA = "dataset_data" + TESTING = "testing" + TEST_CASE = "testing_case" TASK = "task" + + +class LanguageType(str, Enum): + """语言类型""" + + CHINESE = "zh" + ENGLISH = "en" diff --git a/data_chain/entities/request_data.py b/data_chain/entities/request_data.py index 0c0b5825251e4c95f78879e3f9ee4d6c197f5fe1..8c4230971d87d743b8ef355eac6ec8a614642567 100644 --- a/data_chain/entities/request_data.py +++ b/data_chain/entities/request_data.py @@ -23,164 +23,248 @@ from data_chain.entities.enum import ( TaskStatus, OrderType) from data_chain.entities.common import DEFAULT_DOC_TYPE_ID +from data_chain.entities.enum import LanguageType class ListTeamRequest(BaseModel): - team_type: Optional[TeamType] = Field(default=None, description="团队类型", alias="teamType") - team_id: Optional[uuid.UUID] = Field(default=None, description="团队id", alias="teamId") - team_name: Optional[str] = Field(default=None, description="团队名称", alias="teamName") + team_type: Optional[TeamType] = Field( + default=None, description="团队类型", alias="teamType") + team_id: Optional[uuid.UUID] = Field( + default=None, description="团队id", alias="teamId") + team_name: Optional[str] = Field( + default=None, description="团队名称", alias="teamName") page: int = Field(default=1, description="页码") page_size: int = Field(default=40, description="每页数量", alias="pageSize") class ListTeamMsgRequest(BaseModel): - team_id: Optional[uuid.UUID] = Field(default=None, description="团队id", alias="teamId") + team_id: Optional[uuid.UUID] = Field( + default=None, description="团队id", alias="teamId") page: int = Field(default=1, description="页码") page_size: int = Field(default=40, description="每页数量", alias="pageSize") class ListTeamUserRequest(BaseModel): team_id: uuid.UUID = Field(description="团队ID", alias="teamId") - user_sub: Optional[str] = Field(default=None, description="用户ID", alias="userSub") - user_name: Optional[str] = Field(default=None, description="用户名", alias="userName") + user_sub: Optional[str] = Field( + default=None, description="用户ID", alias="userSub") + user_name: Optional[str] = Field( + default=None, description="用户名", alias="userName") page: int = Field(default=1, description="页码") page_size: int = Field(default=40, description="每页数量", alias="pageSize") class CreateTeamRequest(BaseModel): - team_name: str = Field(default='这是一个默认的团队名称', min_length=1, max_length=30, alias="teamName") - description: str = Field(default='', max_length=150) + team_name: str = Field(default='这是一个默认的团队名称', + min_length=1, max_length=256, alias="teamName") + description: str = Field(default='', max_length=256) is_public: bool = Field(default=False, alias="isPublic") class UpdateTeamRequest(BaseModel): - team_name: str = Field(default='这是一个默认的团队名称', min_length=1, max_length=30, alias="teamName") - description: str = Field(default='', max_length=150) + team_name: str = Field(default='这是一个默认的团队名称', + min_length=1, max_length=256, alias="teamName") + description: str = Field(default='', max_length=256) is_public: bool = Field(default=False, alias="isPublic") +class DetleteTeamUserRequest(BaseModel): + team_id: uuid.UUID = Field(description="团队ID", alias="teamId") + user_subs: List[str] = Field( + default=[], description="用户ID列表", alias="userSubs") + + +class ListUserMessageRequest(BaseModel): + msg_type: Optional[UserMessageType] = Field( + default=None, description="消息类型", alias="msgType") + page: int = Field(default=1, description="页码") + page_size: int = Field(default=40, description="每页数量", alias="pageSize") + + class DocumentType(BaseModel): doc_type_id: uuid.UUID = Field(description="文档类型的id", alias="docTypeId") - doc_type_name: str = Field(default='这是一个默认的文档类型名称', min_length=1, max_length=20, alias="docTypeName") + doc_type_name: str = Field( + default='这是一个默认的文档类型名称', min_length=1, max_length=256, alias="docTypeName") class ListKnowledgeBaseRequest(BaseModel): team_id: uuid.UUID = Field(description="团队id", alias="teamId") - kb_id: Optional[uuid.UUID] = Field(default=None, description="资产id", alias="kbId") - kb_name: Optional[str] = Field(default=None, description="资产名称", alias="kbName") - author_name: Optional[str] = Field(default=None, description="资产创建者", alias="authorName") + kb_id: Optional[uuid.UUID] = Field( + default=None, description="资产id", alias="kbId") + kb_name: Optional[str] = Field( + default=None, description="资产名称", alias="kbName") + author_name: Optional[str] = Field( + default=None, description="资产创建者", alias="authorName") page: int = Field(default=1, description="页码") page_size: int = Field(default=40, description="每页数量", alias="pageSize") class CreateKnowledgeBaseRequest(BaseModel): - kb_name: str = Field(default='这是一个默认的资产名称', min_length=1, max_length=20, alias="kbName") - description: str = Field(default='', max_length=150) + kb_name: str = Field(default='这是一个默认的资产名称', min_length=1, + max_length=256, alias="kbName") + description: str = Field(default='', max_length=256) tokenizer: Tokenizer = Field(default=Tokenizer.ZH) - embedding_model: str = Field(default='', description="知识库使用的embedding模型", alias="embeddingModel") - default_chunk_size: int = Field(default=512, description="知识库默认文件分块大小", alias="defaultChunkSize", min=128, max=2048) + embedding_model: str = Field( + default='', description="知识库使用的embedding模型", alias="embeddingModel") + rerank_model: Optional[str] = Field( + default=None, description="知识库使用的rerank模型", alias="rerankModel") + spearating_characters: Optional[str] = Field( + default=None, description="知识库分块的分隔符", alias="spearatingCharacters") + default_chunk_size: int = Field( + default=512, description="知识库默认文件分块大小", alias="defaultChunkSize", min=128, max=2048) default_parse_method: ParseMethod = Field( default=ParseMethod.GENERAL, description="知识库默认解析方法", alias="defaultParseMethod") - upload_count_limit: int = Field(default=128, description="知识库上传文件数量限制", alias="uploadCountLimit", min=128, max=1024) - upload_size_limit: int = Field(default=512, description="知识库上传文件大小限制", alias="uploadSizeLimit", min=128, max=2048) - doc_types: List[DocumentType] = Field(default=[], description="知识库支持的文档类型", alias="docTypes") + upload_count_limit: int = Field( + default=128, description="知识库上传文件数量限制", alias="uploadCountLimit", min=128, max=1024) + upload_size_limit: int = Field( + default=512, description="知识库上传文件大小限制", alias="uploadSizeLimit", min=128, max=2048) + doc_types: List[DocumentType] = Field( + default=[], description="知识库支持的文档类型", alias="docTypes") class UpdateKnowledgeBaseRequest(BaseModel): - kb_name: str = Field(default='这是一个默认的资产名称', min_length=1, max_length=30, alias="kbName") - description: str = Field(default='', max_length=150) + kb_name: str = Field(default='这是一个默认的资产名称', min_length=1, + max_length=256, alias="kbName") + description: str = Field(default='', max_length=256) tokenizer: Tokenizer = Field(default=Tokenizer.ZH) - default_chunk_size: int = Field(default=512, description="知识库默认文件分块大小", alias="defaultChunkSize", min=128, max=2048) + rerank_model: Optional[str] = Field( + default=None, description="知识库使用的rerank模型", alias="rerankModel") + spearating_characters: Optional[str] = Field( + default=None, description="知识库分块的分隔符", alias="spearatingCharacters") + default_chunk_size: int = Field( + default=512, description="知识库默认文件分块大小", alias="defaultChunkSize", min=128, max=2048) default_parse_method: ParseMethod = Field( default=ParseMethod.GENERAL, description="知识库默认解析方法", alias="defaultParseMethod") - upload_count_limit: int = Field(default=128, description="知识库上传文件数量限制", alias="uploadCountLimit", min=128, max=1024) - upload_size_limit: int = Field(default=512, description="知识库上传文件大小限制", alias="uploadSizeLimit", min=128, max=2048) - doc_types: List[DocumentType] = Field(default=[], description="知识库支持的文档类型", alias="docTypes") + upload_count_limit: int = Field( + default=128, description="知识库上传文件数量限制", alias="uploadCountLimit", min=128, max=1024) + upload_size_limit: int = Field( + default=512, description="知识库上传文件大小限制", alias="uploadSizeLimit", min=128, max=2048) + doc_types: List[DocumentType] = Field( + default=[], description="知识库支持的文档类型", alias="docTypes") class ListDocumentRequest(BaseModel): kb_id: uuid.UUID = Field(description="资产id", alias="kbId") - doc_id: Optional[uuid.UUID] = Field(default=None, description="文档id", alias="docId") - doc_name: Optional[str] = Field(default=None, description="文档名称", alias="docName") - doc_type_ids: Optional[list[uuid.UUID]] = Field(default=None, description="文档类型id", alias="docTypeIds") - parse_status: Optional[list[TaskStatus]] = Field(default=None, description="文档解析状态", alias="parseStatus") - parse_methods: Optional[List[ParseMethod]] = Field(default=None, description="文档解析方法", alias="parseMethods") - author_name: Optional[str] = Field(default=None, description="文档创建者", alias="authorName") - created_time_start: Optional[str] = Field(default=None, description="文档创建时间开始", alias="createdTimeStart") - created_time_end: Optional[str] = Field(default=None, description="文档创建时间结束", alias="createdTimeEnd") - created_time_order: OrderType = Field(default=OrderType.DESC, description="文档创建时间排序", alias="createdTimeOrder") - enabled: Optional[bool] = Field(default=None, description="文档是否启用", alias="enabled") + doc_id: Optional[uuid.UUID] = Field( + default=None, description="文档id", alias="docId") + doc_name: Optional[str] = Field( + default=None, description="文档名称", alias="docName") + doc_type_ids: Optional[list[uuid.UUID]] = Field( + default=None, description="文档类型id", alias="docTypeIds") + parse_status: Optional[list[TaskStatus]] = Field( + default=None, description="文档解析状态", alias="parseStatus") + parse_methods: Optional[List[ParseMethod]] = Field( + default=None, description="文档解析方法", alias="parseMethods") + author_name: Optional[str] = Field( + default=None, description="文档创建者", alias="authorName") + created_time_start: Optional[str] = Field( + default=None, description="文档创建时间开始", alias="createdTimeStart") + created_time_end: Optional[str] = Field( + default=None, description="文档创建时间结束", alias="createdTimeEnd") + created_time_order: OrderType = Field( + default=OrderType.DESC, description="文档创建时间排序", alias="createdTimeOrder") + enabled: Optional[bool] = Field( + default=None, description="文档是否启用", alias="enabled") page: int = Field(default=1, description="页码") page_size: int = Field(default=40, description="每页数量", alias="pageSize") class UpdateDocumentRequest(BaseModel): - doc_name: str = Field(default='这是一个默认的文档名称', min_length=1, max_length=150, alias="docName") - doc_type_id: uuid.UUID = Field(default=DEFAULT_DOC_TYPE_ID, description="文档类型的id", alias="docTypeId") + doc_name: str = Field(default='这是一个默认的文档名称', + min_length=1, max_length=256, alias="docName") + doc_type_id: uuid.UUID = Field( + default=DEFAULT_DOC_TYPE_ID, description="文档类型的id", alias="docTypeId") parse_method: ParseMethod = Field( default=ParseMethod.GENERAL, description="知识库默认解析方法", alias="parseMethod") - chunk_size: int = Field(default=512, description="知识库默认文件分块大小", alias="chunkSize", min=128, max=2048) + chunk_size: int = Field( + default=512, description="知识库默认文件分块大小", alias="chunkSize", min=128, max=2048) enabled: bool = Field(default=True, description="文档是否启用") class GetTemporaryDocumentStatusRequest(BaseModel): - ids: List[uuid.UUID] = Field(default=[], description="临时文档id列表", alias="ids") + ids: List[uuid.UUID] = Field( + default=[], description="临时文档id列表", alias="ids") class TemporaryDocument(BaseModel): id: uuid.UUID = Field(description="临时文档id", alias="id") - name: str = Field(default='这是一个默认的临时文档名称', min_length=1, max_length=150, alias="name") + parse_method: ParseMethod = Field( + default=ParseMethod.OCR, description="临时文档解析方法", alias="parseMethod") + name: str = Field(default='这是一个默认的临时文档名称', min_length=1, + max_length=256, alias="name") bucket_name: str = Field(default='default', description="临时文档存储的桶名称") type: str = Field(default='txt', description="临时文档的类型", alias="type") class UploadTemporaryRequest(BaseModel): - document_list: List[TemporaryDocument] = Field(default=[], description="临时文档列表") + document_list: List[TemporaryDocument] = Field( + default=[], description="临时文档列表") class DeleteTemporaryDocumentRequest(BaseModel): - ids: List[uuid.UUID] = Field(default=[], description="临时文档id列表", alias="ids") + ids: List[uuid.UUID] = Field( + default=[], description="临时文档id列表", alias="ids") class ListChunkRequest(BaseModel): doc_id: uuid.UUID = Field(description="文档id", alias="docId") - text: Optional[str] = Field(default=None, description="分块文本内容", alias="text") - types: Optional[list[ChunkType]] = Field(default=None, description="分块类型", alias="types") + text: Optional[str] = Field( + default=None, description="分块文本内容", alias="text") + types: Optional[list[ChunkType]] = Field( + default=None, description="分块类型", alias="types") page: int = Field(default=1, description="页码") page_size: int = Field(default=40, description="每页数量", alias="pageSize") class UpdateChunkRequest(BaseModel): - text: Optional[str] = Field(default=None, description="分块文本内容", alias="text") + text: Optional[str] = Field( + default=None, description="分块文本内容", alias="text") enabled: Optional[bool] = Field(default=None, description="分块是否启用") class SearchChunkRequest(BaseModel): - kb_ids: List[uuid.UUID] = Field(default=[], description="资产id", alias="kbIds") + kb_ids: List[uuid.UUID] = Field( + default=[], description="资产id", alias="kbIds") query: str = Field(default='', description="查询内容") top_k: int = Field(default=5, description="返回的结果数量", alias="topK") - doc_ids: Optional[List[uuid.UUID]] = Field(default=None, description="文档id", alias="docIds") - banned_ids: Optional[List[uuid.UUID]] = Field(default=[], description="禁止的分块id", alias="bannedIds") + doc_ids: Optional[List[uuid.UUID]] = Field( + default=None, description="文档id", alias="docIds") + banned_ids: Optional[List[uuid.UUID]] = Field( + default=[], description="禁止的分块id", alias="bannedIds") search_method: SearchMethod = Field(default=SearchMethod.KEYWORD_AND_VECTOR, description="检索方法", alias="searchMethod") - is_related_surrounding: bool = Field(default=True, description="是否关联上下文", alias="isRelatedSurrounding") - is_classify_by_doc: bool = Field(default=False, description="是否按文档分类", alias="isClassifyByDoc") - is_rerank: bool = Field(default=False, description="是否重新排序", alias="isRerank") - is_compress: bool = Field(default=False, description="是否压缩", alias="isCompress") - tokens_limit: int = Field(default=8192, description="token限制", alias="tokensLimit") + is_related_surrounding: bool = Field( + default=True, description="是否关联上下文", alias="isRelatedSurrounding") + is_classify_by_doc: bool = Field( + default=False, description="是否按文档分类", alias="isClassifyByDoc") + is_rerank: bool = Field( + default=False, description="是否重新排序", alias="isRerank") + is_compress: bool = Field( + default=False, description="是否压缩", alias="isCompress") + tokens_limit: int = Field( + default=8192, description="token限制", alias="tokensLimit") class ListDatasetRequest(BaseModel): kb_id: uuid.UUID = Field(description="资产id", alias="kbId") - dataset_id: Optional[uuid.UUID] = Field(default=None, description="数据集id", alias="datasetId") - dataset_name: Optional[str] = Field(default=None, description="数据集名称", alias="datasetName") - data_cnt_order: Optional[OrderType] = Field(default=OrderType.DESC, description="数据集数据数量", alias="dataCnt") - llm_id: Optional[str] = Field(default=None, description="数据集使用的大模型id", alias="llmId") - is_data_cleared: Optional[bool] = Field(default=None, description="数据集是否清洗", alias="isDataCleared") - is_chunk_related: Optional[bool] = Field(default=None, description="数据集是否上下文关联", alias="isChunkRelated") - generate_status: Optional[List[TaskStatus]] = Field(default=None, description="数据集生成状态", alias="generateStatus") - score_order: Optional[OrderType] = Field(default=OrderType.DESC, description="数据集评分的排序方法", alias="scoreOrder") - author_name: Optional[str] = Field(default=None, description="数据集创建者", alias="authorName") + dataset_id: Optional[uuid.UUID] = Field( + default=None, description="数据集id", alias="datasetId") + dataset_name: Optional[str] = Field( + default=None, description="数据集名称", alias="datasetName") + data_cnt_order: Optional[OrderType] = Field( + default=OrderType.DESC, description="数据集数据数量", alias="dataCnt") + llm_id: Optional[str] = Field( + default=None, description="数据集使用的大模型id", alias="llmId") + is_data_cleared: Optional[bool] = Field( + default=None, description="数据集是否清洗", alias="isDataCleared") + is_chunk_related: Optional[bool] = Field( + default=None, description="数据集是否上下文关联", alias="isChunkRelated") + generate_status: Optional[List[TaskStatus]] = Field( + default=None, description="数据集生成状态", alias="generateStatus") + score_order: Optional[OrderType] = Field( + default=OrderType.DESC, description="数据集评分的排序方法", alias="scoreOrder") + author_name: Optional[str] = Field( + default=None, description="数据集创建者", alias="authorName") page: int = Field(default=1, description="页码") page_size: int = Field(default=40, description="每页数量", alias="pageSize") @@ -194,37 +278,48 @@ class ListDataInDatasetRequest(BaseModel): class CreateDatasetRequest(BaseModel): kb_id: uuid.UUID = Field(description="资产id", alias="kbId") dataset_name: str = Field(default='这是一个默认的数据集名称', description="测试数据集名称", - min_length=1, max_length=30, alias="datasetName") - description: str = Field(default='', description="测试数据集简介", max_length=200) - document_ids: List[uuid.UUID] = Field(default=[], description="测试数据集关联的文档", alias="documentIds") - data_cnt: int = Field(default=64, alias="dataCnt", description="测试数据集内的数据数量", min=1, max=512) + min_length=1, max_length=256, alias="datasetName") + description: str = Field(default='', description="测试数据集简介", max_length=256) + document_ids: List[uuid.UUID] = Field( + default=[], description="测试数据集关联的文档", alias="documentIds") + data_cnt: int = Field(default=64, alias="dataCnt", + description="测试数据集内的数据数量", min=1, max=512) llm_id: str = Field(description="测试数据集使用的大模型id", alias="llmId") - is_data_cleared: bool = Field(default=False, description="测试数据集是否进行清洗", alias="isDataCleared") - is_chunk_related: bool = Field(default=False, description="测试数据集进行上下文关联", alias="isChunkRelated") + is_data_cleared: bool = Field( + default=False, description="测试数据集是否进行清洗", alias="isDataCleared") + is_chunk_related: bool = Field( + default=False, description="测试数据集进行上下文关联", alias="isChunkRelated") class UpdateDatasetRequest(BaseModel): dataset_name: str = Field(default='这是一个默认的数据集名称', description="测试数据集名称", - min_length=1, max_length=30, alias="datasetName") - description: str = Field(default='', description="测试数据集简介", max_length=200) + min_length=1, max_length=256, alias="datasetName") + description: str = Field(default='', description="测试数据集简介", max_length=256) class UpdateDataRequest(BaseModel): question: str = Field(default='这是一个默认的问题', description="问题", - min_length=1, max_length=200, alias="question") + min_length=1, max_length=256, alias="question") answer: str = Field(default='这是一个默认的答案', description="答案", - min_length=1, max_length=1024, alias="answer") + min_length=1, max_length=4096, alias="answer") class ListTestingRequest(BaseModel): kb_id: uuid.UUID = Field(description="资产id", alias="kbId") - testing_id: Optional[uuid.UUID] = Field(default=None, description="测试id", alias="testingId") - testing_name: Optional[str] = Field(default=None, description="测试名称", alias="testingName") - llm_ids: Optional[list[str]] = Field(default=None, description="测试使用的大模型id", alias="llmIds") - search_method: Optional[List[SearchMethod]] = Field(default=None, description="测试使用的检索方法", alias="searchMethod") - run_status: Optional[List[TaskStatus]] = Field(default=None, description="测试运行状态", alias="runStatus") - scores_order: Optional[OrderType] = Field(default=OrderType.DESC, description="测试评分", alias="scoresOrder") - author_name: Optional[str] = Field(default=None, description="测试创建者", alias="authorName") + testing_id: Optional[uuid.UUID] = Field( + default=None, description="测试id", alias="testingId") + testing_name: Optional[str] = Field( + default=None, description="测试名称", alias="testingName") + llm_ids: Optional[list[str]] = Field( + default=None, description="测试使用的大模型id", alias="llmIds") + search_method: Optional[List[SearchMethod]] = Field( + default=None, description="测试使用的检索方法", alias="searchMethod") + run_status: Optional[List[TaskStatus]] = Field( + default=None, description="测试运行状态", alias="runStatus") + scores_order: Optional[OrderType] = Field( + default=OrderType.DESC, description="测试评分", alias="scoresOrder") + author_name: Optional[str] = Field( + default=None, description="测试创建者", alias="authorName") page: int = Field(default=1, description="页码") page_size: int = Field(default=40, description="每页数量", alias="pageSize") @@ -237,8 +332,8 @@ class ListTestCaseRequest(BaseModel): class CreateTestingRequest(BaseModel): testing_name: str = Field(default='这是一个默认的测试名称', description="测试名称", - min_length=1, max_length=30, alias="testingName") - description: str = Field(default='', description="测试简介", max_length=200) + min_length=1, max_length=256, alias="testingName") + description: str = Field(default='', description="测试简介", max_length=256) dataset_id: uuid.UUID = Field(description="测试数据集id", alias="datasetId") llm_id: str = Field(description="测试使用的大模型id", alias="llmId") search_method: SearchMethod = Field(default=SearchMethod.KEYWORD_AND_VECTOR, @@ -248,8 +343,8 @@ class CreateTestingRequest(BaseModel): class UpdateTestingRequest(BaseModel): testing_name: str = Field(default='这是一个默认的测试名称', description="测试名称", - min_length=1, max_length=150, alias="testingName") - description: str = Field(default='', description="测试简介", max_length=200) + min_length=1, max_length=256, alias="testingName") + description: str = Field(default='', description="测试简介", max_length=256) llm_id: str = Field(description="测试使用的大模型id", alias="llmId") search_method: SearchMethod = Field(default=SearchMethod.KEYWORD_AND_VECTOR, description="测试使用的检索方法", alias="searchMethod") @@ -258,32 +353,48 @@ class UpdateTestingRequest(BaseModel): class ListRoleRequest(BaseModel): team_id: uuid.UUID = Field(description="团队id", alias="teamId") - role_id: Optional[uuid.UUID] = Field(default=None, description="角色id", alias="roleId") - role_name: Optional[str] = Field(default=None, description="角色名称", alias="roleName") + role_id: Optional[uuid.UUID] = Field( + default=None, description="角色id", alias="roleId") + role_name: Optional[str] = Field( + default=None, description="角色名称", alias="roleName") + is_editable: Optional[bool] = Field( + default=None, description="是否为编辑模式", alias="isEditable") + language: LanguageType = Field( + default=LanguageType.CHINESE, description="语言类型", alias="language") page: int = Field(default=1, description="页码") page_size: int = Field(default=40, description="每页数量", alias="pageSize") class CreateRoleRequest(BaseModel): - role_name: str = Field(default='这是一个默认的角色名称', min_length=1, max_length=30, alias="roleName") - actions: List[str] = Field(default=[], description="角色拥有的操作的列表", alias="actions") + role_name: str = Field(default='这是一个默认的角色名称', + min_length=1, max_length=256, alias="roleName") + actions: List[str] = Field( + default=[], description="角色拥有的操作的列表", alias="actions") class UpdateRoleRequest(BaseModel): - role_name: str = Field(default='这是一个默认的角色名称', min_length=1, max_length=30, alias="roleName") - actions: List[str] = Field(default=[], description="角色拥有的操作的列表", alias="actions") + role_name: Optional[str] = Field(default='这是一个默认的角色名称', + min_length=1, max_length=256, alias="roleName") + actions: Optional[List[str]] = Field( + default=[], description="角色拥有的操作的列表", alias="actions") class ListUserRequest(BaseModel): - user_name: Optional[str] = Field(default=None, description="用户名", alias="userName") + user_sub: Optional[str] = Field( + default=None, description="用户ID", alias="userSub") + user_name: Optional[str] = Field( + default=None, description="用户名", alias="userName") page: int = Field(default=1, description="页码") page_size: int = Field(default=40, description="每页数量", alias="pageSize") class ListTaskRequest(BaseModel): team_id: uuid.UUID = Field(description="团队id", alias="teamId") - task_id: Optional[uuid.UUID] = Field(default=None, description="任务id", alias="taskId") - task_type: Optional[TaskType] = Field(default=None, description="任务类型", alias="taskType") - task_status: Optional[TaskStatus] = Field(default=None, description="任务状态", alias="taskStatus") + task_id: Optional[uuid.UUID] = Field( + default=None, description="任务id", alias="taskId") + task_type: Optional[TaskType] = Field( + default=None, description="任务类型", alias="taskType") + task_status: Optional[TaskStatus] = Field( + default=None, description="任务状态", alias="taskStatus") page: int = Field(default=1, description="页码") page_size: int = Field(default=40, description="每页数量", alias="pageSize") diff --git a/data_chain/entities/response_data.py b/data_chain/entities/response_data.py index 5a3bf03b855be861587efb43f5d144a5007d5895..00d5f168f1b7c440ee183ab978874be6a7cc8ef0 100644 --- a/data_chain/entities/response_data.py +++ b/data_chain/entities/response_data.py @@ -1,6 +1,6 @@ # Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. -from typing import Any, Optional +from typing import Any, Optional, Union from pydantic import BaseModel, Field import uuid @@ -11,6 +11,7 @@ from data_chain.entities.enum import ( Tokenizer, ParseMethod, UserStatus, + UserMessageStatus, UserMessageType, UserMessageStatus, KnowledgeBaseStatus, @@ -24,6 +25,7 @@ from data_chain.entities.enum import ( TaskType, TaskStatus, OrderType) +from data_chain.parser.parse_result import ParseResult class ResponseData(BaseModel): @@ -37,7 +39,8 @@ class ResponseData(BaseModel): class Team(BaseModel): """团队信息""" team_id: uuid.UUID = Field(description="团队ID", alias="teamId") - team_name: str = Field(min_length=1, max_length=30, description="团队名称", alias="teamName") + team_name: str = Field(min_length=1, max_length=30, + description="团队名称", alias="teamName") description: str = Field(max_length=150, description="团队描述") author_name: str = Field(description="团队创建者的用户ID", alias="authorName") member_cnt: int = Field(description="团队成员数量", alias="memberCount") @@ -59,36 +62,44 @@ class ListTeamResponse(ResponseData): class TeamUser(BaseModel): """团队成员信息""" - user_id: uuid.UUID = Field(description="用户ID", alias="userId") + user_id: str = Field(description="用户ID", alias="userId") user_name: str = Field(description="用户名", alias="userName") role_name: str = Field(description="角色名称", alias="roleName") + is_editable: bool = Field( + default=True, description="是否可编辑", alias="isEditable") class ListTeamUserMsg(BaseModel): """GET /team/usr 数据结构""" total: int = Field(default=0, description="总数") - team_users: list[TeamUser] = Field(default=[], description="团队成员列表", alias="teamUsers") + team_users: list[TeamUser] = Field( + default=[], description="团队成员列表", alias="teamUsers") class ListTeamUserResponse(ResponseData): - result: ListTeamUserMsg = Field(default=ListTeamUserMsg(), description="团队成员列表数据结构") + result: ListTeamUserMsg = Field( + default=ListTeamUserMsg(), description="团队成员列表数据结构") class TeamMsg(BaseModel): """团队信息""" msg_id: uuid.UUID = Field(description="消息ID", alias="msgId") - author_name: str = Field(description="消息发送者的用户名", alias="authorName") - message: str = Field(description="消息内容") + author_name: str = Field(description="消息创建者的用户名", alias="authorName") + zh_msg: str = Field(description="消息内容(中文)", alias="zhMsg") + en_msg: str = Field(description="消息内容(英文)", alias="enMsg") + created_time: str = Field(description="消息创建时间", alias="createdTime") class ListTeamMsgMsg(BaseModel): """GET /team/msg 数据结构""" total: int = Field(default=0, description="总数") - team_msgs: list[TeamMsg] = Field(default=[], description="团队消息列表", alias="teamMsgs") + team_msgs: list[TeamMsg] = Field( + default=[], description="团队消息列表", alias="teamMsgs") class ListTeamMsgResponse(ResponseData): - result: ListTeamMsgMsg = Field(default=ListTeamMsgMsg(), description="团队消息列表数据结构") + result: ListTeamMsgMsg = Field( + default=ListTeamMsgMsg(), description="团队消息列表数据结构") class CreateTeamResponse(ResponseData): @@ -98,12 +109,12 @@ class CreateTeamResponse(ResponseData): class InviteTeamUserResponse(ResponseData): """POST /team/invitation 响应""" - result: Optional[uuid.UUID] = Field(default=None, description="邀请ID") + result: Optional[str] = Field(default=None, description="邀请ID") class JoinTeamResponse(ResponseData): """POST /team/application 响应""" - result: Optional[uuid.UUID] = Field(default=None, description="申请ID") + result: Optional[str] = Field(default=None, description="申请ID") class UpdateTeamResponse(ResponseData): @@ -113,7 +124,7 @@ class UpdateTeamResponse(ResponseData): class UpdateTeamUserRoleResponse(ResponseData): """PUT /team/usr 响应""" - result: Optional[uuid.UUID] = Field(default=None, description="团队成员ID") + result: Optional[str] = Field(default=None, description="团队成员ID") class UpdateTeamAuthorResponse(ResponseData): @@ -128,7 +139,7 @@ class DeleteTeamResponse(ResponseData): class DeleteTeamUserResponse(ResponseData): """DELETE /team/usr 响应""" - result: list[uuid.UUID] = Field(default=[], description="团队成员ID列表") + result: list[str] = Field(default=[], description="团队成员ID列表") class DocumentType(BaseModel): @@ -144,37 +155,50 @@ class Knowledgebase(BaseModel): author_name: str = Field(description="知识库创建者的用户名", alias="authorName") tokenizer: Tokenizer = Field(description="分词器", alias="tokenizer") embedding_model: str = Field(description="嵌入模型", alias="embeddingModel") + rerank_model: Optional[str] = Field( + default=None, description="rerank模型", alias="rerankModel") + spearating_characters: Optional[str] = Field(default=None, + description="分隔符", alias="spearatingCharacters") description: str = Field(description="知识库描述", max=150) doc_cnt: int = Field(description="知识库文档数量", alias="docCnt") doc_size: int = Field(description="知识库文档大小", alias="docSize") - upload_count_limit: int = Field(description="知识库单次文件上传数量限制", alias="uploadCountLimit") - upload_size_limit: int = Field(description="知识库单次文件上传大小限制", alias="uploadSizeLimit") - default_parse_method: ParseMethod = Field(description="默认解析方法", alias="defaultParseMethod") - default_chunk_size: int = Field(description="默认分块大小", alias="defaultChunkSize") + upload_count_limit: int = Field( + description="知识库单次文件上传数量限制", alias="uploadCountLimit") + upload_size_limit: int = Field( + description="知识库单次文件上传大小限制", alias="uploadSizeLimit") + default_parse_method: ParseMethod = Field( + description="默认解析方法", alias="defaultParseMethod") + default_chunk_size: int = Field( + description="默认分块大小", alias="defaultChunkSize") created_time: str = Field(description="知识库创建时间", alias="createdTime") - doc_types: list[DocumentType] = Field(default=[], description="知识库文档类型列表", alias="docTypes") + doc_types: list[DocumentType] = Field( + default=[], description="知识库文档类型列表", alias="docTypes") class TeamKnowledgebase(BaseModel): """团队知识库信息""" team_id: uuid.UUID = Field(description="团队ID", alias="teamId") team_name: str = Field(description="团队名称", alias="teamName") - kb_list: list[Knowledgebase] = Field(default=[], description="知识库列表", alias="kbList") + kb_list: list[Knowledgebase] = Field( + default=[], description="知识库列表", alias="kbList") class ListAllKnowledgeBaseMsg(BaseModel): """GET /kb 数据结构""" - team_knowledge_bases: list[TeamKnowledgebase] = Field(default=[], description="团队知识库列表", alias="teamKnowledgebases") + team_knowledge_bases: list[TeamKnowledgebase] = Field( + default=[], description="团队知识库列表", alias="teamKnowledgebases") class ListAllKnowledgeBaseResponse(ResponseData): """GET /kb 响应""" - result: ListAllKnowledgeBaseMsg = Field(default=ListAllKnowledgeBaseMsg(), description="团队知识库列表数据结构") + result: ListAllKnowledgeBaseMsg = Field( + default=ListAllKnowledgeBaseMsg(), description="团队知识库列表数据结构") class ListKnowledgeBaseMsg(BaseModel): total: int = Field(default=0, description="总数") - kb_list: list[Knowledgebase] = Field(default=[], description="知识库列表数据结构", alias="kbList") + kb_list: list[Knowledgebase] = Field( + default=[], description="知识库列表数据结构", alias="kbList") class ListKnowledgeBaseResponse(ResponseData): @@ -195,7 +219,8 @@ class Task(BaseModel): task_status: TaskStatus = Field(description="任务状态", alias="taskStatus") task_type: TaskType = Field(description="任务类型", alias="taskType") task_completed: float = Field(description="任务完成度", alias="taskCompleted") - finished_time: Optional[str] = Field(default=None, description="任务完成时间", alias="finishedTime") + finished_time: Optional[str] = Field( + default=None, description="任务完成时间", alias="finishedTime") created_time: str = Field(description="任务创建时间", alias="createdTime") @@ -231,8 +256,10 @@ class Document(BaseModel): doc_type: DocumentType = Field(description="文档类型", alias="docType") chunk_size: int = Field(description="文档分片大小", alias="chunkSize") created_time: str = Field(description="文档创建时间", alias="createdTime") - parse_task: Optional[Task] = Field(default=None, description="文档任务", alias="docTask") - parse_method: ParseMethod = Field(description="文档解析方法", alias="parseMethod") + parse_task: Optional[Task] = Field( + default=None, description="文档任务", alias="docTask") + parse_method: ParseMethod = Field( + description="文档解析方法", alias="parseMethod") enabled: bool = Field(description="文档是否启用", alias="enabled") author_name: str = Field(description="文档创建者的用户名", alias="authorName") status: DocumentStatus = Field(description="文档状态", alias="status") @@ -241,12 +268,14 @@ class Document(BaseModel): class ListDocumentMsg(BaseModel): """GET /doc 数据结构""" total: int = Field(default=0, description="总数") - documents: list[Document] = Field(default=[], description="文档列表", alias="documents") + documents: list[Document] = Field( + default=[], description="文档列表", alias="documents") class ListDocumentResponse(ResponseData): """GET /doc 响应""" - result: ListDocumentMsg = Field(default=ListDocumentMsg(), description="文档列表数据结构") + result: ListDocumentMsg = Field( + default=ListDocumentMsg(), description="文档列表数据结构") class GetDocumentReportResponse(ResponseData): @@ -265,7 +294,8 @@ class DOC_STATUS(BaseModel): class GetTemporaryDocumentStatusResponse(ResponseData): - result: list[DOC_STATUS] = Field(default=[], description="临时文档状态列表", alias="result") + result: list[DOC_STATUS] = Field( + default=[], description="临时文档状态列表", alias="result") class UploadTemporaryDocumentResponse(ResponseData): @@ -273,6 +303,11 @@ class UploadTemporaryDocumentResponse(ResponseData): result: list[uuid.UUID] = Field(default=[], description="临时文档ID列表") +class GetTemporaryDocumentTextResponse(ResponseData): + """GET /doc/temporary/parse_result 响应""" + result: str = Field(default="", description="临时文档解析结果") + + class DeleteTemporaryDocumentResponse(ResponseData): """DELETE /doc/temporary 响应""" result: list[uuid.UUID] = Field(default=[], description="临时文档ID列表") @@ -283,6 +318,12 @@ class ParseDocumentResponse(ResponseData): result: list[uuid.UUID] = Field(default=[], description="文档ID列表") +class ParseDocumentRealTimeResponse(ResponseData): + """POST /doc/parse/realtime 响应""" + result: list[Union[ParseResult, None]] = Field( + default=[], description="文档内容列表") + + class UpdateDocumentResponse(ResponseData): """PUT /doc 响应""" result: uuid.UUID = Field(default=None, description="文档ID") @@ -309,7 +350,8 @@ class ListChunkMsg(BaseModel): class ListChunkResponse(ResponseData): """GET /chunk 响应""" - result: ListChunkMsg = Field(default=ListChunkMsg(), description="分片列表数据结构") + result: ListChunkMsg = Field( + default=ListChunkMsg(), description="分片列表数据结构") class UpdateChunkResponse(ResponseData): @@ -325,21 +367,28 @@ class UpdateChunkEnabledResponse(ResponseData): class DocChunk(BaseModel): """Post /chunk/search 数据结构""" doc_id: uuid.UUID = Field(description="文档ID", alias="docId") - doc_name: str = Field(description="文档名称", alias="docName") - doc_abstract: str = Field(default="", description="文档摘要", alias="docAbstract") - doc_extension: str = Field(default="", description="文档扩展名", alias="docExtension") + doc_name: str = Field(default="", description="文档名称", alias="docName") + doc_author: str = Field(default="", description="文档作者", alias="docAuthor") + doc_abstract: str = Field( + default="", description="文档摘要", alias="docAbstract") + doc_extension: str = Field( + default="", description="文档扩展名", alias="docExtension") doc_size: int = Field(default=0, description="文档大小,单位是KB", alias="docSize") + doc_created_at: str = Field( + default="", description="文档创建时间", alias="docCreatedAt") chunks: list[Chunk] = Field(default=[], description="分片列表", alias="chunks") class SearchChunkMsg(BaseModel): """Post /chunk/search 数据结构""" - doc_chunks: list[DocChunk] = Field(default=[], description="文档分片列表", alias="docChunks") + doc_chunks: list[DocChunk] = Field( + default=[], description="文档分片列表", alias="docChunks") class SearchChunkResponse(ResponseData): """POST /chunk/search 响应""" - result: SearchChunkMsg = Field(default=SearchChunkMsg(), description="文档分片列表数据结构") + result: SearchChunkMsg = Field( + default=SearchChunkMsg(), description="文档分片列表数据结构") class LLM(BaseModel): @@ -351,15 +400,22 @@ class LLM(BaseModel): class Dataset(BaseModel): """数据集信息""" dataset_id: uuid.UUID = Field(description="数据集ID", alias="datasetId") - dataset_name: str = Field(description="数据集名称", min=1, max=20, alias="datasetName") + dataset_name: str = Field( + description="数据集名称", min=1, max=20, alias="datasetName") description: str = Field(description="数据集描述", max=150) data_cnt: int = Field(description="数据集条目限制", alias="dataCnt") - data_cnt_existed: int = Field(default=0, description="数据集实际条目", alias="dataCntExisted") - is_data_cleared: bool = Field(default=False, description="数据集是否进行清洗", alias="isDataCleared") - is_chunk_related: bool = Field(default=False, description="数据集进行上下文关联", alias="isChunkRelated") - is_imported: bool = Field(default=False, description="数据集是否导入", alias="isImported") - llm: Optional[LLM] = Field(default=None, description="生成数据集使用的大模型信息", alias="llm") - generate_task: Optional[Task] = Field(default=None, description="数据集生成任务", alias="generateTask") + data_cnt_existed: int = Field( + default=0, description="数据集实际条目", alias="dataCntExisted") + is_data_cleared: bool = Field( + default=False, description="数据集是否进行清洗", alias="isDataCleared") + is_chunk_related: bool = Field( + default=False, description="数据集进行上下文关联", alias="isChunkRelated") + is_imported: bool = Field( + default=False, description="数据集是否导入", alias="isImported") + llm: Optional[LLM] = Field( + default=None, description="生成数据集使用的大模型信息", alias="llm") + generate_task: Optional[Task] = Field( + default=None, description="数据集生成任务", alias="generateTask") score: Optional[float] = Field(description="数据集评分", default=None) author_name: str = Field(description="数据集创建者的用户名", alias="authorName") status: DataSetStatus = Field(description="数据集状态", alias="status") @@ -368,12 +424,14 @@ class Dataset(BaseModel): class ListDatasetMsg(BaseModel): """GET /dataset 数据结构""" total: int = Field(default=0, description="总数") - datasets: list[Dataset] = Field(default=[], description="数据集列表", alias="datasets") + datasets: list[Dataset] = Field( + default=[], description="数据集列表", alias="datasets") class ListDatasetResponse(ResponseData): """GET /dataset 响应""" - result: ListDatasetMsg = Field(default=ListDatasetMsg(), description="数据集列表数据结构") + result: ListDatasetMsg = Field( + default=ListDatasetMsg(), description="数据集列表数据结构") class Data(BaseModel): @@ -393,7 +451,8 @@ class ListDataInDatasetMsg(BaseModel): class ListDataInDatasetResponse(ResponseData): """GET /dataset/data 响应""" - result: ListDataInDatasetMsg = Field(default=ListDataInDatasetMsg(), description="数据列表数据结构") + result: ListDataInDatasetMsg = Field( + default=ListDataInDatasetMsg(), description="数据列表数据结构") class IsDatasetHaveTestingResponse(ResponseData): @@ -449,19 +508,30 @@ class DeleteDataResponse(ResponseData): class Testing(BaseModel): """测试信息""" testing_id: uuid.UUID = Field(description="测试ID", alias="testingId") - testing_name: str = Field(description="测试名称", min=1, max=20, alias="testingName") + testing_name: str = Field( + description="测试名称", min=1, max=20, alias="testingName") description: str = Field(description="测试描述", max=150) - llm: Optional[LLM] = Field(default=None, description="测试使用的大模型信息", alias="llm") - search_method: SearchMethod = Field(description="搜索方法", alias="searchMethod") - testing_task: Optional[Task] = Field(default=None, description="测试任务", alias="testingTask") + llm: Optional[LLM] = Field( + default=None, description="测试使用的大模型信息", alias="llm") + search_method: SearchMethod = Field( + description="搜索方法", alias="searchMethod") + testing_task: Optional[Task] = Field( + default=None, description="测试任务", alias="testingTask") ave_score: float = Field(default=-1, description="综合得分", alias="aveScore") - ave_pre: float = Field(default=-1, description="精确率", alias="avePre") # 精确度 - ave_rec: float = Field(default=-1, description="召回率", alias="aveRec") # 召回率 - ave_fai: float = Field(default=-1, description="忠实值", alias="aveFai") # 忠实值 - ave_rel: float = Field(default=-1, description="可解释性", alias="aveRel") # 可解释性 - ave_lcs: float = Field(default=-1, description="最长公共子串得分", alias="aveLcs") # 最长公共子序列得分 - ave_leve: float = Field(default=-1, description="编辑距离得分", alias="aveLeve") # 编辑距离得分 - ave_jac: float = Field(default=-1, description="杰卡德相似系数", alias="aveJac") # 杰卡德相似系数 + ave_pre: float = Field(default=-1, description="精确率", + alias="avePre") # 精确度 + ave_rec: float = Field(default=-1, description="召回率", + alias="aveRec") # 召回率 + ave_fai: float = Field(default=-1, description="忠实值", + alias="aveFai") # 忠实值 + ave_rel: float = Field( + default=-1, description="可解释性", alias="aveRel") # 可解释性 + ave_lcs: float = Field( + default=-1, description="最长公共子串得分", alias="aveLcs") # 最长公共子序列得分 + ave_leve: float = Field( + default=-1, description="编辑距离得分", alias="aveLeve") # 编辑距离得分 + ave_jac: float = Field( + default=-1, description="杰卡德相似系数", alias="aveJac") # 杰卡德相似系数 author_name: str = Field(description="测试创建者的用户名", alias="authorName") topk: int = Field(description="检索到的片段数量", alias="topk") status: TestingStatus = Field(description="测试状态", alias="status") @@ -471,18 +541,21 @@ class DatasetTesting(BaseModel): """数据集测试信息""" dataset_id: uuid.UUID = Field(description="数据集ID", alias="datasetId") dataset_name: str = Field(description="数据集名称", alias="datasetName") - testings: list[Testing] = Field(default=[], description="测试列表", alias="testings") + testings: list[Testing] = Field( + default=[], description="测试列表", alias="testings") class ListTestingMsg(BaseModel): """GET /testing 数据结构""" total: int = Field(default=0, description="总数") - dataset_testings: list[DatasetTesting] = Field(default=[], description="数据集测试列表", alias="datasetTestings") + dataset_testings: list[DatasetTesting] = Field( + default=[], description="数据集测试列表", alias="datasetTestings") class ListTestingResponse(ResponseData): """GET /testing 响应""" - result: ListTestingMsg = Field(default=ListTestingMsg(), description="测试列表数据结构") + result: ListTestingMsg = Field( + default=ListTestingMsg(), description="测试列表数据结构") class TestCase(BaseModel): @@ -505,21 +578,26 @@ class TestCase(BaseModel): class TestingTestCase(BaseModel): """GET /testing/testcase 数据结构""" - ave_score: float = Field(default=-1, description="平均综合得分", alias="aveScore") + ave_score: float = Field( + default=-1, description="平均综合得分", alias="aveScore") ave_pre: float = Field(default=-1, description="平均精确率", alias="avePre") ave_rec: float = Field(default=-1, description="平均召回率", alias="aveRec") ave_fai: float = Field(default=-1, description="平均忠实值", alias="aveFai") ave_rel: float = Field(default=-1, description="平均可解释性", alias="aveRel") - ave_lcs: float = Field(default=-1, description="平均最长公共子串得分", alias="aveLcs") - ave_leve: float = Field(default=-1, description="平均编辑距离得分", alias="aveLeve") + ave_lcs: float = Field( + default=-1, description="平均最长公共子串得分", alias="aveLcs") + ave_leve: float = Field( + default=-1, description="平均编辑距离得分", alias="aveLeve") ave_jac: float = Field(default=-1, description="平均杰卡德相似系数", alias="aveJac") total: int = Field(default=0, description="总数") - test_cases: list[TestCase] = Field(default=[], description="测试用例列表", alias="testCases") + test_cases: list[TestCase] = Field( + default=[], description="测试用例列表", alias="testCases") class ListTestCaseResponse(ResponseData): """GET /testing/testcase 响应""" - result: TestingTestCase = Field(default=TestingTestCase(), description="测试用例列表数据结构") + result: TestingTestCase = Field( + default=TestingTestCase(), description="测试用例列表数据结构") class CreateTestingResponsing(ResponseData): @@ -542,38 +620,55 @@ class DeleteTestingResponse(ResponseData): result: list[uuid.UUID] = Field(default=[], description="测试ID列表") -class action(BaseModel): +class Action(BaseModel): """操作信息""" - action_name: str = Field(description="操作名称", min=1, max=20, alias="actionName") - action: str = Field(description="操作", min=1, max=20) - is_used: bool = Field(description="是否启用", alias="isUsed") + action: str = Field(description="操作") + action_description: str = Field( + description="操作描述", alias="actionDescription") + is_used: bool = Field(default=False, description="是否启用", alias="isUsed") class TypeAction(BaseModel): """不同类别的类别操作""" action_type: ActionType = Field(description="操作类型", alias="actionType") - actions: list[action] = Field(default=[], description="操作列表", alias="actions") + actions: list[Action] = Field( + default=[], description="操作列表", alias="actions") class ListActionMsg(BaseModel): """GET /role/action 数据结构""" - type_actions: list[TypeAction] = Field(default=[], description="操作类型列表", alias="actionTypes") + type_actions: list[TypeAction] = Field( + default=[], description="操作类型列表", alias="TypeActions") class ListActionResponse(ResponseData): - result: ListActionMsg = Field(default=ListActionMsg(), description="操作列表数据结构") + result: ListActionMsg = Field( + default=ListActionMsg(), description="操作列表数据结构") -class role(BaseModel): +class GetUserRoleMsg(BaseModel): + """GET /role 数据结构""" + role_id: uuid.UUID = Field(description="角色ID", alias="roleId") + role_name: str = Field(description="角色名称", min=1, max=20, alias="roleName") + is_owner: bool = Field(description="是否为团队所有者", alias="isOwner") + + +class GetUserRoleResponse(ResponseData): + """GET /role 响应""" + result: GetUserRoleMsg = Field(description="用户角色数据结构") + + +class Role(BaseModel): """角色信息""" role_id: uuid.UUID = Field(description="角色ID", alias="roleId") role_name: str = Field(description="角色名称", min=1, max=20, alias="roleName") - type_actions: list[TypeAction] = Field(default=[], description="操作类型列表", alias="typeActions") + type_actions: list[TypeAction] = Field( + default=[], description="操作类型列表", alias="typeActions") class ListRoleMsg(BaseModel): """GET /role 数据结构""" - roles: list[role] = Field(default=[], description="角色列表", alias="roles") + roles: list[Role] = Field(default=[], description="角色列表", alias="roles") class ListRoleResponse(ResponseData): @@ -593,30 +688,38 @@ class UpdateRoleResponse(ResponseData): class DeleteRoleResponse(ResponseData): """DELETE /role 响应""" - result: list[uuid.UUID] = Field(default=[], description="角色ID列表") + result: Optional[uuid.UUID] = Field(default=[], description="角色ID列表") class UserMsg(BaseModel): """用户消息""" team_id: uuid.UUID = Field(description="团队ID", alias="teamId") + team_name: str = Field(description="团队名称", alias="teamName") msg_id: uuid.UUID = Field(description="消息ID", alias="msgId") - sender_id: uuid.UUID = Field(description="发送者ID", alias="senderId") - sender_name: str = Field(description="发送者名称", alias="senderName") - receiver_id: uuid.UUID = Field(description="接收者ID", alias="receiverId") - receiver_name: str = Field(description="接收者名称", alias="receiverName") + sender_id: Optional[str] = Field(description="发送者ID", alias="senderId") + sender_name: Optional[str] = Field(description="发送者名称", alias="senderName") + msg_status_to_sender: UserMessageStatus = Field( + description="发送者消息状态", alias="msgStatusToSender") + receiver_id: Optional[str] = Field(description="接收者ID", alias="receiverId") + receiver_name: Optional[str] = Field( + description="接收者名称", alias="receiverName") + msg_status_to_receiver: UserMessageStatus = Field( + description="接收者消息状态", alias="msgStatusToReceiver") msg_type: UserMessageType = Field(description="消息类型", alias="msgType") - msg_status: UserMessageStatus = Field(description="消息状态", alias="msgStatus") + is_editable: bool = Field(description="消息是否可编辑", alias="isEditable") created_time: str = Field(description="创建时间", alias="createdTime") class ListUserMessageMsg(BaseModel): """GET /usr_msg 数据结构""" total: int = Field(default=0, description="总数") - user_messages: list[UserMsg] = Field(default=[], description="用户消息列表", alias="userMessages") + user_messages: list[UserMsg] = Field( + default=[], description="用户消息列表", alias="userMessages") class ListUserMessageResponse(ResponseData): - result: ListUserMessageMsg = Field(default=ListUserMessageMsg(), description="用户消息列表数据结构") + result: ListUserMessageMsg = Field( + default=ListUserMessageMsg(), description="用户消息列表数据结构") class UpdateUserMessageResponse(ResponseData): @@ -626,13 +729,13 @@ class UpdateUserMessageResponse(ResponseData): class DeleteUserMessageResponse(ResponseData): """DELETE /usr_msg 响应""" - result: list[uuid.UUID] = Field(default=[], description="消息ID列表") + result: Optional[uuid.UUID] = Field(default=[], description="消息ID列表") class User(BaseModel): """用户数据结构""" - user_sub: str = Field(description="用户id") - user_name: str = Field(description="用户名称") + user_sub: str = Field(description="用户id", alias="userSub") + user_name: str = Field(description="用户名称", alias="userName") class ListUserMsg(BaseModel): @@ -664,6 +767,11 @@ class ListEmbeddingResponse(ResponseData): result: list[str] = Field(default=[], description="向量化模型的列表数据结构") +class ListRerankResponse(ResponseData): + """GET /other/rerank 数据结构""" + result: list[str] = Field(default=[], description="重排序模型的列表数据结构") + + class ListTokenizerResponse(ResponseData): """GET /other/tokenizer 响应""" result: list[str] = Field(default=[], description="分词器的列表数据结构") diff --git a/data_chain/llm/llm.py b/data_chain/llm/llm.py index b5cd720ad24c049a8bce7e242fc1271b8df76449..031af5bbde694a5d4492174f55a3a65138c5bc3a 100644 --- a/data_chain/llm/llm.py +++ b/data_chain/llm/llm.py @@ -1,12 +1,7 @@ # Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. import asyncio -import time -import re -import json -import tiktoken -from langchain_openai import ChatOpenAI -from langchain.schema import SystemMessage, HumanMessage -from data_chain.logger.logger import logger as logging +from openai import AsyncOpenAI +from data_chain.logger.logger import logger class LLM: @@ -17,79 +12,79 @@ class LLM: self.max_tokens = max_tokens self.request_timeout = request_timeout self.temperature = temperature - self.client = ChatOpenAI(model_name=model_name, - openai_api_base=openai_api_base, - openai_api_key=openai_api_key, - request_timeout=request_timeout, - max_tokens=max_tokens, - temperature=temperature) + self._client = AsyncOpenAI( + api_key=self.openai_api_key, + base_url=self.openai_api_base, + ) def assemble_chat(self, chat=None, system_call='', user_call=''): if chat is None: chat = [] - chat.append(SystemMessage(content=system_call)) - chat.append(HumanMessage(content=user_call)) + chat.append({"role": "system", "content": system_call}) + chat.append({"role": "user", "content": user_call}) return chat - async def nostream(self, chat, system_call, user_call,st_str:str=None,en_str:str=None): - try: - chat = self.assemble_chat(chat, system_call, user_call) - response = await self.client.ainvoke(chat) - content = re.sub(r'.*?\n?', '', response.content, flags=re.DOTALL) - content = re.sub(r'.*?\n?', '', content, flags=re.DOTALL) - content=content.strip() - if st_str is not None: - index = content.find(st_str) - if index != -1: - content = content[index:] - if en_str is not None: - index = content[::-1].find(en_str[::-1]) - if index != -1: - content = content[:len(content)-index] - logging.error("[LLM] 非流式输出内容: %s", content) - except Exception as e: - err = f"[LLM] 非流式输出异常: {e}" - logging.error("[LLM] %s", err) - return '' - return content + async def create_stream( + self, message): + return await self._client.chat.completions.create( + model=self.model_name, + messages=message, # type: ignore[] + max_completion_tokens=self.max_tokens, + temperature=self.temperature, + stream=True, + stream_options={"include_usage": True}, + timeout=300, + extra_body={"enable_thinking": False} + ) # type: ignore[] async def data_producer(self, q: asyncio.Queue, history, system_call, user_call): message = self.assemble_chat(history, system_call, user_call) + stream = await self.create_stream(message) try: - async for frame in self.client.astream(message): - await q.put(frame.content) + async for chunk in stream: + if len(chunk.choices) == 0: + continue + if chunk.choices[0].delta.content is not None: + content = chunk.choices[0].delta.content + else: + continue + await q.put(content) except Exception as e: await q.put(None) err = f"[LLM] 流式输出生产者任务异常: {e}" - logging.error("[LLM] %s", err) + logger.error(err) raise e await q.put(None) async def stream(self, chat, system_call, user_call): - st = time.time() q = asyncio.Queue(maxsize=10) # 启动生产者任务 - producer_task = asyncio.create_task(self.data_producer(q, chat, system_call, user_call)) - first_token_reach = False - enc = tiktoken.encoding_for_model("gpt-4") - input_tokens = len(enc.encode(system_call)) - output_tokens = 0 + asyncio.create_task(self.data_producer( + q, chat, system_call, user_call)) while True: data = await q.get() if data is None: break - if not first_token_reach: - first_token_reach = True - logging.info(f"大模型回复第一个字耗时 = {time.time() - st}") - output_tokens += len(enc.encode(data)) - yield "data: " + json.dumps( - {'content': data, - 'input_tokens': input_tokens, - 'output_tokens': output_tokens - }, ensure_ascii=False - ) + '\n\n' - await asyncio.sleep(0.03) # 使用异步 sleep + yield data - yield "data: [DONE]" - logging.info(f"大模型回复耗时 = {time.time() - st}") + async def nostream(self, chat, system_call, user_call, st_str: str = None, en_str: str = None): + try: + content = '' + async for chunk in self.stream(chat, system_call, user_call): + content += chunk + content = content.strip() + if st_str is not None: + index = content.find(st_str) + if index != -1: + content = content[index:] + if en_str is not None: + index = content[::-1].find(en_str[::-1]) + if index != -1: + content = content[:len(content)-index] + logger.error(f"LLM nostream content: {content}") + except Exception as e: + err = f"[LLM] 非流式输出异常: {e}" + logger.error("[LLM] %s", err) + return '' + return content diff --git a/data_chain/manager/chunk_manager.py b/data_chain/manager/chunk_manager.py index 277abc3e41a18c3a8ea824c4808893ac8bcf36d0..d9324815e03249e1c35c5f0906f7be52e548082e 100644 --- a/data_chain/manager/chunk_manager.py +++ b/data_chain/manager/chunk_manager.py @@ -2,12 +2,14 @@ from sqlalchemy import select, update, func, text, or_, and_, Float, literal_column from typing import List, Tuple, Dict, Optional import uuid +from datetime import datetime from data_chain.entities.enum import DocumentStatus, ChunkStatus, Tokenizer from data_chain.entities.request_data import ListChunkRequest from data_chain.config.config import config from data_chain.stores.database.database import DocumentEntity, ChunkEntity, DataBase from data_chain.manager.knowledge_manager import KnowledgeBaseManager from data_chain.logger.logger import logger as logging +import logging class ChunkManager(): @@ -133,10 +135,12 @@ class ChunkManager(): if req.text is not None: stmt = stmt.where(ChunkEntity.text.ilike(f"%{req.text}%")) if req.types is not None: - stmt = stmt.where(ChunkEntity.type.in_([t.value for t in req.types])) + stmt = stmt.where(ChunkEntity.type.in_( + [t.value for t in req.types])) count_stmt = select(func.count()).select_from(stmt.subquery()) total = (await session.execute(count_stmt)).scalar() - stmt = stmt.offset((req.page - 1) * req.page_size).limit(req.page_size) + stmt = stmt.offset( + (req.page - 1) * req.page_size).limit(req.page_size) stmt = stmt.order_by(ChunkEntity.global_offset) result = await session.execute(stmt) chunk_entities = result.scalars().all() @@ -169,64 +173,139 @@ class ChunkManager(): kb_id: uuid.UUID, vector: List[float], top_k: int, doc_ids: list[uuid.UUID] = None, banned_ids: list[uuid.UUID] = [], chunk_to_type: str = None, pre_ids: list[uuid.UUID] = None) -> List[ChunkEntity]: - """根据知识库ID和向量查询文档解析结果""" + """根据知识库ID和向量查询文档解析结果(适配OpenGauss强制索引)""" try: if top_k <= 0: return [] + st = datetime.now() async with await DataBase.get_session() as session: - fetch_cnt = top_k + # -------------------------- + # 原有逻辑:构建WHERE条件和params(完全保留) + # -------------------------- + where_conditions = [ + "document.enabled = true", + "document.status != 'deleted'", + "chunk.kb_id = :kb_id", + "chunk.enabled = true", + "chunk.status != 'deleted'", + "chunk.text_vector IS NOT NULL" + ] + + params = { + "vector": vector, + "kb_id": kb_id, + "limit": top_k + } + + if banned_ids: + banned_placeholders = [] + for i, banned_id in enumerate(banned_ids): + param_name = f"banned_id_{i}" + banned_placeholders.append(f":{param_name}") + params[param_name] = banned_id + where_conditions.append( + f"chunk.id NOT IN ({','.join(banned_placeholders)})") + + if doc_ids is not None: + doc_placeholders = [] + for i, doc_id in enumerate(doc_ids): + param_name = f"doc_id_{i}" + doc_placeholders.append(f":{param_name}") + params[param_name] = doc_id + where_conditions.append( + f"document.id IN ({','.join(doc_placeholders)})") + + if chunk_to_type is not None: + where_conditions.append( + "chunk.parse_topology_type = :chunk_to_type") + params["chunk_to_type"] = chunk_to_type + + if pre_ids is not None: + pre_placeholders = [] + for i, pre_id in enumerate(pre_ids): + param_name = f"pre_id_{i}" + pre_placeholders.append(f":{param_name}") + params[param_name] = pre_id + where_conditions.append( + f"chunk.pre_id_in_parse_topology IN ({','.join(pre_placeholders)})") + + where_clause = " AND ".join(where_conditions) + + # -------------------------- + # 核心修复:替换为OpenGauss支持的索引提示(二选一,推荐方案1) + # -------------------------- + # 方案1:临时关闭全表扫描(优先推荐,简单有效,会话级生效,不影响其他查询) + # 执行SET命令:关闭全表扫描后,数据库会优先选择可用索引(text_vector_index) + await session.execute(text("SET enable_seqscan = off;")) + + # 方案2(备选):使用OpenGauss查询计划hints(需确保数据库开启hints支持) + # 在SELECT后添加 /*+ IndexScan(chunk text_vector_index) */ 强制索引扫描 + # 若用方案2,需将下面SELECT行改为:SELECT /*+ IndexScan(chunk text_vector_index) */ + + # 构建查询SQL(移除USE INDEX,保留其他原有逻辑) + base_sql = f""" + SELECT + chunk.id, chunk.team_id, chunk.kb_id, chunk.doc_id, chunk.doc_name, + chunk.text, chunk.text_vector, chunk.tokens, chunk.type, + chunk.pre_id_in_parse_topology, chunk.parse_topology_type, + chunk.global_offset, chunk.local_offset, chunk.enabled, + chunk.status, chunk.created_time, chunk.updated_time, + chunk.text_vector <#> :vector AS similarity_score + FROM chunk + JOIN document ON document.id = chunk.doc_id + WHERE {where_clause} + AND (chunk.text_vector <#> :vector) IS NOT NULL + AND (chunk.text_vector <#> :vector) = (chunk.text_vector <#> :vector) + ORDER BY similarity_score ASC NULLS LAST + LIMIT :limit + """ + # -------------------------- + # 原有逻辑:执行查询与结果处理(完全保留) + # -------------------------- + result = await session.execute(text(base_sql), params) + rows = result.fetchall() + chunk_entities = [] - while True: - # 计算相似度分数 - similarity_score = ChunkEntity.text_vector.cosine_distance(vector).label("similarity_score") - - # 构建基础查询条件 - stmt = ( - select(ChunkEntity, similarity_score) - .join(DocumentEntity, - DocumentEntity.id == ChunkEntity.doc_id - ) - .where(DocumentEntity.enabled == True) - .where(DocumentEntity.status != DocumentStatus.DELETED.value) - .where(ChunkEntity.kb_id == kb_id) - .where(ChunkEntity.enabled == True) - .where(ChunkEntity.status != ChunkStatus.DELETED.value) - .where(ChunkEntity.id.notin_(banned_ids)) + for row in rows: + chunk_entity = ChunkEntity( + id=row.id, + team_id=row.team_id, + kb_id=row.kb_id, + doc_id=row.doc_id, + doc_name=row.doc_name, + text=row.text, + text_vector=row.text_vector, + tokens=row.tokens, + type=row.type, + pre_id_in_parse_topology=row.pre_id_in_parse_topology, + parse_topology_type=row.parse_topology_type, + global_offset=row.global_offset, + local_offset=row.local_offset, + enabled=row.enabled, + status=row.status, + created_time=row.created_time, + updated_time=row.updated_time ) + chunk_entities.append(chunk_entity) + + # 可选:查询结束后恢复enable_seqscan(避免影响后续查询,会话结束也会自动恢复) - # 添加可选条件 - if doc_ids is not None: - stmt = stmt.where(DocumentEntity.id.in_(doc_ids)) - if chunk_to_type is not None: - stmt = stmt.where(ChunkEntity.parse_topology_type == chunk_to_type) - if pre_ids is not None: - stmt = stmt.where(ChunkEntity.pre_id_in_parse_topology.in_(pre_ids)) - - # 应用排序条件 - stmt = stmt.order_by(similarity_score) - - stmt = stmt.limit(fetch_cnt) - - # 执行最终查询 - result = await session.execute(stmt) - chunk_entities = result.scalars().all() - if chunk_entities: - break - fetch_cnt *= 2 - fetch_cnt = min(fetch_cnt, max(8192, top_k)+1) - chunk_entities = chunk_entities[:top_k] # 确保返回的结果不超过 top_k + logging.error(f"向量查询耗时:{datetime.now()-st}") return chunk_entities + except Exception as e: - err = "根据知识库ID和向量查询文档解析结果失败" + err = f"根据知识库ID和向量查询文档解析结果失败: {str(e)}" logging.exception("[ChunkManager] %s", err) return [] + @staticmethod async def get_top_k_chunk_by_kb_id_keyword( kb_id: uuid.UUID, query: str, top_k: int, doc_ids: list[uuid.UUID] = None, banned_ids: list[uuid.UUID] = [], chunk_to_type: str = None, pre_ids: list[uuid.UUID] = None, is_tight: bool = True) -> List[ChunkEntity]: """根据知识库ID和向量查询文档解析结果""" try: + st = datetime.now() async with await DataBase.get_session() as session: kb_entity = await KnowledgeBaseManager.get_knowledge_base_by_kb_id(kb_id) if kb_entity.tokenizer == Tokenizer.ZH.value: @@ -242,94 +321,95 @@ class ChunkManager(): else: tokenizer = 'zhparser' - # 计算相似度分数并选择它 + # -------------------------- 新增:提前生成 tsquery(复用逻辑,避免重复计算) -------------------------- if is_tight: - similarity_score = func.ts_rank_cd( - func.to_tsvector(tokenizer, ChunkEntity.text), - func.plainto_tsquery(tokenizer, query) - ).label("similarity_score") + # 与原similarity_score中的tsquery逻辑完全一致 + tsquery = func.plainto_tsquery(tokenizer, query) else: - similarity_score = func.ts_rank_cd( - func.to_tsvector(tokenizer, ChunkEntity.text), - func.to_tsquery( - func.replace( - func.text(func.plainto_tsquery(tokenizer, query)), - '&', '|' - ) + # 与原similarity_score中的tsquery逻辑完全一致 + tsquery = func.to_tsquery( + func.replace( + func.text(func.plainto_tsquery(tokenizer, query)), + '&', '|' ) - ).label("similarity_score") + ) + # --------------------------------------------------------------------------------------------------- + + # 计算相似度分数并选择它(逻辑不变,复用上面生成的tsquery) + similarity_score = func.ts_rank_cd( + ChunkEntity.text_ts_vector, + tsquery # 替换原重复的tsquery生成逻辑,直接用提前生成的 + ).label("similarity_score") stmt = ( select(ChunkEntity, similarity_score) .join(DocumentEntity, DocumentEntity.id == ChunkEntity.doc_id ) - .where(similarity_score > 0) + # -------------------------- 核心新增:通过 @@ 条件强制触发 GIN 索引 -------------------------- + .where(ChunkEntity.text_ts_vector.op('@@')(tsquery)) + # --------------------------------------------------------------------------------------------------- + .where(similarity_score > 0) # 原条件保留,顺序不变 .where(DocumentEntity.enabled == True) .where(DocumentEntity.status != DocumentStatus.DELETED.value) .where(ChunkEntity.kb_id == kb_id) .where(ChunkEntity.enabled == True) .where(ChunkEntity.status != ChunkStatus.DELETED.value) - .where(ChunkEntity.id.notin_(banned_ids)) ) - + if banned_ids: + stmt = stmt.where(ChunkEntity.id.notin_(banned_ids)) if doc_ids is not None: stmt = stmt.where(DocumentEntity.id.in_(doc_ids)) if chunk_to_type is not None: - stmt = stmt.where(ChunkEntity.parse_topology_type == chunk_to_type) + stmt = stmt.where( + ChunkEntity.parse_topology_type == chunk_to_type) if pre_ids is not None: - stmt = stmt.where(ChunkEntity.pre_id_in_parse_topology.in_(pre_ids)) - - # 按相似度分数排序 + stmt = stmt.where( + ChunkEntity.pre_id_in_parse_topology.in_(pre_ids)) + # 按相似度分数排序(逻辑不变) stmt = stmt.order_by(similarity_score.desc()) stmt = stmt.limit(top_k) - # 执行最终查询 + # 执行最终查询(逻辑不变) result = await session.execute(stmt) chunk_entities = result.scalars().all() - + logging.warning( + f"[ChunkManager] get_top_k_chunk_by_kb_id_keyword cost: {(datetime.now()-st).total_seconds()}s") return chunk_entities except Exception as e: err = f"根据知识库ID和向量查询文档解析结果失败: {str(e)}" logging.exception("[ChunkManager] %s", err) return [] + @staticmethod async def get_top_k_chunk_by_kb_id_dynamic_weighted_keyword( kb_id: uuid.UUID, keywords: List[str], weights: List[float], top_k: int, doc_ids: list[uuid.UUID] = None, banned_ids: list[uuid.UUID] = [], chunk_to_type: str = None, pre_ids: list[uuid.UUID] = None) -> List[ChunkEntity]: - """根据知识库ID和关键词和关键词权重查询文档解析结果""" + """根据知识库ID和关键词权重查询文档解析结果(修复NoneType报错+强制索引)""" try: + st = datetime.now() async with await DataBase.get_session() as session: + # 1. 分词器选择(保留原逻辑) kb_entity = await KnowledgeBaseManager.get_knowledge_base_by_kb_id(kb_id) if kb_entity.tokenizer == Tokenizer.ZH.value: - if config['DATABASE_TYPE'].lower() == 'opengauss': - tokenizer = 'chparser' - else: - tokenizer = 'zhparser' + tokenizer = 'chparser' if config['DATABASE_TYPE'].lower( + ) == 'opengauss' else 'zhparser' elif kb_entity.tokenizer == Tokenizer.EN.value: tokenizer = 'english' else: - if config['DATABASE_TYPE'].lower() == 'opengauss': - tokenizer = 'chparser' - else: - tokenizer = 'zhparser' + tokenizer = 'chparser' if config['DATABASE_TYPE'].lower( + ) == 'opengauss' else 'zhparser' - # 构建VALUES子句的参数 + # 2. 构建加权关键词CTE(保留原逻辑) params = {} values_clause = [] - - for i, (term, weight) in enumerate(zip(keywords, weights)): - # 使用单独的参数名,避免与类型转换冲突 - params[f"term_{i}"] = term - params[f"weight_{i}"] = weight - # 在VALUES子句中使用类型转换函数 - values_clause.append(f"(CAST(:term_{i} AS TEXT), CAST(:weight_{i} AS FLOAT8))") - - # 构建VALUES子句 + for idx, (term, weight) in enumerate(zip(keywords, weights)): + params[f"term_{idx}"] = term + params[f"weight_{idx}"] = weight + values_clause.append( + f"(CAST(:term_{idx} AS TEXT), CAST(:weight_{idx} AS FLOAT8))") values_text = f"(VALUES {', '.join(values_clause)}) AS t(term, weight)" - - # 创建weighted_terms CTE weighted_terms = ( select( literal_column("t.term").label("term"), @@ -339,49 +419,81 @@ class ChunkManager(): .cte("weighted_terms") ) - # 计算相似度得分 - similarity_score = func.sum( - func.ts_rank_cd( - func.to_tsvector(tokenizer, ChunkEntity.text), - func.to_tsquery(tokenizer, weighted_terms.c.term) - ) * weighted_terms.c.weight - ).label("similarity_score") - + # 3. 初始化查询(确保stmt始终是Select对象,不直接赋值None) stmt = ( - select(ChunkEntity, similarity_score) - .join(DocumentEntity, - DocumentEntity.id == ChunkEntity.doc_id - ) + select( + ChunkEntity, + func.sum( + func.ts_rank_cd(ChunkEntity.text_ts_vector, func.to_tsquery( + tokenizer, weighted_terms.c.term)) + * weighted_terms.c.weight + ).label("similarity_score") + ) + # 关联文档表 + .join(DocumentEntity, DocumentEntity.id == ChunkEntity.doc_id) + .join( # 关联CTE+强制触发GIN索引(核心优化) + weighted_terms, + ChunkEntity.text_ts_vector.op( + '@@')(func.to_tsquery(tokenizer, weighted_terms.c.term)), + isouter=False + ) + # 基础过滤条件 .where(DocumentEntity.enabled == True) .where(DocumentEntity.status != DocumentStatus.DELETED.value) .where(ChunkEntity.kb_id == kb_id) .where(ChunkEntity.enabled == True) .where(ChunkEntity.status != ChunkStatus.DELETED.value) - .where(ChunkEntity.id.notin_(banned_ids)) ) - # 添加 GROUP BY 子句,按 ChunkEntity.id 分组 - stmt = stmt.group_by(ChunkEntity.id) - stmt.having(similarity_score > 0) + # 4. 动态条件:禁用ID(修复关键:用if-else确保stmt不被赋值为None) + if banned_ids: + stmt = stmt.where(ChunkEntity.id.notin_(banned_ids)) + # 5. 其他动态条件(同样用if-else确保链式调用不中断) if doc_ids is not None: stmt = stmt.where(DocumentEntity.id.in_(doc_ids)) if chunk_to_type is not None: - stmt = stmt.where(ChunkEntity.parse_topology_type == chunk_to_type) + stmt = stmt.where( + ChunkEntity.parse_topology_type == chunk_to_type) if pre_ids is not None: - stmt = stmt.where(ChunkEntity.pre_id_in_parse_topology.in_(pre_ids)) - - # 按相似度分数排序 - stmt = stmt.order_by(similarity_score.desc()) - stmt = stmt.limit(top_k) + stmt = stmt.where( + ChunkEntity.pre_id_in_parse_topology.in_(pre_ids)) + + # 6. 分组、过滤分数、排序、限制行数(链式调用安全) + stmt = (stmt + .group_by(ChunkEntity.id) # 按chunk分组计算总权重 + .having( # 过滤总分数>0的结果 + func.sum( + func.ts_rank_cd(ChunkEntity.text_ts_vector, func.to_tsquery( + tokenizer, weighted_terms.c.term)) + * weighted_terms.c.weight + ) > 0 + ) + .order_by( # 按总分数降序 + func.sum( + func.ts_rank_cd(ChunkEntity.text_ts_vector, func.to_tsquery( + tokenizer, weighted_terms.c.term)) + * weighted_terms.c.weight + ).desc() + ) + .limit(top_k) # 限制返回数量 + ) - # 执行最终查询 + # 7. 执行查询与结果处理(保留原逻辑) result = await session.execute(stmt, params=params) chunk_entities = result.scalars().all() + # 8. 日志输出 + cost = (datetime.now() - st).total_seconds() + logging.warning( + f"[ChunkManager] get_top_k_chunk_by_kb_id_dynamic_weighted_keyword cost: {cost}s " + f"| kb_id: {kb_id} | keywords: {keywords[:2]}... | match_count: {len(chunk_entities)}" + ) return chunk_entities + except Exception as e: - err = f"根据知识库ID和向量查询文档解析结果失败: {str(e)}" + # 异常日志补充关键上下文 + err = f"根据知识库ID和关键词权重查询失败: kb_id={kb_id}, keywords={keywords[:2]}..., error={str(e)[:150]}" logging.exception("[ChunkManager] %s", err) return [] @@ -409,6 +521,42 @@ class ChunkManager(): logging.exception("[ChunkManager] %s", err) raise e + @staticmethod + async def update_chunk_text_ts_vector_by_chunk_ids(chunk_ids: List[uuid.UUID]) -> None: + """根据文档ID更新文档解析结果""" + if not chunk_ids: + return + try: + kb_entity = await KnowledgeBaseManager.get_knowledge_base_by_kb_id( + (await ChunkManager.get_chunk_by_chunk_id(chunk_ids[0])).kb_id) + if kb_entity.tokenizer == Tokenizer.ZH.value: + if config['DATABASE_TYPE'].lower() == 'opengauss': + tokenizer = 'chparser' + else: + tokenizer = 'zhparser' + elif kb_entity.tokenizer == Tokenizer.EN.value: + tokenizer = 'english' + else: + if config['DATABASE_TYPE'].lower() == 'opengauss': + tokenizer = 'chparser' + else: + tokenizer = 'zhparser' + async with await DataBase.get_session() as session: + stmt = ( + update(ChunkEntity) + .where(ChunkEntity.id.in_(chunk_ids)) + .values({ + ChunkEntity.text_ts_vector: func.to_tsvector( + tokenizer, ChunkEntity.text) + }) + ) + await session.execute(stmt) + await session.commit() + return True + except Exception as e: + err = "根据文档ID更新文档解析结果失败" + logging.exception("[ChunkManager] %s", err) + @staticmethod async def update_chunk_by_doc_id(doc_id: uuid.UUID, chunk_dict: Dict[str, str]) -> bool: """根据文档ID更新文档解析结果""" diff --git a/data_chain/manager/dataset_manager.py b/data_chain/manager/dataset_manager.py index cb61aa8a495a892f7ca7271c2960ba6692f031dd..f607c6bfaf67ba8d19e7c613ed3ad716494db9c2 100644 --- a/data_chain/manager/dataset_manager.py +++ b/data_chain/manager/dataset_manager.py @@ -72,6 +72,23 @@ class DatasetManager: logging.exception("[DatasetManager] %s", err) raise e + @staticmethod + async def get_dataset_by_data_id(data_id: uuid.UUID) -> DataSetEntity: + """根据数据ID查询数据集""" + try: + async with await DataBase.get_session() as session: + stmt = ( + select(DataSetEntity) + .join(QAEntity, QAEntity.dataset_id == DataSetEntity.id) + .where(and_(QAEntity.id == data_id, DataSetEntity.status != DataSetStatus.DELETED.value)) + ) + result = await session.execute(stmt) + return result.scalars().first() + except Exception as e: + err = "根据数据ID查询数据集失败" + logging.exception("[DatasetManager] %s", err) + raise e + @staticmethod async def list_dataset(req: ListDatasetRequest) -> tuple[int, List[DataSetEntity]]: """列出数据集""" @@ -83,34 +100,43 @@ class DatasetManager: select(DataSetEntity) .outerjoin(subq, and_(DataSetEntity.id == subq.c.op_id, subq.c.rn == 1)) ) - stmt = stmt.where(DataSetEntity.status != DataSetStatus.DELETED.value) + stmt = stmt.where(DataSetEntity.status != + DataSetStatus.DELETED.value) if req.kb_id is not None: stmt = stmt.where(DataSetEntity.kb_id == req.kb_id) if req.dataset_id is not None: stmt = stmt.where(DataSetEntity.id == req.dataset_id) if req.dataset_name is not None: - stmt = stmt.where(DataSetEntity.name.ilike(f"%{req.dataset_name}%")) + stmt = stmt.where(DataSetEntity.name.ilike( + f"%{req.dataset_name}%")) if req.llm_id is not None: stmt = stmt.where(DataSetEntity.llm_id == req.llm_id) if req.is_data_cleared is not None: - stmt = stmt.where(DataSetEntity.is_data_cleared == req.is_data_cleared) + stmt = stmt.where( + DataSetEntity.is_data_cleared == req.is_data_cleared) if req.is_chunk_related is not None: - stmt = stmt.where(DataSetEntity.is_chunk_related == req.is_chunk_related) + stmt = stmt.where( + DataSetEntity.is_chunk_related == req.is_chunk_related) if req.generate_status is not None: - status_list = [status.value for status in req.generate_status] - status_list += [DataSetStatus.DELETED.value] + status_list = [ + status.value for status in req.generate_status] + if TaskStatus.SUCCESS in req.generate_status: + status_list += [TaskStatus.DELETED.value] stmt = stmt.where(subq.c.status.in_(status_list)) - stmt = stmt.order_by(DataSetEntity.created_at.desc(), DataSetEntity.id.desc()) + stmt = stmt.order_by( + DataSetEntity.created_at.desc(), DataSetEntity.id.desc()) if req.score_order: if req.score_order == "asc": stmt = stmt.order_by(asc(DataSetEntity.score)) else: stmt = stmt.order_by(desc(DataSetEntity.score)) if req.author_name: - stmt = stmt.where(DataSetEntity.author_name.ilike(f"%{req.author_name}%")) + stmt = stmt.where( + DataSetEntity.author_name.ilike(f"%{req.author_name}%")) count_stmt = select(func.count()).select_from(stmt.subquery()) total = (await session.execute(count_stmt)).scalar() - stmt = stmt.offset((req.page - 1) * req.page_size).limit(req.page_size) + stmt = stmt.offset( + (req.page - 1) * req.page_size).limit(req.page_size) result = await session.execute(stmt) dataset_entities = result.scalars().all() return total, dataset_entities @@ -128,7 +154,8 @@ class DatasetManager: select(DataSetEntity) .where(DataSetEntity.kb_id == kb_id) ) - stmt = stmt.where(DataSetEntity.status != DataSetStatus.DELETED.value) + stmt = stmt.where(DataSetEntity.status != + DataSetStatus.DELETED.value) stmt = stmt.order_by(DataSetEntity.id.desc()) result = await session.execute(stmt) return result.scalars().all() @@ -146,7 +173,8 @@ class DatasetManager: select(DataSetEntity) .where(DataSetEntity.id.in_(dataset_ids)) ) - stmt = stmt.where(DataSetEntity.status != DataSetStatus.DELETED.value) + stmt = stmt.where(DataSetEntity.status != + DataSetStatus.DELETED.value) stmt = stmt.order_by(DataSetEntity.id.desc()) stmt = stmt.order_by(DataSetEntity.id) result = await session.execute(stmt) diff --git a/data_chain/manager/document_manager.py b/data_chain/manager/document_manager.py index 55cb87a90397d1d0a6df5db6176b423311b301c2..bad9d72c6321937d9cdf7d17079ed597841cbe46 100644 --- a/data_chain/manager/document_manager.py +++ b/data_chain/manager/document_manager.py @@ -51,31 +51,89 @@ class DocumentManager(): try: if top_k <= 0: return [] + # 构建基础WHERE条件 + where_conditions = [ + "document.kb_id = :kb_id", + "document.status != :deleted_status", + "document.enabled = TRUE", + "document.abstract_vector IS NOT NULL" + ] + + # 构建参数字典 + params = { + "vector": vector, + "kb_id": kb_id, + "deleted_status": DocumentStatus.DELETED.value + } + + # 添加banned_ids条件 + if banned_ids: + banned_placeholders = [] + for i, banned_id in enumerate(banned_ids): + param_name = f"banned_id_{i}" + banned_placeholders.append(f":{param_name}") + params[param_name] = banned_id + where_conditions.append( + f"document.id NOT IN ({','.join(banned_placeholders)})") + + # 添加doc_ids条件 + if doc_ids is not None: + doc_placeholders = [] + for i, doc_id in enumerate(doc_ids): + param_name = f"doc_id_{i}" + doc_placeholders.append(f":{param_name}") + params[param_name] = doc_id + where_conditions.append( + f"document.id IN ({','.join(doc_placeholders)})") + + # 组合WHERE条件 + where_clause = " AND ".join(where_conditions) + + # 构建查询SQL - 添加分数有效性检查 + base_sql = f""" + SELECT + document.id, document.team_id, document.kb_id, document.author_id, + document.author_name, document.name, document.extension, + document.size, document.parse_method, document.parse_relut_topology, + document.chunk_size, document.type_id, document.enabled, + document.status, document.full_text, document.abstract, + document.abstract_vector, document.created_time, document.updated_time, + document.abstract_vector <#> :vector AS similarity_score + FROM document + WHERE {where_clause} + AND (document.abstract_vector <#> :vector) IS NOT NULL + AND (document.abstract_vector <#> :vector) = (document.abstract_vector <#> :vector) + ORDER BY similarity_score ASC NULLS LAST + LIMIT :limit + """ async with await DataBase.get_session() as session: - fetch_cnt = top_k + params["limit"] = top_k + result = await session.execute(text(base_sql), params) + rows = result.fetchall() document_entities = [] - while fetch_cnt < max(top_k, 8192): - similarity_score = DocumentEntity.abstract_vector.cosine_distance(vector).label("similarity_score") - stmt = ( - select(DocumentEntity, similarity_score) - .where(DocumentEntity.kb_id == kb_id) - .where(DocumentEntity.id.notin_(banned_ids)) - .where(DocumentEntity.status != DocumentStatus.DELETED.value) - .where(DocumentEntity.enabled == True) + for row in rows: + doc_entity = DocumentEntity( + id=row.id, + team_id=row.team_id, + kb_id=row.kb_id, + author_id=row.author_id, + author_name=row.author_name, + name=row.name, + extension=row.extension, + size=row.size, + parse_method=row.parse_method, + parse_relut_topology=row.parse_relut_topology, + chunk_size=row.chunk_size, + type_id=row.type_id, + enabled=row.enabled, + status=row.status, + full_text=row.full_text, + abstract=row.abstract, + abstract_vector=row.abstract_vector, + created_time=row.created_time, + updated_time=row.updated_time ) - if doc_ids: - stmt = stmt.where(DocumentEntity.id.in_(doc_ids)) - stmt = stmt.order_by( - similarity_score - ) - # 获取所有符合条件的文档数量 - stmt = stmt.limit(fetch_cnt) # Ensure at least 50 results for vector search - result = await session.execute(stmt) - document_entities = result.scalars().all() - if document_entities: - break - fetch_cnt *= 2 # Increase fetch count by 50 until we have enough results - document_entities = document_entities[:top_k] # Limit to top_k results + document_entities.append(doc_entity) return document_entities except Exception as e: err = "获取前K个文档失败" @@ -89,7 +147,7 @@ class DocumentManager(): try: async with await DataBase.get_session() as session: kb_entity = await KnowledgeBaseManager.get_knowledge_base_by_kb_id(kb_id) - tokenizer = '' + # 设置分词器,增加默认值处理 if kb_entity.tokenizer == Tokenizer.ZH.value: if config['DATABASE_TYPE'].lower() == 'opengauss': tokenizer = 'chparser' @@ -97,15 +155,25 @@ class DocumentManager(): tokenizer = 'zhparser' elif kb_entity.tokenizer == Tokenizer.EN.value: tokenizer = 'english' + else: + # 增加默认分词器处理,与第一个方法保持一致 + if config['DATABASE_TYPE'].lower() == 'opengauss': + tokenizer = 'chparser' + else: + tokenizer = 'zhparser' - similarity_score = func.ts_rank_cd( - func.to_tsvector(tokenizer, DocumentEntity.abstract), - func.to_tsquery( - func.replace( - func.text(func.plainto_tsquery(tokenizer, query)), - '&', '|' - ) + # 提前生成tsquery,复用逻辑 + tsquery = func.to_tsquery( + func.replace( + func.text(func.plainto_tsquery(tokenizer, query)), + '&', '|' ) + ) + + # 计算相似度分数,使用提前生成的tsquery + similarity_score = func.ts_rank_cd( + DocumentEntity.abstract_ts_vector, + tsquery ).label("similarity_score") stmt = ( @@ -114,6 +182,10 @@ class DocumentManager(): .where(DocumentEntity.id.notin_(banned_ids)) .where(DocumentEntity.status != DocumentStatus.DELETED.value) .where(DocumentEntity.enabled == True) + # 新增:通过@@条件强制触发GIN索引 + .where(DocumentEntity.abstract_ts_vector.op('@@')(tsquery)) + # 新增:过滤相似度大于0的结果 + .where(similarity_score > 0) ) if doc_ids: stmt = stmt.where(DocumentEntity.id.in_(doc_ids)) @@ -133,38 +205,30 @@ class DocumentManager(): async def get_top_k_document_by_kb_id_dynamic_weighted_keyword( kb_id: uuid.UUID, keywords: List[str], weights: List[float], top_k: int, doc_ids: list[uuid.UUID] = None, banned_ids: list[uuid.UUID] = []) -> List[DocumentEntity]: - """根据知识库ID和关键词和关键词权重查询文档解析结果""" + """根据知识库ID和关键词权重查询文档(修复NoneType报错+强制索引)""" try: + st = datetime.now() # 新增计时日志 async with await DataBase.get_session() as session: + # 1. 分词器选择(与第一个方法保持一致) kb_entity = await KnowledgeBaseManager.get_knowledge_base_by_kb_id(kb_id) if kb_entity.tokenizer == Tokenizer.ZH.value: - if config['DATABASE_TYPE'].lower() == 'opengauss': - tokenizer = 'chparser' - else: - tokenizer = 'zhparser' + tokenizer = 'chparser' if config['DATABASE_TYPE'].lower( + ) == 'opengauss' else 'zhparser' elif kb_entity.tokenizer == Tokenizer.EN.value: tokenizer = 'english' else: - if config['DATABASE_TYPE'].lower() == 'opengauss': - tokenizer = 'chparser' - else: - tokenizer = 'zhparser' + tokenizer = 'chparser' if config['DATABASE_TYPE'].lower( + ) == 'opengauss' else 'zhparser' - # 构建VALUES子句的参数 + # 2. 构建加权关键词CTE(保留原逻辑) params = {} values_clause = [] - - for i, (term, weight) in enumerate(zip(keywords, weights)): - # 使用单独的参数名,避免与类型转换冲突 - params[f"term_{i}"] = term - params[f"weight_{i}"] = weight - # 在VALUES子句中使用类型转换函数 - values_clause.append(f"(CAST(:term_{i} AS TEXT), CAST(:weight_{i} AS FLOAT8))") - - # 构建VALUES子句 + for idx, (term, weight) in enumerate(zip(keywords, weights)): + params[f"term_{idx}"] = term + params[f"weight_{idx}"] = weight + values_clause.append( + f"(CAST(:term_{idx} AS TEXT), CAST(:weight_{idx} AS FLOAT8))") values_text = f"(VALUES {', '.join(values_clause)}) AS t(term, weight)" - - # 创建weighted_terms CTE weighted_terms = ( select( literal_column("t.term").label("term"), @@ -174,41 +238,73 @@ class DocumentManager(): .cte("weighted_terms") ) - # 计算相似度得分 - similarity_score = func.sum( - func.ts_rank_cd( - func.to_tsvector(tokenizer, DocumentEntity.abstract), - func.to_tsquery(tokenizer, weighted_terms.c.term) - ) * weighted_terms.c.weight - ).label("similarity_score") - + # 3. 初始化查询(确保stmt始终是Select对象) stmt = ( - select(DocumentEntity, similarity_score) + select( + DocumentEntity, + func.sum( + func.ts_rank_cd(DocumentEntity.abstract_ts_vector, func.to_tsquery( + tokenizer, weighted_terms.c.term)) + * weighted_terms.c.weight + ).label("similarity_score") + ) + # 关联CTE+强制触发GIN索引(核心优化) + .join( + weighted_terms, + DocumentEntity.abstract_ts_vector.op( + '@@')(func.to_tsquery(tokenizer, weighted_terms.c.term)), + isouter=False + ) + # 基础过滤条件 .where(DocumentEntity.enabled == True) .where(DocumentEntity.status != DocumentStatus.DELETED.value) .where(DocumentEntity.kb_id == kb_id) - .where(DocumentEntity.id.notin_(banned_ids)) ) - # 添加 GROUP BY 子句,按 ChunkEntity.id 分组 - stmt = stmt.group_by(DocumentEntity.id) - stmt.having(similarity_score > 0) + # 4. 动态条件:禁用ID(确保stmt链式调用不中断) + if banned_ids: + stmt = stmt.where(DocumentEntity.id.notin_(banned_ids)) + # 5. 其他动态条件 if doc_ids is not None: stmt = stmt.where(DocumentEntity.id.in_(doc_ids)) - # 按相似度分数排序 - stmt = stmt.order_by(similarity_score.desc()) - stmt = stmt.limit(top_k) + # 6. 分组、过滤分数、排序、限制行数(链式调用安全) + stmt = (stmt + .group_by(DocumentEntity.id) # 按文档ID分组计算总权重 + .having( # 过滤总分数>0的结果 + func.sum( + func.ts_rank_cd(DocumentEntity.abstract_ts_vector, func.to_tsquery( + tokenizer, weighted_terms.c.term)) + * weighted_terms.c.weight + ) > 0 + ) + .order_by( # 按总分数降序 + func.sum( + func.ts_rank_cd(DocumentEntity.abstract_ts_vector, func.to_tsquery( + tokenizer, weighted_terms.c.term)) + * weighted_terms.c.weight + ).desc() + ) + .limit(top_k) # 限制返回数量 + ) - # 执行最终查询 + # 7. 执行查询与结果处理 result = await session.execute(stmt, params=params) - doc_entites = result.scalars().all() + doc_entities = result.scalars().all() + + # 8. 新增执行时间日志 + cost = (datetime.now() - st).total_seconds() + logging.warning( + f"[DocumentManager] get_top_k_document_by_kb_id_dynamic_weighted_keyword cost: {cost}s " + f"| kb_id: {kb_id} | keywords: {keywords[:2]}... | match_count: {len(doc_entities)}" + ) + return doc_entities - return doc_entites except Exception as e: - err = f"根据知识库ID和关键字动态查询文档失败: {str(e)}" - logging.exception("[ChunkManager] %s", err) + # 异常日志补充关键上下文 + err = f"根据知识库ID和关键词权重查询文档失败: kb_id={kb_id}, keywords={keywords[:2]}..., error={str(e)[:150]}" + logging.exception("[DocumentManager] %s", err) return [] @staticmethod @@ -242,22 +338,27 @@ class DocumentManager(): select(DocumentEntity) .outerjoin(subq, and_(DocumentEntity.id == subq.c.op_id, subq.c.rn == 1)) ) - stmt = stmt.where(DocumentEntity.status != DocumentStatus.DELETED.value) + stmt = stmt.where(DocumentEntity.status != + DocumentStatus.DELETED.value) if req.kb_id is not None: stmt = stmt.where(DocumentEntity.kb_id == req.kb_id) if req.doc_id is not None: stmt = stmt.where(DocumentEntity.id == req.doc_id) if req.doc_name is not None: - stmt = stmt.where(DocumentEntity.name.ilike(f"%{req.doc_name}%")) + stmt = stmt.where( + DocumentEntity.name.ilike(f"%{req.doc_name}%")) if req.doc_type_ids is not None: - stmt = stmt.where(DocumentEntity.type_id.in_(req.doc_type_ids)) + stmt = stmt.where( + DocumentEntity.type_id.in_(req.doc_type_ids)) if req.parse_status is not None: - stmt = stmt.where(subq.c.status.in_([status.value for status in req.parse_status])) + stmt = stmt.where(subq.c.status.in_( + [status.value for status in req.parse_status])) if req.parse_methods is not None: stmt = stmt.where(DocumentEntity.parse_method.in_( [parse_method.value for parse_method in req.parse_methods])) if req.author_name is not None: - stmt = stmt.where(DocumentEntity.author_name.ilike(f"%{req.author_name}%")) + stmt = stmt.where( + DocumentEntity.author_name.ilike(f"%{req.author_name}%")) if req.enabled is not None: stmt = stmt.where(DocumentEntity.enabled == req.enabled) if req.created_time_start and req.created_time_end: @@ -269,7 +370,8 @@ class DocumentManager(): ) count_stmt = select(func.count()).select_from(stmt.subquery()) total = (await session.execute(count_stmt)).scalar() - stmt = stmt.offset((req.page - 1) * req.page_size).limit(req.page_size) + stmt = stmt.offset( + (req.page - 1) * req.page_size).limit(req.page_size) if req.created_time_order == OrderType.DESC: stmt = stmt.order_by(DocumentEntity.created_time.desc()) else: @@ -331,6 +433,36 @@ class DocumentManager(): logging.exception("[DocumentManager] %s", err) raise e + @staticmethod + async def update_document_abstract_ts_vector_by_doc_ids(doc_ids: list[uuid.UUID]) -> None: + """根据文档ID批量更新文档摘要词向量""" + try: + kb_entity = await KnowledgeBaseManager.get_knowledge_base_by_kb_id( + (await DocumentManager.get_document_by_doc_id(doc_ids[0])).kb_id) + if kb_entity.tokenizer == Tokenizer.ZH.value: + if config['DATABASE_TYPE'].lower() == 'opengauss': + tokenizer = 'chparser' + else: + tokenizer = 'zhparser' + elif kb_entity.tokenizer == Tokenizer.EN.value: + tokenizer = 'english' + else: + if config['DATABASE_TYPE'].lower() == 'opengauss': + tokenizer = 'chparser' + else: + tokenizer = 'zhparser' + async with await DataBase.get_session() as session: + stmt = update(DocumentEntity).where( + and_(DocumentEntity.id.in_(doc_ids), + DocumentEntity.status != DocumentStatus.DELETED.value) + ).values(abstract_ts_vector=func.to_tsvector(tokenizer, DocumentEntity.abstract)) + await session.execute(stmt) + await session.commit() + except Exception as e: + err = "批量更新文档摘要词向量失败" + logging.exception("[DocumentManager] %s", err) + raise e + @staticmethod async def update_doc_type_by_kb_id( kb_id: uuid.UUID, old_doc_type_ids: list[uuid.UUID], diff --git a/data_chain/manager/knowledge_manager.py b/data_chain/manager/knowledge_manager.py index ae51c6f6e77b9db76e1d5c79ac8d8febc33c5b52..4b0a5b7aa785c000d4b6a09c1fa1538602bf3a42 100644 --- a/data_chain/manager/knowledge_manager.py +++ b/data_chain/manager/knowledge_manager.py @@ -46,17 +46,22 @@ class KnowledgeBaseManager(): stmt = select(KnowledgeBaseEntity).where( KnowledgeBaseEntity.status != KnowledgeBaseStatus.DELETED.value) if req.team_id is not None: - stmt = stmt.where(KnowledgeBaseEntity.team_id == req.team_id) + stmt = stmt.where( + KnowledgeBaseEntity.team_id == req.team_id) if req.kb_id is not None: stmt = stmt.where(KnowledgeBaseEntity.id == req.kb_id) if req.kb_name is not None: - stmt = stmt.where(KnowledgeBaseEntity.name.like(f"%{req.kb_name}%")) + stmt = stmt.where( + KnowledgeBaseEntity.name.like(f"%{req.kb_name}%")) if req.author_name is not None: - stmt = stmt.where(KnowledgeBaseEntity.author_name.like(f"%{req.author_name}%")) + stmt = stmt.where( + KnowledgeBaseEntity.author_name.like(f"%{req.author_name}%")) count_stmt = select(func.count()).select_from(stmt.subquery()) total = (await session.execute(count_stmt)).scalar() - stmt = stmt.limit(req.page_size).offset((req.page - 1) * req.page_size) - stmt = stmt.order_by(KnowledgeBaseEntity.created_time.desc(), KnowledgeBaseEntity.id.desc()) + stmt = stmt.limit(req.page_size).offset( + (req.page - 1) * req.page_size) + stmt = stmt.order_by( + KnowledgeBaseEntity.created_time.desc(), KnowledgeBaseEntity.id.desc()) result = await session.execute(stmt) knowledge_base_entities = result.scalars().all() return (total, knowledge_base_entities) @@ -79,7 +84,10 @@ class KnowledgeBaseManager(): if kb_id: stmt = stmt.where(KnowledgeBaseEntity.id == kb_id) if kb_name: - stmt = stmt.where(KnowledgeBaseEntity.name.like(f"%{kb_name}%")) + stmt = stmt.where( + KnowledgeBaseEntity.name.like(f"%{kb_name}%")) + stmt = stmt.order_by( + KnowledgeBaseEntity.created_time.desc(), KnowledgeBaseEntity.id.desc()) result = await session.execute(stmt) knowledge_base_entities = result.scalars().all() return knowledge_base_entities @@ -93,7 +101,8 @@ class KnowledgeBaseManager(): """列出知识库文档类型""" try: async with await DataBase.get_session() as session: - stmt = select(DocumentTypeEntity).where(DocumentTypeEntity.kb_id == kb_id) + stmt = select(DocumentTypeEntity).where( + DocumentTypeEntity.kb_id == kb_id) result = await session.execute(stmt) document_type_entities = result.scalars().all() return document_type_entities @@ -107,10 +116,12 @@ class KnowledgeBaseManager(): """根据知识库ID更新知识库""" try: async with await DataBase.get_session() as session: - stmt = update(KnowledgeBaseEntity).where(KnowledgeBaseEntity.id == kb_id).values(**kb_dict) + stmt = update(KnowledgeBaseEntity).where( + KnowledgeBaseEntity.id == kb_id).values(**kb_dict) await session.execute(stmt) await session.commit() - stmt = select(KnowledgeBaseEntity).where(KnowledgeBaseEntity.id == kb_id) + stmt = select(KnowledgeBaseEntity).where( + KnowledgeBaseEntity.id == kb_id) result = await session.execute(stmt) knowledge_base_entity = result.scalars().first() return knowledge_base_entity diff --git a/data_chain/manager/role_manager.py b/data_chain/manager/role_manager.py index 4b97da7e3d6aa9297cddbbfe5951f44a930203cb..2ec89c50f095afbb92eb0960bb3579c41c4ab7fd 100644 --- a/data_chain/manager/role_manager.py +++ b/data_chain/manager/role_manager.py @@ -1,11 +1,18 @@ # Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. -from sqlalchemy import select, delete, and_ +from sqlalchemy import select, delete, update, and_, func from typing import Dict import uuid from data_chain.logger.logger import logger as logging -from data_chain.entities.request_data import ListTeamRequest -from data_chain.entities.enum import TeamStatus +from data_chain.entities.request_data import ( + ListRoleRequest +) +from data_chain.entities.enum import ( + TeamStatus, + RoleActionStatus, + UserRoleStatus, + RoleStatus +) from data_chain.stores.database.database import DataBase, RoleEntity, ActionEntity, RoleActionEntity, UserRoleEntity @@ -67,6 +74,67 @@ class RoleManager: logging.exception("[RoleManager] %s", err) raise e + @staticmethod + async def get_role_by_id(role_id: uuid.UUID) -> RoleEntity: + """根据角色ID获取角色""" + try: + async with await DataBase.get_session() as session: + stmt = select(RoleEntity).where( + and_( + RoleEntity.id == role_id, + RoleEntity.status != RoleStatus.DELETED.value + ) + ) + result = await session.execute(stmt) + role_entity = result.scalars().first() + return role_entity + except Exception as e: + err = "根据角色ID获取角色失败" + logging.exception("[RoleManager] %s", err) + raise e + + @staticmethod + async def get_user_role_by_user_sub_and_team_id( + user_sub: str, team_id: uuid.UUID) -> UserRoleEntity: + """根据用户ID和团队ID获取用户角色""" + try: + async with await DataBase.get_session() as session: + stmt = select(UserRoleEntity).where( + and_( + UserRoleEntity.user_id == user_sub, + UserRoleEntity.team_id == team_id, + UserRoleEntity.status != UserRoleStatus.DELETED.value + ) + ) + result = await session.execute(stmt) + user_role_entity = result.scalars().first() + return user_role_entity + except Exception as e: + err = "根据用户ID和团队ID获取用户角色失败" + logging.exception("[RoleManager] %s", err) + raise e + + @staticmethod + async def get_role_by_role_name_and_team_id( + role_name: str, team_id: uuid.UUID) -> RoleEntity: + """根据角色名称和团队ID获取角色""" + try: + async with await DataBase.get_session() as session: + stmt = select(RoleEntity).where( + and_( + RoleEntity.name == role_name, + RoleEntity.team_id == team_id, + RoleEntity.status != RoleStatus.DELETED.value + ) + ) + result = await session.execute(stmt) + role_entity = result.scalars().first() + return role_entity + except Exception as e: + err = "根据角色名称和团队ID获取角色失败" + logging.exception("[RoleManager] %s", err) + raise e + @staticmethod async def get_action_by_team_id_user_sub_and_action( user_sub: str, team_id: uuid.UUID, action: str) -> ActionEntity: @@ -80,6 +148,8 @@ class RoleManager: UserRoleEntity.user_id == user_sub, UserRoleEntity.team_id == team_id, ActionEntity.action == action, + RoleActionEntity.status != RoleActionStatus.DELETED.value, + UserRoleEntity.status != UserRoleStatus.DELETED.value, ) ) result = await session.execute(stmt) @@ -89,3 +159,178 @@ class RoleManager: err = "根据团队ID、用户ID和操作获取操作失败" logging.exception("[RoleManager] %s", err) raise e + + @staticmethod + async def list_roles(req: ListRoleRequest) -> tuple[int, list[RoleEntity]]: + """根据团队ID获取角色列表""" + try: + async with await DataBase.get_session() as session: + stmt = select(RoleEntity).where( + and_( + RoleEntity.team_id == req.team_id, + RoleEntity.status != RoleStatus.DELETED.value + ) + ) + if req.role_id is not None: + stmt = stmt.where(RoleEntity.id == req.role_id) + if req.role_name is not None: + stmt = stmt.where( + RoleEntity.role_name.iike(f"%{req.role_name}%")) + count_stmt = select( + func.count()).select_from(stmt.subquery()) + result = await session.execute(count_stmt) + total = result.scalar() + stmt = stmt.order_by(RoleEntity.created_time.asc()) + stmt = stmt.offset((req.page - 1) * + req.page_size).limit(req.page_size) + result = await session.execute(stmt) + role_entities = result.scalars().all() + return total, role_entities + except Exception as e: + err = "根据团队ID获取角色列表失败" + logging.exception("[RoleManager] %s", err) + raise e + + @staticmethod + async def list_user_roles_by_team_id_and_user_subs( + team_id: uuid.UUID, user_subs: list[str]) -> list[UserRoleEntity]: + """根据团队ID和用户ID列表列出用户角色""" + try: + async with await DataBase.get_session() as session: + stmt = select(UserRoleEntity).where( + and_( + UserRoleEntity.team_id == team_id, + UserRoleEntity.user_id.in_(user_subs), + UserRoleEntity.status != UserRoleStatus.DELETED.value + ) + ) + result = await session.execute(stmt) + user_role_entities = result.scalars().all() + return user_role_entities + except Exception as e: + err = "根据团队ID和用户ID列表列出用户角色失败" + logging.exception("[RoleManager] %s", err) + raise e + + @staticmethod + async def list_roles_by_role_ids(role_ids: list[uuid.UUID]) -> list[RoleEntity]: + """根据角色ID列表列出角色""" + try: + async with await DataBase.get_session() as session: + stmt = select(RoleEntity).where( + and_( + RoleEntity.id.in_(role_ids), + RoleEntity.status != RoleStatus.DELETED.value + ) + ) + result = await session.execute(stmt) + role_entities = result.scalars().all() + return role_entities + except Exception as e: + err = "根据角色ID列表列出角色失败" + logging.exception("[RoleManager] %s", err) + raise e + + @staticmethod + async def list_role_actions_by_role_ids(role_ids: list[uuid.UUID]) -> list[RoleActionEntity]: + """根据角色ID列表列出角色操作""" + try: + async with await DataBase.get_session() as session: + stmt = select(RoleActionEntity).where( + and_( + RoleActionEntity.role_id.in_(role_ids), + RoleActionEntity.status != RoleActionStatus.DELETED.value + ) + ) + result = await session.execute(stmt) + role_action_entities = result.scalars().all() + return role_action_entities + except Exception as e: + err = "根据角色ID列表列出角色操作失败" + logging.exception("[RoleManager] %s", err) + raise e + + @staticmethod + async def update_role_by_id(role_id: uuid.UUID, role_dict: Dict[str, str]) -> bool: + """通过角色ID更新角色""" + try: + async with await DataBase.get_session() as session: + stmt = update(RoleEntity).where( + RoleEntity.id == role_id + ).values(**role_dict) + await session.execute(stmt) + await session.commit() + return True + except Exception as e: + err = f"通过角色ID更新角色失败 {e}" + logging.warning("[RoleManager] %s", err) + return False + + @staticmethod + async def update_role_actions_by_role_id( + role_id: uuid.UUID, role_action_dict: Dict[str, str]) -> bool: + """通过角色ID更新角色操作""" + try: + async with await DataBase.get_session() as session: + stmt = update(RoleActionEntity).where( + RoleActionEntity.role_id == role_id + ).values(**role_action_dict) + await session.execute(stmt) + await session.commit() + return True + except Exception as e: + err = "通过角色ID更新角色操作失败" + logging.warning("[RoleManager] %s", err) + + @staticmethod + async def update_user_role_by_id(user_role_id: uuid.UUID, user_role_dict: Dict[str, str]) -> bool: + """通过用户角色ID更新用户角色""" + try: + async with await DataBase.get_session() as session: + stmt = update(UserRoleEntity).where( + UserRoleEntity.id == user_role_id + ).values(**user_role_dict) + await session.execute(stmt) + await session.commit() + return True + except Exception as e: + err = "通过用户角色ID更新用户角色失败" + logging.warning("[RoleManager] %s", err) + return False + + @staticmethod + async def update_user_role_by_role_id( + role_id: uuid.UUID, user_role_dict: Dict[str, str]) -> bool: + """通过角色ID更新用户角色""" + try: + async with await DataBase.get_session() as session: + stmt = update(UserRoleEntity).where( + UserRoleEntity.role_id == role_id + ).values(**user_role_dict) + await session.execute(stmt) + await session.commit() + return True + except Exception as e: + err = "通过角色ID更新用户角色失败" + logging.warning("[RoleManager] %s", err) + return False + + @staticmethod + async def update_user_roles_by_team_id_and_user_subs( + team_id: uuid.UUID, user_subs: list[str], user_role_dict: Dict[str, str]) -> bool: + """通过团队ID和用户ID列表更新用户角色""" + try: + async with await DataBase.get_session() as session: + stmt = update(UserRoleEntity).where( + and_( + UserRoleEntity.team_id == team_id, + UserRoleEntity.user_id.in_(user_subs) + ) + ).values(**user_role_dict) + await session.execute(stmt) + await session.commit() + return True + except Exception as e: + err = "通过团队ID和用户ID列表更新用户角色失败" + logging.warning("[RoleManager] %s", err) + return False diff --git a/data_chain/manager/task_queue_mamanger.py b/data_chain/manager/task_queue_mamanger.py index 8f4db40dfc5fa1f28631d4a10711e219168790b6..b0df886ba3e7a15691f3b533d6d3993903c9a638 100644 --- a/data_chain/manager/task_queue_mamanger.py +++ b/data_chain/manager/task_queue_mamanger.py @@ -1,12 +1,10 @@ # Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. from sqlalchemy import select, delete, update, desc, asc, func, exists, or_, and_ -from sqlalchemy.orm import aliased import uuid from typing import Dict, List, Optional, Tuple from data_chain.logger.logger import logger as logging -from data_chain.stores.database.database import DataBase, TaskEntity -from data_chain.stores.mongodb.mongodb import MongoDB, Task +from data_chain.stores.database.database import DataBase, TaskQueueEntity from data_chain.entities.enum import TaskStatus @@ -14,60 +12,72 @@ class TaskQueueManager(): """任务队列管理类""" @staticmethod - async def add_task(task: Task): + async def add_task(task: TaskQueueEntity): try: - async with MongoDB.get_session() as session, await session.start_transaction(): - task_colletion = MongoDB.get_collection('witchiand_task') - await task_colletion.insert_one(task.model_dump(by_alias=True), session=session) + async with await DataBase.get_session() as session: + session.add(task) + await session.commit() except Exception as e: err = "添加任务到队列失败" logging.exception("[TaskQueueManager] %s", err) + raise e @staticmethod async def delete_task_by_id(task_id: uuid.UUID): """根据任务ID删除任务""" try: - async with MongoDB.get_session() as session, await session.start_transaction(): - task_colletion = MongoDB.get_collection('witchiand_task') - await task_colletion.delete_one({"_id": task_id}, session=session) + async with await DataBase.get_session() as session: + stmt = delete(TaskQueueEntity).where(TaskQueueEntity.id == task_id) + await session.execute(stmt) + await session.commit() except Exception as e: err = "删除任务失败" logging.exception("[TaskQueueManager] %s", err) raise e @staticmethod - async def get_oldest_tasks_by_status(status: TaskStatus) -> Task: + async def get_oldest_tasks_by_status(status: TaskStatus) -> Optional[TaskQueueEntity]: """根据任务状态获取最早的任务""" try: - async with MongoDB.get_session() as session: - task_colletion = MongoDB.get_collection('witchiand_task') - task = await task_colletion.find_one({"status": status}, sort=[("created_time", 1)], session=session) - return Task(**task) if task else None + async with await DataBase.get_session() as session: + stmt = ( + select(TaskQueueEntity) + .where(TaskQueueEntity.status == status.value) + .order_by(asc(TaskQueueEntity.created_time)) + .limit(1) + ) + result = await session.execute(stmt) + return result.scalars().first() except Exception as e: err = "获取最早的任务失败" logging.exception("[TaskQueueManager] %s", err) raise e @staticmethod - async def get_task_by_id(task_id: uuid.UUID) -> Task: + async def get_task_by_id(task_id: uuid.UUID) -> Optional[TaskQueueEntity]: """根据任务ID获取任务""" try: - async with MongoDB.get_session() as session: - task_colletion = MongoDB.get_collection('witchiand_task') - task = await task_colletion.find_one({"_id": task_id}, session=session) - return Task(**task) if task else None + async with await DataBase.get_session() as session: + stmt = select(TaskQueueEntity).where(TaskQueueEntity.id == task_id) + result = await session.execute(stmt) + return result.scalars().first() except Exception as e: err = "获取任务失败" logging.exception("[TaskQueueManager] %s", err) raise e @staticmethod - async def update_task_by_id(task_id: uuid.UUID, task: Task): + async def update_task_by_id(task_id: uuid.UUID, task: TaskQueueEntity): """根据任务ID更新任务""" try: - async with MongoDB.get_session() as session, await session.start_transaction(): - task_colletion = MongoDB.get_collection('witchiand_task') - await task_colletion.update_one({"_id": task_id}, {"$set": task.model_dump(by_alias=True)}, session=session) + async with await DataBase.get_session() as session: + stmt = ( + update(TaskQueueEntity) + .where(TaskQueueEntity.id == task_id) + .values(status=task.status) + ) + await session.execute(stmt) + await session.commit() except Exception as e: err = "更新任务失败" logging.exception("[TaskQueueManager] %s", err) diff --git a/data_chain/manager/team_manager.py b/data_chain/manager/team_manager.py index f90a860eb3ed933c03732827a064aee84701f5f3..227d8ad6f7c98b26077ed1ddb85988bfd040c2d8 100644 --- a/data_chain/manager/team_manager.py +++ b/data_chain/manager/team_manager.py @@ -4,9 +4,16 @@ from typing import Dict import uuid from data_chain.logger.logger import logger as logging -from data_chain.entities.request_data import ListTeamRequest -from data_chain.entities.enum import TeamStatus -from data_chain.stores.database.database import DataBase, TeamEntity, TeamUserEntity +from data_chain.entities.request_data import ( + ListTeamUserRequest, + ListTeamRequest +) +from data_chain.entities.enum import ( + TeamStatus, + UserStatus, + TeamUserStaus +) +from data_chain.stores.database.database import DataBase, TeamEntity, UserEntity, TeamUserEntity class TeamManager: @@ -50,10 +57,12 @@ class TeamManager: if req.team_id: stmt = stmt.where(TeamEntity.id == req.team_id) if req.team_name: - stmt = stmt.where(TeamEntity.name.ilike(f"%{req.team_name}%")) + stmt = stmt.where( + TeamEntity.name.ilike(f"%{req.team_name}%")) count_stmt = select(func.count()).select_from(stmt.subquery()) total = (await session.execute(count_stmt)).scalar() - stmt = stmt.limit(req.page_size).offset((req.page - 1) * req.page_size) + stmt = stmt.limit(req.page_size).offset( + (req.page - 1) * req.page_size) stmt = stmt.order_by(TeamEntity.created_time.desc()) result = await session.execute(stmt) team_entities = result.scalars().all() @@ -73,10 +82,12 @@ class TeamManager: if req.team_id: stmt = stmt.where(TeamEntity.id == req.team_id) if req.team_name: - stmt = stmt.where(TeamEntity.name.ilike(f"%{req.team_name}%")) + stmt = stmt.where( + TeamEntity.name.ilike(f"%{req.team_name}%")) count_stmt = select(func.count()).select_from(stmt.subquery()) total = (await session.execute(count_stmt)).scalar() - stmt = stmt.limit(req.page_size).offset((req.page - 1) * req.page_size) + stmt = stmt.limit(req.page_size).offset( + (req.page - 1) * req.page_size) stmt = stmt.order_by(TeamEntity.created_time.desc()) result = await session.execute(stmt) team_entities = result.scalars().all() @@ -116,10 +127,12 @@ class TeamManager: if req.team_id: stmt = stmt.where(TeamEntity.id == req.team_id) if req.team_name: - stmt = stmt.where(TeamEntity.name.ilike(f"%{req.team_name}%")) + stmt = stmt.where( + TeamEntity.name.ilike(f"%{req.team_name}%")) count_stmt = select(func.count()).select_from(stmt.subquery()) total = (await session.execute(count_stmt)).scalar() - stmt = stmt.limit(req.page_size).offset((req.page - 1) * req.page_size) + stmt = stmt.limit(req.page_size).offset( + (req.page - 1) * req.page_size) stmt = stmt.order_by(TeamEntity.created_time.desc()) result = await session.execute(stmt) team_entities = result.scalars().all() @@ -129,6 +142,86 @@ class TeamManager: logging.exception("[TeamManager] %s", err) raise e + @staticmethod + async def list_team_user_by_team_id(req: ListTeamUserRequest) -> tuple[int, list[UserEntity]]: + """列出团队成员""" + try: + async with await DataBase.get_session() as session: + stmt = select(UserEntity).join(TeamUserEntity, UserEntity.id == TeamUserEntity.user_id).where( + and_(TeamUserEntity.team_id == req.team_id, UserEntity.status != UserStatus.DELETED.value, TeamUserEntity.status != TeamUserStaus.DELETED.value)) + if req.user_sub: + stmt = stmt.where(UserEntity.id.ilike(f"%{req.user_sub}%")) + if req.user_name: + stmt = stmt.where( + UserEntity.name.ilike(f"%{req.user_name}%")) + count_stmt = select(func.count()).select_from(stmt.subquery()) + total = (await session.execute(count_stmt)).scalar() + stmt = stmt.limit(req.page_size).offset( + (req.page - 1) * req.page_size) + stmt = stmt.order_by(UserEntity.created_time.desc()) + result = await session.execute(stmt) + team_user_entities = result.scalars().all() + return (total, team_user_entities) + except Exception as e: + err = "列出团队成员失败" + logging.exception("[TeamManager] %s", err) + raise e + + @staticmethod + async def list_team_user_by_team_id_and_user_subs(team_id: uuid.UUID, user_subs: list[str]) -> list[TeamUserEntity]: + """列出团队成员通过用户ID列表""" + try: + async with await DataBase.get_session() as session: + stmt = select(TeamUserEntity).where(and_( + TeamUserEntity.team_id == team_id, + TeamUserEntity.user_id.in_(user_subs), + TeamUserEntity.status != TeamUserStaus.DELETED.value + )) + result = await session.execute(stmt) + team_user_entities = result.scalars().all() + return team_user_entities + except Exception as e: + err = "列出团队成员通过用户ID列表失败" + logging.exception("[TeamManager] %s", err) + raise e + + @staticmethod + async def get_team_user_by_user_sub_and_team_id(user_sub: str, team_id: uuid.UUID) -> TeamUserEntity: + """根据用户ID和团队ID获取团队成员""" + try: + async with await DataBase.get_session() as session: + stmt = select(TeamUserEntity).where(and_( + TeamUserEntity.user_id == user_sub, + TeamUserEntity.team_id == team_id, + TeamUserEntity.status != TeamUserStaus.DELETED.value + )) + result = await session.execute(stmt) + team_user_entity = result.scalars().first() + return team_user_entity + except Exception as e: + err = "根据用户ID和团队ID获取团队成员失败" + logging.exception("[TeamManager] %s", err) + raise e + + @staticmethod + async def get_team_by_id(team_id: uuid.UUID) -> TeamEntity: + """根据团队ID获取团队""" + try: + async with await DataBase.get_session() as session: + stmt = select(TeamEntity).where( + and_( + TeamEntity.id == team_id, + TeamEntity.status != TeamStatus.DELETED.value + ) + ) + result = await session.execute(stmt) + team_entity = result.scalars().first() + return team_entity + except Exception as e: + err = "根据团队ID获取团队失败" + logging.exception("[TeamManager] %s", err) + raise e + @staticmethod async def delete_team_by_id(team_id: uuid.UUID) -> uuid.UUID: """删除团队""" @@ -148,7 +241,8 @@ class TeamManager: """删除团队""" try: async with await DataBase.get_session() as session: - stmt = delete(TeamEntity).where(TeamEntity.status == TeamStatus.DELETED.value) + stmt = delete(TeamEntity).where( + TeamEntity.status == TeamStatus.DELETED.value) await session.execute(stmt) await session.commit() except Exception as e: @@ -161,7 +255,8 @@ class TeamManager: """更新团队""" try: async with await DataBase.get_session() as session: - stmt = update(TeamEntity).where(TeamEntity.id == team_id).values(**team_dict) + stmt = update(TeamEntity).where( + TeamEntity.id == team_id).values(**team_dict) await session.execute(stmt) await session.commit() stmt = select(TeamEntity).where(TeamEntity.id == team_id) @@ -172,3 +267,23 @@ class TeamManager: err = "更新团队失败" logging.exception("[TeamManager] %s", err) raise e + + @staticmethod + async def update_team_users_by_team_id_and_user_subs( + team_id: uuid.UUID, user_subs: list[str], user_role_dict: Dict[str, str]) -> bool: + """通过团队ID和用户ID列表更新团队成员""" + try: + async with await DataBase.get_session() as session: + stmt = update(TeamUserEntity).where( + and_( + TeamUserEntity.team_id == team_id, + TeamUserEntity.user_id.in_(user_subs) + ) + ).values(**user_role_dict) + await session.execute(stmt) + await session.commit() + return True + except Exception as e: + err = "通过团队ID和用户ID列表更新团队成员失败" + logging.exception("[TeamManager] %s", err) + return False diff --git a/data_chain/manager/team_message_manager.py b/data_chain/manager/team_message_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..c8db33c0db0e2d310a247c72958046622da5f4b8 --- /dev/null +++ b/data_chain/manager/team_message_manager.py @@ -0,0 +1,50 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. +from sqlalchemy import select, update, delete, and_, func +from data_chain.logger.logger import logger as logging +from data_chain.entities.request_data import ( + ListTeamMsgRequest +) +from data_chain.entities.enum import ( + TeamMessageStatus +) +from data_chain.stores.database.database import DataBase, TeamMessageEntity + + +class TeamMessageManager: + """团队消息管理器""" + @staticmethod + async def add_team_msg(team_msg_entity: TeamMessageEntity) -> TeamMessageEntity: + """添加团队消息""" + try: + async with await DataBase.get_session() as session: + session.add(team_msg_entity) + await session.commit() + await session.refresh(team_msg_entity) + except Exception as e: + err = "添加团队消息失败" + logging.exception("[TeamMessageManager] %s", err) + raise e + return team_msg_entity + + @staticmethod + async def list_team_msg_by_team_id(req: ListTeamMsgRequest) -> tuple[int, list[TeamMessageEntity]]: + """列出团队消息""" + try: + async with await DataBase.get_session() as session: + stmt = select(TeamMessageEntity).where( + and_(TeamMessageEntity.team_id == req.team_id, + TeamMessageEntity.status != TeamMessageStatus.DELETED.value) + ).order_by(TeamMessageEntity.created_time.desc()) + total_stmt = select(func.count()).select_from(stmt.subquery()) + total_result = await session.execute(total_stmt) + total = total_result.scalar_one() + stmt = stmt.order_by(TeamMessageEntity.created_time.desc()) + stmt = stmt.offset((req.page - 1) * + req.page_size).limit(req.page_size) + result = await session.execute(stmt) + team_msg_entities = result.scalars().all() + return total, team_msg_entities + except Exception as e: + err = "列出团队消息失败" + logging.exception("[TeamMessageManager] %s", err) + raise e diff --git a/data_chain/manager/testcase_manager.py b/data_chain/manager/testcase_manager.py index d3f2770d2d9db047f1e67080994825a1a3009650..08873b7bf2e1812421755d6d253a789e99f92ace 100644 --- a/data_chain/manager/testcase_manager.py +++ b/data_chain/manager/testcase_manager.py @@ -40,6 +40,24 @@ class TestCaseManager(): err = "批量添加测试用例失败" logging.exception("[TestCaseManager] %s", err) + @staticmethod + async def get_test_case_by_id(test_case_id: uuid.UUID) -> Optional[TestCaseEntity]: + """根据测试用例ID获取测试用例""" + try: + async with await DataBase.get_session() as session: + stmt = ( + select(TestCaseEntity) + .where(TestCaseEntity.id == test_case_id, + TestCaseEntity.status != TestCaseStatus.DELETED.value) + ) + result = await session.execute(stmt) + test_case_entity = result.scalars().first() + return test_case_entity + except Exception as e: + err = "根据测试用例ID获取测试用例失败" + logging.exception("[TestCaseManager] %s", err) + raise e + @staticmethod async def list_test_case(req: ListTestCaseRequest) -> Tuple[int, List[TestCaseEntity]]: """根据测试ID查询测试用例""" @@ -54,7 +72,8 @@ class TestCaseManager(): total = (await session.execute(count_stmt)).scalar() stmt = stmt.order_by(TestCaseEntity.created_at.desc()) stmt = stmt.order_by(TestCaseEntity.id.asc()) - stmt = stmt.offset((req.page - 1) * req.page_size).limit(req.page_size) + stmt = stmt.offset( + (req.page - 1) * req.page_size).limit(req.page_size) result = await session.execute(stmt) testcase_entities = result.scalars().all() return (total, testcase_entities) diff --git a/data_chain/manager/touch b/data_chain/manager/touch deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/data_chain/manager/user_manager.py b/data_chain/manager/user_manager.py index 43391fe5b8a3538d139d340a1d53efe290d44ce4..2f933ee76d6560cc0c7777467aa35ea3c75478ea 100644 --- a/data_chain/manager/user_manager.py +++ b/data_chain/manager/user_manager.py @@ -1,7 +1,8 @@ # Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. -from sqlalchemy import select, delete +from sqlalchemy import select, delete, func from data_chain.logger.logger import logger as logging +from data_chain.entities.request_data import ListUserRequest from data_chain.entities.enum import UserStatus from data_chain.stores.database.database import DataBase, UserEntity @@ -18,5 +19,30 @@ class UserManager: return True except Exception as e: err = "用户添加失败" - logging.error("[UserManger] %s", err) + logging.warning("[UserManger] %s", err) return False + + @staticmethod + async def list_user(req: ListUserRequest) -> tuple[int, list[UserEntity]]: + try: + async with await DataBase.get_session() as session: + stmt = select(UserEntity).where( + UserEntity.status == UserStatus.ACTIVE) + if req.user_sub: + stmt = stmt.where(UserEntity.id.ilike(f"%{req.user_sub}%")) + if req.user_name: + stmt = stmt.where( + UserEntity.name.ilike(f"%{req.user_name}%")) + count_stmt = select( + func.count()).select_from(stmt.subquery()) + total = (await session.execute(count_stmt)).scalar() + stmt = stmt.offset((req.page - 1) * req.page_size).limit( + req.page_size) + stmt=stmt.order_by(UserEntity.created_time.desc()) + result = (await session.execute(stmt)).scalars().all() + return total, result + except Exception as e: + err = "用户列表获取失败" + logging.warning("[UserManger] %s", err) + raise e + return [] diff --git a/data_chain/manager/user_message_manager.py b/data_chain/manager/user_message_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..e5f7ce9544c77102944dfbde7a6a42f0da61bf07 --- /dev/null +++ b/data_chain/manager/user_message_manager.py @@ -0,0 +1,151 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. +from sqlalchemy import select, update, delete, and_, or_, func +from typing import Dict +import uuid + +from data_chain.logger.logger import logger as logging +from data_chain.entities.request_data import ( + ListUserMessageRequest +) +from data_chain.entities.enum import ( + TeamStatus, + UserRoleStatus, + RoleStatus, + UserMessageStatus +) +from data_chain.manager.team_manager import TeamManager +from data_chain.stores.database.database import ( + DataBase, + TeamEntity, + UserRoleEntity, + RoleActionEntity, + UserMessageEntity +) + + +class UserMessageManager: + """用户消息管理器""" + + @staticmethod + async def add_user_message(user_message_entity: UserMessageEntity) -> bool: + try: + async with await DataBase.get_session() as session: + session.add(user_message_entity) + await session.commit() + await session.refresh(user_message_entity) + return True + except Exception as e: + err = "用户消息添加失败" + logging.warning("[UserMessageManager] %s", err) + return False + + @staticmethod + async def list_user_messages( + user_sub: str, req: ListUserMessageRequest) -> tuple[int, list[UserMessageEntity]]: + """根据用户标识和消息类型列出用户消息""" + try: + async with await DataBase.get_session() as session: + # 查询用户加入或者创建的团队,加入的团队要求用户的角色有 POST /usr_msg/list 权限 + stmt = select(TeamEntity.id).where( + TeamEntity.author_id == user_sub, + TeamEntity.status == TeamStatus.EXISTED.value + ) + team_ids_created = [team_id for team_id, in await session.execute(stmt)] + stmt = select(TeamEntity.id).join( + UserRoleEntity, TeamEntity.id == UserRoleEntity.team_id + ).join( + RoleActionEntity, UserRoleEntity.role_id == RoleActionEntity.role_id + ).where( + and_( + TeamEntity.author_id != user_sub, + UserRoleEntity.user_id == user_sub, + TeamEntity.status == TeamStatus.EXISTED.value, + UserRoleEntity.status == UserRoleStatus.EXISTED.value, + RoleActionEntity.action == 'POST /usr_msg/list' + ) + ).distinct() + team_ids_joined = [team_id for team_id, in await session.execute(stmt)] + team_ids = team_ids_created + team_ids_joined + + # 创建基础查询条件 + base_conditions = or_( + and_( + UserMessageEntity.sender_id == user_sub, + UserMessageEntity.status_to_sender != UserMessageStatus.DELETED.value + ), + and_( + UserMessageEntity.receiver_id == user_sub, + UserMessageEntity.status_to_receiver != UserMessageStatus.DELETED.value + ), + and_( + UserMessageEntity.sender_id != user_sub, + UserMessageEntity.receiver_id != user_sub, + UserMessageEntity.team_id.in_(team_ids), + UserMessageEntity.is_to_all == True, + UserMessageEntity.status_to_receiver != UserMessageStatus.DELETED.value + ) + ) + + # 单独构建计数查询 + count_stmt = select(func.count()).select_from( + UserMessageEntity).where(base_conditions) + if req.msg_type: + count_stmt = count_stmt.where( + UserMessageEntity.type == req.msg_type.value) + total = (await session.execute(count_stmt)).scalar() + + # 构建数据查询 + data_stmt = select(UserMessageEntity).where(base_conditions) + if req.msg_type: + data_stmt = data_stmt.where( + UserMessageEntity.type == req.msg_type.value) + + # 添加排序和分页 + data_stmt = data_stmt.order_by(UserMessageEntity.created_time.desc())\ + .offset((req.page - 1) * req.page_size)\ + .limit(req.page_size) + + result = await session.execute(data_stmt) + user_message_entities = result.scalars().all() + + return total, user_message_entities + except Exception as e: + err = f"根据用户标识和消息类型列出用户消息失败 {e}" + logging.warning("[UserMessageManager] %s", err) + return 0, [] + + @staticmethod + async def get_user_message_by_msg_id(msg_id: uuid.UUID) -> UserMessageEntity: + """通过消息ID获取用户消息""" + try: + async with await DataBase.get_session() as session: + stmt = select(UserMessageEntity).where( + UserMessageEntity.id == msg_id, + or_( + UserMessageEntity.status_to_sender != UserMessageStatus.DELETED.value, + UserMessageEntity.status_to_receiver != UserMessageStatus.DELETED.value + ) + ) + result = await session.execute(stmt) + user_message_entity = result.scalars().first() + return user_message_entity + except Exception as e: + err = f"通过消息ID获取用户消息失败 {e}" + logging.warning("[UserMessageManager] %s", err) + return None + + @staticmethod + async def update_user_message_by_msg_id(msg_id: uuid.UUID, msg_dict: Dict[str, str]) -> bool: + """通过消息ID更新用户消息""" + try: + async with await DataBase.get_session() as session: + stmt = update(UserMessageEntity).where( + UserMessageEntity.id == msg_id + ).values(**msg_dict) + await session.execute(stmt) + await session.commit() + return True + except Exception as e: + err = "通过消息ID更新用户消息失败" + logging.warning("[UserMessageManager] %s", err) + return False diff --git a/data_chain/parser/handler/deep_pdf_parser.py b/data_chain/parser/handler/deep_pdf_parser.py index bf4db0867147f29cc1145acc4f0aab66db1ef370..9eee88938ee812569132856498e7a78a16a0243e 100644 --- a/data_chain/parser/handler/deep_pdf_parser.py +++ b/data_chain/parser/handler/deep_pdf_parser.py @@ -645,3 +645,5 @@ class DeepPdfParser(BaseParser): # 处理图片节点 continue return parse_result +import asyncio +result=asyncio.run(DeepPdfParser.parser("./test.pdf")) \ No newline at end of file diff --git a/data_chain/parser/handler/html_parser.py b/data_chain/parser/handler/html_parser.py index 9a98c196a6dffcd5d2ce83f0f4cfb825e1f6f3c0..5520e133130d140630af5a395294eb35203c318b 100644 --- a/data_chain/parser/handler/html_parser.py +++ b/data_chain/parser/handler/html_parser.py @@ -20,7 +20,8 @@ class HTMLParser(BaseParser): table_data = [] for row in rows: cells = row.find_all(['th', 'td']) - row_data = [cell.get_text(strip=True, separator=' ') for cell in cells] + row_data = [cell.get_text(strip=True, separator=' ') + for cell in cells] if row_data: table_data.append(row_data) return table_data @@ -67,7 +68,8 @@ class HTMLParser(BaseParser): continue if element.name == 'div' or element.name == 'head' or element.name == 'header' or \ element.name == 'body' or element.name == 'section' or element.name == 'article' or \ - element.name == 'nav' or element.name == 'main' or element.name == 'p' or element.name == 'ol': + element.name == 'nav' or element.name == 'main' or element.name == 'p' or element.name == 'ol'\ + or element.name == 'hr' or element.name == 'ul': # 处理div内部元素 inner_html = ''.join(str(child) for child in element.children) child_subtree = await HTMLParser.build_subtree(inner_html, current_level+1) @@ -116,9 +118,26 @@ class HTMLParser(BaseParser): content_html = ''.join(str(el) for el in content_elements) child_subtree = await HTMLParser.build_subtree(content_html, level) parse_topology_type = ChunkParseTopology.TREENORMAL + node = ParseNode( + id=uuid.uuid4(), + title=title, + lv=level, + parse_topology_type=parse_topology_type, + content="", + type=ChunkType.TEXT, + link_nodes=child_subtree + ) else: - child_subtree = [] parse_topology_type = ChunkParseTopology.TREELEAF + text = title + node = ParseNode( + id=uuid.uuid4(), + lv=current_level, + parse_topology_type=parse_topology_type, + content=text, + type=ChunkType.TEXT, + link_nodes=[] + ) node = ParseNode( id=uuid.uuid4(), diff --git a/data_chain/parser/handler/md_parser.py b/data_chain/parser/handler/md_parser.py index 97ff373160e3f7f5a6ae58813075476e3e7060c3..c835abc72aeaa56f23e4c7b1d062ad516d8409f7 100644 --- a/data_chain/parser/handler/md_parser.py +++ b/data_chain/parser/handler/md_parser.py @@ -29,7 +29,8 @@ class MdParser(BaseParser): cells = row.find_all(['th', 'td']) # 提取单元格中的文本,并去除多余的空白字符 - row_data = [cell.get_text(strip=True, separator=' ') for cell in cells] + row_data = [cell.get_text(strip=True, separator=' ') + for cell in cells] if row_data: # 如果该行有数据 table_data.append(row_data) @@ -79,7 +80,7 @@ class MdParser(BaseParser): logging.error(f"[MdParser] 处理非标签节点失败: {e}") continue - if element.name == 'p' or element.name == 'ol' or element.name == 'hr': + if element.name == 'p' or element.name == 'ol' or element.name == 'hr' or element.name == 'ul' or element.name == 'div': inner_html = ''.join(str(child) for child in element.children) child_subtree = await MdParser.build_subtree(inner_html, current_level+1) parse_topology_type = ChunkParseTopology.TREENORMAL if len( @@ -99,7 +100,7 @@ class MdParser(BaseParser): node = ParseNode( id=uuid.uuid4(), lv=current_level, - parse_topology_type=ChunkParseTopology.TREELEAF, + parse_topology_type=parse_topology_type, content=text, type=ChunkType.TEXT, link_nodes=[] @@ -110,7 +111,7 @@ class MdParser(BaseParser): level = int(element.name[1:]) except Exception: level = current_level - title = element.get_text() + title = element.get_text().strip() content_elements = [] while current_level_elements: @@ -127,19 +128,26 @@ class MdParser(BaseParser): content_html = ''.join(str(el) for el in content_elements) child_subtree = await MdParser.build_subtree(content_html, level) parse_topology_type = ChunkParseTopology.TREENORMAL + node = ParseNode( + id=uuid.uuid4(), + title=title, + lv=level, + parse_topology_type=parse_topology_type, + content="", + type=ChunkType.TEXT, + link_nodes=child_subtree + ) else: - child_subtree = [] parse_topology_type = ChunkParseTopology.TREELEAF - - node = ParseNode( - id=uuid.uuid4(), - title=title, - lv=level, - parse_topology_type=parse_topology_type, - content="", - type=ChunkType.TEXT, - link_nodes=child_subtree - ) + text = title + node = ParseNode( + id=uuid.uuid4(), + lv=current_level, + parse_topology_type=parse_topology_type, + content=text, + type=ChunkType.TEXT, + link_nodes=[] + ) subtree.append(node) elif element.name == 'code': code_text = element.get_text().strip() diff --git a/data_chain/parser/handler/md_zip_parser.py b/data_chain/parser/handler/md_zip_parser.py index c80dc860ac98f89173766ed97d6374ba831f0382..0b01a1fd225883841220844cb6b7f3d862fa78e7 100644 --- a/data_chain/parser/handler/md_zip_parser.py +++ b/data_chain/parser/handler/md_zip_parser.py @@ -30,7 +30,8 @@ class MdZipParser(BaseParser): cells = row.find_all(['th', 'td']) # 提取单元格中的文本,并去除多余的空白字符 - row_data = [cell.get_text(strip=True, separator=' ') for cell in cells] + row_data = [cell.get_text(strip=True, separator=' ') + for cell in cells] if row_data: # 如果该行有数据 table_data.append(row_data) @@ -100,7 +101,7 @@ class MdZipParser(BaseParser): except Exception as e: logging.error(f"[MdZipParser] 处理非标签节点失败: {e}") continue - if element.name == 'p' or element.name == 'ol' or element.name == 'hr': + if element.name == 'p' or element.name == 'ol' or element.name == 'hr' or element.name == 'ul' or element.name == 'div': inner_html = ''.join(str(child) for child in element.children) child_subtree = await MdZipParser.build_subtree(file_path, inner_html, current_level+1) parse_topology_type = ChunkParseTopology.TREENORMAL if len( @@ -146,21 +147,28 @@ class MdZipParser(BaseParser): # 如果有内容,处理这些内容 if content_elements: content_html = ''.join(str(el) for el in content_elements) - child_subtree = await MdZipParser.build_subtree(file_path, content_html, level) + child_subtree = await MdZipParser.build_subtree(content_html, level) parse_topology_type = ChunkParseTopology.TREENORMAL + node = ParseNode( + id=uuid.uuid4(), + title=title, + lv=level, + parse_topology_type=parse_topology_type, + content="", + type=ChunkType.TEXT, + link_nodes=child_subtree + ) else: - child_subtree = [] parse_topology_type = ChunkParseTopology.TREELEAF - - node = ParseNode( - id=uuid.uuid4(), - title=title, - lv=level, - parse_topology_type=parse_topology_type, - content="", - type=ChunkType.TEXT, - link_nodes=child_subtree - ) + text = title + node = ParseNode( + id=uuid.uuid4(), + lv=current_level, + parse_topology_type=parse_topology_type, + content=text, + type=ChunkType.TEXT, + link_nodes=[] + ) subtree.append(node) elif element.name == 'code': code_text = element.get_text().strip() diff --git a/data_chain/parser/parse_result.py b/data_chain/parser/parse_result.py index b69e3f451ae5c3335e2a2c4578ae5d4535f023b1..1f4db212187348033615929bb39db4f425344ef8 100644 --- a/data_chain/parser/parse_result.py +++ b/data_chain/parser/parse_result.py @@ -24,5 +24,6 @@ class ParseNode(BaseModel): class ParseResult(BaseModel): """解析结果""" + doc_hash: str = Field(default='', description="文档hash值") parse_topology_type: DocParseRelutTopology = Field(..., description="解析拓扑类型") nodes: list[ParseNode] = Field(..., description="节点列表") diff --git a/data_chain/parser/tools/ocr_tool.py b/data_chain/parser/tools/ocr_tool.py index 858517dab74091ccc2f6d9badcf86049021e17cb..d2db21503883bee61e5455259e111fb42df88750 100644 --- a/data_chain/parser/tools/ocr_tool.py +++ b/data_chain/parser/tools/ocr_tool.py @@ -2,11 +2,13 @@ from PIL import Image, ImageEnhance import yaml import cv2 import numpy as np +import requests from data_chain.parser.tools.token_tool import TokenTool from data_chain.logger.logger import logger as logging from data_chain.config.config import config from data_chain.llm.llm import LLM from data_chain.parser.tools.instruct_scan_tool import InstructScanTool +from data_chain.config.config import config class OcrTool: @@ -14,7 +16,7 @@ class OcrTool: rec_model_dir = 'data_chain/parser/model/ocr/ch_PP-OCRv4_rec_infer' cls_model_dir = 'data_chain/parser/model/ocr/ch_ppocr_mobile_v2.0_cls_infer' # 优化 OCR 参数配置 - if InstructScanTool.check_avx512_support(): + if InstructScanTool.check_avx512_support() and config['OCR_METHOD'] == "offline": from paddleocr import PaddleOCR model = PaddleOCR( det_model_dir=det_model_dir, @@ -30,6 +32,10 @@ class OcrTool: async def ocr_from_image_path(image_path: str) -> list: try: # 打开图片 + if config['OCR_METHOD'] == 'online' and config['OCR_API_URL']: + result = requests.get(config['OCR_API_URL'], files={'file': ( + image_path, open(image_path, 'rb'), 'image/jpeg')}).json() + return result.get("result", []) if OcrTool.model is None: err = "[OCRTool] 当前机器不支持 AVX-512,无法进行OCR识别" logging.error(err) @@ -58,6 +64,8 @@ class OcrTool: async def merge_text_from_ocr_result(ocr_result: list) -> str: text = '' try: + if ocr_result[0] is None or len(ocr_result[0]) == 0: + return "" for _ in ocr_result[0]: text += str(_[1][0]) return text @@ -67,16 +75,20 @@ class OcrTool: return '' @staticmethod - async def enhance_ocr_result(ocr_result, image_related_text='', llm: LLM = None) -> str: + async def enhance_ocr_result(ocr_result, image_related_text='', llm: LLM = None, language: str = "中文") -> str: try: with open(config['PROMPT_PATH'], 'r', encoding='utf-8') as f: prompt_dict = yaml.load(f, Loader=yaml.SafeLoader) - prompt_template = prompt_dict.get('OCR_ENHANCED_PROMPT', '') + prompt_template = prompt_dict.get('OCR_ENHANCED_PROMPT', {}) + prompt_template = prompt_template.get(language, '') pre_part_description = "" token_limit = llm.max_tokens//2 image_related_text = TokenTool.get_k_tokens_words_from_content(image_related_text, token_limit) ocr_result_parts = TokenTool.split_str_with_slide_window(str(ocr_result), token_limit) - user_call = '请详细输出图片的摘要,不要输出其他内容' + if language == 'en': + user_call = 'Please provide a English summary of the image content, do not output anything else' + else: + user_call = '请详细输出图片的中文摘要,不要输出其他内容' for part in ocr_result_parts: pre_part_description_cp = pre_part_description try: @@ -96,19 +108,16 @@ class OcrTool: return OcrTool.merge_text_from_ocr_result(ocr_result) @staticmethod - async def image_to_text(image: np.ndarray, image_related_text: str = '', llm: LLM = None) -> str: + async def image_to_text( + image_file_path: str, image_related_text: str = '', llm: LLM = None, language: str = '中文') -> str: try: - if OcrTool.model is None: - err = "[OCRTool] 当前机器不支持 AVX-512,无法进行OCR识别" - logging.error(err) - return '' - ocr_result = await OcrTool.ocr_from_image(image) + ocr_result = await OcrTool.ocr_from_image_path(image_file_path) if ocr_result is None: return '' if llm is None: text = await OcrTool.merge_text_from_ocr_result(ocr_result) else: - text = await OcrTool.enhance_ocr_result(ocr_result, image_related_text, llm) + text = await OcrTool.enhance_ocr_result(ocr_result, image_related_text, llm, language) if "图片内容为空" in text: return "" return text diff --git a/data_chain/parser/tools/token_tool.py b/data_chain/parser/tools/token_tool.py index a9a050fd135d1878d1b46579e29814dcbc8466e2..db3fc829f94ef0fcb9f379b64898a7c36fbfe502 100644 --- a/data_chain/parser/tools/token_tool.py +++ b/data_chain/parser/tools/token_tool.py @@ -261,20 +261,24 @@ class TokenTool: return [sentence for index, sentence, score in top_k_sentence_and_score_list] @staticmethod - async def get_abstract_by_llm(content: str, llm: LLM) -> str: + async def get_abstract_by_llm(content: str, llm: LLM, language: str) -> str: """ 使用llm进行内容摘要 """ try: with open(config['PROMPT_PATH'], 'r', encoding='utf-8') as f: prompt_dict = yaml.load(f, Loader=yaml.SafeLoader) - prompt_template = prompt_dict.get('CONTENT_TO_ABSTRACT_PROMPT', '') + prompt_template = prompt_dict.get('CONTENT_TO_ABSTRACT_PROMPT', {}) + prompt_template = prompt_template.get(language, '') sentences = TokenTool.split_str_with_slide_window(content, llm.max_tokens//3*2) abstract = '' for sentence in sentences: abstract = TokenTool.get_k_tokens_words_from_content(abstract, llm.max_tokens//3) sys_call = prompt_template.format(content=sentence, abstract=abstract) - user_call = '请结合文本和摘要输出新的摘要' + if language == 'en': + user_call = 'Please output a new English abstract based on the text and the existing abstract' + else: + user_call = '请结合文本和已有摘要生成新的中文摘要' abstract = await llm.nostream([], sys_call, user_call) return abstract except Exception as e: @@ -282,17 +286,21 @@ class TokenTool: logging.exception("[TokenTool] %s", err) @staticmethod - async def get_title_by_llm(content: str, llm: LLM) -> str: + async def get_title_by_llm(content: str, llm: LLM, language: str = '中文') -> str: """ 使用llm进行标题生成 """ try: with open(config['PROMPT_PATH'], 'r', encoding='utf-8') as f: prompt_dict = yaml.load(f, Loader=yaml.SafeLoader) - prompt_template = prompt_dict.get('CONTENT_TO_TITLE_PROMPT', '') + prompt_template = prompt_dict.get('CONTENT_TO_TITLE_PROMPT', {}) + prompt_template = prompt_template.get(language, '') content = TokenTool.get_k_tokens_words_from_content(content, llm.max_tokens) sys_call = prompt_template.format(content=content) - user_call = '请结合文本输出标题' + if language == 'en': + user_call = 'Please generate a English title based on the text' + else: + user_call = '请结合文本生成一个中文标题' title = await llm.nostream([], sys_call, user_call) return title except Exception as e: @@ -300,7 +308,7 @@ class TokenTool: logging.exception("[TokenTool] %s", err) @staticmethod - async def cal_recall(answer_1: str, answer_2: str, llm: LLM) -> float: + async def cal_recall(answer_1: str, bac_info: str, llm: LLM, language: str) -> float: """ 计算recall 参数: @@ -311,10 +319,11 @@ class TokenTool: try: with open(config['PROMPT_PATH'], 'r', encoding='utf-8') as f: prompt_dict = yaml.load(f, Loader=yaml.SafeLoader) - prompt_template = prompt_dict.get('ANSWER_TO_ANSWER_PROMPT', '') - answer_1 = TokenTool.get_k_tokens_words_from_content(answer_1, llm.max_tokens//2) - answer_2 = TokenTool.get_k_tokens_words_from_content(answer_2, llm.max_tokens//2) - prompt = prompt_template.format(text_1=answer_1, text_2=answer_2) + prompt_template = prompt_dict.get('ANSWER_TO_ANSWER_PROMPT', {}) + prompt_template = prompt_template.get(language, '') + answer_1 = TokenTool.get_k_tokens_words_from_content(answer_1, llm.max_tokens//8) + bac_info = TokenTool.get_k_tokens_words_from_content(bac_info, llm.max_tokens-llm.max_tokens//8) + prompt = prompt_template.format(text_1=answer_1, text_2=bac_info) sys_call = prompt user_call = '请输出相似度' similarity = await llm.nostream([], sys_call, user_call) @@ -325,7 +334,7 @@ class TokenTool: return -1 @staticmethod - async def cal_precision(question: str, content: str, llm: LLM) -> float: + async def cal_precision(question: str, content: str, llm: LLM, language: str) -> float: """ 计算precision 参数: @@ -335,17 +344,19 @@ class TokenTool: try: with open(config['PROMPT_PATH'], 'r', encoding='utf-8') as f: prompt_dict = yaml.load(f, Loader=yaml.SafeLoader) - prompt_template = prompt_dict.get('CONTENT_TO_STATEMENTS_PROMPT', '') + prompt_template = prompt_dict.get('CONTENT_TO_STATEMENTS_PROMPT', {}) + prompt_template = prompt_template.get(language, '') content = TokenTool.compress_tokens(content, llm.max_tokens) sys_call = prompt_template.format(content=content) user_call = '请结合文本输出陈诉列表' statements = await llm.nostream([], sys_call, user_call, st_str='[', - en_str=']') + en_str=']') statements = json.loads(statements) if len(statements) == 0: return 0 score = 0 - prompt_template = prompt_dict.get('STATEMENTS_TO_QUESTION_PROMPT', '') + prompt_template = prompt_dict.get('STATEMENTS_TO_QUESTION_PROMPT', {}) + prompt_template = prompt_template.get(language, '') for statement in statements: statement = TokenTool.get_k_tokens_words_from_content(statement, llm.max_tokens) prompt = prompt_template.format(statement=statement, question=question) @@ -353,7 +364,7 @@ class TokenTool: user_call = '请结合文本输出YES或NO' yn = await llm.nostream([], sys_call, user_call) yn = yn.lower() - if yn == 'yes': + if 'yes' in yn: score += 1 return score/len(statements)*100 except Exception as e: @@ -362,7 +373,7 @@ class TokenTool: return -1 @staticmethod - async def cal_faithfulness(question: str, answer: str, content: str, llm: LLM) -> float: + async def cal_faithfulness(question: str, answer: str, content: str, llm: LLM, language: str) -> float: """ 计算faithfulness 参数: @@ -372,15 +383,17 @@ class TokenTool: try: with open(config['PROMPT_PATH'], 'r', encoding='utf-8') as f: prompt_dict = yaml.load(f, Loader=yaml.SafeLoader) - prompt_template = prompt_dict.get('QA_TO_STATEMENTS_PROMPT', '') + prompt_template = prompt_dict.get('QA_TO_STATEMENTS_PROMPT', {}) + prompt_template = prompt_template.get(language, '') question = TokenTool.get_k_tokens_words_from_content(question, llm.max_tokens//8) answer = TokenTool.get_k_tokens_words_from_content(answer, llm.max_tokens//8*7) prompt = prompt_template.format(question=question, answer=answer) sys_call = prompt user_call = '请结合问题和答案输出陈诉' - statements = await llm.nostream([], sys_call, user_call,st_str='[', - en_str=']') - prompt_template = prompt_dict.get('STATEMENTS_TO_FRAGMENT_PROMPT', '') + statements = await llm.nostream([], sys_call, user_call, st_str='[', + en_str=']') + prompt_template = prompt_dict.get('STATEMENTS_TO_FRAGMENT_PROMPT', {}) + prompt_template = prompt_template.get(language, '') statements = json.loads(statements) if len(statements) == 0: return 0 @@ -394,7 +407,7 @@ class TokenTool: user_call = user_call yn = await llm.nostream([], sys_call, user_call) yn = yn.lower() - if yn == 'yes': + if 'yes' in yn: score += 1 return score/len(statements)*100 except Exception as e: @@ -416,7 +429,7 @@ class TokenTool: return cosine_dist @staticmethod - async def cal_relevance(question: str, answer: str, llm: LLM) -> float: + async def cal_relevance(question: str, answer: str, llm: LLM, language: str) -> float: """ 计算relevance 参数: @@ -426,7 +439,8 @@ class TokenTool: try: with open(config['PROMPT_PATH'], 'r', encoding='utf-8') as f: prompt_dict = yaml.load(f, Loader=yaml.SafeLoader) - prompt_template = prompt_dict.get('GENREATE_QUESTION_FROM_CONTENT_PROMPT', '') + prompt_template = prompt_dict.get('GENERATE_QUESTION_FROM_CONTENT_PROMPT', {}) + prompt_template = prompt_template.get(language, '') answer = TokenTool.get_k_tokens_words_from_content(answer, llm.max_tokens) sys_call = prompt_template.format(k=5, content=answer) user_call = '请结合文本输出问题列表' diff --git a/data_chain/rag/base_searcher.py b/data_chain/rag/base_searcher.py index ef96e0f61ca6d1f2b5a1e65937d3426638eb62bb..6e44b9fd18d1c345b3d92e91a3758ec367d14922 100644 --- a/data_chain/rag/base_searcher.py +++ b/data_chain/rag/base_searcher.py @@ -2,13 +2,14 @@ import uuid from pydantic import BaseModel, Field import random +from typing import Union from data_chain.logger.logger import logger as logging from data_chain.apps.base.convertor import Convertor from data_chain.stores.database.database import ChunkEntity from data_chain.parser.tools.token_tool import TokenTool from data_chain.manager.chunk_manager import ChunkManager from data_chain.entities.response_data import Chunk, DocChunk - +from data_chain.rerank.rerank import Rerank class BaseSearcher: @staticmethod @@ -39,21 +40,33 @@ class BaseSearcher: err = f"[BaseSearch] 检索器不存在,search_method: {search_method}" logging.exception(err) raise Exception(err) - + @staticmethod - async def rerank(chunk_entities: list[ChunkEntity], query: str) -> list[ChunkEntity]: + async def rerank(chunk_entities: list[ChunkEntity],rerank_method:Union[None,str], query: str) -> list[ChunkEntity]: """ 重新排序 :param list: 检索结果 :param query: 查询 :return: 重新排序后的结果 """ - score_chunk_entities = [] - for chunk_entity in chunk_entities: - score = TokenTool.cal_jac(chunk_entity.text, query) - score_chunk_entities.append((score, chunk_entity)) - score_chunk_entities.sort(key=lambda x: x[0], reverse=True) - sorted_chunk_entities = [chunk_entity for _, chunk_entity in score_chunk_entities] + if rerank_method is None: + score_chunk_entities = [] + for chunk_entity in chunk_entities: + score = TokenTool.cal_jac(chunk_entity.text, query) + score_chunk_entities.append((score, chunk_entity)) + score_chunk_entities.sort(key=lambda x: x[0], reverse=True) + sorted_chunk_entities = [chunk_entity for _, chunk_entity in score_chunk_entities] + else: + text=[] + for chunk_entity in chunk_entities: + text.append(chunk_entity.text) + try: + rerank_index = await Rerank.rerank(query, text, top_k=len(text)) + except Exception as e: + err = f"[BaseSearch] 重新排序失败,error: {e}" + logging.exception(err) + return chunk_entities + sorted_chunk_entities = [chunk_entities[i] for i in rerank_index] return sorted_chunk_entities @staticmethod diff --git a/data_chain/rag/doc2chunk_bfs_searcher.py b/data_chain/rag/doc2chunk_bfs_searcher.py index c72e7bd1c1a05b2c9ea68e027dac522714871210..629e8d230b3095ec7e7d8352e673718e244f9ce5 100644 --- a/data_chain/rag/doc2chunk_bfs_searcher.py +++ b/data_chain/rag/doc2chunk_bfs_searcher.py @@ -37,7 +37,7 @@ class Doc2ChunkBfsSearcher(BaseSearcher): root_chunk_entities_vector = [] for _ in range(3): try: - root_chunk_entities_vector = await asyncio.wait_for(ChunkManager.get_top_k_chunk_by_kb_id_vector(kb_id, vector, top_k-len(root_chunk_entities_keyword), doc_ids, banned_ids, ChunkParseTopology.TREEROOT.value), timeout=3) + root_chunk_entities_vector = await asyncio.wait_for(ChunkManager.get_top_k_chunk_by_kb_id_vector(kb_id, vector, top_k-len(root_chunk_entities_keyword), doc_ids, banned_ids, ChunkParseTopology.TREEROOT.value), timeout=20) break except Exception as e: err = f"[KeywordVectorSearcher] 向量检索失败,error: {e}" @@ -54,7 +54,7 @@ class Doc2ChunkBfsSearcher(BaseSearcher): root_chunk_entities_vector = [] for _ in range(3): try: - root_chunk_entities_vector = await asyncio.wait_for(ChunkManager.get_top_k_chunk_by_kb_id_vector(kb_id, vector, top_k-len(root_chunk_entities_keyword), doc_ids, banned_ids, None, pre_ids), timeout=3) + root_chunk_entities_vector = await asyncio.wait_for(ChunkManager.get_top_k_chunk_by_kb_id_vector(kb_id, vector, top_k-len(root_chunk_entities_keyword), doc_ids, banned_ids, None, pre_ids), timeout=20) break except Exception as e: err = f"[KeywordVectorSearcher] 向量检索失败,error: {e}" diff --git a/data_chain/rag/doc2chunk_searcher.py b/data_chain/rag/doc2chunk_searcher.py index 40510d513c9f2355a87a3ac3efe64f1fb008794c..6b3aca75cb1cc0d915acabe029a8d100364cd9ff 100644 --- a/data_chain/rag/doc2chunk_searcher.py +++ b/data_chain/rag/doc2chunk_searcher.py @@ -37,7 +37,7 @@ class Doc2ChunkSearcher(BaseSearcher): doc_entities_vector = [] for _ in range(3): try: - doc_entities_vector = await asyncio.wait_for(DocumentManager.get_top_k_document_by_kb_id_vector(kb_id, vector, top_k-len(doc_entities_keyword), use_doc_ids, banned_ids), timeout=3) + doc_entities_vector = await asyncio.wait_for(DocumentManager.get_top_k_document_by_kb_id_vector(kb_id, vector, top_k-len(doc_entities_keyword), use_doc_ids, banned_ids), timeout=10) break except Exception as e: err = f"[KeywordVectorSearcher] 向量检索失败,error: {e}" @@ -53,7 +53,7 @@ class Doc2ChunkSearcher(BaseSearcher): chunk_entities_vector = [] for _ in range(3): try: - chunk_entities_vector = await asyncio.wait_for(ChunkManager.get_top_k_chunk_by_kb_id_vector(kb_id, vector, top_k-len(chunk_entities_keyword), use_doc_ids, banned_ids), timeout=3) + chunk_entities_vector = await asyncio.wait_for(ChunkManager.get_top_k_chunk_by_kb_id_vector(kb_id, vector, top_k-len(chunk_entities_keyword), use_doc_ids, banned_ids), timeout=10) break except Exception as e: err = f"[KeywordVectorSearcher] 向量检索失败,error: {e}" diff --git a/data_chain/rag/dynamic_weighted_keyword_and_vector_searcher.py b/data_chain/rag/dynamic_weighted_keyword_and_vector_searcher.py index 5efe05ee080f65920d5fcdff60fe0ae80745ccc8..3f2e6d61f6b86b22108729790cb689197cc52c4e 100644 --- a/data_chain/rag/dynamic_weighted_keyword_and_vector_searcher.py +++ b/data_chain/rag/dynamic_weighted_keyword_and_vector_searcher.py @@ -42,14 +42,17 @@ class KeywordVectorSearcher(BaseSearcher): try: import time start_time = time.time() - chunk_entities_get_by_vector = await asyncio.wait_for(ChunkManager.get_top_k_chunk_by_kb_id_vector(kb_id, vector, top_k-len(chunk_entities_get_by_keyword)-len(chunk_entities_get_by_dynamic_weighted_keyword), doc_ids, banned_ids), timeout=3) + logging.error(f"[KeywordVectorSearcher] 开始进行向量检索,top_k: {top_k-len(chunk_entities_get_by_keyword)-len(chunk_entities_get_by_dynamic_weighted_keyword)}") + chunk_entities_get_by_vector = await asyncio.wait_for(ChunkManager.get_top_k_chunk_by_kb_id_vector(kb_id, vector, top_k-len(chunk_entities_get_by_keyword)-len(chunk_entities_get_by_dynamic_weighted_keyword), doc_ids, banned_ids), timeout=20) end_time = time.time() logging.info(f"[KeywordVectorSearcher] 向量检索成功完成,耗时: {end_time - start_time:.2f}秒") break except Exception as e: - err = f"[KeywordVectorSearcher] 向量检索失败,error: {e}" + import traceback + err = f"[KeywordVectorSearcher] 向量检索失败,error: {e}, traceback: {traceback.format_exc()}" logging.error(err) continue + logging.error(f"[KeywordVectorSearcher] chunk_entities_get_by_keyword: {len(chunk_entities_get_by_keyword)}, chunk_entities_get_by_dynamic_weighted_keyword: {len(chunk_entities_get_by_dynamic_weighted_keyword)}, chunk_entities_get_by_vector: {len(chunk_entities_get_by_vector)}") chunk_entities = chunk_entities_get_by_keyword + chunk_entities_get_by_dynamic_weighted_keyword + chunk_entities_get_by_vector except Exception as e: err = f"[KeywordVectorSearcher] 关键词向量检索失败,error: {e}" diff --git a/data_chain/rag/enhanced_by_llm_searcher.py b/data_chain/rag/enhanced_by_llm_searcher.py index 00b7bae3afb3397ad67c77402ad7568ffdaf3416..738eaac974ef3121aff613532edec4bff75afe8b 100644 --- a/data_chain/rag/enhanced_by_llm_searcher.py +++ b/data_chain/rag/enhanced_by_llm_searcher.py @@ -13,6 +13,7 @@ from data_chain.entities.enum import SearchMethod from data_chain.parser.tools.token_tool import TokenTool from data_chain.llm.llm import LLM from data_chain.config.config import config +from data_chain.manager.knowledge_manager import KnowledgeBaseManager class EnhancedByLLMSearcher(BaseSearcher): @@ -36,7 +37,9 @@ class EnhancedByLLMSearcher(BaseSearcher): try: with open('./data_chain/common/prompt.yaml', 'r', encoding='utf-8') as f: prompt_dict = yaml.safe_load(f) - prompt_template = prompt_dict['CHUNK_QUERY_MATCH_PROMPT'] + knowledge_entity = await KnowledgeBaseManager.get_knowledge_base_by_kb_id(kb_id) + prompt_template = prompt_dict.get('CHUNK_QUERY_MATCH_PROMPT', {}) + prompt_template = prompt_template.get(knowledge_entity.tokenizer, '') chunk_entities = [] rd = 0 max_retry = 5 @@ -56,7 +59,7 @@ class EnhancedByLLMSearcher(BaseSearcher): sub_chunk_entities_vector = [] for _ in range(3): try: - sub_chunk_entities_vector = await asyncio.wait_for(ChunkManager.get_top_k_chunk_by_kb_id_vector(kb_id, vector, top_k, doc_ids, banned_ids), timeout=3) + sub_chunk_entities_vector = await asyncio.wait_for(ChunkManager.get_top_k_chunk_by_kb_id_vector(kb_id, vector, top_k, doc_ids, banned_ids), timeout=20) break except Exception as e: err = f"[EnhancedByLLMSearcher] 向量检索失败,error: {e}" diff --git a/data_chain/rag/keyword_and_vector_searcher.py b/data_chain/rag/keyword_and_vector_searcher.py index 86b3b4f5cfca9065c6318caa45ab39c2ae517f74..9a0c7de20b63cae03044d217406caa4f6c1a3939 100644 --- a/data_chain/rag/keyword_and_vector_searcher.py +++ b/data_chain/rag/keyword_and_vector_searcher.py @@ -40,7 +40,7 @@ class KeywordVectorSearcher(BaseSearcher): try: import time start_time = time.time() - chunk_entities_get_by_vector = await asyncio.wait_for(ChunkManager.get_top_k_chunk_by_kb_id_vector(kb_id, vector, top_k-len(chunk_entities_get_by_keyword), doc_ids, banned_ids), timeout=3) + chunk_entities_get_by_vector = await asyncio.wait_for(ChunkManager.get_top_k_chunk_by_kb_id_vector(kb_id, vector, top_k-len(chunk_entities_get_by_keyword), doc_ids, banned_ids), timeout=20) end_time = time.time() logging.info(f"[KeywordVectorSearcher] 向量检索成功完成,耗时: {end_time - start_time:.2f}秒") break diff --git a/data_chain/rag/query_extend_searcher.py b/data_chain/rag/query_extend_searcher.py index a09f660baeebd3e16ce61b31ae85e9723bb2fd3f..bbccd51035af5b933ed85ed77395a52cfb97ad6f 100644 --- a/data_chain/rag/query_extend_searcher.py +++ b/data_chain/rag/query_extend_searcher.py @@ -14,6 +14,7 @@ from data_chain.entities.enum import SearchMethod from data_chain.parser.tools.token_tool import TokenTool from data_chain.llm.llm import LLM from data_chain.config.config import config +from data_chain.manager.knowledge_manager import KnowledgeBaseManager class QueryExtendSearcher(BaseSearcher): @@ -35,7 +36,9 @@ class QueryExtendSearcher(BaseSearcher): """ with open('./data_chain/common/prompt.yaml', 'r', encoding='utf-8') as f: prompt_dict = yaml.safe_load(f) - prompt_template = prompt_dict['QUERY_EXTEND_PROMPT'] + konwledge_entity = await KnowledgeBaseManager.get_knowledge_base_by_kb_id(kb_id) + prompt_template = prompt_dict.get('QUERY_EXTEND_PROMPT', {}) + prompt_template = prompt_template.get(konwledge_entity.tokenizer, '') chunk_entities = [] llm = LLM( openai_api_key=config['OPENAI_API_KEY'], @@ -61,7 +64,7 @@ class QueryExtendSearcher(BaseSearcher): chunk_entities_get_by_vector = [] for _ in range(3): try: - chunk_entities_get_by_vector = await asyncio.wait_for(ChunkManager.get_top_k_chunk_by_kb_id_vector(kb_id, vector, top_k-len(chunk_entities_get_by_keyword), doc_ids, banned_ids), timeout=3) + chunk_entities_get_by_vector = await asyncio.wait_for(ChunkManager.get_top_k_chunk_by_kb_id_vector(kb_id, vector, top_k-len(chunk_entities_get_by_keyword), doc_ids, banned_ids), timeout=20) break except Exception as e: err = f"[KeywordVectorSearcher] 向量检索失败,error: {e}" diff --git a/data_chain/rag/vector_searcher.py b/data_chain/rag/vector_searcher.py index dad5e8676792927fa28f27a0ec9b8ac0cb08a079..1bd1d0cac655c2196db2232d84af26da3b3e02fe 100644 --- a/data_chain/rag/vector_searcher.py +++ b/data_chain/rag/vector_searcher.py @@ -29,7 +29,7 @@ class VectorSearcher(BaseSearcher): chunk_entities = [] for _ in range(3): try: - chunk_entities = await asyncio.wait_for(ChunkManager.get_top_k_chunk_by_kb_id_vector(kb_id, vector, top_k, doc_ids, banned_ids), timeout=3) + chunk_entities = await asyncio.wait_for(ChunkManager.get_top_k_chunk_by_kb_id_vector(kb_id, vector, top_k, doc_ids, banned_ids), timeout=20) break except Exception as e: err = f"[VectorSearcher] 向量检索失败,error: {e}" diff --git a/data_chain/rerank/rerank.py b/data_chain/rerank/rerank.py new file mode 100644 index 0000000000000000000000000000000000000000..853dfdf967368c756ed86a12ec12d87e4a9afd9c --- /dev/null +++ b/data_chain/rerank/rerank.py @@ -0,0 +1,78 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. +import requests +import json +import urllib3 +from data_chain.config.config import config +from data_chain.logger.logger import logger as logging +from data_chain.entities.enum import RerankType + +urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) +class Rerank(): + @staticmethod + async def assemable_data(query:str, documents:list[str], top_k:int=3)->dict: + if config['RERANK_TYPE'] == RerankType.BAILIAN: + data={ + "model": config["RERANK_MODEL_NAME"], + "input":{ + "query": query, + "documents": documents + }, + "parameters": { + "return_documents": True, + "top_n": top_k + } + } + elif config['RERANK_TYPE'] == RerankType.GUIJILIUDONG: + data={ + "model": config["RERANK_MODEL_NAME"], + "query": query, + "documents": documents + } + elif config['RERANK_TYPE'] == RerankType.VLLM: + data={ + "model": config["RERANK_MODEL_NAME"], + "text_1": query, + "text_2": documents + } + elif config['RERANK_TYPE'] == RerankType.ASSECEND: + data={ + "query": query, + "texts": documents + } + return data + @staticmethod + async def parse_response(response: requests.Response, top_k:int=3)->list[int]: + documents_index=[] + if config['RERANK_TYPE'] == RerankType.BAILIAN: + for item in response.json()["output"]["results"]: + documents_index.append(item['index']) + elif config['RERANK_TYPE'] == RerankType.GUIJILIUDONG: + for item in response.json()['results']: + documents_index.append(item['index']) + elif config['RERANK_TYPE'] == RerankType.VLLM: + for item in response.json()['data']: + documents_index.append(item['index']) + elif config['RERANK_TYPE'] == RerankType.ASSECEND: + for i in range(len(response.json())): + documents_index.append(response.json()[i]['index']) + return documents_index[:top_k] + @staticmethod + async def rerank(query:str, documents:list[str],top_k:int=3)->list[int]: + if len(documents) <= top_k: + return documents + api_key = config["RERANK_API_KEY"] + url = config["RERANK_ENDPOINT"] + headers = { + 'Authorization': f'Bearer {api_key}', + 'Content-Type' : 'application/json' + } + + data = await Rerank.assemable_data(query, documents, top_k) + response = requests.post(url, headers=headers, json=data) + if response.status_code != 200: + err = f"[Rerank] 重排序失败 ,error: {response.text}" + logging.error(err) + return documents[:top_k] + documents_index = await Rerank.parse_response(response, top_k) + return documents_index + diff --git a/data_chain/stores/database/database.py b/data_chain/stores/database/database.py index 4e8ae10d581bdeb25159026b668404bb1f3f08db..432ddd4c0656ea4999a1047fd40e9eb28ef970dc 100644 --- a/data_chain/stores/database/database.py +++ b/data_chain/stores/database/database.py @@ -1,16 +1,25 @@ # Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker from sqlalchemy import Index +from datetime import datetime +import uuid from uuid import uuid4 import urllib.parse from data_chain.logger.logger import logger as logging from pgvector.sqlalchemy import Vector from sqlalchemy import Boolean, Column, ForeignKey, Integer, Float, String, func from sqlalchemy.types import TIMESTAMP, UUID -from sqlalchemy.orm import declarative_base +from sqlalchemy.dialects.postgresql import TSVECTOR +from sqlalchemy.orm import declarative_base, DeclarativeBase, MappedAsDataclass, Mapped, mapped_column from data_chain.config.config import config from data_chain.entities.enum import (Tokenizer, ParseMethod, + TeamStatus, + TeamMessageStatus, + TeamUserStaus, + RoleStatus, + RoleActionStatus, + UserRoleStatus, UserStatus, UserMessageType, UserMessageStatus, @@ -42,7 +51,7 @@ class TeamEntity(Base): description = Column(String) member_cnt = Column(Integer, default=0) is_public = Column(Boolean, default=True) - status = Column(String) + status = Column(String, default=TeamStatus.EXISTED.value) created_time = Column( TIMESTAMP(timezone=True), nullable=True, @@ -62,7 +71,9 @@ class TeamMessageEntity(Base): team_id = Column(UUID, ForeignKey('team.id')) author_id = Column(String) author_name = Column(String) - message = Column(String, default='') + zh_message = Column(String, default='') + en_message = Column(String, default='') + status = Column(String, default=TeamMessageStatus.EXISTED.value) created_time = Column( TIMESTAMP(timezone=True), nullable=True, @@ -83,6 +94,7 @@ class RoleEntity(Base): name = Column(String) is_unique = Column(Boolean, default=False) editable = Column(Boolean, default=False) + status = Column(String, default=RoleStatus.EXISTED.value) # 角色状态 created_time = Column( TIMESTAMP(timezone=True), nullable=True, @@ -119,6 +131,7 @@ class RoleActionEntity(Base): id = Column(UUID, default=uuid4, primary_key=True) role_id = Column(UUID, ForeignKey('role.id', ondelete="CASCADE")) action = Column(String) + status = Column(String, default=RoleActionStatus.EXISTED.value) created_time = Column( TIMESTAMP(timezone=True), nullable=True, @@ -154,13 +167,17 @@ class UserMessageEntity(Base): id = Column(UUID, default=uuid4, primary_key=True) team_id = Column(UUID) + team_name = Column(String) + role_id = Column(UUID) sender_id = Column(String) sender_name = Column(String) + status_to_sender = Column(String, default=UserMessageStatus.UNREAD.value) receiver_id = Column(String) receiver_name = Column(String) + is_to_all = Column(Boolean, default=False) + status_to_receiver = Column(String, default=UserMessageStatus.UNREAD.value) message = Column(String) type = Column(String) - status = Column(String, default=UserMessageStatus.UNREAD.value) created_time = Column( TIMESTAMP(timezone=True), nullable=True, @@ -174,6 +191,7 @@ class TeamUserEntity(Base): id = Column(UUID, default=uuid4, primary_key=True) team_id = Column(UUID, ForeignKey('team.id', ondelete="CASCADE")) # 团队id user_id = Column(String) # 用户id + status = Column(String, default=TeamUserStaus.EXISTED.value) # 用户在团队中的状态 created_time = Column( TIMESTAMP(timezone=True), nullable=True, @@ -192,6 +210,7 @@ class UserRoleEntity(Base): team_id = Column(UUID, ForeignKey('team.id', ondelete="CASCADE")) # 团队id user_id = Column(String) # 用户id role_id = Column(UUID) # 角色id + status = Column(String, default=UserRoleStatus.EXISTED.value) # 用户角色状态 created_time = Column( TIMESTAMP(timezone=True), nullable=True, @@ -208,18 +227,22 @@ class KnowledgeBaseEntity(Base): __tablename__ = 'knowledge_base' id = Column(UUID, default=uuid4, primary_key=True) - team_id = Column(UUID, ForeignKey('team.id', ondelete="CASCADE"), nullable=True) # 团队id + team_id = Column(UUID, ForeignKey( + 'team.id', ondelete="CASCADE"), nullable=True) # 团队id author_id = Column(String) # 作者id author_name = Column(String) # 作者名称 name = Column(String, default='') # 知识库名资产名 tokenizer = Column(String, default=Tokenizer.ZH.value) # 分词器 description = Column(String, default='') # 资产描述 embedding_model = Column(String) # 资产向量化模型 + rerank_model = Column(String) # 资产rerank模型 + spearating_characters = Column(String) # 资产分块的分隔符 doc_cnt = Column(Integer, default=0) # 资产文档个数 doc_size = Column(Integer, default=0) # 资产下所有文档大小(TODO: 单位kb或者字节) upload_count_limit = Column(Integer, default=128) # 更新次数限制 upload_size_limit = Column(Integer, default=512) # 更新大小限制 - default_parse_method = Column(String, default=ParseMethod.GENERAL.value) # 默认解析方法 + default_parse_method = Column( + String, default=ParseMethod.GENERAL.value) # 默认解析方法 default_chunk_size = Column(Integer, default=1024) # 默认分块大小 status = Column(String, default=KnowledgeBaseStatus.IDLE.value) created_time = Column( @@ -238,7 +261,8 @@ class DocumentTypeEntity(Base): __tablename__ = 'document_type' id = Column(UUID, default=uuid4, primary_key=True) - kb_id = Column(UUID, ForeignKey('knowledge_base.id', ondelete="CASCADE"), nullable=True) + kb_id = Column(UUID, ForeignKey('knowledge_base.id', + ondelete="CASCADE"), nullable=True) name = Column(String) created_time = Column( TIMESTAMP(timezone=True), @@ -257,20 +281,23 @@ class DocumentEntity(Base): id = Column(UUID, default=uuid4, primary_key=True) team_id = Column(UUID) # 文档所属团队id - kb_id = Column(UUID, ForeignKey('knowledge_base.id', ondelete="CASCADE")) # 文档所属资产id + kb_id = Column(UUID, ForeignKey( + 'knowledge_base.id', ondelete="CASCADE")) # 文档所属资产id author_id = Column(String) # 文档作者id author_name = Column(String) # 文档作者名称 name = Column(String) # 文档名 extension = Column(String) # 文件后缀 size = Column(Integer) # 文档大小 parse_method = Column(String, default=ParseMethod.GENERAL.value) # 文档解析方法 - parse_relut_topology = Column(String, default=DocParseRelutTopology.LIST.value) # 文档解析结果拓扑结构 + parse_relut_topology = Column( + String, default=DocParseRelutTopology.LIST.value) # 文档解析结果拓扑结构 chunk_size = Column(Integer) # 文档分块大小 type_id = Column(UUID) # 文档类别 enabled = Column(Boolean) # 文档是否启用 status = Column(String, default=DocumentStatus.IDLE.value) # 文档状态 full_text = Column(String) # 文档全文 abstract = Column(String) # 文档摘要 + abstract_ts_vector = Column(TSVECTOR) # 文档摘要词向量 abstract_vector = Column(Vector(1024)) # 文档摘要向量 created_time = Column( TIMESTAMP(timezone=True), @@ -283,6 +310,8 @@ class DocumentEntity(Base): onupdate=func.current_timestamp() ) __table_args__ = ( + Index('abstract_ts_vector_index', + abstract_ts_vector, postgresql_using='gin'), Index( 'abstract_vector_index', abstract_vector, @@ -299,16 +328,19 @@ class ChunkEntity(Base): id = Column(UUID, default=uuid4, primary_key=True) # chunk id team_id = Column(UUID) # 团队id kb_id = Column(UUID) # 知识库id - doc_id = Column(UUID, ForeignKey('document.id', ondelete="CASCADE")) # 片段所属文档id + doc_id = Column(UUID, ForeignKey( + 'document.id', ondelete="CASCADE")) # 片段所属文档id doc_name = Column(String) # 片段所属文档名称 text = Column(String) # 片段文本内容 + text_ts_vector = Column(TSVECTOR) # 片段文本词向量 text_vector = Column(Vector(1024)) # 文本向量 tokens = Column(Integer) # 片段文本token数 type = Column(String, default=ChunkType.TEXT.value) # 片段类型 # 前一个chunk的id(假如解析结果为链表,那么这里是前一个节点的id,如果文档解析结果为树,那么这里是父节点的id) pre_id_in_parse_topology = Column(UUID) # chunk的在解析结果中的拓扑类型(假如解析结果为链表,那么这里为链表头、中间和尾;假如解析结果为树,那么这里为树根、树的中间节点和叶子节点) - parse_topology_type = Column(String, default=ChunkParseTopology.LISTHEAD.value) + parse_topology_type = Column( + String, default=ChunkParseTopology.LISTHEAD.value) global_offset = Column(Integer) # chunk在文档中的相对偏移 local_offset = Column(Integer) # chunk在块中的相对偏移 enabled = Column(Boolean) # chunk是否启用 @@ -323,6 +355,7 @@ class ChunkEntity(Base): server_default=func.current_timestamp(), onupdate=func.current_timestamp()) __table_args__ = ( + Index('text_ts_vector_index', text_ts_vector, postgresql_using='gin'), Index( 'text_vector_index', text_vector, @@ -358,7 +391,8 @@ class DataSetEntity(Base): id = Column(UUID, default=uuid4, primary_key=True) # 数据集id team_id = Column(UUID) # 数据集所属团队id - kb_id = Column(UUID, ForeignKey('knowledge_base.id', ondelete="CASCADE")) # 数据集所属资产id + kb_id = Column(UUID, ForeignKey('knowledge_base.id', + ondelete="CASCADE")) # 数据集所属资产id author_id = Column(String) # 数据的创建者id author_name = Column(String) # 数据的创建者名称 llm_id = Column(String) # 数据的生成使用的大模型的id @@ -386,7 +420,8 @@ class DataSetDocEntity(Base): __tablename__ = 'dataset_doc' id = Column(UUID, default=uuid4, primary_key=True) # 数据集文档id - dataset_id = Column(UUID, ForeignKey('dataset.id', ondelete="CASCADE")) # 数据集id + dataset_id = Column(UUID, ForeignKey( + 'dataset.id', ondelete="CASCADE")) # 数据集id doc_id = Column(UUID) # 文档id created_at = Column( TIMESTAMP(timezone=True), @@ -404,7 +439,8 @@ class QAEntity(Base): __tablename__ = 'qa' id = Column(UUID, default=uuid4, primary_key=True) # 数据id - dataset_id = Column(UUID, ForeignKey('dataset.id', ondelete="CASCADE")) # 数据所属数据集id + dataset_id = Column(UUID, ForeignKey( + 'dataset.id', ondelete="CASCADE")) # 数据所属数据集id doc_id = Column(UUID) # 数据关联的文档id doc_name = Column(String, default="未知文档") # 数据关联的文档名称 question = Column(String) # 数据的问题 @@ -430,13 +466,15 @@ class TestingEntity(Base): id = Column(UUID, default=uuid4, primary_key=True) # 测试任务的id team_id = Column(UUID) # 测试任务所属团队id kb_id = Column(UUID) # 测试任务所属资产id - dataset_id = Column(UUID, ForeignKey('dataset.id', ondelete="CASCADE")) # 测试任务使用数据集的id + dataset_id = Column(UUID, ForeignKey( + 'dataset.id', ondelete="CASCADE")) # 测试任务使用数据集的id author_id = Column(String) # 测试任务的创建者id author_name = Column(String) # 测试任务的创建者名称 name = Column(String) # 测试任务的名称 description = Column(String) # 测试任务的描述 llm_id = Column(String) # 测试任务的使用的大模型 - search_method = Column(String, default=SearchMethod.KEYWORD_AND_VECTOR.value) # 测试任务的使用的检索增强模式类型 + search_method = Column( + String, default=SearchMethod.KEYWORD_AND_VECTOR.value) # 测试任务的使用的检索增强模式类型 top_k = Column(Integer, default=5) # 测试任务的检索增强模式的top_k status = Column(String, default=TestingStatus.IDLE.value) # 测试任务的状态 ave_score = Column(Float, default=-1) # 测试任务的综合得分 @@ -463,7 +501,8 @@ class TestCaseEntity(Base): __tablename__ = 'testcase' id = Column(UUID, default=uuid4, primary_key=True) # 测试case的id - testing_id = Column(UUID, ForeignKey('testing.id', ondelete="CASCADE")) # 测试 + testing_id = Column(UUID, ForeignKey( + 'testing.id', ondelete="CASCADE")) # 测试 question = Column(String) # 数据的问题 answer = Column(String) # 数据的答案 chunk = Column(String) # 数据的片段 @@ -496,7 +535,8 @@ class TaskEntity(Base): id = Column(UUID, default=uuid4, primary_key=True) team_id = Column(UUID) # 团队id - user_id = Column(String, ForeignKey('users.id', ondelete="CASCADE")) # 创建者id + user_id = Column(String, ForeignKey( + 'users.id', ondelete="CASCADE")) # 创建者id op_id = Column(UUID) # 任务关联的实体id, 资产或者文档id op_name = Column(String) # 任务关联的实体名称 type = Column(String) # 任务类型 @@ -534,6 +574,23 @@ class TaskReportEntity(Base): ) +class TaskQueueEntity(Base): + __tablename__ = 'task_queue' + + id = Column(UUID, default=uuid4, primary_key=True) # 任务ID + status = Column(String) # 任务状态 + created_time = Column( + TIMESTAMP(timezone=True), + nullable=True, + server_default=func.current_timestamp() + ) + # 添加索引以提高查询性能 + __table_args__ = ( + Index('idx_task_queue_status', 'status'), + Index('idx_task_queue_created_time', 'created_time'), + ) + + class DataBase: # 对密码进行 URL 编码 @@ -569,7 +626,8 @@ class DataBase: if DataBase.init_all_table_flag is False: await DataBase.init_all_table() DataBase.init_all_table_flag = True - connection = async_sessionmaker(DataBase.engine, expire_on_commit=False)() + connection = async_sessionmaker( + DataBase.engine, expire_on_commit=False)() return cls._ConnectionManager(connection) class _ConnectionManager: diff --git a/ocr_server/init.py b/ocr_server/init.py new file mode 100644 index 0000000000000000000000000000000000000000..19012fa73b9771ef13cbc7e372ec8e97a02f3d61 --- /dev/null +++ b/ocr_server/init.py @@ -0,0 +1,6 @@ +from paddleocr import PaddleOCR +import cv2 +ocr = PaddleOCR(use_angle_cls=True, lang="ch") +image_path = 'test.jpg' +image = cv2.imread(image_path) +result = ocr.predict(image) diff --git a/ocr_server/requiremenets.text b/ocr_server/requiremenets.text new file mode 100644 index 0000000000000000000000000000000000000000..0066d61d9aa587a9f34b015fa36ad82a8bb1d1a5 --- /dev/null +++ b/ocr_server/requiremenets.text @@ -0,0 +1,5 @@ +aiofiles 24.1.0 +fastapi 0.116.1 +paddleocr 3.2.0 +paddlepaddle 3.1.1 +uvicorn 0.35.0 \ No newline at end of file diff --git a/ocr_server/requiremenets.txt b/ocr_server/requiremenets.txt new file mode 100644 index 0000000000000000000000000000000000000000..d61d9906ba30f3cdff775d7939b0eda976a00fbf --- /dev/null +++ b/ocr_server/requiremenets.txt @@ -0,0 +1,61 @@ +aiofiles==24.1.0 +aistudio_sdk==0.3.5 +annotated-types==0.7.0 +anyio==4.10.0 +bce-python-sdk==0.9.42 +certifi==2025.8.3 +chardet==5.2.0 +charset-normalizer==3.4.3 +click==8.2.1 +colorlog==6.9.0 +fastapi==0.116.1 +filelock==3.19.1 +fsspec==2025.7.0 +future==1.0.0 +h11==0.16.0 +hf-xet==1.1.8 +httpcore==1.0.9 +httpx==0.28.1 +huggingface-hub==0.34.4 +idna==3.10 +imagesize==1.4.1 +mdc==1.2.1 +modelscope==1.29.1 +networkx==3.5 +numpy==2.3.2 +opencv-contrib-python==4.10.0.84 +opt-einsum==3.3.0 +packaging==25.0 +paddleocr==3.2.0 +paddlepaddle==3.1.1 +paddlex==3.2.0 +pandas==2.3.2 +pillow==11.3.0 +prettytable==3.16.0 +protobuf==6.32.0 +py-cpuinfo==9.0.0 +pyclipper==1.3.0.post6 +pycryptodome==3.23.0 +pydantic==2.11.7 +pydantic_core==2.33.2 +pypdfium2==4.30.0 +python-dateutil==2.9.0.post0 +python-json-logger==2.0.7 +python-multipart==0.0.20 +pytz==2025.2 +PyYAML==6.0.2 +requests==2.32.5 +ruamel.yaml==0.18.15 +ruamel.yaml.clib==0.2.12 +shapely==2.1.1 +six==1.17.0 +sniffio==1.3.1 +starlette==0.47.3 +tqdm==4.67.1 +typing-inspection==0.4.1 +typing_extensions==4.15.0 +tzdata==2025.2 +ujson==5.11.0 +urllib3==2.5.0 +uvicorn==0.35.0 +wcwidth==0.2.13 diff --git a/ocr_server/server.py b/ocr_server/server.py new file mode 100644 index 0000000000000000000000000000000000000000..3b53d95a66bb815227c7c3547341f5944ccc2e23 --- /dev/null +++ b/ocr_server/server.py @@ -0,0 +1,152 @@ +from typing import Any +from pydantic import BaseModel +from fastapi import FastAPI, UploadFile, File +from fastapi.responses import JSONResponse +import logging +import os +import aiofiles +import cv2 +import numpy as np +from paddleocr import PaddleOCR +import uuid +from datetime import datetime + +import os +# 强制离线模式 +os.environ["PADDLEX_OFFLINE"] = "True" +# 禁用Paddle的网络请求 +os.environ["PADDLE_NO_NETWORK"] = "True" +# 指定模型缓存路径(确保已放置模型) +os.environ["PADDLEX_HOME"] = "/root/.paddlex" + + +class ResponseData(BaseModel): + """基础返回数据结构""" + + code: int + message: str + result: list + + +# 配置日志 +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +# 初始化FastAPI应用 +app = FastAPI() + +# 创建保存上传文件的目录 +UPLOAD_DIR = "uploaded_files" +os.makedirs(UPLOAD_DIR, exist_ok=True) + +# 初始化PaddleOCR,使用角度分类,中文识别 +ocr = PaddleOCR( + # 1. 指定本地文本检测模型目录(PP-OCRv5_server_det) + text_detection_model_dir="/root/.paddlex/official_models/PP-OCRv5_server_det", + # 2. 指定本地文本识别模型目录(PP-OCRv5_server_rec) + text_recognition_model_dir="/root/.paddlex/official_models/PP-OCRv5_server_rec", + # (可选)若需要文档方向分类/文本行方向分类,也可指定对应本地模型 + doc_orientation_classify_model_dir="/root/.paddlex/official_models/PP-LCNet_x1_0_doc_ori", + textline_orientation_model_dir="/root/.paddlex/official_models/PP-LCNet_x1_0_textline_ori", + # (可选)关闭不需要的功能(如文档矫正,根据需求调整) + use_doc_unwarping=False, # 若不需要 UVDoc 文档矫正,可关闭 + lang=None, # 因已指定本地模型,lang/ocr_version 会被自动忽略(符合原代码逻辑) + ocr_version=None, + device="npu:0" +) + + +@app.get("/ocr", response_model=ResponseData) +async def ocr_recognition(file: UploadFile = File(...)) -> JSONResponse: + """ + 接收上传的图片文件,先保存到本地,再进行OCR识别并返回结果字符串 + """ + try: + # 生成唯一的文件名,避免冲突 + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + file_extension = os.path.splitext(file.filename)[1] + unique_filename = f"{timestamp}_{uuid.uuid4().hex[:8]}{file_extension}" + file_path = os.path.join(UPLOAD_DIR, unique_filename) + + # 异步保存文件到本地 + async with aiofiles.open(file_path, 'wb') as out_file: + content = await file.read() # 读取文件内容 + await out_file.write(content) # 写入到本地文件 + + logger.info(f"文件已保存到: {file_path}, 大小: {len(content)} bytes") + + # 使用cv2读取本地文件 + image = cv2.imread(file_path) + if image is None: + if os.path.exists(file_path): + os.remove(file_path) + logger.error("无法读取上传的图片文件,可能格式不支持或文件损坏") + return JSONResponse( + status_code=400, + content=ResponseData( + code=400, + message="无法读取上传的图片文件,可能格式不支持或文件损坏", + result=[] + ).model_dump(exclude_none=True) + ) + + logger.info(f"图片读取成功,尺寸: {image.shape}") + + # 进行OCR识别 + # PaddleOCR可以直接处理numpy数组(cv2格式) + result = ocr.predict(image) + if not result: + if os.path.exists(file_path): + os.remove(file_path) + return JSONResponse( + status_code=200, + content=ResponseData( + code=200, + message="OCR识别完成,但未识别到任何文本", + result=[] + ).model_dump(exclude_none=True) + ) + if not result[0]: + if os.path.exists(file_path): + os.remove(file_path) + return JSONResponse( + status_code=200, + content=ResponseData( + code=200, + message="OCR识别完成,但未识别到任何文本", + result=[] + ).model_dump(exclude_none=True) + ) + rec_texts = result[0].get("rec_texts", []) + rec_scores = result[0].get("rec_scores", []) + rec_polys = result[0].get("rec_polys", []) + rt = [[]] + for i, text in enumerate(rec_texts): + rt[0].append([rec_polys[i].tolist(), [text, float(f"{rec_scores[i]:.4f}")]]) + if os.path.exists(file_path): + os.remove(file_path) + return JSONResponse( + status_code=200, + content=ResponseData( + code=200, + message="OCR识别成功", + result=rt + ).model_dump(exclude_none=True) + ) + + except Exception as e: + if os.path.exists(file_path): + os.remove(file_path) + logger.error(f"处理过程出错: {str(e)}", exc_info=True) + return JSONResponse( + status_code=500, + content=ResponseData( + code=500, + message=f"处理过程出错: {str(e)}", + result=[] + ).model_dump(exclude_none=True) + ) +if __name__ == "__main__": + import uvicorn + # 在9999端口启动服务,允许外部访问 + uvicorn.run(app, host="0.0.0.0", port=9999) diff --git a/ocr_server/test.jpg b/ocr_server/test.jpg new file mode 100644 index 0000000000000000000000000000000000000000..614fa5c83ed355d60a1119a32ee74300648845c6 Binary files /dev/null and b/ocr_server/test.jpg differ diff --git a/ocr_server/test.py b/ocr_server/test.py new file mode 100644 index 0000000000000000000000000000000000000000..f11c4c9da33713da802525e9ac8e41805ef943ee --- /dev/null +++ b/ocr_server/test.py @@ -0,0 +1,52 @@ +import requests + + +def call_ocr_api(image_path, api_url="http://localhost:9999/ocr"): + """ + 调用OCR接口识别图片中的文字 + + 参数: + image_path: 本地图片文件路径 + api_url: OCR接口的URL地址 + + 返回: + 识别到的文字字符串 + """ + try: + # 打开图片文件并准备上传 + with open(image_path, 'rb') as file: + # 构造表单数据,键名需与接口中的参数名一致 + files = {'file': (image_path, file, 'image/jpeg')} + # 发送GET请求 + response = requests.get(api_url, files=files) + + # 检查响应状态 + if response.status_code == 200: + # 返回识别结果 + return response.json() + else: + print(f"请求失败,状态码: {response.status_code}") + print(f"错误信息: {response.text}") + return None + + except FileNotFoundError: + print(f"错误: 找不到图片文件 {image_path}") + return None + except Exception as e: + print(f"调用接口时发生错误: {str(e)}") + return None + + +# 使用示例 +if __name__ == "__main__": + # 替换为你的图片路径 + image_path = "test.jpg" + # 调用OCR接口 + result = call_ocr_api(image_path) + + if result: + print("OCR识别结果:") + print("-" * 50) + print(type(result)) + print(result) + print("-" * 50) diff --git a/openGauss-sqlalchemy.tar.gz b/openGauss-sqlalchemy.tar.gz deleted file mode 100644 index d40502b58c7a200a9c3eac489eb0e7cd788ec043..0000000000000000000000000000000000000000 Binary files a/openGauss-sqlalchemy.tar.gz and /dev/null differ diff --git a/requirements.txt b/requirements.txt index e2b1e3981564323a264e26273bca90b2c45f447d..f97dae85049f6f64d8260f48fc241e046838ac26 100644 --- a/requirements.txt +++ b/requirements.txt @@ -11,14 +11,12 @@ fastapi-pagination==0.12.19 httpx==0.27.0 itsdangerous==2.1.2 jieba==0.42.1 -langchain==0.3.7 -langchain-openai==0.2.5 minio==7.2.4 markdown2==2.5.2 markdown==3.3.4 more-itertools==10.1.0 numpy==1.26.4 -openai==1.65.2 +openai==1.91.0 opencv-python==4.9.0.80 openpyxl==3.1.2 paddleocr==2.9.1 @@ -48,4 +46,5 @@ uvicorn==0.21.0 xlrd==2.0.1 py-cpuinfo==9.0.0 opengauss-sqlalchemy==2.4.0 -#marker-pdf==1.8.0 \ No newline at end of file +#marker-pdf==1.8.0 +motor==3.7.1 \ No newline at end of file diff --git a/run.sh b/run.sh index 56374de5e257fa6ea17027f6992391dfa6204573..706e3107068fb05753185a86ef484920df4b67f3 100644 --- a/run.sh +++ b/run.sh @@ -1,9 +1,7 @@ #!/usr/bin/env sh java -jar tika-server-standard-2.9.2.jar & -python3 /rag-service/chat2db/app/app.py & +python3 /rag-service/chat2db/main.py & python3 /rag-service/data_chain/apps/app.py & -sleep 5 -python3 /rag-service/chat2db/common/init_sql_example.py while true do diff --git a/test.pdf b/test.pdf index a64e0a48ef0f81bd2dde554984850f61956cfc41..4b18ba4cebc4c00f682109b641cac29b92b5df4a 100644 Binary files a/test.pdf and b/test.pdf differ diff --git a/test/config.py b/test/config.py new file mode 100644 index 0000000000000000000000000000000000000000..ba2a4bcc32e88811a959f25f76a3124a8d129ba3 --- /dev/null +++ b/test/config.py @@ -0,0 +1,51 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. +"""配置文件处理模块""" +import toml +from enum import Enum +from typing import Any +from pydantic import BaseModel, Field +from pathlib import Path +from copy import deepcopy +import sys +import os + + +class LLMConfig(BaseModel): + """LLM配置模型""" + llm_endpoint: str = Field(default="https://dashscope.aliyuncs.com/compatible-mode/v1", description="LLM远程主机地址") + llm_api_key: str = Field(default="", description="LLM API Key") + llm_model_name: str = Field(default="qwen3-coder-480b-a35b-instruct", description="LLM模型名称") + max_tokens: int = Field(default=8192, description="LLM最大Token数") + temperature: float = Field(default=0.7, description="LLM温度参数") + + +class EmbeddingType(str, Enum): + OPENAI = "openai" + MINDIE = "mindie" + + +class EmbeddingConfig(BaseModel): + """Embedding配置模型""" + embedding_type: EmbeddingType = Field(default=EmbeddingType.OPENAI, description="向量化类型") + embedding_endpoint: str = Field(default="", description="向量化API地址") + embedding_api_key: str = Field(default="", description="向量化API Key") + embedding_model_name: str = Field(default="text-embedding-3-small", description="向量化模型名称") + + +class ConfigModel(BaseModel): + """公共配置模型""" + embedding: EmbeddingConfig = Field(default=EmbeddingConfig(), description="向量化配置") + llm: LLMConfig = Field(default=LLMConfig(), description="LLM配置") + + +class BaseConfig(): + """配置文件读取和使用Class""" + + def __init__(self) -> None: + """读取配置文件;当PROD环境变量设置时,配置文件将在读取后删除""" + config_file = os.path.join("config.toml") + self._config = ConfigModel.model_validate(toml.load(config_file)) + + def get_config(self) -> ConfigModel: + """获取配置文件内容""" + return deepcopy(self._config) diff --git a/test/config.toml b/test/config.toml new file mode 100644 index 0000000000000000000000000000000000000000..64cf50923f1898d92979c26c9f7e4d5d5de66ef6 --- /dev/null +++ b/test/config.toml @@ -0,0 +1,12 @@ +[embedding] +embedding_type = "openai" +embedding_endpoint = "https://api.siliconflow.cn/v1" +embedding_api_key = "sk-123456" +embedding_model_name = "BAAI/bge-m3" + +[llm] +llm_endpoint = "https://dashscope.aliyuncs.com/compatible-mode/v1" +llm_api_key = "sk-123456" +llm_model_name = "qwen3-coder-480b-a35b-instruct" +max_tokens = 8192 +temperature = 0.7 \ No newline at end of file diff --git a/chat2db/app/base/vectorize.py b/test/embedding.py similarity index 34% rename from chat2db/app/base/vectorize.py rename to test/embedding.py index 5362047fa0fd407a523bba76e1862e77aa6ef389..0637189d01e014cbf7afd9a5643020018cfc6b5f 100644 --- a/chat2db/app/base/vectorize.py +++ b/test/embedding.py @@ -1,47 +1,50 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. import requests -import urllib3 -from chat2db.config.config import config import json -import sys -import logging - -logging.basicConfig(stream=sys.stdout, level=logging.INFO, - format='%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s') +import urllib3 +from config import BaseConfig, EmbeddingType urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) -class Vectorize(): +class Embedding(): @staticmethod async def vectorize_embedding(text): - if config['EMBEDDING_TYPE']=='openai': + vector = None + if BaseConfig().get_config().embedding.embedding_type == EmbeddingType.OPENAI: headers = { - "Authorization": f"Bearer {config['EMBEDDING_API_KEY']}" - } + "Authorization": f"Bearer {BaseConfig().get_config().embedding.embedding_api_key}", + } data = { "input": text, - "model": config["EMBEDDING_MODEL_NAME"], + "model": BaseConfig().get_config().embedding.embedding_model_name, "encoding_format": "float" } try: - res = requests.post(url=config["EMBEDDING_ENDPOINT"],headers=headers, json=data, verify=False) + res = requests.post(url=BaseConfig().get_config().embedding.embedding_endpoint, + headers=headers, json=data, verify=False) if res.status_code != 200: return None - return res.json()['data'][0]['embedding'] + vector = res.json()['data'][0]['embedding'] except Exception as e: - logging.error(f"Embedding error failed due to: {e}") + err = f"[Embedding] 向量化失败 ,error: {e}" + print(err) return None - elif config['EMBEDDING_TYPE'] =='mindie': + elif BaseConfig().get_config().embedding.embedding_type == 'mindie': try: data = { - "inputs": text, + "inputs": text, } - res = requests.post(url=config["EMBEDDING_ENDPOINT"], json=data, verify=False) + res = requests.post(url=BaseConfig().get_config().embedding.embedding_endpoint, json=data, verify=False) if res.status_code != 200: return None - return json.loads(res.text)[0] + vector = json.loads(res.text)[0] except Exception as e: - logging.error(f"Embedding error failed due to: {e}") - return None + err = f"[Embedding] 向量化失败 ,error: {e}" + print(err) + return None else: return None + while len(vector) < 1024: + vector.append(0) + return vector[:1024] diff --git a/test/tools/llm.py b/test/llm.py similarity index 51% rename from test/tools/llm.py rename to test/llm.py index 103f4ff1f4577ca1fd6d800256c56b6d10445b66..f9e19cc651d65621d7a38610c527237ffd6e90d4 100644 --- a/test/tools/llm.py +++ b/test/llm.py @@ -1,60 +1,93 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. -import asyncio -import time -import json -from langchain_openai import ChatOpenAI -from langchain.schema import SystemMessage, HumanMessage - - -class LLM: - def __init__(self, openai_api_key, openai_api_base, model_name, max_tokens, request_timeout=60, temperature=0.1): - self.client = ChatOpenAI(model_name=model_name, - openai_api_base=openai_api_base, - openai_api_key=openai_api_key, - request_timeout=request_timeout, - max_tokens=max_tokens, - temperature=temperature) - print(model_name) - def assemble_chat(self, chat=None, system_call='', user_call=''): - if chat is None: - chat = [] - chat.append(SystemMessage(content=system_call)) - chat.append(HumanMessage(content=user_call)) - return chat - - async def nostream(self, chat, system_call, user_call): - chat = self.assemble_chat(chat, system_call, user_call) - response = await self.client.ainvoke(chat) - return response.content - - async def data_producer(self, q: asyncio.Queue, history, system_call, user_call): - message = self.assemble_chat(history, system_call, user_call) - try: - async for frame in self.client.astream(message): - await q.put(frame.content) - except Exception as e: - await q.put(None) - print(f"Error in data producer due to: {e}") - return - await q.put(None) - - async def stream(self, chat, system_call, user_call): - st = time.time() - q = asyncio.Queue(maxsize=10) - - # 启动生产者任务 - producer_task = asyncio.create_task(self.data_producer(q, chat, system_call, user_call)) - first_token_reach = False - while True: - data = await q.get() - if data is None: - break - if not first_token_reach: - first_token_reach = True - print(f"大模型回复第一个字耗时 = {time.time() - st}") - for char in data: - yield "data: " + json.dumps({'content': char}, ensure_ascii=False) + '\n\n' - await asyncio.sleep(0.03) # 使用异步 sleep - - yield "data: [DONE]" - print(f"大模型回复耗时 = {time.time() - st}") +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. +import asyncio +import time +import re +import json +import tiktoken +from langchain_openai import ChatOpenAI +from langchain.schema import SystemMessage, HumanMessage + + +class LLM: + def __init__(self, openai_api_key, openai_api_base, model_name, max_tokens, request_timeout=60, temperature=0.1): + self.openai_api_key = openai_api_key + self.openai_api_base = openai_api_base + self.model_name = model_name + self.max_tokens = max_tokens + self.request_timeout = request_timeout + self.temperature = temperature + self.client = ChatOpenAI(model_name=model_name, + openai_api_base=openai_api_base, + openai_api_key=openai_api_key, + request_timeout=request_timeout, + max_tokens=max_tokens, + temperature=temperature) + + def assemble_chat(self, chat=None, system_call='', user_call=''): + if chat is None: + chat = [] + chat.append(SystemMessage(content=system_call)) + chat.append(HumanMessage(content=user_call)) + return chat + + async def nostream(self, chat, system_call, user_call, st_str: str = None, en_str: str = None): + try: + chat = self.assemble_chat(chat, system_call, user_call) + response = await self.client.ainvoke(chat) + content = re.sub(r'.*?\n?', '', response.content, flags=re.DOTALL) + content = re.sub(r'.*?\n?', '', content, flags=re.DOTALL) + content = content.strip() + if st_str is not None: + index = content.find(st_str) + if index != -1: + content = content[index:] + if en_str is not None: + index = content[::-1].find(en_str[::-1]) + if index != -1: + content = content[:len(content)-index] + except Exception as e: + err = f"[LLM] 非流式输出异常: {e}" + print("[LLM] %s", err) + return '' + return content + + async def data_producer(self, q: asyncio.Queue, history, system_call, user_call): + message = self.assemble_chat(history, system_call, user_call) + try: + async for frame in self.client.astream(message): + await q.put(frame.content) + except Exception as e: + await q.put(None) + err = f"[LLM] 流式输出生产者任务异常: {e}" + print("[LLM] %s", err) + raise e + await q.put(None) + + async def stream(self, chat, system_call, user_call): + st = time.time() + q = asyncio.Queue(maxsize=10) + + # 启动生产者任务 + producer_task = asyncio.create_task(self.data_producer(q, chat, system_call, user_call)) + first_token_reach = False + enc = tiktoken.encoding_for_model("gpt-4") + input_tokens = len(enc.encode(system_call)) + output_tokens = 0 + while True: + data = await q.get() + if data is None: + break + if not first_token_reach: + first_token_reach = True + print(f"大模型回复第一个字耗时 = {time.time() - st}") + output_tokens += len(enc.encode(data)) + yield "data: " + json.dumps( + {'content': data, + 'input_tokens': input_tokens, + 'output_tokens': output_tokens + }, ensure_ascii=False + ) + '\n\n' + await asyncio.sleep(0.03) # 使用异步 sleep + + yield "data: [DONE]" + print(f"大模型回复耗时 = {time.time() - st}") diff --git a/test/prompt.yaml b/test/prompt.yaml new file mode 100644 index 0000000000000000000000000000000000000000..25ff89c0a46ec731936edb7dde5bb728e5069036 --- /dev/null +++ b/test/prompt.yaml @@ -0,0 +1,708 @@ +ACC_ANALYSIS_RESULT_MERGE_PROMPT: + en: | + You are a text analysis expert. Your task is to combine two analysis results and output a new one. Note: + #01 Please combine the content of the two analysis results to produce a new analysis result. + #02 Please analyze using the four metrics of recall, precision, faithfulness, and interpretability. + #03 The new analysis result must be no longer than 500 characters. + #04 Please output only the new analysis result; do not output any other content. + Example: + Input 1: + Analysis Result 1: + Recall: Currently, the recall is 95.00, with room for improvement. We will optimize the vectorized search algorithm to further mine information in the original fragment that is relevant to the question but not retrieved, such as some specific practical cases in the openEuler ecosystem. The embedding model bge-m3 will be adjusted to more comprehensively and accurately capture semantics, expand the search scope, improve recall, and make the generated answers closer to the standard answer. + Accuracy: The accuracy is 99.00, which is quite high. However, further optimization is possible, including deeper semantic analysis of the retrieved snippets. By combining the features of the large model qwen2.5-32b, this can precisely match the question semantics and avoid subtle semantic deviations. For example, this can more precisely illustrate the specific manifestations of OpenEuler's high performance in cloud computing and edge computing. + Fidelity: The fidelity value is 90.00, indicating that some answers are not fully derived from the retrieved snippets. Optimizing the rag retrieval algorithm, improving the recall rate of the embedding model, and adjusting the text chunk size to 512 may be inappropriate and require re-evaluation based on the content. This ensures that the retrieved snippets contain sufficient context to support the answer, ensuring that the generated answer content is fully derived from the retrieved snippets. For example, regarding the development of the openEuler ecosystem, relevant technical details should be obtained from the retrieved snippets. + Interpretability: The interpretability is 85.00, which is relatively low. Improve the compliance of the large model qwen2.5-32b and optimize the recall of the rag retrieval algorithm and the embedding model bge-m3. This ensures that retrieved snippets better support answer generation and clearly answer questions. For example, when answering questions related to OpenEuler, this makes the answer logic clearer and more targeted, improving overall interpretability. + + Analysis Result 2: + The recall rate is currently 95.00. Further optimization of the rag retrieval algorithm and embedding model can be used to increase the semantic similarity between the generated answers and the standard answers, approaching or achieving a higher recall rate. For example, the algorithm can be continuously optimized to better match relevant snippets. + The precision is 99.00, close to the maximum score, indicating that the generated answers are semantically similar to the questions. However, further improvement is possible. This can be achieved by refining the embedding model to better understand the question semantics, optimizing the contextual completeness of the retrieved snippets, and reducing fluctuations in precision caused by insufficient context. + The faithfulness score is currently 90.00, indicating that some content in the generated answer is not fully derived from the retrieved snippet. The rag retrieval algorithm can be optimized to improve its recall rate. The text chunk size can also be adjusted appropriately to ensure that the retrieved snippet fully answers the question, thereby improving the faithfulness score. + Regarding interpretability, it is currently 85.00, indicating that the generated answer has room for improvement in terms of answering questions. On the one hand, the large model used can be optimized to improve its compliance, making the generated answer more accurate. On the other hand, the recall rates of the rag retrieval algorithm and embedding model can be further optimized to ensure that the retrieved snippet fully supports the answer and improve interpretability. + + Output: + Recall: Currently at 95.00, there is room for improvement. The vectorized retrieval algorithm can be optimized to further uncover information in the original snippet that is relevant to the question but not retrieved, as demonstrated in some specific practical cases within the openEuler ecosystem. Adjustments were made to the embedding model bge-m3 to enable it to more comprehensively and accurately capture semantics, expand the search scope, improve recall, and bring the generated answers closer to the standard answer. + Accuracy: The accuracy reached 99.00, which is already high. However, further optimization is needed to conduct deeper semantic analysis of the retrieved snippets. By combining the features of the large model qwen2.5-32b, this can precisely match the question semantics and avoid subtle semantic deviations. For example, this could more accurately demonstrate the specific characteristics of OpenEuler's high performance in cloud computing and edge computing. + Fidelity: The fidelity value was 90.00, indicating that some answer content was not fully derived from the retrieved snippet. The rag retrieval algorithm was optimized to improve the recall of the embedding model. Adjusting the text chunk size to 512 may be unreasonable and requires re-evaluation based on the content to ensure that the retrieved snippets contain sufficient context to support the answer, ensuring that the generated answer content is fully derived from the retrieved snippet. For example, relevant technical details regarding the development of the OpenEuler ecosystem should be obtained from the retrieved snippet. + Interpretability: The interpretability value was 85.00, which is relatively low. Improve the compliance of the large qwen2.5-32b model and optimize the recall of the rag retrieval algorithm and the embedding model bge-m3. This ensures that retrieval fragments can better support answer generation and clearly answer questions. For example, when answering questions related to OpenEuler, this improves answer logic, makes it more targeted, and improves overall interpretability. + + The following two analysis results: + Analysis Result 1: {analysis_result_1} + Analysis Result 2: {analysis_result_2} + + 中文: | + 你是一个文本分析专家,你的任务融合两条分析结果输出一份新的分析结果。注意: + #01 请根据两条分析结果中的内容融合出一条新的分析结果 + #02 请结合召回率、精确度、忠实值和可解释性四个指标进行分析 + #03 新的分析结果长度不超过500字 + #04 请仅输出新的分析结果,不要输出其他内容 + 例子: + 输入1: + 分析结果1: + 召回率:目前召回率为 95.00,有提升空间。优化向量化检索算法,进一步挖掘原始片段中与问题相关但未被检索到的信息,如 openEuler 生态中一些具体实践案例等。调整 embedding 模型 bge-m3,使其能更全面准确地捕捉语义,扩大检索范围,提高召回率,使生成答案更接近标准答案。 + 精确度:精确度达 99.00,已较高。但可进一步优化,对检索到的片段进行更深入的语义分析,结合大模型 qwen2.5-32b 的特点,精准匹配问题语义,避免细微语义偏差,例如更精确阐述 openEuler 在云计算和边缘计算中高性能等特性的具体表现。 + 忠实值:忠实值为 90.00,说明部分答案内容未完全源于检索片段。优化 rag 检索算法,提高 embedding 模型召回率,调整文本分块大小为 512 可能存在不合理,需根据内容重新评估,确保检索片段包含足够上下文以支撑答案,使生成答案内容均来自检索片段,如关于 openEuler 生态建设中相关技术细节应从检索片段获取。 + 可解释性:可解释性为 85.00,相对较低。提升大模型 qwen2.5-32b 的遵从度,优化 rag 检索算法和 embedding 模型 bge-m3 的召回率,使检索片段能更好支撑生成答案,保证答案能清晰回答问题,例如在回答 openEuler 相关问题时,使答案逻辑更清晰、针对性更强,提高整体可解释性。 + + 分析结果2: + 从召回率来看,目前为 95.00,可进一步优化 rag 检索算法和 embedding 模型,以提高生成答案与标准回答之间的语义相似程度,接近或达到更高的召回率,例如可以持续优化算法来更好地匹配相关片段。 + 从精确度来看,为 99.00,接近满分,说明生成的答案与问题语义相似程度较高,但仍可进一步提升,可通过完善 embedding 模型来更好地理解问题语义,优化检索到的片段的上下文完整性,减少因上下文不足导致的精确度波动。 + 对于忠实值,目前为 90.00,说明生成的答案中部分内容未完全来自检索到的片段。可优化 rag 检索算法,提高其召回率,同时合理调整文本分块大小,确保检索到的片段能充分回答问题,从而提高忠实值。 + 关于可解释性,当前为 85.00,说明生成的答案在用于回答问题方面有一定提升空间。一方面可以优化使用的大模型,提高其遵从度,使其生成的答案更准确地回答问题;另一方面,继续优化 rag 检索算法和 embedding 模型的召回率,保证检索到的片段能全面支撑问题的回答,提高可解释性。 + + 输出: + 召回率:目前召回率为 95.00,有提升空间。优化向量化检索算法,进一步挖掘原始片段中与问题相关但未被检索到的信息,如 openEuler 生态中一些具体实践案例等。调整 embedding 模型 bge-m3,使其能更全面准确地捕捉语义,扩大检索范围,提高召回率,使生成答案更接近标准答案。 + 精确度:精确度达 99.00,已较高。但可进一步优化,对检索到的片段进行更深入的语义分析,结合大模型 qwen2.5-32b 的特点,精准匹配问题语义,避免细微语义偏差,例如更精确阐述 openEuler 在云计算和边缘计算中高性能等特性的具体表现。 + 忠实值:忠实值为 90.00,说明部分答案内容未完全源于检索片段。优化 rag 检索算法,提高 embedding 模型召回率,调整文本分块大小为 512 可能存在不合理,需根据内容重新评估,确保检索片段包含足够上下文以支撑答案,使生成答案内容均来自检索片段,如关于 openEuler 生态建设中相关技术细节应从检索片段获取。 + 可解释性:可解释性为 85.00,相对较低。提升大模型 qwen2.5-32b 的遵从度,优化 rag 检索算法和 embedding 模型 bge-m3 的召回率,使检索片段能更好支撑生成答案,保证答案能清晰回答问题,例如在回答 openEuler 相关问题时,使答案逻辑更清晰、针对性更强,提高整体可解释性。 + + 下面两条分析结果: + 分析结果1:{analysis_result_1} + 分析结果2:{analysis_result_2} + +ACC_RESULT_ANALYSIS_PROMPT: + en: | + You are a text analysis expert. Your task is to: analyze the large model used in the test, the embedding model used in the test, the parsing method and chunk size of related documents, the snippets matched by the RAG algorithm for a single test result, and propose methods to improve the accuracy of question-answering in the current knowledge base. + + The test results include the following information: + - Question: The question used in the test + - Standard answer: The standard answer used in the test + - Generated answer: The answer output by the large model in the test results + - Original snippet: The original snippet provided in the test results + - Retrieved snippet: The snippet retrieved by the RAG algorithm in the test results + + The four evaluation metrics are defined as follows: + - Precision: Evaluates the semantic similarity between the generated answer and the question. A lower score indicates lower compliance of the large model; additionally, it may mean the snippets retrieved by the RAG algorithm lack context and are insufficient to support the answer. + - Recall: Evaluates the semantic similarity between the generated answer and the standard answer. A lower score indicates lower compliance of the large model. + - Fidelity: Evaluates whether the content of the generated answer is derived from the retrieved snippet. A lower score indicates lower recall of the RAG retrieval algorithm and embedding model (resulting in retrieved snippets insufficient to answer the question); additionally, it may mean the text chunk size is inappropriate. + - Interpretability: Evaluates whether the generated answer is useful for answering the question. A lower score indicates lower recall of the RAG retrieval algorithm and embedding model (resulting in retrieved snippets insufficient to answer the question); additionally, it may mean lower compliance of the used large model. + + Notes: + #01 Analyze methods to improve the accuracy of current knowledge base question-answering based on the test results. + #02 Conduct the analysis using the four metrics: Recall, Precision, Fidelity, and Interpretability. + #03 The analysis result must not exceed 500 words. + #04 Output only the analysis result; do not include any other content. + + Example: + Input: + Model name: qwen2.5-32b + Embedding model: bge-m3 + Text chunk size: 512 + Used RAG algorithm: Vectorized retrieval + Question: What is OpenEuler? + Standard answer: OpenEuler is an open source operating system designed to support cloud computing and edge computing. It features high performance, high security, and high reliability. + Generated answer: OpenEuler is an open source operating system designed to support cloud computing and edge computing. It features high performance, high security, and high reliability. + Original snippet: openEuler is an open source operating system incubated and operated by the Open Atom Open Source Foundation. Its mission is to build an open source operating system ecosystem for digital infrastructure and provide solid underlying support for cutting-edge fields such as cloud computing and edge computing. In cloud computing scenarios, openEuler can fully optimize resource scheduling and allocation mechanisms. Through a lightweight kernel design and efficient virtualization technology, it significantly improves the responsiveness and throughput of cloud services. In edge computing, its exceptional low resource consumption and real-time processing capabilities ensure the timeliness and accuracy of data processing at edge nodes in complex environments. openEuler boasts a series of exceptional features: In terms of performance, its independently developed intelligent scheduling algorithm dynamically adapts to different load scenarios, and combined with deep optimization of hardware resources, significantly improves system efficiency. Regarding security, its built-in multi-layered security system, including mandatory access control, vulnerability scanning, and remediation mechanisms, provides a solid defense for system data and applications. Regarding reliability, its distributed storage, automatic fault detection, and rapid recovery technologies ensure stable system operation in the face of unexpected situations such as network fluctuations and hardware failures, minimizing the risk of service interruptions. These features make openEuler a crucial technological cornerstone for promoting high-quality development of the digital economy, helping enterprises and developers seize the initiative in digital transformation. + Retrieved snippet: As a pioneer in the open source operating system field, openEuler deeply integrates the wisdom of community developers and continuously iterates and upgrades to adapt to the rapidly changing technological environment. In the current era of prevalent microservices architectures, openEuler Through deep optimization of containerization technology and support for mainstream orchestration tools such as Kubernetes, it makes application deployment and management more convenient and efficient, significantly enhancing the flexibility of enterprise business deployments. At the same time, it actively embraces the AI era. By adapting and optimizing machine learning frameworks, it provides powerful computing power for AI model training and inference, effectively reducing the development and operating costs of AI applications. Regarding ecosystem development, openEuler boasts a large and active open source community, bringing together technology enthusiasts and industry experts from around the world, forming a complete ecosystem from kernel development and driver adaptation to application optimization. The community regularly hosts technical exchanges and developer conferences to promote knowledge sharing and technological innovation, providing developers with a wealth of learning resources and practical opportunities. Numerous hardware and software manufacturers have joined the openEuler ecosystem, launching solutions and products based on the system across key industries such as finance, telecommunications, and energy. These efforts, validated through real-world application scenarios and feeding back into openEuler's technological development, have fostered a virtuous cycle of innovation, making openEuler not just an operating system but a powerful engine driving collaborative industry development. + Recall: 95.00 + Precision: 99.00 + Fidelity: 90.00 + Interpretability: 85.00 + + Output: + Based on the test results, methods for improving the accuracy of current knowledge base question-answering can be analyzed from the following aspects: Recall: The current recall is 95.00, with room for improvement. Optimize the vectorized retrieval algorithm to further mine question-related but unretrieved information in the original snippets, such as some specific practical cases in the openEuler ecosystem. Adjust the embedding model bge-m3 to more comprehensively and accurately capture semantics, expand the search scope, improve recall, and make the generated answers closer to the standard answer. Precision: The accuracy reached 99.00, which is already high. However, further optimization is possible, including deeper semantic analysis of retrieved snippets. By combining the features of the large model qwen2.5-32b, this can accurately match the question semantics and avoid subtle semantic deviations. For example, more precise demonstration of openEuler's high performance in cloud computing and edge computing can be achieved. Fidelity: A fidelity score of 90.00 indicates that some answers are not fully derived from the search snippet. We optimized the rag retrieval algorithm, improved the recall of the embedding model, and adjusted the text chunk size to 512. This may be inappropriate and requires reassessment based on the content. We need to ensure that the search snippet contains sufficient context to support the answer, ensuring that the generated answer content is derived from the search snippet. For example, relevant technical details regarding the development of the openEuler ecosystem should be obtained from the search snippet. Interpretability: The interpretability score is 85.00, which is relatively low. We improved the compliance of the large model qwen2.5-32b and optimized the recall of the rag retrieval algorithm and the embedding model bge-m3. This ensures that the search snippet better supports answer generation and clearly answers the question. For example, when answering openEuler-related questions, the answer logic is made clearer and more targeted, improving overall interpretability. + + The following is the test result content: + Used large model: {model_name} + Embedding model: {embedding_model} + Text chunk size: {chunk_size} + Used RAG parsing algorithm: {rag_algorithm} + Question: {question} + Standard answer: {standard_answer} + Generated answer: {generated_answer} + Original fragment: {original_fragment} + Retrieved fragment: {retrieved_fragment} + Recall: {recall} + Precision: {precision} + Faithfulness: {faithfulness} + Interpretability: {relevance} + + 中文: | + 你是一个文本分析专家,你的任务是:根据给出的测试使用的大模型、embedding模型、测试相关文档的解析方法和分块大小、单条测试结果分析RAG算法匹配到的片段,并分析当前知识库问答准确率的提升方法。 + + 测试结果包含以下内容: + - 问题:测试使用的问题 + - 标准答案:测试使用的标准答案 + - 生成的答案:测试结果中大模型输出的答案 + - 原始片段:测试结果中的原始片段 + - 检索的片段:测试结果中RAG算法检索到的片段 + + 四个评估指标定义如下: + - 精确率:评估生成的答案与问题之间的语义相似程度。评分越低,说明使用的大模型遵从度越低;其次可能是RAG检索到的片段缺少上下文,不足以支撑问题的回答。 + - 召回率:评估生成的答案与标准回答之间的语义相似程度。评分越低,说明使用的大模型遵从度越低。 + - 忠实值:评估生成的答案中的内容是否来自于检索到的片段。评分越低,说明RAG检索算法和embedding模型的召回率越低(导致检索到的片段不足以回答问题);其次可能是文本分块大小不合理。 + - 可解释性:评估生成的答案是否能用于回答问题。评分越低,说明RAG检索算法和embedding模型的召回率越低(导致检索到的片段不足以回答问题);其次可能是使用的大模型遵从度越低。 + + 注意: + #01 请根据测试结果中的内容分析当前知识库问答准确率的提升方法。 + #02 请结合召回率、精确率、忠实值和可解释性四个指标进行分析。 + #03 分析结果长度不超过500字。 + #04 请仅输出分析结果,不要输出其他内容。 + + 例子: + 输入: + 模型名称:qwen2.5-32b + embedding模型:bge-m3 + 文本的分块大小:512 + 使用解析的RAG算法:向量化检索 + 问题:openEuler是什么操作系统? + 标准答案:openEuler是一个开源的操作系统,旨在为云计算和边缘计算提供支持。它具有高性能、高安全性和高可靠性等特点。 + 生成的答案:openEuler是一个开源的操作系统,旨在为云计算和边缘计算提供支持。它具有高性能、高安全性和高可靠性等特点。 + 原始片段:openEuler是由开放原子开源基金会孵化及运营的开源操作系统,以构建面向数字基础设施的开源操作系统生态为使命,致力于为云计算、边缘计算等前沿领域提供坚实的底层支持。在云计算场景中,openEuler能够充分优化资源调度与分配机制,通过轻量化的内核设计和高效的虚拟化技术,显著提升云服务的响应速度与吞吐量;在边缘计算领域,它凭借出色的低资源消耗特性与实时处理能力,保障了边缘节点在复杂环境下数据处理的及时性与准确性。openEuler具备一系列卓越特性:在性能方面,其自主研发的智能调度算法能够动态适配不同负载场景,结合对硬件资源的深度优化利用,大幅提升系统运行效率;安全性上,通过内置的多层次安全防护体系,包括强制访问控制、漏洞扫描与修复机制,为系统数据与应用程序构筑起坚实的安全防线;可靠性层面,基于分布式存储、故障自动检测与快速恢复技术,确保系统在面对网络波动、硬件故障等突发状况时,依然能够稳定运行,最大限度降低服务中断风险。这些特性使openEuler成为推动数字经济高质量发展的重要技术基石,助力企业与开发者在数字化转型进程中抢占先机。 + 检索的片段:openEuler作为开源操作系统领域的先锋力量,深度融合了社区开发者的智慧结晶,不断迭代升级以适应快速变化的技术环境。在微服务架构盛行的当下,openEuler通过对容器化技术的深度优化,支持Kubernetes等主流编排工具,让应用部署与管理变得更加便捷高效,极大提升了企业的业务部署灵活性。同时,它积极拥抱AI时代,通过对机器学习框架的适配与优化,为AI模型训练和推理提供强大的算力支撑,有效降低了AI应用的开发与运行成本。在生态建设方面,openEuler拥有庞大且活跃的开源社区,汇聚了来自全球的技术爱好者与行业专家,形成了从内核开发、驱动适配到应用优化的完整生态链。社区定期举办技术交流与开发者大会,推动知识共享与技术创新,为开发者提供了丰富的学习资源与实践机会。众多硬件厂商和软件企业纷纷加入openEuler生态,推出基于该系统的解决方案和产品,涵盖金融、电信、能源等关键行业,以实际应用场景验证并反哺openEuler的技术发展,形成了良性循环的创新生态,让openEuler不仅是一个操作系统,更成为推动产业协同发展的强大引擎。 + 召回率:95.00 + 精确率:99.00 + 忠实值:90.00 + 可解释性:85.00 + + 输出: + 根据测试结果中的内容,当前知识库问答准确率提升的方法可以从以下几个方面进行分析:召回率:目前召回率为95.00,有提升空间。优化向量化检索算法,进一步挖掘原始片段中与问题相关但未被检索到的信息,如openEuler生态中一些具体实践案例等。调整embedding模型bge-m3,使其能更全面准确地捕捉语义,扩大检索范围,提高召回率,使生成答案更接近标准答案。精确率:精确率达99.00,已较高。但可进一步优化,对检索到的片段进行更深入的语义分析,结合大模型qwen2.5-32b的特点,精准匹配问题语义,避免细微语义偏差,例如更精确阐述openEuler在云计算和边缘计算中高性能等特性的具体表现。忠实值:忠实值为90.00,说明部分答案内容未完全源于检索片段。优化RAG检索算法,提高embedding模型召回率,文本分块大小为512可能存在不合理,需根据内容重新评估,确保检索片段包含足够上下文以支撑答案,使生成答案内容均来自检索片段,如关于openEuler生态建设中相关技术细节应从检索片段获取。可解释性:可解释性为85.00,相对较低。提升大模型qwen2.5-32b的遵从度,优化RAG检索算法和embedding模型bge-m3的召回率,使检索片段能更好支撑生成答案,保证答案能清晰回答问题,例如在回答openEuler相关问题时,使答案逻辑更清晰、针对性更强,提高整体可解释性。 + + 下面是测试结果中的内容: + 使用的大模型:{model_name} + embedding模型:{embedding_model} + 文本的分块大小:{chunk_size} + 使用解析的RAG算法:{rag_algorithm} + 问题:{question} + 标准答案:{standard_answer} + 生成的答案:{generated_answer} + 原始片段:{original_fragment} + 检索的片段:{retrieved_fragment} + 召回率:{recall} + 精确率:{precision} + 忠实值:{faithfulness} + 可解释性:{relevance} + +ANSWER_TO_ANSWER_PROMPT: + # 英文文本相似度评分提示词 + en: | + You are a text analysis expert. Your task is to compare the similarity between two documents and output a score between 0 and 100 with two decimal places. + + Note: + #01 Score based on text similarity in three dimensions: semantics, word order, and keywords. + #02 If the core expressions of the two documents are consistent, the score will be relatively high. + #03 If one document contains the core content of the other, the score will also be relatively high. + #04 If there is content overlap between the two documents, the score will be determined by the proportion of the overlap. + #05 Output only the score (no other content). + + Example 1: + Input - Text 1: openEuler is an open source operating system designed to support cloud computing and edge computing. It features high performance, high security, and high reliability. + Text 2: openEuler is an open source operating system designed to support cloud computing and edge computing. It features high performance, high security, and high reliability. + Output: 100.00 + + Example 2: + Input - Text 1: openEuler is an open source operating system designed to support cloud computing and edge computing. It features high performance, high security, and high reliability. + Text 2: openEuler is an open-source operating system designed to support cloud computing and edge computing. It features high performance and high security. + Output: 90.00 + + Example 3: + Input - Text 1: openEuler is an open-source operating system designed to support cloud computing and edge computing. It features high performance, high security, and high reliability. + Text 2: A white horse is not a horse + Output: 00.00 + + The following are the given texts: + Text 1: {text_1} + Text 2: {text_2} + + # 中文文本相似度评分提示词 + 中文: | + 你是一个文本分析专家,你的任务是对比两个文本之间的相似度,并输出一个 0-100 之间的分数(保留两位小数)。 + + 注意: + #01 请根据文本在语义、语序和关键字三个维度的相似度进行打分。 + #02 如果两个文本在核心表达上一致,那么分数将相对较高。 + #03 如果一个文本包含另一个文本的核心内容,那么分数也将相对较高。 + #04 如果两个文本间存在内容重合,那么将按照重合内容的比例确定分数。 + #05 仅输出分数,不要输出其他任何内容。 + + 例子 1: + 输入 - 文本 1:openEuler是一个开源的操作系统,旨在为云计算和边缘计算提供支持。它具有高性能、高安全性和高可靠性等特点。 + 文本 2:openEuler是一个开源的操作系统,旨在为云计算和边缘计算提供支持。它具有高性能、高安全性和高可靠性等特点。 + 输出:100.00 + + 例子 2: + 输入 - 文本 1:openEuler是一个开源的操作系统,旨在为云计算和边缘计算提供支持。它具有高性能、高安全性和高可靠性等特点。 + 文本 2:openEuler是一个开源的操作系统,旨在为云计算和边缘计算提供支持。它具有高性能和高安全性等特点。 + 输出:90.00 + + 例子 3: + 输入 - 文本 1:openEuler是一个开源的操作系统,旨在为云计算和边缘计算提供支持。它具有高性能、高安全性和高可靠性等特点。 + 文本 2:白马非马 + 输出:00.00 + + 下面是给出的文本: + 文本 1:{text_1} + 文本 2:{text_2} + +CAL_QA_SCORE_PROMPT: + en: >- + You are a text analysis expert. Your task is to evaluate the questions and answers generated from a given fragment, and assign a score between 0 and 100 (retaining two decimal places). Please evaluate based on the following criteria: + + ### 1. Question Evaluation + - **Relevance**: Is the question closely related to the topic of the given fragment? Is it accurately based on the fragment content? Does it deviate from or distort the core message of the fragment? + - **Plausibility**: Is the question formulated clearly and logically coherently? Does it conform to normal language and thinking habits? Is it free of semantic ambiguity, vagueness, or self-contradiction? + - **Variety**: If there are multiple questions, are their angles and types sufficiently varied to avoid being overly monotonous or repetitive? Can they explore the fragment content from different perspectives? + - **Difficulty**: Is the question difficulty appropriate? Not too easy (where answers can be directly copied from the fragment), nor too difficult (where respondents cannot find clues or evidence from the fragment)? + + ### 2. Answer Evaluation + - **Accuracy**: Does the answer accurately address the question? Is it consistent with the information in the fragment? Does it contain errors or omit key points? + - **Completeness**: Is the answer comprehensive, covering all aspects of the question? For questions requiring elaboration, does it provide sufficient details and explanations? + - **Succinctness**: On the premise of ensuring completeness and accuracy, is the answer concise and clear? Does it avoid lengthy or redundant expressions, and convey key information in concise language? + - **Coherence**: Is the answer logically clear? Are transitions between content sections natural and smooth? Are there any jumps or confusion? + + ### 3. Overall Assessment + - **Consistency**: Do the question and answer match each other? Does the answer address the raised question? Are they consistent in content and logic? + - **Integration**: Does the answer effectively integrate information from the fragment? Is it not just a simple excerpt, but rather an integrated, refined presentation in a logical manner? + - **Innovation**: In some cases, evaluate whether the answer demonstrates innovation or unique insights? Does it appropriately expand or deepen the information in the fragment? + + ### Note + #01 Please output only the score (without any other content). + + ### Example + Input 1: + Question: What operating system is openEuler? + Answer: openEuler is an open source operating system designed to support cloud computing and edge computing. It features high performance, high security, and high reliability. + Snippet: openEuler is an open source operating system designed to support cloud and edge computing. It features high performance, high security, and high reliability. + Output 1: 100.00 + + Below is the given question, answer, and snippet: + Question: {question} + Answer: {answer} + Snippet: {fragment} + 中文: >- + 你是文本分析专家,任务是评估由给定片段生成的问题与答案,输出 0-100 之间的分数(保留两位小数)。请根据以下标准进行评估: + + ### 1. 问题评估 + - **相关性**:问题是否与给定片段的主题紧密相关?是否准确基于片段内容提出?有无偏离或曲解片段的核心信息? + - **合理性**:问题表述是否清晰、逻辑连贯?是否符合正常的语言表达和思维习惯?不存在语义模糊、歧义或自相矛盾的情况? + - **多样性**:若存在多个问题,问题之间的角度和类型是否具有足够多样性(避免过于单一或重复)?能否从不同方面挖掘片段内容? + - **难度**:问题难度是否适中?既不过于简单(答案可直接从片段中照搬),也不过于困难(回答者难以从片段中找到线索或依据)? + + ### 2. 答案评估 + - **准确性**:答案是否准确无误地回答了问题?与片段中的信息是否一致?有无错误或遗漏关键要点? + - **完整性**:答案是否完整,涵盖问题涉及的各个方面?对于需要详细阐述的问题,是否提供了足够的细节和解释? + - **简洁性**:在保证回答完整、准确的前提下,答案是否简洁明了?是否避免冗长、啰嗦的表述,能否以简洁语言传达关键信息? + - **连贯性**:答案逻辑是否清晰?各部分内容之间的衔接是否自然流畅?有无跳跃或混乱的情况? + + ### 3. 整体评估 + - **一致性**:问题与答案之间是否相互匹配?答案是否针对所提出的问题进行回答?两者在内容和逻辑上是否保持一致? + - **融合性**:答案是否能很好地融合片段中的信息?是否并非简单摘抄,而是经过整合、提炼后以合理方式呈现? + - **创新性**:在某些情况下,评估答案是否具有一定创新性或独特见解?是否能在片段信息基础上进行适当拓展或深入思考? + + ### 注意事项 + #01 请仅输出分数,不要输出其他内容。 + + ### 示例 + 输入 1: + 问题:openEuler 是什么操作系统? + 答案:openEuler 是一个开源的操作系统,旨在为云计算和边缘计算提供支持。它具有高性能、高安全性和高可靠性等特点。 + 片段:openEuler 是一个开源的操作系统,旨在为云计算和边缘计算提供支持。它具有高性能、高安全性和高可靠性等特点。 + 输出 1:100.00 + + 下面是给出的问题、答案和片段: + 问题:{question} + 答案:{answer} + 片段:{fragment} + +CHUNK_QUERY_MATCH_PROMPT: + en: | + You are a text analysis expert. Your task is to determine whether a given fragment is relevant to a question. + Note: + #01 If the fragment is relevant, output YES. + #02 If the fragment is not relevant, output NO. + #03 Only output YES or NO, and do not output anything else. + + Example: + Input 1: + Fragment: openEuler is an open source operating system. + Question: What kind of operating system is openEuler? + Output 1: YES + + Input 2: + Fragment: A white horse is not a horse. + Question: What kind of operating system is openEuler? + Output 2: NO + + Here are the given fragment and question: + Fragment: {chunk} + Question: {question} + 中文: | + 你是一个文本分析专家,你的任务是根据给出的片段和问题,判断片段是否与问题相关。 + 注意: + #01 如果片段与问题相关,请输出YES; + #02 如果片段与问题不相关,请输出NO; + #03 请仅输出YES或NO,不要输出其他内容。 + + 例子: + 输入1: + 片段:openEuler是一个开源的操作系统。 + 问题:openEuler是什么操作系统? + 输出1:YES + + 输入2: + 片段:白马非马 + 问题:openEuler是什么操作系统? + 输出2:NO + + 下面是给出的片段和问题: + 片段:{chunk} + 问题:{question} + +CONTENT_TO_ABSTRACT_PROMPT: + en: | + You are a text summarization expert. Your task is to generate a new English summary based on a given text and an existing summary. + Note: + #01 Please combine the most important content from the text and the existing summary to generate the new summary. + #02 The length of the new summary must be greater than 200 words and less than 500 words. + #03 Please only output the new English summary; do not output any other content. + + Example: + Input 1: + Text: openEuler features high performance, high security, and high reliability. + Abstract: openEuler is an open source operating system designed to support cloud computing and edge computing. + Output 1: openEuler is an open source operating system designed to support cloud computing and edge computing. openEuler features high performance, high security, and high reliability. + + Below is the given text and summary: + Text: {content} + Abstract: {abstract} + 中文: | + 你是一个文本摘要专家,你的任务是根据给出的文本和已有摘要,生成一个新的中文摘要。 + 注意: + #01 请结合文本和已有摘要中最重要的内容,生成新的摘要; + #02 新的摘要长度必须大于200字且小于500字; + #03 请仅输出新的中文摘要,不要输出其他内容。 + + 例子: + 输入1: + 文本:openEuler具有高性能、高安全性和高可靠性等特点。 + 摘要:openEuler是一个开源的操作系统,旨在为云计算和边缘计算提供支持。 + 输出1:openEuler是一个开源的操作系统,旨在为云计算和边缘计算提供支持。openEuler具有高性能、高安全性和高可靠性等特点。 + + 下面是给出的文本和摘要: + 文本:{content} + 摘要:{abstract} + +CONTENT_TO_STATEMENTS_PROMPT: + en: | + You are a text parsing expert. Your task is to extract multiple English statements from a given text and return them as a list. + + Note: + #01 Statements must be derived from key points in the text. + #02 Statements must be arranged in relative order. + #03 Each statement must be at least 20 characters long and no more than 50 characters long. + #04 The total number of statements output must not exceed three. + #05 Please output only the list of statements, not any other content. Each statement must be in English. + Example: + + Input: openEuler is an open source operating system designed to support cloud computing and edge computing. It features high performance, high security, and high reliability. + Output: [ "openEuler is an open source operating system", "openEuler is designed to support cloud computing and edge computing", "openEuler features high performance, high security, and high reliability" ] + + The following is the given text: {content} + 中文: | + 你是一个文本分解专家,你的任务是根据我给出的文本,将文本提取为多个中文陈述,陈述使用列表形式返回 + + 注意: + #01 陈述必须来源于文本中的重点内容 + #02 陈述按相对顺序排列 + #03 输出的单个陈述长度不少于20个字,不超过50个字 + #04 输出的陈述总数不超过3个 + #05 请仅输出陈述列表,不要输出其他内容,且每一条陈述都是中文。 + 例子: + + 输入:openEuler是一个开源的操作系统,旨在为云计算和边缘计算提供支持。它具有高性能、高安全性和高可靠性等特点。 + 输出:[ "openEuler是一个开源的操作系统", "openEuler旨在为云计算和边缘计算提供支持", "openEuler具有高性能、高安全性和高可靠性等特点" ] + + 下面是给出的文本: {content} + +CONTENT_TO_TITLE_PROMPT: + en: >- + You are a title extraction expert. Your task is to generate an English title based on the given text. + Note: + #01 The title must be derived from the content of the text. + #02 The title must be no longer than 20 characters. + #03 Please output only the English title, and do not output any other content. + #04 If the given text is insufficient to generate a title, output "Unable to generate title." + Example: + Input: openEuler is an open source operating system designed to support cloud computing and edge computing. It features high performance, high security, and high reliability. + Output: Overview of the openEuler operating system. + Below is the given text: {content} + 中文: >- + 你是一个标题提取专家,你的任务是根据给出的文本生成一个中文标题。 + 注意: + #01 标题必须来源于文本中的内容 + #02 标题长度不超过20个字 + #03 请仅输出中文标题,不要输出其他内容 + #04 如果给出的文本不够生成标题,请输出“无法生成标题” + 例子: + 输入:openEuler是一个开源的操作系统,旨在为云计算和边缘计算提供支持。它具有高性能、高安全性和高可靠性等特点。 + 输出:openEuler操作系统概述 + 下面是给出的文本:{content} + +GENERATE_ANSWER_FROM_QUESTION_AND_CONTENT_PROMPT: + en: | + You are a text analysis expert. Your task is to generate an English answer based on a given question and text. + Note: + #01 The answer must be derived from the content in the text. + #02 The answer must be at least 50 words and no more than 500 words. + #03 Please only output the English answer; do not output any other content. + Example: + Input 1: + Question: What kind of operating system is openEuler? + Text: openEuler is an open source operating system designed to support cloud computing and edge computing. It features high performance, high security, and high reliability. + Output 1: openEuler is an open source operating system designed to support cloud computing and edge computing. + + Input 2: + Question: How secure is openEuler? + Text: openEuler is an open source operating system designed to support cloud computing and edge computing. It features high performance, high security, and high reliability. + Output 2: openEuler is highly secure. + + Below is the given question and text: + Question: {question} + Text: {content} + 中文: | + 你是一个文本分析专家,你的任务是根据给出的问题和文本生成中文答案。 + 注意: + #01 答案必须来源于文本中的内容; + #02 答案长度不少于50字且不超过500个字; + #03 请仅输出中文答案,不要输出其他内容。 + 例子: + 输入1: + 问题:openEuler是什么操作系统? + 文本:openEuler是一个开源的操作系统,旨在为云计算和边缘计算提供支持。它具有高性能、高安全性和高可靠性等特点。 + 输出1:openEuler是一个开源的操作系统,旨在为云计算和边缘计算提供支持。 + + 输入2: + 问题:openEuler的安全性如何? + 文本:openEuler是一个开源的操作系统,旨在为云计算和边缘计算提供支持。它具有高性能、高安全性和高可靠性等特点。 + 输出2:openEuler具有高安全性。 + + 下面是给出的问题和文本: + 问题:{question} + 文本:{content} + +GENERATE_QUESTION_FROM_CONTENT_PROMPT: + en: | + You are a text analysis expert. Your task is to generate {k} English questions based on the given text and return them as a list. + Note: + #01 Questions must be derived from the content of the text. + #02 A single question must not exceed 50 characters. + #03 Do not output duplicate questions. + #04 The output questions should be diverse, covering different aspects of the text. + #05 Please only output a list of English questions, not other content. + Example: + Input: openEuler is an open source operating system designed to support cloud computing and edge computing. It features high performance, high security, and high reliability. + Output: ["What is openEuler?","What fields does openEuler support?","What are the characteristics of openEuler?","How secure is openEuler?","How reliable is openEuler?"] + The following is the given text: {content} + 中文: | + 你是一个文本分析专家,你的任务是根据给出的文本生成{k}个中文问题并用列表返回。 + 注意: + #01 问题必须来源于文本中的内容; + #02 单个问题长度不超过50个字; + #03 不要输出重复的问题; + #04 输出的问题要多样,覆盖文本中的不同方面; + #05 请仅输出中文问题列表,不要输出其他内容。 + 例子: + 输入:openEuler是一个开源的操作系统,旨在为云计算和边缘计算提供支持。它具有高性能、高安全性和高可靠性等特点。 + 输出:["openEuler是什么操作系统?","openEuler旨在为哪个领域提供支持?","openEuler具有哪些特点?","openEuler的安全性如何?","openEuler的可靠性如何?"] + 下面是给出的文本:{content} + +OCR_ENHANCED_PROMPT: + en: | + You are an expert in image OCR content summarization. Your task is to describe the image based on the context I provide, descriptions of adjacent images, a summary of the previous OCR result for the current image, and the partial OCR results (including text and relative coordinates). + + Note: + #01 The image content must be described in detail, using at least 200 and no more than 500 words. Detailed data listing is acceptable. + #02 If this diagram is a flowchart, please describe the content in the order of the flowchart. + #03 If this diagram is a table, please output the table content in Markdown format. + #04 If this diagram is an architecture diagram, please describe the content according to the hierarchy of the architecture diagram. + #05 The summarized image description must include the key information in the image; it cannot simply describe the image's location. + #06 Adjacent text in the image recognition results may be part of the same paragraph. Please merge them before summarizing. + #07 The text may be misplaced. Please correct the order before summarizing. + #08 Please only output the image summary; do not output any other content. + #09 Do not output coordinates or other information; only output a description of the relative position of each part. + #10 If the image content is empty, output "Image content is empty." + #11 If the image itself is a paragraph of text, output the text content directly. + #12 Please use English for the output. + Context: {image_related_text} + Summary of the OCR content of the previous part of the current image: {pre_part_description} + Result of the OCR of the current part of the image: {part} + 中文: | + 你是一个图片OCR内容总结专家,你的任务是根据我提供的上下文、相邻图片组描述、当前图片上一次的OCR内容总结、当前图片部分OCR的结果(包含文字和文字的相对坐标)给出图片描述。 + + 注意: + #01 必须使用大于200字小于500字详细描述这个图片的内容,可以详细列出数据。 + #02 如果这个图是流程图,请按照流程图顺序描述内容。 + #03 如果这张图是表格,请用Markdown形式输出表格内容。 + #04 如果这张图是架构图,请按照架构图层次结构描述内容。 + #05 总结的图片描述必须包含图片中的主要信息,不能只描述图片位置。 + #06 图片识别结果中相邻的文字可能是同一段落的内容,请合并后总结。 + #07 文字可能存在错位,请修正顺序后进行总结。 + #08 请仅输出图片的总结即可,不要输出其他内容。 + #09 不要输出坐标等信息,输出每个部分相对位置的描述即可。 + #10 如果图片内容为空,请输出“图片内容为空”。 + #11 如果图片本身就是一段文字,请直接输出文字内容。 + #12 请使用中文输出。 + 上下文:{image_related_text} + 当前图片上一部分的OCR内容总结:{pre_part_description} + 当前图片部分OCR的结果:{part} + +QA_TO_STATEMENTS_PROMPT: + en: | + You are a text parsing expert. Your task is to extract the answers from the questions and answers I provide into multiple English statements, returning them as a list. + + Note: + #01 The statements must be derived from the key points of the answers. + #02 The statements must be arranged in relative order. + #03 The length of each statement output must not exceed 50 characters. + #04 The total number of statements output must not exceed 20. + #05 Please only output the list of English statements; do not output any other content. + + Example: + Input: Question: What is openEuler? Answer: openEuler is an open source operating system designed to support cloud computing and edge computing. It features high performance, high security, and high reliability. + Output: [ "openEuler is an open source operating system", "openEuler is designed to support cloud computing and edge computing", "openEuler features high performance, high security, and high reliability" ] + + Below are the given questions and answers: + Question: {question} + Answer: {answer} + 中文: | + 你是一个文本分解专家,你的任务是根据我给出的问题和答案,将答案提取为多个中文陈述,陈述使用列表形式返回。 + + 注意: + #01 陈述必须来源于答案中的重点内容 + #02 陈述按相对顺序排列 + #03 输出的单个陈述长度不超过50个字 + #04 输出的陈述总数不超过20个 + #05 请仅输出中文陈述列表,不要输出其他内容 + + 例子: + 输入:问题:openEuler是什么操作系统? 答案:openEuler是一个开源的操作系统,旨在为云计算和边缘计算提供支持。它具有高性能、高安全性和高可靠性等特点。 + 输出:[ "openEuler是一个开源的操作系统", "openEuler旨在为云计算和边缘计算提供支持", "openEuler具有高性能、高安全性和高可靠性等特点" ] + + 下面是给出的问题和答案: + 问题:{question} + 答案:{answer} + +QUERY_EXTEND_PROMPT: + en: | + You are a question expansion expert. Your task is to expand {k} questions based on the given question. + + Note: + #01 The content of the expanded question must be derived from the content of the original question. + #02 The expanded question length must not exceed 50 characters. + #03 Questions can be rewritten by replacing synonyms, swapping word order within the question, changing English capitalization, etc. + #04 Please only output the expanded question list, do not output other content. + + Example: + Input: What operating system is openEuler? + Output: [ "What kind of operating system is openEuler?", "What are the characteristics of the openEuler operating system?", "What are the functions of the openEuler operating system?", "What are the advantages of the openEuler operating system?" ] + + The following is the given question: {question} + 中文: | + 你是一个问题扩写专家,你的任务是根据给出的问题扩写{k}个问题。 + + 注意: + #01 扩写的问题的内容必须来源于原问题中的内容 + #02 扩写的问题长度不超过50个字 + #03 可以通过近义词替换、问题内词序交换、修改英文大小写等方式来改写问题 + #04 请仅输出扩写的问题列表,不要输出其他内容 + + 例子: + 输入:openEuler是什么操作系统? + 输出:[ "openEuler是一个什么样的操作系统?", "openEuler操作系统的特点是什么?", "openEuler操作系统有哪些功能?", "openEuler操作系统的优势是什么?" ] + + 下面是给出的问题:{question} + +STATEMENTS_TO_FRAGMENT_PROMPT: + en: | + You are a text expert. Your task is to determine whether a given statement is strongly related to the fragment. + + Note: + #01 If the statement is strongly related to the fragment or is derived from the fragment, output YES. + #02 If the content in the statement is unrelated to the fragment, output NO. + #03 If the statement is a refinement of a portion of the fragment, output YES. + #05 Only output YES or NO, and do not output anything else. + + Example: + Input 1: + Statement: openEuler is an open source operating system. + Fragment: openEuler is an open source operating system designed to support cloud computing and edge computing. It features high performance, high security, and high reliability. + Output 1: YES + + Input 2: + Statement: A white horse is not a horse. + Fragment: openEuler is an open source operating system designed to support cloud computing and edge computing. It features high performance, high security, and high reliability. + Output 2: NO + + Below is a given statement and fragment: + Statement: {statement} + Fragment: {fragment} + 中文: | + 你是一个文本专家,你的任务是判断给出的陈述是否与片段强相关。 + + 注意: + #01 如果陈述与片段强相关或者来自于片段,请输出YES + #02 如果陈述中的内容与片段无关,请输出NO + #03 如果陈述是片段中某部分的提炼,请输出YES + #05 请仅输出YES或NO,不要输出其他内容 + + 例子: + 输入1: + 陈述:openEuler是一个开源的操作系统。 + 片段:openEuler是一个开源的操作系统,旨在为云计算和边缘计算提供支持。它具有高性能、高安全性和高可靠性等特点。 + 输出1:YES + + 输入2: + 陈述:白马非马 + 片段:openEuler是一个开源的操作系统,旨在为云计算和边缘计算提供支持。它具有高性能、高安全性和高可靠性等特点。 + 输出2:NO + + 下面是给出的陈述和片段: + 陈述:{statement} + 片段:{fragment} + +STATEMENTS_TO_QUESTION_PROMPT: + en: | + You are a text analysis expert. Your task is to determine whether a given statement is relevant to a question. + + Note: + #01 If the statement is relevant to the question, output YES. + #02 If the statement is not relevant to the question, output NO. + #03 Only output YES or NO, and do not output anything else. + #04 A statement's relevance to the question means that the content in the statement can answer the question or overlaps with the question in terms of content. + + Example: + Input 1: + Statement: openEuler is an open source operating system. + Question: What kind of operating system is openEuler? + Output 1: YES + + Input 2: + Statement: A white horse is not a horse. + Question: What kind of operating system is openEuler? + Output 2: NO + + Below is the given statement and question: + Statement: {statement} + Question: {question} + 中文: | + 你是一个文本分析专家,你的任务是判断给出的陈述是否与问题相关。 + + 注意: + #01 如果陈述与问题相关,请输出YES + #02 如果陈述与问题不相关,请输出NO + #03 请仅输出YES或NO,不要输出其他内容 + #04 陈述与问题相关是指,陈述中的内容可以回答问题或者与问题在内容上有交集 + + 例子: + 输入1: + 陈述:openEuler是一个开源的操作系统。 + 问题:openEuler是什么操作系统? + 输出1:YES + + 输入2: + 陈述:白马非马 + 问题:openEuler是什么操作系统? + 输出2:NO + + 下面是给出的陈述和问题: + 陈述:{statement} + 问题:{question} diff --git a/test/requirements.txt b/test/requirements.txt deleted file mode 100644 index e4b18f8b56f026c78097a3ca6e28fa6671cc9393..0000000000000000000000000000000000000000 --- a/test/requirements.txt +++ /dev/null @@ -1,6 +0,0 @@ -jieba==0.42.1 -pandas==2.1.4 -pydantic==2.10.2 -langchain==0.1.16 -langchain-openai==0.1.7 -synonyms==3.23.5 \ No newline at end of file diff --git a/test/requiremnets.txt b/test/requiremnets.txt new file mode 100644 index 0000000000000000000000000000000000000000..e32be7d73bf0c41be259e29e5b441978602ba691 --- /dev/null +++ b/test/requiremnets.txt @@ -0,0 +1,11 @@ +toml==0.10.2 +pydantic==2.11.7 +urllib3==2.2.1 +requests==2.32.2 +langchain==0.3.7 +langchain-core==0.3.56 +langchain-openai==0.2.5 +tiktoken==0.9.0 +jieba==0.42.1 +numpy==1.26.4 +jieba==0.42.1 \ No newline at end of file diff --git a/test/result.xlsx b/test/result.xlsx new file mode 100644 index 0000000000000000000000000000000000000000..4a85ab395df84ef5f5511eb15477e0b2005f5557 Binary files /dev/null and b/test/result.xlsx differ diff --git a/test/tools/stopwords.txt b/test/stopwords.txt similarity index 99% rename from test/tools/stopwords.txt rename to test/stopwords.txt index 5784b4462a67442a7301abb939b8ca17fa791598..bfb5f302afa87935686501368c011a0a99de855e 100644 --- a/test/tools/stopwords.txt +++ b/test/stopwords.txt @@ -1276,7 +1276,6 @@ indeed 第三句 更 看上去 -安全 零 也好 上去 @@ -3702,7 +3701,6 @@ sup 它们的 它是 它的 -安全 完全 完成 定 diff --git a/test/test.py b/test/test.py new file mode 100644 index 0000000000000000000000000000000000000000..cde43529da7dac29d5ed7f6b838f4b36c22c8ec4 --- /dev/null +++ b/test/test.py @@ -0,0 +1,88 @@ +import argparse +import asyncio +import pandas as pd + +from token_tool import TokenTool +from pydantic import BaseModel, Field + + +class TestEntity(BaseModel): + """测试实体模型""" + question: str = Field(default="", description="问题") + answer: str = Field(default="", description="答案") + chunk: str = Field(default="", description="上下文片段") + llm_answer: str = Field(default="", description="大模型答案") + related_chunk: str = Field(default="", description="相关上下文片段") + pre: float = Field(default=0.0, description="准确率") + rec: float = Field(default=0.0, description="召回率") + fai: float = Field(default=0.0, description="可信度") + rel: float = Field(default=0.0, description="相关性") + lcs: float = Field(default=0.0, description="最长公共子串") + leve: float = Field(default=0.0, description="编辑距离") + jac: float = Field(default=0.0, description="Jaccard相似度") + + +async def read_data_from_file(input_xlsx_file: str) -> list[TestEntity]: + """从文件读取测试数据""" + df = pd.read_excel(input_xlsx_file) + data = [] + for _, row in df.iterrows(): + entity = TestEntity( + question=row.get('question', ''), + answer=row.get('answer', ''), + chunk=row.get('chunk', ''), + llm_answer=row.get('llm_answer', ''), + related_chunk=row.get('related_chunk', '') + ) + data.append(entity) + return data + + +async def write_data_to_file(output_xlsx_file: str, data: list[TestEntity]) -> None: + """ + 将测试数据写入文件 + 第一个sheet写入平均分,第二个sheet写入详细数据 + """ + average_data = { + 'pre': sum(item.pre for item in data) / len(data) if data else 0, + 'rec': sum(item.rec for item in data) / len(data) if data else 0, + 'fai': sum(item.fai for item in data) / len(data) if data else 0, + 'rel': sum(item.rel for item in data) / len(data) if data else 0, + 'lcs': sum(item.lcs for item in data) / len(data) if data else 0, + 'leve': sum(item.leve for item in data) / len(data) if data else 0, + 'jac': sum(item.jac for item in data) / len(data) if data else 0, + } + average_df = pd.DataFrame([average_data]) + detailed_df = pd.DataFrame([item.model_dump() for item in data]) + with pd.ExcelWriter(output_xlsx_file) as writer: + average_df.to_excel(writer, sheet_name='average', index=False) + detailed_df.to_excel(writer, sheet_name='detailed', index=False) + + +async def evaluate_metrics(data: list[TestEntity], language: str) -> None: + """评估测试数据的各项指标""" + token_tool = TokenTool() + for item in data: + item.pre = await token_tool.cal_precision(item.question, item.llm_answer, language) + item.rec = await token_tool.cal_recall(item.question, item.related_chunk, language) + item.fai = await token_tool.cal_faithfulness(item.question, item.llm_answer, item.related_chunk, language) + item.rel = await token_tool.cal_relevance(item.question, item.llm_answer, language) + item.lcs = token_tool.cal_lcs(item.answer, item.llm_answer) + item.leve = token_tool.cal_leve(item.answer, item.llm_answer) + item.jac = token_tool.cal_jac(item.answer, item.llm_answer) + print(f"评估完成: 问题: {item.question}, 准确率: {item.pre}, 召回率: {item.rec}, 可信度: {item.fai}, 相关性: {item.rel}, 最长公共子串: {item.lcs}, 编辑距离: {item.leve}, Jaccard相似度: {item.jac}") + + +def work(input_xlsx_file: str, output_xlsx_file: str, language: str) -> None: + data = asyncio.run(read_data_from_file(input_xlsx_file)) + asyncio.run(evaluate_metrics(data, language)) + asyncio.run(write_data_to_file(output_xlsx_file, data)) + + +if __name__ == '__main__': + args = argparse.ArgumentParser() + args.add_argument('--input_xlsx_file', type=str, required=True, help='输入xlsx文件路径') + args.add_argument('--output_xlsx_file', type=str, required=True, help='输出xlsx文件路径') + args.add_argument('--language', type=str, default='中文', help='语言类型,默认中文zh,英文en') + parsed_args = args.parse_args() + work(parsed_args.input_xlsx_file, parsed_args.output_xlsx_file, parsed_args.language) diff --git a/test/test.xlsx b/test/test.xlsx new file mode 100644 index 0000000000000000000000000000000000000000..38396ac66b526a7647465e7468a14097a09b9586 Binary files /dev/null and b/test/test.xlsx differ diff --git a/test/test_qa.py b/test/test_qa.py deleted file mode 100644 index ab3408d5aa1b1ee5aa49c16b2c5509ec332a732a..0000000000000000000000000000000000000000 --- a/test/test_qa.py +++ /dev/null @@ -1,719 +0,0 @@ -import subprocess -import argparse -import asyncio -import json -import os -import random -import time -from pathlib import Path -import jieba -import pandas as pd - -import yaml -import requests -from typing import Optional, List -from pydantic import BaseModel, Field -from tools.config import config -from tools.llm import LLM -from tools.similar_cal_tool import Similar_cal_tool -current_dir = Path(__file__).resolve().parent - - -def login_and_get_tokens(account, password, base_url): - """ - 尝试登录并获取新的session ID和CSRF token。 - - :param login_url: 登录的URL地址 - :param account: 用户账号 - :param password: 用户密码 - :return: 包含新session ID和CSRF token的字典,或者在失败时返回None - """ - # 构造请求头部 - headers = { - 'Content-Type': 'application/x-www-form-urlencoded', - } - - # 构造请求数据 - params = { - 'account': account, - 'password': password - } - # 发送POST请求 - url = f"{base_url}/user/login" - response = requests.get(url, headers=headers, params=params) - # 检查响应状态码是否为200表示成功 - if response.status_code == 200: - # 如果登录成功,获取新的session ID和CSRF token - new_session = response.cookies.get("WD_ECSESSION") - new_csrf_token = response.cookies.get("wd_csrf_tk") - if new_session and new_csrf_token: - return response.json(), { - 'ECSESSION': new_session, - 'csrf_token': new_csrf_token - } - else: - print("Failed to get new session or CSRF token.") - return None - else: - print(f"Failed to login, status code: {response.status_code}") - return None - - -def tokenize(text): - return len(list(jieba.cut(str(text)))) - - -class DictionaryBaseModel(BaseModel): - pass - - -class ListChunkRequest(DictionaryBaseModel): - document_id: str - text: Optional[str] = None - page_number: int = 1 - page_size: int = 50 - type: Optional[list[str]] = None - - -def list_chunks(session_cookie: str, csrf_cookie: str, document_id: str, - text: Optional[str] = None, page_number: int = 1, page_size: int = 50, - base_url="http://0.0.0.0:9910") -> dict: - """ - 请求文档块列表的函数。 - - :param session_cookie: 用户会话cookie - :param csrf_cookie: CSRF保护cookie - :param document_id: 文档ID - :param text: 可选的搜索文本 - :param page_number: 页码,默认为1 - :param page_size: 每页大小,默认为10 - :param base_url: API基础URL,默认为本地测试服务器地址 - :return: JSON响应数据 - """ - # 构造请求cookies - # print(document_id) - cookies = { - "WD_ECSESSION": session_cookie, - "wd_csrf_tk": csrf_cookie - } - - # 创建请求体实例 - payload = ListChunkRequest( - document_id=document_id, - text=text, - page_number=page_number, - page_size=page_size, - ).dict() - - # 发送POST请求 - url = f"{base_url}/chunk/list" - response = requests.post(url, cookies=cookies, json=payload) - - # 一次性获取所有chunk - # print(response.json()) - page_size = response.json()['data']['total'] - - # 创建请求体实例 - payload = ListChunkRequest( - document_id=document_id, - text=text, - page_number=page_number, - page_size=page_size, - ).dict() - - # 发送POST请求 - url = f"{base_url}/chunk/list" - response = requests.post(url, cookies=cookies, json=payload) - - # 返回JSON响应数据 - return response.json() - - -def parser(): - # 创建 ArgumentParser 对象 - parser = argparse.ArgumentParser(description="Script to process document and generate QA pairs.") - subparser = parser.add_subparsers(dest='mode', required=True, help='Mode of operation') - - # 离线模式参数 - offline = subparser.add_parser('offline', help='Offline mode for processing documents') # noqa: F841 - offline.add_argument("-i", "--input_path", required=True, default="./document", help="Path of document names",) - # 在线模式所需添加的参数 - online = subparser.add_parser('online', help='Online mode for processing documents') - online.add_argument('-n', '--name', type=str, required=True, help='User name') - online.add_argument('-p', '--password', type=str, required=True, help='User password') - online.add_argument('-k', '--kb_id', type=str, required=True, help='KnowledgeBase ID') - online.add_argument('-u', '--url', type=str, required=True, help='URL for witChainD') - - # 添加可选参数,并设置默认值 - online.add_argument('-q', '--qa_count', type=int, default=1, - help='Number of QA pairs to generate per text block (default: 1)') - - # 添加文件名列表参数 - online.add_argument('-d', '--doc_names', nargs='+', required=False, default=[], help='List of document names') - - # 解析命令行参数 - args = parser.parse_args() - return args - - -def get_prompt_dict(): - """ - 获取prompt表 - """ - try: - with open(config['PROMPT_PATH'], 'r', encoding='utf-8') as f: - prompt_dict = yaml.load(f, Loader=yaml.SafeLoader) - return prompt_dict - except Exception as e: - print(f"open {config['PROMPT_PATH']} error {e}") - raise e - - -prompt_dict = get_prompt_dict() -llm = LLM(model_name=config['MODEL_NAME'], - openai_api_base=config['OPENAI_API_BASE'], - openai_api_key=config['OPENAI_API_KEY'], - max_tokens=config['MAX_TOKENS'], - request_timeout=60, - temperature=0.35) - -def get_random_number(l, r): - return random.randint(l, r-1) - - -class QAgenerator: - - async def qa_generate(self, chunks, file): - """ - 多线程生成问答对 - """ - start_time = time.time() - results = [] - prev_texts = [] - ans = 0 - # 使用 asyncio.gather 来并行处理每个 chunk - tasks = [] - # 获取 chunks 的长度 - num_chunks = len(chunks) - image_sum = 0 - for chunk in chunks: - chunk['count'] = 0 - # if chunk['type'] == 'image': - # chunk['count'] = chunk['count'] + 1 - # image_sum = image_sum + 1 - for i in range(args.qa_count): - x = get_random_number(min(3, num_chunks-1), num_chunks) - print(x) - chunks[x]['count'] = chunks[x]['count'] + 1 - - now_text = "" - for chunk in chunks: - now_text = now_text + chunk['text'] + '\n' - # if chunk['type'] == 'table' and len(now_text) < (config['MAX_TOKENS'] // 8): - # continue - prev_text = '\n'.join(prev_texts) - while tokenize(prev_text) > (config['MAX_TOKENS'] / 4): - prev_texts.pop(0) - prev_text = '\n'.join(prev_texts) - if chunk['count'] > 0: - tasks.append(self.generate(now_text, prev_text, results, file, chunk['count'], chunk['type'])) - prev_texts.append(now_text) - now_text = '' - ans = ans + chunk['count'] + image_sum - - # 等待所有任务完成 - await asyncio.gather(*tasks) - print('问答对案例:', results[:50]) - print("问答对生成总计用时:", time.time() - start_time) - print(f"总计生成{ans}条问答对") - return results - - async def generate(self, now_text, prev_text, results, file, qa_count, type_text): - """ - 生成问答 - """ - prev_text = prev_text[-(config['MAX_TOKENS'] // 8):] - prompt = prompt_dict.get('GENERATE_QA') - count = 0 - while count < 5: - try: - # 使用多线程处理 chat_with_llm 调用 - result_temp = await self.chat_with_llm(llm, prompt, now_text, prev_text, - qa_count, file) - - for result in result_temp: - result['text'] = prev_text + now_text - result['type_text'] = type_text - results.append(result) - count = 5 - except Exception as e: - count += 1 - print('error:', e, 'retry times', count) - if count == 5: - results.append({'text': now_text, 'question': '无法生成问答对', - 'answer': '无法生成问答对', 'type': 'error', 'type_text': 'error'}) - - @staticmethod - async def chat_with_llm(llm, prompt, text, prev_text, qa_count, file_name) -> dict: - """ - 对于给定的文本,通过llm生成问题-答案-段落对。 - params: - - llm: LLm - - text: str - - prompt: str - return: - - qa_pairs: list[dict] - - """ - text.replace("\"", "\\\"") - user_call = (f"文本内容来自于{file_name},请以JSON格式输出{qa_count}对不同的问题-答案-领域,格式为[" - "{" - "\"question\": \" 问题 \", " - "\"answer\": \" 回答 \"," - "\"type\": \" 领域 \"" - "}\n" - "],并且必须将问题和回答中和未被转义的双引号转义,元素标签请用双引号括起来") - prompt = prompt.format(chunk=text, qa_count=qa_count, text=prev_text, file_name=file_name) - # print(prompt) - qa_pair = await llm.nostream([], prompt, user_call) - # 提取问题、答案段落对的list,字符串格式为["问题","答案","段落对"] - print(qa_pair) - # print("原文:", text) - qa_pair = json.loads(qa_pair) - return qa_pair - - -class QueryRequest(BaseModel): - question: str - kb_sn: Optional[str] = None - top_k: int = Field(5, ge=0, le=10) - fetch_source: bool = False - history: Optional[List] = [] - - -def call_get_answer(text, kb_id, session_cookie, csrf_cookie, base_url="http://0.0.0.0:9910"): - # 构造请求cookies - cookies = { - "WD_ECSESSION": session_cookie, - "wd_csrf_tk": csrf_cookie - } - - # 构造请求体 - req = QueryRequest( - question=text, - kb_sn=kb_id, - top_k=3, - fetch_source=True, - history=[] - ) - - url = f"{base_url}/kb/get_answer" - print(url) - headers = { - "Content-Type": "application/json", - "Accept": "application/json" - } - data = req.json().encode("utf-8") - - for i in range(5): - try: - response = requests.post(url, headers=headers, cookies=cookies, data=data) - - if response.status_code == 200: - result = response.json() - # print("成功获取答案") - return result - print(f"请求失败,状态码: {response.status_code}, 响应内容: {response.text}") - time.sleep(1) - except Exception as e: - print(f"请求answer失败,错误原因{e}, 重试次数:{i+1}") - time.sleep(1) - - -async def get_answers(QA, kb_id, session_cookie, csrf_cookie, base_url): - text = QA['question'] - print(f"原文:{QA['text'][:40]}...") - result = call_get_answer(text, kb_id, session_cookie, csrf_cookie, base_url) - if result is None: - return None - else: - QA['witChainD_answer'] = result['data']['answer'] - QA['witChainD_source'] = result['data']['source'] - QA['time_cost']=result['data']['time_cost'] - print(f"原文:{QA['text'][:40] + '...'}\n问题:{text}\n回答:{result['data']['answer'][:40]}\n\n") - return QA - - -async def get_QAs_answers(QAs, kb_id, session_cookie, csrf_cookie, base_url): - results = [] - tasks = [] - for QA in QAs: - tasks.append(get_answers(QA, kb_id, session_cookie, csrf_cookie, base_url)) - response = await asyncio.gather(*tasks) - for idx, result in enumerate(response): - if result is not None: - results.append(result) - return results - - -class QAScore(): - async def get_score(self, QA): - prompt = prompt_dict['SCORE_QA'] - llm_score_dict = await self.chat_with_llm(llm, prompt, QA['question'], QA['text'], QA['witChainD_source'], QA['answer'], QA['witChainD_answer']) - print(llm_score_dict) - - QA['context_relevancy'] = llm_score_dict['context_relevancy'] - QA['context_recall'] = llm_score_dict['context_recall'] - QA['faithfulness'] = llm_score_dict['faithfulness'] - QA['answer_relevancy'] = llm_score_dict['answer_relevancy'] - print(QA) - try: - lcs_score = Similar_cal_tool.longest_common_subsequence(QA['answer'], QA['witChainD_answer']) - except: - lcs_score = 0 - QA['lcs_score'] = lcs_score - try: - jac_score = Similar_cal_tool.jaccard_distance(QA['answer'], QA['witChainD_answer']) - except: - jac_score = 0 - QA['jac_score'] = jac_score - try: - leve_score = Similar_cal_tool.levenshtein_distance(QA['answer'], QA['witChainD_answer']) - except: - leve_score = 0 - QA['leve_score'] = leve_score - return QA - - async def get_scores(self, QAs): - tasks = [] - results = [] - for QA in QAs: - tasks.append(self.get_score(QA)) - response = await asyncio.gather(*tasks) - for idx, result in enumerate(response): - if result is not None: - results.append(result) - return results - - @staticmethod - async def chat_with_llm(llm, prompt, question, meta_chunk, chunk, answer, answer_text) -> dict: - """ - 对于给定的文本,通过llm生成问题-答案-段落对。 - params: - - llm: LLm - - text: str - - prompt: str - return: - - qa_pairs: list[dict] - - """ - required_metrics = { - "context_relevancy", - "context_recall", - "faithfulness", - "answer_relevancy", - } - for i in range(5): - try: - user_call = """请对答案打分,并以下面形式返回结果{ - \"context_relevancy\": 分数, - \"context_recall\": 分数, - \"faithfulness\": 分数, - \"answer_relevancy\": 分数 -} -注意:属性名必须使用双引号,分数为数字,保留两位小数。""" - prompt = prompt.format(question=question, meta_chunk=meta_chunk, - chunk=chunk, answer=answer, answer_text=answer_text) - # print(prompt) - score_dict = await llm.nostream([], prompt, user_call) - st = score_dict.find('{') - en = score_dict.rfind('}') - if st != -1 and en != -1: - score_dict = score_dict[st:en+1] - # print(score_dict) - score_dict = json.loads(score_dict) - # 提取问题、答案段落对的list,字符串格式为["问题","答案","段落对"] - # print(score) - present_metrics = set(score_dict.keys()) - missing_metrics = required_metrics - present_metrics - if missing_metrics: - missing = ", ".join(missing_metrics) - print(f"评分结果缺少必要指标: {missing}") - for metric in required_metrics: - if metric not in score_dict: - score_dict[metric] = 0.00 - print(score_dict) - return score_dict - except Exception as e: - continue - return { - "context_relevancy": 0, - "context_recall": 0, - "faithfulness": 0, - "answer_relevancy": 0, - } - - -def list_documents(session_cookie, csrf_cookie, kb_id, base_url="http://0.0.0.0:9910"): - # 构造请求cookies - cookies = { - "WD_ECSESSION": session_cookie, - "wd_csrf_tk": csrf_cookie - } - - # 构造请求URL - url = f"{base_url}/doc/list" - - # 构造请求体 - payload = { - "kb_id": str(kb_id), # 将uuid对象转换为字符串 - "page_number": 1, - "page_size": 50, - } - - # 发送POST请求 - response = requests.post(url, cookies=cookies, json=payload) - # print(response.text) - - # 一次性获取所有document - total = response.json()['data']['total'] - documents = [] - for i in range(1, (total + 50) // 50 + 1): - # 创建请求体实例 - print(f"page {i} gets") - payload = { - "kb_id": str(kb_id), # 将uuid对象转换为字符串 - "page_number": i, - "page_size": 50, - } - - response = requests.post(url, cookies=cookies, json=payload) - js = response.json() - now_documents = js['data']['data_list'] - documents.extend(now_documents) - # 返回响应文本 - return documents - -def get_document(dir): - documents = [] - print(os.listdir(dir)) - for file in os.listdir(dir): - if file.endswith('.xlsx'): - file_path = os.path.join(dir, file) - df = pd.read_excel(file_path) - documents.append(df.to_dict(orient='records')) - if file.endswith('.csv'): - file_path = os.path.join(dir, file) - df = pd.read_csv(file_path, ) - documents.append(df.to_dict(orient='records')) - return documents - -if __name__ == '__main__': - """ - 脚本参数包含 name, password, doc_id, qa_count, url - - name: 通过-n或者--name读入,必须 - - password: 通过-p或者--password读入,必须 - - kb_id: 通过-k或者--kb_id读入,必须 - - qa_count: 通过-q或者--qa_count读入,非必须,默认为1,表示每个文档生成多少个问答对 - - url: 通过-u或者--url读入,必须,为witChainD的路径 - - doc_names: 通过-d或者--doc_names读入,非必须,默认为None,表示所有文档的名称 - 需要在.env中配置好LLM和witChainD相关的config,以及prompt路径 - """ - args = parser() - QAs = [] - if args.mode == 'online': - js, tmp_dict = login_and_get_tokens(args.name, args.password, args.url) - session_cookie = tmp_dict['ECSESSION'] - csrf_cookie = tmp_dict['csrf_token'] - print('login success') - documents = list_documents(session_cookie, csrf_cookie, args.kb_id, args.url) - print('get document success') - print(documents) - for document in documents: - # print('refresh tokens') - # print(json.dumps(document, indent=4, ensure_ascii=False)) - if args.doc_names != [] and document['name'] not in args.doc_names: - # args.doc_names = [] - continue - else: - args.doc_names = [] - js, tmp_dict = login_and_get_tokens(args.name, args.password, args.url) - session_cookie = tmp_dict['ECSESSION'] - csrf_cookie = tmp_dict['csrf_token'] - args.doc_id = document['id'] - args.doc_name = document['name'] - count = 0 - while count < 5: - try: - js = list_chunks(session_cookie, csrf_cookie, str(args.doc_id), base_url=args.url) - print(f'js: {js}') - count = 10 - except Exception as e: - print(f"document {args.doc_name} check failed {e} with retry {count}") - count = count + 1 - time.sleep(1) - continue - if count == 5: - print(f"document {args.doc_name} check failed") - continue - chunks = js['data']['data_list'] - new_chunks = [] - for chunk in chunks: - new_chunk = { - 'text': chunk['text'], - 'type': chunk['type'], - } - new_chunks.append(new_chunk) - chunks = new_chunks - model = QAgenerator() - try: - print('正在生成QA对...') - t_QAs = asyncio.run(model.qa_generate(chunks=chunks, file=args.doc_name)) - print("QA对生成完毕,正在获取答案...") - tt_QAs = asyncio.run(get_QAs_answers(t_QAs, args.kb_id, session_cookie, csrf_cookie, args.url)) - print(f"tt_QAs: {tt_QAs}") - print("答案获取完毕,正在计算答案正确性...") - ttt_QAs = asyncio.run(QAScore().get_scores(tt_QAs)) - print(f"ttt_QAs: {ttt_QAs}") - for QA in t_QAs: - QAs.append(QA) - df = pd.DataFrame(QAs) - df.astype(str) - print(document['name'], 'down') - print('sample:', t_QAs[0]['question'][:40]) - df.to_excel(current_dir / 'temp_answer.xlsx', index=False) - print(f'temp_Excel结果已输出到{current_dir}/temp_answer.xlsx') - except Exception as e: - import traceback - print(traceback.print_exc()) - print(f"document {args.doc_name} failed {e}") - continue - else: - # 离线模式 - # print(document_path) - t_QAs = get_document(args.input_path) - print(f"获取到{len(t_QAs)}个文档") - for item in t_QAs[0]: - single_item = { - "question": item["问题"], - "answer": item["标准答案"], - "witChainD_answer": item["llm的回答"], - "text": item["原始片段"], - "witChainD_source": item["检索片段"], - } - # print(single_item) - ttt_QAs = asyncio.run(QAScore().get_score(single_item)) - QAs.append(ttt_QAs) - # # 输出QAs到xlsx中 - # exit(0) - newQAs = [] - total = { - "context_relevancy(上下文相关性)": [], - "context_recall(召回率)": [], - "faithfulness(忠实性)": [], - "answer_relevancy(答案的相关性)": [], - "lcs_score(最大公共子串)": [], - "jac_score(杰卡德距离)": [], - "leve_score(编辑距离)": [], - "time_cost": { - "keyword_searching": [], - "text_to_vector": [], - "vector_searching": [], - "vectors_related_texts": [], - "text_expanding": [], - "llm_answer": [], - }, - } - - time_cost_metrics = list(total["time_cost"].keys()) - - for QA in QAs: - print(QA) - try: - if 'time_cost' in QA.keys(): - ReOrderedQA = { - '领域': str(QA['type']), - '问题': str(QA['question']), - '标准答案': str(QA['answer']), - 'llm的回答': str(QA['witChainD_answer']), - 'context_relevancy(上下文相关性)': str(QA['context_relevancy']), - 'context_recall(召回率)': str(QA['context_recall']), - 'faithfulness(忠实性)': str(QA['faithfulness']), - 'answer_relevancy(答案的相关性)': str(QA['answer_relevancy']), - 'lcs_score(最大公共子串)': str(QA['lcs_score']), - 'jac_score(杰卡德距离)': str(QA['jac_score']), - 'leve_score(编辑距离)': str(QA['leve_score']), - '原始片段': str(QA['text']), - '检索片段': str(QA['witChainD_source']), - 'keyword_searching_cost(关键字搜索时间消耗)': str(QA['time_cost']['keyword_searching'])+'s', - 'query_to_vector_cost(qeury向量化时间消耗)': str(QA['time_cost']['text_to_vector'])+'s', - 'vector_searching_cost(向量化检索时间消耗)': str(QA['time_cost']['vector_searching'])+'s', - 'vectors_related_texts_cost(向量关联文档时间消耗)': str(QA['time_cost']['vectors_related_texts'])+'s', - 'text_expanding_cost(上下文关联时间消耗)': str(QA['time_cost']['text_expanding'])+'s', - 'llm_answer_cost(大模型回答时间消耗)': str(QA['time_cost']['llm_answer'])+'s' - } - else: - ReOrderedQA = { - # '领域': str(QA['type']), - '问题': str(QA['question']), - '标准答案': str(QA['answer']), - 'llm的回答': str(QA['witChainD_answer']), - 'context_relevancy(上下文相关性)': str(QA['context_relevancy']), - 'context_recall(召回率)': str(QA['context_recall']), - 'faithfulness(忠实性)': str(QA['faithfulness']), - 'answer_relevancy(答案的相关性)': str(QA['answer_relevancy']), - 'lcs_score(最大公共子串)': str(QA['lcs_score']), - 'jac_score(杰卡德距离)': str(QA['jac_score']), - 'leve_score(编辑距离)': str(QA['leve_score']), - '原始片段': str(QA['text']), - '检索片段': str(QA['witChainD_source']) - } - print(ReOrderedQA) - newQAs.append(ReOrderedQA) - - for metric in total.keys(): - if metric != "time_cost": # 跳过time_cost(特殊处理) - value = ReOrderedQA.get(metric) - if value is not None: - total[metric].append(float(value)) - - if "time_cost" in QA: - for sub_metric in time_cost_metrics: - value = QA["time_cost"].get(sub_metric) - if value is not None: - total["time_cost"][sub_metric].append(float(value)) - except Exception as e: - print(f"QA {QA} error {e}") - - # 计算平均值 - avg = {} - for metric, values in total.items(): - if metric != "time_cost": - avg[metric] = sum(values) / len(values) if values else 0.0 - else: # 处理time_cost - avg_time_cost = {} - for sub_metric, sub_values in values.items(): - avg_time_cost[sub_metric] = ( - sum(sub_values) / len(sub_values) if sub_values else 0.0 - ) - avg[metric] = avg_time_cost - - - excel_path = current_dir / 'answer.xlsx' - with pd.ExcelWriter(excel_path, engine='xlsxwriter') as writer: - # 写入第一个sheet(测试样例) - df = pd.DataFrame(newQAs).astype(str) - df.to_excel(writer, sheet_name="测试样例", index=False) - - # 写入第二个sheet(测试结果) - filtered_time_cost = {k: v for k, v in avg["time_cost"].items() if v != 0} - flat_avg = { - **{k: v for k, v in avg.items() if k != "time_cost"}, - **{f"time_cost_{k}": v for k, v in filtered_time_cost.items()}, - } - print(f"写入测试结果:{flat_avg}") - avg_df = pd.DataFrame([flat_avg]) - avg_df.to_excel(writer, sheet_name="测试结果", index=False) - - - print(f'测试样例和结果已输出到{excel_path}') diff --git a/test/token_tool.py b/test/token_tool.py new file mode 100644 index 0000000000000000000000000000000000000000..36e067bcb1e1ebefbbd18fa3e5711f76f2b71d13 --- /dev/null +++ b/test/token_tool.py @@ -0,0 +1,506 @@ +import asyncio +import tiktoken +import jieba +from jieba.analyse import extract_tags +import yaml +import json +import re +import uuid +import numpy as np +from pydantic import BaseModel, Field +from llm import LLM +from embedding import Embedding +from config import BaseConfig + + +class Grade(BaseModel): + content_len: int = Field(..., description="内容长度") + tokens: int = Field(..., description="token数") + + +class TokenTool: + stop_words_path = "./stopwords.txt" + prompt_path = "./prompt.yaml" + with open(stop_words_path, 'r', encoding='utf-8') as f: + stopwords = set(line.strip() for line in f) + with open(prompt_path, 'r', encoding='utf-8') as f: + prompt_dict = yaml.load(f, Loader=yaml.SafeLoader) + + @staticmethod + def filter_stopwords(content: str) -> str: + """ + 过滤停用词 + """ + try: + words = TokenTool.split_words(content) + filtered_words = [word for word in words if word not in TokenTool.stopwords] + return ' '.join(filtered_words) + except Exception as e: + err = f"[TokenTool] 过滤停用词失败 {e}" + print("[TokenTool] %s", err) + return content + + @staticmethod + def get_leave_tokens_from_content_len(content: str) -> int: + """ + 根据内容长度获取留存的token数 + """ + grades = [ + Grade(content_len=0, tokens=0), + Grade(content_len=10, tokens=8), + Grade(content_len=50, tokens=16), + Grade(content_len=250, tokens=32), + Grade(content_len=1250, tokens=64), + Grade(content_len=6250, tokens=128), + Grade(content_len=31250, tokens=256), + Grade(content_len=156250, tokens=512), + Grade(content_len=781250, tokens=1024), + ] + tokens = TokenTool.get_tokens(content) + if tokens >= grades[-1].tokens: + return 1024 + index = 0 + for i in range(len(grades)-1): + if grades[i].content_len <= tokens < grades[i+1].content_len: + index = i + break + leave_tokens = grades[index].tokens+(grades[index+1].tokens-grades[index].tokens)*( + tokens-grades[index].content_len)/(grades[index+1].content_len-grades[index].content_len) + return int(leave_tokens) + + @staticmethod + def get_leave_setences_from_content_len(content: str) -> int: + """ + 根据内容长度获取留存的句子数量 + """ + grades = [ + Grade(content_len=0, tokens=0), + Grade(content_len=10, tokens=4), + Grade(content_len=50, tokens=8), + Grade(content_len=250, tokens=16), + Grade(content_len=1250, tokens=32), + Grade(content_len=6250, tokens=64), + Grade(content_len=31250, tokens=128), + Grade(content_len=156250, tokens=256), + Grade(content_len=781250, tokens=512), + ] + sentences = TokenTool.content_to_sentences(content) + if len(sentences) >= grades[-1].tokens: + return 1024 + index = 0 + for i in range(len(grades)-1): + if grades[i].content_len <= len(sentences) < grades[i+1].content_len: + index = i + break + leave_sentences = grades[index].tokens+(grades[index+1].tokens-grades[index].tokens)*( + len(sentences)-grades[index].content_len)/(grades[index+1].content_len-grades[index].content_len) + return int(leave_sentences) + + @staticmethod + def get_tokens(content: str) -> int: + try: + enc = tiktoken.encoding_for_model("gpt-4") + return len(enc.encode(str(content))) + except Exception as e: + err = f"[TokenTool] 获取token失败 {e}" + print("[TokenTool] %s", err) + return 0 + + @staticmethod + def get_k_tokens_words_from_content(content: str, k: int = 16) -> list: + try: + if (TokenTool.get_tokens(content) <= k): + return content + l = 0 + r = len(content) + while l+1 < r: + mid = (l+r)//2 + if (TokenTool.get_tokens(content[:mid]) <= k): + l = mid + else: + r = mid + return content[:l] + except Exception as e: + err = f"[TokenTool] 获取k个token的词失败 {e}" + print("[TokenTool] %s", err) + return "" + + @staticmethod + def split_str_with_slide_window(content: str, slide_window_size: int) -> list: + """ + 将字符串按滑动窗口切割 + """ + result = [] + try: + while len(content) > 0: + sub_content = TokenTool.get_k_tokens_words_from_content(content, slide_window_size) + result.append(sub_content) + content = content[len(sub_content):] + return result + except Exception as e: + err = f"[TokenTool] 滑动窗口切割失败 {e}" + print("[TokenTool] %s", err) + return [] + + @staticmethod + def compress_tokens(content: str, k: int = None) -> str: + try: + words = TokenTool.split_words(content) + # 过滤掉停用词 + filtered_words = [ + word for word in words if word not in TokenTool.stopwords + ] + filtered_content = ''.join(filtered_words) + if k is not None: + # 如果k不为None,则获取k个token的词 + filtered_content = TokenTool.get_k_tokens_words_from_content(filtered_content, k) + return filtered_content + except Exception as e: + err = f"[TokenTool] 压缩token失败 {e}" + print("[TokenTool] %s", err) + return content + + @staticmethod + def split_words(content: str) -> list: + try: + return list(jieba.cut(str(content))) + except Exception as e: + err = f"[TokenTool] 分词失败 {e}" + print("[TokenTool] %s", err) + return [] + + @staticmethod + def get_top_k_keywords(content: str, k=10) -> list: + try: + # 使用jieba提取关键词 + keywords = extract_tags(content, topK=k, withWeight=True) + return [keyword for keyword, weight in keywords] + except Exception as e: + err = f"[TokenTool] 获取关键词失败 {e}" + print("[TokenTool] %s", err) + return [] + + @staticmethod + def get_top_k_keywords_and_weights(content: str, k=10) -> list: + try: + # 使用jieba提取关键词 + keyword_weight_list = extract_tags(content, topK=k, withWeight=True) + keywords = [keyword for keyword, weight in keyword_weight_list] + weights = [weight for keyword, weight in keyword_weight_list] + return keywords, weights + except Exception as e: + err = f"[TokenTool] 获取关键词失败 {e}" + print("[TokenTool] %s", err) + return [] + + @staticmethod + def get_top_k_keysentence(content: str, k: int = None) -> list: + """ + 获取前k个关键句子 + """ + if k is None: + k = TokenTool.get_leave_setences_from_content_len(content) + leave_tokens = TokenTool.get_leave_tokens_from_content_len(content) + words = TokenTool.split_words(content) + # 过滤掉停用词 + filtered_words = [ + word for word in words if word not in TokenTool.stopwords + ] + keywords = TokenTool.get_top_k_keywords(''.join(filtered_words), leave_tokens) + keywords = set(keywords) + sentences = TokenTool.content_to_sentences(content) + sentence_and_score_list = [] + index = 0 + for sentence in sentences: + score = 0 + words = TokenTool.split_words(sentence) + for word in words: + if word in keywords: + score += 1 + sentence_and_score_list.append((index, sentence, score)) + index += 1 + sentence_and_score_list.sort(key=lambda x: x[1], reverse=True) + top_k_sentence_and_score_list = sentence_and_score_list[:k] + top_k_sentence_and_score_list.sort(key=lambda x: x[0]) + return [sentence for index, sentence, score in top_k_sentence_and_score_list] + + @staticmethod + async def cal_recall(answer_1: str, answer_2: str, language: str) -> float: + """ + 计算recall + 参数: + answer_1:答案1 + answer_2:答案2 + llm:大模型 + """ + llm = LLM( + openai_api_base=BaseConfig().get_config().llm.llm_endpoint, + openai_api_key=BaseConfig().get_config().llm.llm_api_key, + model_name=BaseConfig().get_config().llm.llm_model_name, + max_tokens=BaseConfig().get_config().llm.max_tokens, + temperature=BaseConfig().get_config().llm.temperature + ) + try: + prompt_template = TokenTool.prompt_dict.get('ANSWER_TO_ANSWER_PROMPT', {}) + prompt_template = prompt_template.get(language, '') + answer_1 = TokenTool.get_k_tokens_words_from_content(answer_1, llm.max_tokens//2) + answer_2 = TokenTool.get_k_tokens_words_from_content(answer_2, llm.max_tokens//2) + prompt = prompt_template.format(text_1=answer_1, text_2=answer_2) + sys_call = prompt + user_call = '请输出相似度' + similarity = await llm.nostream([], sys_call, user_call) + return eval(similarity) + except Exception as e: + err = f"[TokenTool] 计算recall失败 {e}" + print("[TokenTool] %s", err) + return -1 + + @staticmethod + async def cal_precision(question: str, content: str, language: str) -> float: + """ + 计算precision + 参数: + question:问题 + content:内容 + """ + llm = LLM( + openai_api_base=BaseConfig().get_config().llm.llm_endpoint, + openai_api_key=BaseConfig().get_config().llm.llm_api_key, + model_name=BaseConfig().get_config().llm.llm_model_name, + max_tokens=BaseConfig().get_config().llm.max_tokens, + temperature=BaseConfig().get_config().llm.temperature + ) + try: + prompt_template = TokenTool.prompt_dict.get('CONTENT_TO_STATEMENTS_PROMPT', {}) + prompt_template = prompt_template.get(language, '') + content = TokenTool.compress_tokens(content, llm.max_tokens) + sys_call = prompt_template.format(content=content) + user_call = '请结合文本输出陈诉列表' + statements = await llm.nostream([], sys_call, user_call, st_str='[', + en_str=']') + statements = json.loads(statements) + if len(statements) == 0: + return 0 + score = 0 + prompt_template = TokenTool.prompt_dict.get('STATEMENTS_TO_QUESTION_PROMPT', {}) + prompt_template = prompt_template.get(language, '') + for statement in statements: + statement = TokenTool.get_k_tokens_words_from_content(statement, llm.max_tokens) + prompt = prompt_template.format(statement=statement, question=question) + sys_call = prompt + user_call = '请结合文本输出YES或NO' + yn = await llm.nostream([], sys_call, user_call) + yn = yn.lower() + if yn == 'yes': + score += 1 + return score/len(statements)*100 + except Exception as e: + err = f"[TokenTool] 计算precision失败 {e}" + print("[TokenTool] %s", err) + return -1 + + @staticmethod + async def cal_faithfulness(question: str, answer: str, content: str, language: str) -> float: + """ + 计算faithfulness + 参数: + question:问题 + answer:答案 + """ + llm = LLM( + openai_api_base=BaseConfig().get_config().llm.llm_endpoint, + openai_api_key=BaseConfig().get_config().llm.llm_api_key, + model_name=BaseConfig().get_config().llm.llm_model_name, + max_tokens=BaseConfig().get_config().llm.max_tokens, + temperature=BaseConfig().get_config().llm.temperature + ) + try: + prompt_template = TokenTool.prompt_dict.get('QA_TO_STATEMENTS_PROMPT', {}) + prompt_template = prompt_template.get(language, '') + question = TokenTool.get_k_tokens_words_from_content(question, llm.max_tokens//8) + answer = TokenTool.get_k_tokens_words_from_content(answer, llm.max_tokens//8*7) + prompt = prompt_template.format(question=question, answer=answer) + sys_call = prompt + user_call = '请结合问题和答案输出陈诉' + statements = await llm.nostream([], sys_call, user_call, st_str='[', + en_str=']') + prompt_template = TokenTool.prompt_dict.get('STATEMENTS_TO_FRAGMENT_PROMPT', {}) + prompt_template = prompt_template.get(language, '') + statements = json.loads(statements) + if len(statements) == 0: + return 0 + score = 0 + content = TokenTool.compress_tokens(content, llm.max_tokens//8*7) + for statement in statements: + statement = TokenTool.get_k_tokens_words_from_content(statement, llm.max_tokens//8) + prompt = prompt_template.format(statement=statement, fragment=content) + sys_call = prompt + user_call = '请输出YES或NO' + user_call = user_call + yn = await llm.nostream([], sys_call, user_call) + yn = yn.lower() + if yn == 'yes': + score += 1 + return score/len(statements)*100 + except Exception as e: + err = f"[TokenTool] 计算faithfulness失败 {e}" + print("[TokenTool] %s", err) + return -1 + + @staticmethod + def cosine_distance_numpy(vector1, vector2): + # 计算向量的点积 + dot_product = np.dot(vector1, vector2) + # 计算向量的 L2 范数 + norm_vector1 = np.linalg.norm(vector1) + norm_vector2 = np.linalg.norm(vector2) + # 计算余弦相似度 + cosine_similarity = dot_product / (norm_vector1 * norm_vector2) + # 计算余弦距离 + cosine_dist = 1 - cosine_similarity + return cosine_dist + + @staticmethod + async def cal_relevance(question: str, answer: str, language: str) -> float: + """ + 计算relevance + 参数: + question:问题 + answer:答案 + """ + llm = LLM( + openai_api_base=BaseConfig().get_config().llm.llm_endpoint, + openai_api_key=BaseConfig().get_config().llm.llm_api_key, + model_name=BaseConfig().get_config().llm.llm_model_name, + max_tokens=BaseConfig().get_config().llm.max_tokens, + temperature=BaseConfig().get_config().llm.temperature + ) + try: + prompt_template = TokenTool.prompt_dict.get('GENERATE_QUESTION_FROM_CONTENT_PROMPT', {}) + prompt_template = prompt_template.get(language, '') + answer = TokenTool.get_k_tokens_words_from_content(answer, llm.max_tokens) + sys_call = prompt_template.format(k=5, content=answer) + user_call = '请结合文本输出问题列表' + question_vector = await Embedding.vectorize_embedding(question) + qs = await llm.nostream([], sys_call, user_call) + qs = json.loads(qs) + if len(qs) == 0: + return 0 + score = 0 + for q in qs: + q_vector = await Embedding.vectorize_embedding(q) + score += TokenTool.cosine_distance_numpy(question_vector, q_vector) + return (score/len(qs)+1)/2*100 + except Exception as e: + err = f"[TokenTool] 计算relevance失败 {e}" + print("[TokenTool] %s", err) + return -1 + + @staticmethod + def cal_lcs(str1: str, str2: str) -> float: + """ + 计算两个字符串的最长公共子序列长度得分 + """ + try: + words1 = TokenTool.split_words(str1) + words2 = TokenTool.split_words(str2) + new_words1 = [] + new_words2 = [] + for word in words1: + if word not in TokenTool.stopwords: + new_words1.append(word) + for word in words2: + if word not in TokenTool.stopwords: + new_words2.append(word) + if len(new_words1) == 0 and len(new_words2) == 0: + return 100 + if len(new_words1) == 0 or len(new_words2) == 0: + return 0 + m = len(new_words1) + n = len(new_words2) + dp = np.zeros((m+1, n+1)) + for i in range(1, m+1): + for j in range(1, n+1): + if new_words1[i-1] == new_words2[j-1]: + dp[i][j] = dp[i-1][j-1] + 1 + else: + dp[i][j] = max(dp[i-1][j], dp[i][j-1]) + lcs_length = dp[m][n] + score = lcs_length / min(len(new_words1), len(new_words2)) * 100 + return score + except Exception as e: + err = f"[TokenTool] 计算lcs失败 {e}" + print("[TokenTool] %s", err) + return -1 + + @staticmethod + def cal_leve(str1: str, str2: str) -> float: + """ + 计算两个字符串的编辑距离 + """ + try: + words1 = TokenTool.split_words(str1) + words2 = TokenTool.split_words(str2) + new_words1 = [] + new_words2 = [] + for word in words1: + if word not in TokenTool.stopwords: + new_words1.append(word) + for word in words2: + if word not in TokenTool.stopwords: + new_words2.append(word) + if len(new_words1) == 0 and len(new_words2) == 0: + return 100 + if len(new_words1) == 0 or len(new_words2) == 0: + return 0 + m = len(new_words1) + n = len(new_words2) + dp = np.zeros((m+1, n+1)) + for i in range(m+1): + dp[i][0] = i + for j in range(n+1): + dp[0][j] = j + for i in range(1, m+1): + for j in range(1, n+1): + if new_words1[i-1] == new_words2[j-1]: + dp[i][j] = dp[i-1][j-1] + else: + dp[i][j] = min(dp[i-1][j]+1, dp[i][j-1]+1, dp[i-1][j-1]+1) + edit_distance = dp[m][n] + score = (1 - edit_distance / max(len(new_words1), len(new_words2))) * 100 + return score + except Exception as e: + err = f"[TokenTool] 计算leve失败 {e}" + print("[TokenTool] %s", err) + return -1 + + @staticmethod + def cal_jac(str1: str, str2: str) -> float: + """ + 计算两个字符串的Jaccard相似度 + """ + try: + if len(str1) == 0 and len(str2) == 0: + return 100 + words1 = TokenTool.split_words(str1) + words2 = TokenTool.split_words(str2) + new_words1 = [] + new_words2 = [] + for word in words1: + if word not in TokenTool.stopwords: + new_words1.append(word) + for word in words2: + if word not in TokenTool.stopwords: + new_words2.append(word) + if len(new_words1) == 0 or len(new_words2) == 0: + return 0 + set1 = set(new_words1) + set2 = set(new_words2) + intersection = len(set1.intersection(set2)) + union = len(set1.union(set2)) + score = intersection / union * 100 + return score + except Exception as e: + err = f"[TokenTool] 计算jac失败 {e}" + print("[TokenTool] %s", err) + return -1 diff --git a/test/tools/=1.21.6, b/test/tools/=1.21.6, deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/test/tools/config.py b/test/tools/config.py deleted file mode 100644 index e677904723fd6896f8497af2aa6f543e6564b5a8..0000000000000000000000000000000000000000 --- a/test/tools/config.py +++ /dev/null @@ -1,10 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. -config = { - "PROMPT_PATH": "./tools/prompt.yaml", - "MODEL_NAME": "your_model_name", # Replace with your actual model name - "OPENAI_API_BASE": "your_openai_api_base_url", - "OPENAI_API_KEY": "your_openai_api_key", - "REQUEST_TIMEOUT": 120, - "MAX_TOKENS": 8096, - "MODEL_ENH": "false", -} diff --git a/test/tools/prompt.yaml b/test/tools/prompt.yaml deleted file mode 100644 index 7a3031b9bc28e7554fc21667dbcf19a5991f3d7d..0000000000000000000000000000000000000000 --- a/test/tools/prompt.yaml +++ /dev/null @@ -1,82 +0,0 @@ -GENERATE_QA: "你是一个问答生成专家,你的任务是根据我提供的段落内容和已有的问题,生成{qa_count}个不重复的针对该段落内容的问题与回答, -并判断这个问答对的属于领域,并只输出问题、回答、领域。 - -注意: - -1. 单个回答长度必须大于30字小于120字 - -2. 问题不能出现重复 - -3. 请指定明确的场景,如'xx公司', 'xx系统', 'xx项目', ‘xx软件'等 - -4. 问题中不要使用模糊的指代词, 如'这'、'那' - -5. 划分领域的时候请忽略上下文内容,领域大概可以分为(建筑、园林、摄影、戏剧、戏曲、舞蹈、音乐、书法、绘画、雕塑、美食、营养、健身、运动、旅游、地理、气象、海洋、地质、生态、天文、化学、物理、生物、数学、统计、逻辑、人工智能、大数据、云计算、网络、通信、自动化、机械、电子、材料、能源、化工、纺织、服装、美容、美发、礼仪、公关、广告、营销、管理、金融、证券、保险、期货、税务、审计、会计、法律实务、知识产权) - -6. 问题必须与段落内容有逻辑关系 - -7. 问题与回答在不重复的前提下,应当尽可能多地包含段落内容 - -8. 输出的格式为: -[ - -{{ - \"question\": \" 问题 \", - \"answer\": \" 回答 \", - \"type\": \" 领域 \" -}} - -, - -{{ - \"question\": \" 问题 \", - \"answer\": \" 回答 \", - \"type\": \" 领域 \" -}} - -] - -10. 不要输出多余内容 - -下面是给出的段落内容: - -{chunk} - -下面是段落的上下文内容: - -{text} - -下面是段落的来源文件 -{file_name} -" -SCORE_QA: "你是一个打分专家,你的任务是根据我提供的问题、原始片段和检索到的片段以及标准答案和答案,判断答案在下面四项指标的分数,每个指标要精确到小数点后面2位,且每次需要进行客观评价 - -1.context_relevancy 解释:(上下文相关性,越高表示检索到的片段中无用的信息越少 0-100) -2.context_recall 解释:(召回率,越高表示检索出来的片段与标准答案越相关 0-100) -3.faithfulness 解释:(忠实性,越高表示答案的生成使用了越多检索出来的片段0-100) -4.answer_relevancy 解释:(答案与问题的相关性 0-100) - -注意: -请以下面格式输出 -{{ - \"context_relevancy\": 分数, - \"context_recall\": 分数, - \"faithfulness\": 分数, - \"answer_relevancy\": 分数 -}} - -下面是问题: -{question} - -下面是原始片段: -{meta_chunk} - -下面是检索到的片段: -{chunk} - -下面是标准答案: -{answer} - -下面是答案: -{answer_text} -" diff --git a/test/tools/similar_cal_tool.py b/test/tools/similar_cal_tool.py deleted file mode 100644 index 56319a42ae22b151ec3735d81c0607063509278d..0000000000000000000000000000000000000000 --- a/test/tools/similar_cal_tool.py +++ /dev/null @@ -1,158 +0,0 @@ -import jieba -import jieba.analyse -import synonyms - -class Similar_cal_tool: - with open('./tools/stopwords.txt', 'r', encoding='utf-8') as f: - stopwords = set(f.read().splitlines()) - - @staticmethod - def normalized_scores(scores): - min_score = None - max_score = None - for score in scores: - if min_score is None: - min_score = score - else: - min_score = min(min_score, score) - if max_score is None: - max_score = score - else: - max_score = max(max_score, score) - if min_score == max_score: - for i in range(len(scores)): - scores[i] = 1 - else: - for i in range(len(scores)): - scores[i] = (scores[i]-min_score)/(max_score-min_score) - return scores - - @staticmethod - def filter_stop_words(text): - words = jieba.lcut(text) - filtered_words = [word for word in words if word not in Similar_cal_tool.stopwords] - text = ''.join(filtered_words) - return text - - @staticmethod - def extract_keywords_sorted(text, topK=10): - keywords = jieba.analyse.textrank(text, topK=topK, withWeight=False) - return keywords - - @staticmethod - def get_synonyms_score_dict(word): - try: - syns, scores = synonyms.nearby(word) - scores = Similar_cal_tool.normalized_scores(scores) - syns_scores_dict = {} - for syn, score in tuple(syns, scores): - syns_scores_dict[syn] = score - return syns_scores_dict - except: - return {word: 1} - - @staticmethod - def text_to_keywords(text): - words = jieba.lcut(text) - if len(set(words)) <64: - return words - topK = 5 - lv = 64 - while lv < len(words): - topK *= 2 - lv *= 2 - keywords_sorted = Similar_cal_tool.extract_keywords_sorted(text, topK) - keywords_sorted_set = set(keywords_sorted) - new_words = [] - for word in words: - if word in keywords_sorted_set: - new_words.append(word) - return new_words - @staticmethod - def cal_syns_word_score(word, syns_scores_dict): - if word not in syns_scores_dict: - return 0 - return syns_scores_dict[word] - @staticmethod - def longest_common_subsequence(str1, str2): - words1 = Similar_cal_tool.text_to_keywords(str1) - words2 = Similar_cal_tool.text_to_keywords(str2) - m, n = len(words1), len(words2) - if m == 0 and n == 0: - return 1 - if m == 0: - return 0 - if n == 0: - return 0 - dp = [[0]*(n+1) for _ in range(m+1)] - syns_scores_dicts_1 = [] - syns_scores_dicts_2 = [] - for word in words1: - syns_scores_dicts_1.append(Similar_cal_tool.get_synonyms_score_dict(word)) - for word in words2: - syns_scores_dicts_2.append(Similar_cal_tool.get_synonyms_score_dict(word)) - - for i in range(1, m+1): - for j in range(1, n+1): - dp[i][j] = max(dp[i-1][j], dp[i][j-1]) - dp[i][j] = dp[i-1][j-1] + (Similar_cal_tool.cal_syns_word_score(words1[i-1], syns_scores_dicts_2[j-1] - )+Similar_cal_tool.cal_syns_word_score(words2[j-1], syns_scores_dicts_1[i-1])) - - return dp[m][n]/(2*min(m,n)) - - def jaccard_distance(str1, str2): - words1 = set(Similar_cal_tool.text_to_keywords(str1)) - words2 = set(Similar_cal_tool.text_to_keywords(str2)) - m, n = len(words1), len(words2) - if m == 0 and n == 0: - return 1 - if m == 0: - return 0 - if n == 0: - return 0 - syns_scores_dict_1 = {} - syns_scores_dict_2 = {} - for word in words1: - tmp_dict=Similar_cal_tool.get_synonyms_score_dict(word) - for key,val in tmp_dict.items(): - syns_scores_dict_1[key]=max(syns_scores_dict_1.get(key,0),val) - for word in words2: - tmp_dict=Similar_cal_tool.get_synonyms_score_dict(word) - for key,val in tmp_dict.items(): - syns_scores_dict_2[key]=max(syns_scores_dict_2.get(key,0),val) - sum=0 - for word in words1: - sum+=Similar_cal_tool.cal_syns_word_score(word,syns_scores_dict_2) - for word in words2: - sum+=Similar_cal_tool.cal_syns_word_score(word,syns_scores_dict_2) - return sum/(len(words1)+len(words2)) - def levenshtein_distance(str1, str2): - words1 = Similar_cal_tool.text_to_keywords(str1) - words2 = Similar_cal_tool.text_to_keywords(str2) - m, n = len(words1), len(words2) - if m == 0 and n == 0: - return 1 - if m == 0: - return 0 - if n == 0: - return 0 - dp = [[0]*(n+1) for _ in range(m+1)] - syns_scores_dicts_1 = [] - syns_scores_dicts_2 = [] - for word in words1: - syns_scores_dicts_1.append(Similar_cal_tool.get_synonyms_score_dict(word)) - for word in words2: - syns_scores_dicts_2.append(Similar_cal_tool.get_synonyms_score_dict(word)) - dp = [[0 for _ in range(n + 1)] for _ in range(m + 1)] - - for i in range(m + 1): - dp[i][0] = i - for j in range(n + 1): - dp[0][j] = j - - for i in range(1, m + 1): - for j in range(1, n + 1): - dp[i][j] = 1 + min(dp[i - 1][j], dp[i][j - 1], dp[i - 1][j - 1]) - dp[i][j] = min(dp[i][j],dp[i - 1][j - 1]+1-((Similar_cal_tool.cal_syns_word_score(words1[i-1], syns_scores_dicts_2[j-1] - )+Similar_cal_tool.cal_syns_word_score(words2[j-1], syns_scores_dicts_1[i-1])))/2) - return 1-dp[m][n]/(m+n) diff --git "a/test/witchainD\346\265\213\350\257\225\346\214\207\345\257\274.docm" "b/test/witchainD\346\265\213\350\257\225\346\214\207\345\257\274.docm" deleted file mode 100644 index dd3b2489e2a5a895fa19daacfd79623f6d78e4e1..0000000000000000000000000000000000000000 Binary files "a/test/witchainD\346\265\213\350\257\225\346\214\207\345\257\274.docm" and /dev/null differ