From d35561d293cf54e3d24a8274fec5bf95b94a0351 Mon Sep 17 00:00:00 2001 From: zxstty Date: Fri, 18 Apr 2025 14:47:38 +0800 Subject: [PATCH 1/4] =?UTF-8?q?=E4=BF=AE=E5=A4=8D=E5=87=BD=E6=95=B0?= =?UTF-8?q?=E5=90=8D=E7=A7=B0bug?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- chat2db/app/router/database.py | 2 +- chat2db/app/service/sql_generate_service.py | 2 +- chat2db/common/.env.example | 35 +++++++++++---------- 3 files changed, 20 insertions(+), 19 deletions(-) diff --git a/chat2db/app/router/database.py b/chat2db/app/router/database.py index 2906412..6a9a735 100644 --- a/chat2db/app/router/database.py +++ b/chat2db/app/router/database.py @@ -173,7 +173,7 @@ async def generate_sql_from_database(request: DatabaseSqlGenerateRequest): else: table_id_list = None results = {} - sql_list = await SqlGenerateService.generate_sql_base_on_exmpale( + sql_list = await SqlGenerateService.generate_sql_base_on_example( database_id=database_id, question=question, table_id_list=table_id_list, use_llm_enhancements=use_llm_enhancements) try: diff --git a/chat2db/app/service/sql_generate_service.py b/chat2db/app/service/sql_generate_service.py index 9a1f9ec..0b2ce0f 100644 --- a/chat2db/app/service/sql_generate_service.py +++ b/chat2db/app/service/sql_generate_service.py @@ -253,7 +253,7 @@ class SqlGenerateService(): table_id = table_info['table_id'] column_info_list = data_frame.get('column_info_list', '') note = await SqlGenerateService.merge_table_and_column_info(table_info, column_info_list) - sql_example = await SqlGenerateService.megre_sql_example(data_frame.get('sql_example_list', [])) + sql_example = await SqlGenerateService.merge_sql_example(data_frame.get('sql_example_list', [])) try: prompt = prompt.format( database_url=database_url, note=note, k=len(data_frame.get('sql_example_list', [])), diff --git a/chat2db/common/.env.example b/chat2db/common/.env.example index 9011268..818b356 100644 --- a/chat2db/common/.env.example +++ b/chat2db/common/.env.example @@ -1,26 +1,27 @@ # FastAPI -UVICORN_IP= -UVICORN_PORT= -SSL_CERTFILE= -SSL_KEYFILE= -SSL_ENABLE= +UVICORN_IP=0.0.0.0 +UVICORN_PORT=9015 +# SSL_CERTFILE= +# SSL_KEYFILE= +# SSL_ENABLE= # Postgres -DATABASE_URL= +DATABASE_URL=postgresql+psycopg2://postgres:123456@0.0.0.0:5444/postgres # QWEN -LLM_KEY= -LLM_URL= -LLM_MAX_TOKENS= -LLM_MODEL= +LLM_KEY=sk-gcdlwtzbzgloaogjdpkvaumftdcbxjufqadbgxwecwdasnaw +LLM_URL=https://api.siliconflow.cn/v1 +LLM_MODEL=Qwen/Qwen2.5-32B-Instruct +LLM_MAX_TOKENS=4096 + # Vectorize -EMBEDDING_TYPE= -EMBEDDING_API_KEY= -EMBEDDING_ENDPOINT= -EMBEDDING_MODEL_NAME= +EMBEDDING_TYPE = 'openai' +EMBEDDING_API_KEY = 'sk-123456' +EMBEDDING_ENDPOINT = 'http://1.94.145.11:8000/v1/embeddings' +EMBEDDING_MODEL_NAME = 'bge-m3:Q4_K_S' # security -HALF_KEY1= -HALF_KEY2= -HALF_KEY3= \ No newline at end of file +HALF_KEY1='123456' +HALF_KEY2='123456' +HALF_KEY3='123456' \ No newline at end of file -- Gitee From a5d81c42e99409e2e53bd4de6d94daa391ba29b9 Mon Sep 17 00:00:00 2001 From: zxstty Date: Sat, 19 Apr 2025 16:16:01 +0800 Subject: [PATCH 2/4] =?UTF-8?q?=E5=A2=9E=E5=8A=A0chat2db=E6=A1=88=E4=BE=8B?= =?UTF-8?q?=E5=88=9D=E5=A7=8B=E5=8C=96=E8=84=9A=E6=9C=AC?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- chat2db/app/router/database.py | 8 +- chat2db/common/init_sql_example.py | 124 ++++++++++++++++++----- chat2db/common/table_name.yaml | 10 ++ chat2db/common/table_name_id.yaml | 10 -- chat2db/manager/database_info_manager.py | 18 +++- chat2db/model/request.py | 3 +- run.sh | 1 + 7 files changed, 137 insertions(+), 37 deletions(-) create mode 100644 chat2db/common/table_name.yaml delete mode 100644 chat2db/common/table_name_id.yaml diff --git a/chat2db/app/router/database.py b/chat2db/app/router/database.py index 6a9a735..c934eff 100644 --- a/chat2db/app/router/database.py +++ b/chat2db/app/router/database.py @@ -47,7 +47,7 @@ async def add_database_info(request: DatabaseAddRequest): return ResponseData( code=status.HTTP_422_UNPROCESSABLE_ENTITY, message="数据库连接添加失败,当前存在重复数据库配置", - result={} + result={'database_id': database_id} ) return ResponseData( code=status.HTTP_200_OK, @@ -59,7 +59,11 @@ async def add_database_info(request: DatabaseAddRequest): @router.post("/del", response_model=ResponseData) async def del_database_info(request: DatabaseDelRequest): database_id = request.database_id - flag = await DatabaseInfoManager.del_database_by_id(database_id) + database_url = request.database_url + if database_id: + flag = await DatabaseInfoManager.del_database_by_id(database_id) + else: + flag= await DatabaseInfoManager.del_database_by_url(database_url) if not flag: return ResponseData( code=status.HTTP_422_UNPROCESSABLE_ENTITY, diff --git a/chat2db/common/init_sql_example.py b/chat2db/common/init_sql_example.py index 4288dfa..800f171 100644 --- a/chat2db/common/init_sql_example.py +++ b/chat2db/common/init_sql_example.py @@ -1,29 +1,107 @@ import yaml +from fastapi import status import requests -chat2db_url='http://0.0.0.0:9015' -with open('table_name_id.yaml') as f: - table_name_id=yaml.load(f,Loader=yaml.SafeLoader) -with open('table_name_sql_exmple.yaml') as f: - table_name_sql_example_list=yaml.load(f,Loader=yaml.SafeLoader) +import uuid +from typing import Optional +from pydantic import BaseModel, Field +from chat2db.config.config import config +ip = config['UVICORN_IP'] +port = config['UVICORN_PORT'] +base_url = f'http://{ip}:{port}' +database_url = config['DATABASE_URL'] + + +class DatabaseDelRequest(BaseModel): + database_id: Optional[str] = Field(default=None, description="数据库id") + database_url: Optional[str] = Field(default=None, description="数据库url") + + +def del_database_url(base_url, database_url): + server_url = f'{base_url}/database/del' + try: + request_data = DatabaseDelRequest(database_url=database_url).dict() + response = requests.post(server_url, json=request_data) + if response.json()['code'] != status.HTTP_200_OK: + print(response.json()['message']) + except Exception as e: + print(f"删除数据库配置失败: {e}") + exit(0) + return None + + +class DatabaseAddRequest(BaseModel): + database_url: str + + +def add_database_url(base_url, database_url): + server_url = f'{base_url}/database/add' + try: + request_data = DatabaseAddRequest(database_url=database_url).dict() + + response = requests.post(server_url, json=request_data) + response.raise_for_status() + if response.json()['code'] != status.HTTP_200_OK: + raise Exception(response.json()['message']) + except Exception as e: + print(f"增加数据库配置失败: {e}") + exit(0) + return response.json()['result']['database_id'] + + +class TableAddRequest(BaseModel): + database_id: str + table_name: str + + +def add_table(base_url, database_id, table_name): + server_url = f'{base_url}/table/add' + try: + request_data = TableAddRequest(database_id=database_id, table_name=table_name).dict() + response = requests.post(server_url, json=request_data) + response.raise_for_status() + if response.json()['code'] != status.HTTP_200_OK: + raise Exception(response.json()['message']) + except Exception as e: + print(f"增加表配置失败: {e}") + return + return response.json()['result']['table_id'] + + +class SqlExampleAddRequest(BaseModel): + table_id: str + question: str + sql: str + + +def add_sql_example(base_url, table_id, question, sql): + server_url = f'{base_url}/sql/example/add' + try: + request_data = SqlExampleAddRequest(table_id=table_id, question=question, sql=sql).dict() + response = requests.post(server_url, json=request_data) + if response.json()['code'] != status.HTTP_200_OK: + raise Exception(response.json()['message']) + except Exception as e: + print(f"增加sql案例失败: {e}") + return + return response.json()['result']['sql_example_id'] + + +database_id = del_database_url(base_url, database_url) +database_id = add_database_url(base_url, database_url) +with open('./chat2db/common/table_name.yaml') as f: + table_name_list = yaml.load(f, Loader=yaml.SafeLoader) +table_name_id = {} +for table_name in table_name_list: + table_id = add_table(base_url, database_id, table_name) + if table_id: + table_name_id[table_name] = table_id +with open('./chat2db/common/table_name_sql_exmple.yaml') as f: + table_name_sql_example_list = yaml.load(f, Loader=yaml.SafeLoader) for table_name_sql_example in table_name_sql_example_list: - table_name=table_name_sql_example['table_name'] + table_name = table_name_sql_example['table_name'] if table_name not in table_name_id: continue - table_id=table_name_id[table_name] - sql_example_list=table_name_sql_example['sql_example_list'] + table_id = table_name_id[table_name] + sql_example_list = table_name_sql_example['sql_example_list'] for sql_example in sql_example_list: - request_data = { - "table_id": str(table_id), - "question": sql_example['question'], - "sql": sql_example['sql'] - } - url = f"{chat2db_url}/sql/example/add" # 请替换为实际的 API 域名 - - try: - response = requests.post(url, json=request_data) - if response.status_code!=200: - print(f'添加sql案例失败{response.text}') - else: - print(f'添加sql案例成功{response.text}') - except Exception as e: - print(f'添加sql案例失败由于{e}') \ No newline at end of file + add_sql_example(base_url, table_id, sql_example['question'], sql_example['sql']) diff --git a/chat2db/common/table_name.yaml b/chat2db/common/table_name.yaml new file mode 100644 index 0000000..553cf1b --- /dev/null +++ b/chat2db/common/table_name.yaml @@ -0,0 +1,10 @@ +- oe_community_openeuler_version +- oe_community_organization_structure +- oe_compatibility_card +- oe_compatibility_commercial_software +- oe_compatibility_cve_database +- oe_compatibility_oepkgs +- oe_compatibility_osv +- oe_compatibility_overall_unit +- oe_compatibility_security_notice +- oe_compatibility_solution diff --git a/chat2db/common/table_name_id.yaml b/chat2db/common/table_name_id.yaml deleted file mode 100644 index 22d04d5..0000000 --- a/chat2db/common/table_name_id.yaml +++ /dev/null @@ -1,10 +0,0 @@ -oe_community_openeuler_version: 5ffcd194-b895-4d3c-a3ad-c4af4a270d6a -oe_community_organization_structure: 89cd2fd0-a5ca-4ee8-8bfe-44da86a32208 -oe_compatibility_card: fb4dbdda-88c5-482a-bec8-28d51669284e -oe_compatibility_commercial_software: 86ab7dad-4848-48da-8667-72be6c99780f -oe_compatibility_cve_database: 984d1c82-c6d5-4d3d-93d9-8d5bc11254ba -oe_compatibility_oepkgs: bb2698c7-f715-487f-95c4-1061a9c33851 -oe_compatibility_osv: 8c9f6608-e2e2-475d-8f67-b3bdab3e9234 -oe_compatibility_overall_unit: 82b9521a-c924-4c52-aedc-229bca5ea4c0 -oe_compatibility_security_notice: f0fa7cc5-4e5d-4c69-b202-39d522f18383 -oe_compatibility_solution: 7bdaf15c-af9f-4cb8-adec-e7e45f1de6ca diff --git a/chat2db/manager/database_info_manager.py b/chat2db/manager/database_info_manager.py index ff28c9d..cc234fb 100644 --- a/chat2db/manager/database_info_manager.py +++ b/chat2db/manager/database_info_manager.py @@ -37,7 +37,23 @@ class DatabaseInfoManager(): return False session.commit() return True - + + @staticmethod + async def del_database_by_url(database_url): + with PostgresDB.get_session() as session: + hashmac = hashlib.sha256(database_url.encode('utf-8')).hexdigest() + database_info_entry = session.query(DatabaseInfo).filter(DatabaseInfo.hashmac == hashmac).first() + if database_info_entry: + database_info_to_delete = session.query(DatabaseInfo).filter(DatabaseInfo.id == database_info_entry.id).first() + if database_info_to_delete: + session.delete(database_info_to_delete) + else: + return False + else: + return False + session.commit() + return True + @staticmethod async def get_database_url_by_id(id): with PostgresDB.get_session() as session: diff --git a/chat2db/model/request.py b/chat2db/model/request.py index 6f72e85..6d8c955 100644 --- a/chat2db/model/request.py +++ b/chat2db/model/request.py @@ -14,7 +14,8 @@ class DatabaseAddRequest(BaseModel): class DatabaseDelRequest(BaseModel): - database_id: uuid.UUID + database_id: Optional[uuid.UUID] = Field(default=None, description="数据库id") + database_url: Optional[str] = Field(default=None, description="数据库url") class DatabaseSqlGenerateRequest(BaseModel): database_url: str diff --git a/run.sh b/run.sh index b8f5d01..6564a85 100644 --- a/run.sh +++ b/run.sh @@ -1,5 +1,6 @@ #!/usr/bin/env sh java -jar tika-server-standard-2.9.2.jar & +python3 /rag-service/chat2db/common/init_sql_example.py python3 /rag-service/chat2db/app/app.py & python3 /rag-service/data_chain/apps/app.py & -- Gitee From 076c3be452eb6ad5e2988ca26a42336f9c4e3846 Mon Sep 17 00:00:00 2001 From: zxstty Date: Sat, 19 Apr 2025 17:13:35 +0800 Subject: [PATCH 3/4] =?UTF-8?q?=E7=AE=80=E5=8C=96database=5Ftype=E8=AF=86?= =?UTF-8?q?=E5=88=AB=E6=96=B9=E5=BC=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- chat2db/app/router/database.py | 20 +++++++------------- chat2db/app/router/sql_generate.py | 15 +++++---------- chat2db/app/router/table.py | 4 +--- chat2db/app/service/diff_database_service.py | 18 ++++++++++++++++++ chat2db/app/service/keyword_service.py | 8 ++------ chat2db/app/service/sql_generate_service.py | 13 +++---------- 6 files changed, 36 insertions(+), 42 deletions(-) diff --git a/chat2db/app/router/database.py b/chat2db/app/router/database.py index c934eff..37aacca 100644 --- a/chat2db/app/router/database.py +++ b/chat2db/app/router/database.py @@ -26,15 +26,13 @@ router = APIRouter( @router.post("/add", response_model=ResponseData) async def add_database_info(request: DatabaseAddRequest): database_url = request.database_url - if 'mysql' not in database_url and 'postgres' not in database_url: + database_type = DiffDatabaseService.get_database_type_from_url(database_url) + if not DiffDatabaseService.is_database_type_allow(database_type): return ResponseData( code=status.HTTP_422_UNPROCESSABLE_ENTITY, message="不支持当前数据库", result={} ) - database_type = 'postgres' - if 'mysql' in database_url: - database_type = 'mysql' flag = await DiffDatabaseService.get_database_service(database_type).test_database_connection(database_url) if not flag: return ResponseData( @@ -63,7 +61,7 @@ async def del_database_info(request: DatabaseDelRequest): if database_id: flag = await DatabaseInfoManager.del_database_by_id(database_id) else: - flag= await DatabaseInfoManager.del_database_by_url(database_url) + flag = await DatabaseInfoManager.del_database_by_url(database_url) if not flag: return ResponseData( code=status.HTTP_422_UNPROCESSABLE_ENTITY, @@ -90,21 +88,19 @@ async def query_database_info(): @router.get("/list", response_model=ResponseData) async def list_table_in_database(database_id: uuid.UUID, table_filter: str = ''): database_url = await DatabaseInfoManager.get_database_url_by_id(database_id) + database_type = DiffDatabaseService.get_database_type_from_url(database_url) if database_url is None: return ResponseData( code=status.HTTP_422_UNPROCESSABLE_ENTITY, message="查询数据库内表格配置失败,数据库配置不存在", result={} ) - if 'mysql' not in database_url and 'postgres' not in database_url: + if not DiffDatabaseService.is_database_type_allow(database_type): return ResponseData( code=status.HTTP_422_UNPROCESSABLE_ENTITY, message="不支持当前数据库", result={} ) - database_type = 'postgres' - if 'mysql' in database_url: - database_type = 'mysql' flag = await DiffDatabaseService.get_database_service(database_type).test_database_connection(database_url) if not flag: return ResponseData( @@ -130,15 +126,13 @@ async def generate_sql_from_database(request: DatabaseSqlGenerateRequest): table_name_list = request.table_name_list question = request.question use_llm_enhancements = request.use_llm_enhancements - if 'mysql' not in database_url and 'postgres' not in database_url: + database_type = DiffDatabaseService.get_database_type_from_url(database_url) + if not DiffDatabaseService.is_database_type_allow(database_type): return ResponseData( code=status.HTTP_422_UNPROCESSABLE_ENTITY, message="不支持当前数据库", result={} ) - database_type = 'postgres' - if 'mysql' in database_url: - database_type = 'mysql' flag = await DiffDatabaseService.get_database_service(database_type).test_database_connection(database_url) if not flag: return ResponseData( diff --git a/chat2db/app/router/sql_generate.py b/chat2db/app/router/sql_generate.py index 144506d..8a5c7a0 100644 --- a/chat2db/app/router/sql_generate.py +++ b/chat2db/app/router/sql_generate.py @@ -2,7 +2,6 @@ import logging from fastapi import APIRouter, status -import json import sys from chat2db.manager.database_info_manager import DatabaseInfoManager @@ -54,15 +53,13 @@ async def repair_sql(request: SqlRepairRequest): database_id = request.database_id table_id = request.table_id database_url = await DatabaseInfoManager.get_database_url_by_id(database_id) + database_type = DiffDatabaseService.get_database_type_from_url(database_url) if database_url is None: return ResponseData( code=status.HTTP_422_UNPROCESSABLE_ENTITY, message="当前数据库配置不存在", result={} ) - database_type = 'postgres' - if 'mysql' in database_url: - database_type = 'mysql' table_info = await TableInfoManager.get_table_info_by_table_id(table_id) if table_info is None: return ResponseData( @@ -93,8 +90,8 @@ async def repair_sql(request: SqlRepairRequest): code=status.HTTP_200_OK, message="sql修复成功", result={'database_id': database_id, - 'table_id': table_id, - 'sql': sql} + 'table_id': table_id, + 'sql': sql} ) @@ -109,9 +106,7 @@ async def execute_sql(request: SqlExcuteRequest): message="当前数据库配置不存在", result={} ) - database_type = 'postgres' - if 'mysql' in database_url: - database_type = 'mysql' + database_type = DiffDatabaseService.get_database_type_from_url(database_url) try: results = await DiffDatabaseService.database_map[database_type].try_excute(database_url, sql) except Exception as e: @@ -120,7 +115,7 @@ async def execute_sql(request: SqlExcuteRequest): return ResponseData( code=status.HTTP_500_INTERNAL_SERVER_ERROR, message="sql执行失败", - result={'Error':str(e)} + result={'Error': str(e)} ) return ResponseData( code=status.HTTP_200_OK, diff --git a/chat2db/app/router/table.py b/chat2db/app/router/table.py index c47f269..33ca4f9 100644 --- a/chat2db/app/router/table.py +++ b/chat2db/app/router/table.py @@ -31,9 +31,7 @@ async def add_database_info(request: TableAddRequest): message="当前数据库配置不存在", result={} ) - database_type = 'postgres' - if 'mysql' in database_url: - database_type = 'mysql' + database_type = DiffDatabaseService.get_database_type_from_url(database_url) flag = await DiffDatabaseService.get_database_service(database_type).test_database_connection(database_url) if not flag: return ResponseData( diff --git a/chat2db/app/service/diff_database_service.py b/chat2db/app/service/diff_database_service.py index 0fe77b0..22289b1 100644 --- a/chat2db/app/service/diff_database_service.py +++ b/chat2db/app/service/diff_database_service.py @@ -1,10 +1,28 @@ +import re +from urllib.parse import urlparse from chat2db.app.base.mysql import Mysql from chat2db.app.base.postgres import Postgres class DiffDatabaseService(): + database_types = ["mysql", "postgres", "opengauss"] database_map = {"mysql": Mysql, "postgres": Postgres} @staticmethod def get_database_service(database_type): + if database_type not in DiffDatabaseService.database_types: + raise f"不支持当前数据库类型{database_type}" return DiffDatabaseService.database_map[database_type] + + @staticmethod + def get_database_type_from_url(database_url): + result = urlparse(database_url) + try: + database_type = result.scheme.split('+')[0] + except Exception as e: + raise e + return database_type + + @staticmethod + def is_database_type_allow(database_type): + return database_type in DiffDatabaseService.database_types \ No newline at end of file diff --git a/chat2db/app/service/keyword_service.py b/chat2db/app/service/keyword_service.py index f2efb7d..685c341 100644 --- a/chat2db/app/service/keyword_service.py +++ b/chat2db/app/service/keyword_service.py @@ -58,9 +58,7 @@ class KeywordManager(): async def add(self, database_id, table_id, column_name): database_url = await DatabaseInfoManager.get_database_url_by_id(database_id) - database_type = 'postgres' - if 'mysql' in database_url: - database_type = 'mysql' + database_type = DiffDatabaseService.get_database_type_from_url(database_url) table_info = await TableInfoManager.get_table_info_by_table_id(table_id) table_name = table_info['table_name'] tmp_dict = await DiffDatabaseService.get_database_service( @@ -108,9 +106,7 @@ class KeywordManager(): results = [] if database_id in self.keyword_asset_dict.keys(): database_url = await DatabaseInfoManager.get_database_url_by_id(database_id) - database_type = 'postgres' - if 'mysql' in database_url: - database_type = 'mysql' + database_type = DiffDatabaseService.get_database_type_from_url(database_url) for table_id in self.keyword_asset_dict[database_id].keys(): if table_id_list is None or table_id in table_id_list: table_info = self.keyword_asset_dict[database_id][table_id]['table_info'] diff --git a/chat2db/app/service/sql_generate_service.py b/chat2db/app/service/sql_generate_service.py index 0b2ce0f..f20f977 100644 --- a/chat2db/app/service/sql_generate_service.py +++ b/chat2db/app/service/sql_generate_service.py @@ -120,9 +120,7 @@ class SqlGenerateService(): except Exception as e: logging.error(f'数据库{database_id}信息获取失败由于{e}') return [] - database_type = 'postgres' - if 'mysql' in database_url: - database_type = 'mysql' + database_type = DiffDatabaseService.get_database_type_from_url(database_url) del database_url try: question_vector = await Vectorize.vectorize_embedding(question) @@ -233,9 +231,7 @@ class SqlGenerateService(): return {} if database_url is None: raise Exception('数据库配置不存在') - database_type = 'postgres' - if 'mysql' in database_url: - database_type = 'mysql' + database_type = DiffDatabaseService.get_database_type_from_url(database_url) data_frame_list = await SqlGenerateService.find_most_similar_sql_example(database_id, table_id_list, question, use_llm_enhancements) try: with open('./chat2db/templetes/prompt.yaml', 'r', encoding='utf-8') as f: @@ -280,10 +276,7 @@ class SqlGenerateService(): @staticmethod async def generate_sql_base_on_data(database_url, table_name, sql_var=False): database_type = None - if 'postgres' in database_url: - database_type = 'postgres' - if 'mysql' in database_url: - database_type = 'mysql' + database_type = DiffDatabaseService.get_database_type_from_url(database_url) flag = await DiffDatabaseService.get_database_service(database_type).test_database_connection(database_url) if not flag: return None -- Gitee From cbc202990409d01cebb1dc9e79c9508b12fbd17a Mon Sep 17 00:00:00 2001 From: zxstty Date: Sat, 19 Apr 2025 17:46:19 +0800 Subject: [PATCH 4/4] =?UTF-8?q?=E4=BF=AE=E5=A4=8Ddatatype=E8=A7=A3?= =?UTF-8?q?=E6=9E=90=E7=9A=84bug?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- chat2db/app/service/diff_database_service.py | 8 ++++---- chat2db/database/postgres.py | 8 +++++++- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/chat2db/app/service/diff_database_service.py b/chat2db/app/service/diff_database_service.py index 22289b1..bb9f979 100644 --- a/chat2db/app/service/diff_database_service.py +++ b/chat2db/app/service/diff_database_service.py @@ -5,8 +5,8 @@ from chat2db.app.base.postgres import Postgres class DiffDatabaseService(): - database_types = ["mysql", "postgres", "opengauss"] - database_map = {"mysql": Mysql, "postgres": Postgres} + database_types = ["mysql", "postgresql", "opengauss"] + database_map = {"mysql": Mysql, "postgresql": Postgres, "opengauss": Postgres} @staticmethod def get_database_service(database_type): @@ -21,8 +21,8 @@ class DiffDatabaseService(): database_type = result.scheme.split('+')[0] except Exception as e: raise e - return database_type + return database_type.lower() @staticmethod def is_database_type_allow(database_type): - return database_type in DiffDatabaseService.database_types \ No newline at end of file + return database_type in DiffDatabaseService.database_types diff --git a/chat2db/database/postgres.py b/chat2db/database/postgres.py index 230a410..c3b8bf8 100644 --- a/chat2db/database/postgres.py +++ b/chat2db/database/postgres.py @@ -2,7 +2,7 @@ import logging from uuid import uuid4 from pgvector.sqlalchemy import Vector from sqlalchemy.orm import sessionmaker, declarative_base -from sqlalchemy import TIMESTAMP, UUID, Column, String, Boolean, ForeignKey, create_engine, func, Index +from sqlalchemy import TIMESTAMP, UUID, Column, String, Boolean, ForeignKey, create_engine, func, Index import sys from chat2db.config.config import config @@ -91,6 +91,12 @@ class PostgresDB: pool_pre_ping=True) Base.metadata.create_all(cls.engine) + if 'opengauss' in config['DATABASE_URL']: + from sqlalchemy import event + from opengauss_sqlalchemy.register_async import register_vector + @event.listens_for(cls.engine.sync_engine, "connect") + def connect(dbapi_connection, connection_record): + dbapi_connection.run_async(register_vector) return cls._engine @classmethod -- Gitee