From 857a1fed6bf26aa130421d7a4e999bb976640886 Mon Sep 17 00:00:00 2001 From: Hongyu Shi Date: Tue, 16 Sep 2025 11:52:39 +0800 Subject: [PATCH 1/4] =?UTF-8?q?feat(validator):=20=E6=A3=80=E6=B5=8B?= =?UTF-8?q?=E5=A4=A7=E6=A8=A1=E5=9E=8B=E6=94=AF=E6=8C=81=E7=9A=84=E5=B7=A5?= =?UTF-8?q?=E5=85=B7=E8=B0=83=E7=94=A8=E7=B1=BB=E5=9E=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Hongyu Shi --- src/tool/validators.py | 354 ++++++++++++++---- .../deployment/test_validate_llm_config.py | 21 +- 2 files changed, 303 insertions(+), 72 deletions(-) diff --git a/src/tool/validators.py b/src/tool/validators.py index 9536002..ab41f00 100644 --- a/src/tool/validators.py +++ b/src/tool/validators.py @@ -4,6 +4,7 @@ 提供实际 API 调用验证配置的有效性。 """ +import json from typing import Any import httpx @@ -61,8 +62,13 @@ class APIValidator: await client.close() return False, chat_msg, {} - # 测试 function_call 支持 - func_valid, func_msg, func_info = await self._test_function_call(client, model, max_tokens, temperature) + # 测试 function_call 支持并检测类型 + func_valid, func_msg, func_type = await self._detect_function_call_type( + client, + model, + max_tokens, + temperature, + ) await client.close() @@ -75,14 +81,18 @@ class APIValidator: else: success_msg = "LLM 配置验证成功" if func_valid: - success_msg += " - 支持 function_call" + success_msg += f" - 支持 function_call,类型: {func_type}" else: success_msg += f" - 不支持 function_call: {func_msg}" - return True, success_msg, { - "supports_function_call": func_valid, - "function_call_info": func_info, - } + return ( + True, + success_msg, + { + "supports_function_call": func_valid, + "type": func_type, + }, + ) async def validate_embedding_config( self, @@ -124,11 +134,15 @@ class APIValidator: if response.data and len(response.data) > 0: embedding = response.data[0].embedding dimension = len(embedding) - return True, f"Embedding 配置验证成功 - 维度: {dimension}", { - "model": model, - "dimension": dimension, - "sample_embedding_length": len(embedding), - } + return ( + True, + f"Embedding 配置验证成功 - 维度: {dimension}", + { + "model": model, + "dimension": dimension, + "sample_embedding_length": len(embedding), + }, + ) return False, "Embedding 响应为空", {} @@ -161,28 +175,102 @@ class APIValidator: return False, "对话响应为空" - async def _test_function_call( + async def _detect_function_call_type( self, client: AsyncOpenAI, model: str, max_tokens: int | None = None, temperature: float | None = None, - ) -> tuple[bool, str, dict[str, Any]]: - """测试 function_call 支持""" + ) -> tuple[bool, str, str]: + """ + 检测并测试不同类型的 function_call 支持 + + 按照以下顺序尝试: + 1. OpenAI 标准 function_call 格式 + 2. OpenAI tools 格式 + 3. vLLM 特有格式 + 4. Ollama 特有格式 + + Returns: + tuple[bool, str, str]: (是否支持, 详细消息, 格式类型) + + """ + # 尝试 OpenAI tools 格式 + tools_valid, tools_msg = await self._test_tools_format( + client, + model, + max_tokens, + temperature, + ) + if tools_valid: + return True, tools_msg, "function_call" + + # 尝试 structured_output 格式 + structured_valid, structured_msg = await self._test_structured_output( + client, + model, + max_tokens, + temperature, + ) + if structured_valid: + return True, structured_msg, "structured_output" + + # 尝试 json_mode 格式 + json_mode_valid, json_mode_msg = await self._test_json_mode( + client, + model, + max_tokens, + temperature, + ) + if json_mode_valid: + return True, json_mode_msg, "json_mode" + + # 尝试 vLLM 格式 + vllm_valid, vllm_msg = await self._test_vllm_function_call( + client, + model, + max_tokens, + temperature, + ) + if vllm_valid: + return True, vllm_msg, "vllm" + + # 尝试 Ollama 格式 + ollama_valid, ollama_msg = await self._test_ollama_function_call( + client, + model, + max_tokens, + temperature, + ) + if ollama_valid: + return True, ollama_msg, "ollama" + + return False, "不支持任何 function_call 格式", "none" + + async def _test_tools_format( + self, + client: AsyncOpenAI, + model: str, + max_tokens: int | None = None, + temperature: float | None = None, + ) -> tuple[bool, str]: + """测试新版 tools 格式的 function calling""" try: - # 定义一个简单的测试函数 - test_function = { - "name": "get_current_time", - "description": "获取当前时间", - "parameters": {"type": "object", "properties": {}, "required": []}, + test_tool = { + "type": "function", + "function": { + "name": "get_current_time", + "description": "获取当前时间", + "parameters": {"type": "object", "properties": {}, "required": []}, + }, } # 构建请求参数 call_kwargs = { "model": model, "messages": [{"role": "user", "content": "请调用函数获取当前时间"}], - "functions": [test_function], # type: ignore[arg-type] - "function_call": "auto", + "tools": [test_tool], # type: ignore[arg-type] + "tool_choice": "auto", "max_tokens": max_tokens if max_tokens is not None else 50, } @@ -192,75 +280,215 @@ class APIValidator: response = await client.chat.completions.create(**call_kwargs) except (AuthenticationError, APIError, OpenAIError) as e: - # 如果 functions 参数不支持,尝试 tools 格式 - if "functions" in str(e).lower() or "function_call" in str(e).lower(): - return await self._test_tools_format(client, model, max_tokens, temperature) - return False, f"function_call 测试失败: {e!s}", {"supports_functions": False} + return False, f"tools 格式测试失败: {e!s}" else: if response.choices and len(response.choices) > 0: choice = response.choices[0] - if hasattr(choice.message, "function_call") and choice.message.function_call: - return True, "支持 function_call", { - "function_name": choice.message.function_call.name, - "supports_functions": True, - } + if hasattr(choice.message, "tool_calls") and choice.message.tool_calls: + return True, "支持 tools 格式的 function_call" + + return False, "不支持 function_call 功能" + + async def _test_structured_output( + self, + client: AsyncOpenAI, + model: str, + max_tokens: int | None = None, + temperature: float | None = None, + ) -> tuple[bool, str]: + """测试 structured_output 格式的 JSON 输出""" + try: + test_schema = { + "type": "object", + "properties": { + "status": {"type": "string"}, + "timestamp": {"type": "string"}, + }, + "required": ["status"], + "additionalProperties": False, + } - # 尝试 tools 格式(OpenAI API 新版本) - return await self._test_tools_format(client, model, max_tokens, temperature) + call_kwargs = { + "model": model, + "messages": [{"role": "user", "content": "请返回状态信息的JSON"}], + "response_format": { + "type": "json_schema", + "json_schema": { + "name": "status_response", + "description": "Status response in JSON format", + "schema": test_schema, + "strict": True, + }, + }, + "max_tokens": max_tokens if max_tokens is not None else 50, + } - return False, "function_call 响应为空", {"supports_functions": False} + if temperature is not None: + call_kwargs["temperature"] = temperature - async def _test_tools_format( + response = await client.chat.completions.create(**call_kwargs) + except (AuthenticationError, APIError, OpenAIError) as e: + return False, f"structured_output 格式测试失败: {e!s}" + else: + if response.choices and len(response.choices) > 0: + choice = response.choices[0] + if hasattr(choice.message, "content") and choice.message.content: + try: + json.loads(choice.message.content) + except (json.JSONDecodeError, ValueError): + return False, "structured_output 响应不是有效 JSON" + else: + return True, "支持 structured_output 格式" + + return False, "structured_output 响应为空" + + async def _test_json_mode( self, client: AsyncOpenAI, model: str, max_tokens: int | None = None, temperature: float | None = None, - ) -> tuple[bool, str, dict[str, Any]]: - """测试新版 tools 格式的 function calling""" + ) -> tuple[bool, str]: + """测试 json_mode 格式的 JSON 输出""" try: - test_tool = { - "type": "function", - "function": { - "name": "get_current_time", - "description": "获取当前时间", - "parameters": {"type": "object", "properties": {}, "required": []}, + call_kwargs = { + "model": model, + "messages": [ + {"role": "system", "content": "你必须返回有效的JSON格式"}, + {"role": "user", "content": "请返回包含status字段的JSON对象"}, + ], + "response_format": {"type": "json_object"}, + "max_tokens": max_tokens if max_tokens is not None else 50, + } + + if temperature is not None: + call_kwargs["temperature"] = temperature + + response = await client.chat.completions.create(**call_kwargs) + except (AuthenticationError, APIError, OpenAIError) as e: + return False, f"json_mode 格式测试失败: {e!s}" + else: + if response.choices and len(response.choices) > 0: + choice = response.choices[0] + if hasattr(choice.message, "content") and choice.message.content: + try: + json.loads(choice.message.content) + except (json.JSONDecodeError, ValueError): + return False, "json_mode 响应不是有效 JSON" + else: + return True, "支持 json_mode 格式" + + return False, "json_mode 响应为空" + + async def _test_vllm_function_call( + self, + client: AsyncOpenAI, + model: str, + max_tokens: int | None = None, + temperature: float | None = None, + ) -> tuple[bool, str]: + """测试 vLLM 特有的 guided_json 格式""" + try: + # vLLM 支持 guided_json 参数来强制 JSON 输出 + test_schema = { + "type": "object", + "properties": { + "status": {"type": "string"}, + "message": {"type": "string"}, }, + "required": ["status"], } - # 构建请求参数 call_kwargs = { "model": model, - "messages": [{"role": "user", "content": "请调用函数获取当前时间"}], - "tools": [test_tool], # type: ignore[arg-type] - "tool_choice": "auto", + "messages": [{"role": "user", "content": "请返回状态信息的JSON"}], "max_tokens": max_tokens if max_tokens is not None else 50, + "extra_body": {"guided_json": test_schema}, } - # 只有当 temperature 不为 None 时才添加到参数中 if temperature is not None: call_kwargs["temperature"] = temperature response = await client.chat.completions.create(**call_kwargs) + + if response.choices and len(response.choices) > 0: + choice = response.choices[0] + content = getattr(choice.message, "content", "") + + if content: + try: + json.loads(content) + except (json.JSONDecodeError, ValueError): + # 如果不是有效 JSON,可能不支持 guided_json + pass + else: + return True, "支持 vLLM guided_json 格式" + + # 检查是否包含结构化输出的迹象 + if content and any(keyword in content.lower() for keyword in ["json", "{", "}"]): + return True, "支持 vLLM 结构化输出(部分支持)" + except (AuthenticationError, APIError, OpenAIError) as e: - return False, f"tools 格式测试失败: {e!s}", {"supports_functions": False} + error_str = str(e).lower() + if any(keyword in error_str for keyword in ["extra_body", "guided_json", "not supported"]): + return False, f"不支持 vLLM guided_json 格式: {e!s}" + raise + else: + return False, "vLLM guided_json 响应无效" + + async def _test_ollama_function_call( + self, + client: AsyncOpenAI, + model: str, + max_tokens: int | None = None, + temperature: float | None = None, + ) -> tuple[bool, str]: + """测试 Ollama 特有的 function calling 格式""" + try: + # Ollama 对 function calling 的支持可能有限 + # 通常通过特殊的 prompt 格式来实现 + + # 尝试使用结构化 prompt 来测试 function calling + structured_prompt = """ +你是一个助手,可以调用函数。当需要调用函数时,请按以下格式回复: +FUNCTION_CALL: get_current_time() + +现在请调用 get_current_time 函数获取当前时间。 +""" + + call_kwargs = { + "model": model, + "messages": [{"role": "user", "content": structured_prompt}], + "max_tokens": max_tokens if max_tokens is not None else 100, + } + + if temperature is not None: + call_kwargs["temperature"] = temperature + + response = await client.chat.completions.create(**call_kwargs) + if response.choices and len(response.choices) > 0: choice = response.choices[0] - if hasattr(choice.message, "tool_calls") and choice.message.tool_calls: - tool_call = choice.message.tool_calls[0] - # 安全地访问 function 属性 - function_name = "" - function_obj = getattr(tool_call, "function", None) - if function_obj and hasattr(function_obj, "name"): - function_name = function_obj.name - return True, "支持 tools 格式的 function_call", { - "function_name": function_name, - "supports_functions": True, - "format": "tools", - } - - return False, "不支持 function_call 功能", {"supports_functions": False} + content = getattr(choice.message, "content", "") + + # 检查 Ollama 可能的函数调用响应格式 + if content and any( + keyword in content + for keyword in [ + "FUNCTION_CALL:", + "get_current_time", + "function", + "call", + ] + ): + return True, "支持 Ollama function_call 格式" + + except (AuthenticationError, APIError, OpenAIError) as e: + return False, f"不支持 Ollama function_call 格式: {e!s}" + + else: + return False, "Ollama function_call 响应无效" async def validate_oi_connection(base_url: str, access_token: str) -> tuple[bool, str]: # noqa: PLR0911 diff --git a/tests/app/deployment/test_validate_llm_config.py b/tests/app/deployment/test_validate_llm_config.py index 77f2e2e..2ee0228 100644 --- a/tests/app/deployment/test_validate_llm_config.py +++ b/tests/app/deployment/test_validate_llm_config.py @@ -2,13 +2,19 @@ API 配置验证功能演示 简单演示如何使用新的验证功能。 -使用方法: source .venv/bin/activate && PYTHONPATH=src python tests/app/deployment/test_validate_llm_config.py +使用方法: source .venv/bin/activate && python tests/app/deployment/test_validate_llm_config.py """ import asyncio import sys from typing import Any +# 添加 src 目录到 Python 路径 +sys.path.insert(0, "src") + +# 为了避免循环导入,我们需要在导入 app.deployment.models 之前 +# 先确保 tool.validators 可以被正确导入,但不触发 tool.__init__.py 中的其他导入 +import tool.validators # noqa: F401 # 直接导入 validators,避免通过 tool.__init__.py from app.deployment.models import DeploymentConfig, EmbeddingConfig, LLMConfig @@ -24,16 +30,13 @@ def _output_llm_validation_info(llm_info: dict[str, Any]) -> None: if llm_info.get("supports_function_call"): _output(" 🔧 Function Call: ✅ 支持") - if "function_call_info" in llm_info: - format_type = llm_info["function_call_info"].get("format", "unknown") + # 显示检测到的类型 + if "type" in llm_info: + format_type = llm_info["type"] _output(f" 📋 支持格式: {format_type}") else: _output(" 🔧 Function Call: ❌ 不支持") - if "available_models" in llm_info: - models = llm_info["available_models"][:3] - _output(f" 📦 可用模型示例: {', '.join(models)}") - def _output_embedding_validation_info(embed_info: dict[str, Any]) -> None: """输出 Embedding 验证信息""" @@ -53,7 +56,7 @@ async def main() -> None: llm=LLMConfig( endpoint="http://127.0.0.1:1234/v1", api_key="lm-studio", - model="qwen/qwen3-30b-a3b-2507", + model="qwen/qwen3-coder-30b", max_tokens=4096, temperature=0.7, request_timeout=30, @@ -123,7 +126,7 @@ async def main() -> None: if __name__ == "__main__": _output("🚀 开始演示...") _output("💡 运行方法: ") - _output("💡 source .venv/bin/activate && PYTHONPATH=src python tests/app/deployment/test_validate_llm_config.py") + _output("💡 source .venv/bin/activate && python tests/app/deployment/test_validate_llm_config.py") _output() asyncio.run(main()) -- Gitee From d8e6af5c1b45a07026766a76ea4c59d3335196f8 Mon Sep 17 00:00:00 2001 From: Hongyu Shi Date: Tue, 16 Sep 2025 14:12:26 +0800 Subject: [PATCH 2/4] =?UTF-8?q?feat:=20=E8=87=AA=E5=8A=A8=E8=AF=86?= =?UTF-8?q?=E5=88=AB=E5=B9=B6=E5=A1=AB=E5=85=85=20function=20call=20?= =?UTF-8?q?=E7=B1=BB=E5=9E=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Hongyu Shi --- src/app/deployment/models.py | 7 +++++++ src/app/deployment/service.py | 2 +- src/tool/oi_llm_config.py | 10 +++++++++- src/tool/validators.py | 2 +- 4 files changed, 18 insertions(+), 3 deletions(-) diff --git a/src/app/deployment/models.py b/src/app/deployment/models.py index 1de0b5f..24b15bf 100644 --- a/src/app/deployment/models.py +++ b/src/app/deployment/models.py @@ -77,6 +77,9 @@ class DeploymentConfig: enable_web: bool = False enable_rag: bool = False + # 检测到的后端类型(从 API 验证中获得) + detected_backend_type: str = "function_call" # 默认值 + def validate(self) -> tuple[bool, list[str]]: """ 验证配置的有效性 @@ -124,6 +127,10 @@ class DeploymentConfig: self.llm.request_timeout, ) + # 如果验证成功,保存检测到的后端类型 + if llm_valid and llm_info.get("supports_function_call", False): + self.detected_backend_type = llm_info.get("detected_function_call_type", "function_call") + return llm_valid, llm_msg, llm_info async def validate_embedding_connectivity(self) -> tuple[bool, str, dict]: diff --git a/src/app/deployment/service.py b/src/app/deployment/service.py index 59f4547..5b12819 100644 --- a/src/app/deployment/service.py +++ b/src/app/deployment/service.py @@ -115,7 +115,7 @@ class DeploymentResourceManager: # 更新 function_call 配置 if "function_call" in toml_data: - toml_data["function_call"]["backend"] = "function_call" + toml_data["function_call"]["backend"] = config.detected_backend_type toml_data["function_call"]["endpoint"] = config.llm.endpoint toml_data["function_call"]["api_key"] = config.llm.api_key toml_data["function_call"]["model"] = config.llm.model diff --git a/src/tool/oi_llm_config.py b/src/tool/oi_llm_config.py index 7e51713..4da7edd 100644 --- a/src/tool/oi_llm_config.py +++ b/src/tool/oi_llm_config.py @@ -46,6 +46,7 @@ class LLMSystemConfig: llm: LLMConfig = field(default_factory=LLMConfig) embedding: EmbeddingConfig = field(default_factory=EmbeddingConfig) + detected_function_call_type: str = field(default="function_call") @classmethod def check_prerequisites(cls) -> tuple[bool, list[str]]: @@ -207,7 +208,7 @@ class LLMSystemConfig: return False, "LLM API 端点不能为空", {} validator = APIValidator() - return await validator.validate_llm_config( + is_valid, message, info = await validator.validate_llm_config( self.llm.endpoint, self.llm.api_key, self.llm.model, @@ -216,6 +217,12 @@ class LLMSystemConfig: self.llm.temperature, # 传递温度参数 ) + # 保存检测到的 function call 类型 + if is_valid and info.get("supports_function_call", False): + self.detected_function_call_type = info.get("detected_function_call_type", "function_call") + + return is_valid, message, info + async def validate_embedding_connectivity(self) -> tuple[bool, str, dict]: """ 验证 Embedding API 连接性 @@ -353,6 +360,7 @@ class LLMSystemConfig: data["function_call"] = {} data["function_call"].update( { + "backend": self.detected_function_call_type, "endpoint": self.llm.endpoint, "key": self.llm.api_key, "model": self.llm.model, diff --git a/src/tool/validators.py b/src/tool/validators.py index ab41f00..f06b4fc 100644 --- a/src/tool/validators.py +++ b/src/tool/validators.py @@ -90,7 +90,7 @@ class APIValidator: success_msg, { "supports_function_call": func_valid, - "type": func_type, + "detected_function_call_type": func_type, }, ) -- Gitee From f81331f3aecf1f3746cc3286529311c2b220c4ed Mon Sep 17 00:00:00 2001 From: Hongyu Shi Date: Tue, 16 Sep 2025 15:03:35 +0800 Subject: [PATCH 3/4] =?UTF-8?q?feat:=20=E8=87=AA=E5=8A=A8=E6=A3=80?= =?UTF-8?q?=E6=B5=8B=20Embedding=20=E6=8E=A5=E5=8F=A3=E7=B1=BB=E5=9E=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Hongyu Shi --- src/app/deployment/models.py | 8 ++- src/tool/oi_llm_config.py | 10 ++- src/tool/validators.py | 131 +++++++++++++++++++++++++++-------- 3 files changed, 118 insertions(+), 31 deletions(-) diff --git a/src/app/deployment/models.py b/src/app/deployment/models.py index 24b15bf..4fe94c3 100644 --- a/src/app/deployment/models.py +++ b/src/app/deployment/models.py @@ -49,7 +49,7 @@ class EmbeddingConfig: 包含嵌入模型的配置信息。 """ - type: str = "openai" + type: str = "" # 可选值: "openai", "mindie" endpoint: str = "" api_key: str = "" model: str = "" @@ -156,6 +156,12 @@ class DeploymentConfig: self.llm.request_timeout, # 使用相同的超时设置 ) + # 如果验证成功,保存检测到的 embedding 类型 + if embed_valid and embed_info.get("type"): + detected_type = embed_info.get("type") + if detected_type in ("openai", "mindie"): + self.embedding.type = detected_type + return embed_valid, embed_msg, embed_info def _validate_basic_fields(self) -> list[str]: diff --git a/src/tool/oi_llm_config.py b/src/tool/oi_llm_config.py index 4da7edd..d058d57 100644 --- a/src/tool/oi_llm_config.py +++ b/src/tool/oi_llm_config.py @@ -235,13 +235,21 @@ class LLMSystemConfig: return False, "Embedding API 端点不能为空", {} validator = APIValidator() - return await validator.validate_embedding_config( + is_valid, message, info = await validator.validate_embedding_config( self.embedding.endpoint, self.embedding.api_key, self.embedding.model, 300, # 使用默认超时时间 300 秒 ) + # 如果验证成功,保存检测到的 embedding 类型 + if is_valid and info.get("type"): + detected_type = info.get("type") + if detected_type in ("openai", "mindie"): + self.embedding.type = detected_type + + return is_valid, message, info + def _load_from_toml(self) -> None: """ 从 TOML 文件加载配置 diff --git a/src/tool/validators.py b/src/tool/validators.py index f06b4fc..b261e42 100644 --- a/src/tool/validators.py +++ b/src/tool/validators.py @@ -114,37 +114,29 @@ class APIValidator: tuple[bool, str, dict]: (是否验证成功, 错误/成功消息, 额外信息) """ - self.logger.info("开始验证 Embedding 配置 - 端点: %s, 模型: %s", endpoint, model) + self.logger.info("开始验证 Embedding 配置 - 端点: %s", endpoint) - try: - client = AsyncOpenAI(api_key=api_key, base_url=endpoint, timeout=timeout) - - # 测试 embedding 功能 - test_text = "这是一个测试文本" - response = await client.embeddings.create(input=test_text, model=model) - - await client.close() - except TimeoutError: - return False, f"连接超时 - 无法在 {timeout} 秒内连接到 {endpoint}", {} - except (AuthenticationError, APIError, OpenAIError) as e: - error_msg = f"Embedding 配置验证失败: {e!s}" - self.logger.exception(error_msg) - return False, error_msg, {} - else: - if response.data and len(response.data) > 0: - embedding = response.data[0].embedding - dimension = len(embedding) - return ( - True, - f"Embedding 配置验证成功 - 维度: {dimension}", - { - "model": model, - "dimension": dimension, - "sample_embedding_length": len(embedding), - }, - ) + # 首先尝试 OpenAI 格式 + openai_success, openai_msg, openai_info = await self._validate_openai_embedding( + endpoint, + api_key, + model, + timeout, + ) + if openai_success: + return True, openai_msg, openai_info + + # 如果 OpenAI 格式失败,尝试 MindIE 格式 + mindie_success, mindie_msg, mindie_info = await self._validate_mindie_embedding( + endpoint, + api_key, + timeout, + ) + if mindie_success: + return True, mindie_msg, mindie_info - return False, "Embedding 响应为空", {} + # 两种格式都失败 + return False, "无法连接到 Embedding 模型服务。", {} async def _test_basic_chat( self, @@ -490,6 +482,87 @@ FUNCTION_CALL: get_current_time() else: return False, "Ollama function_call 响应无效" + async def _validate_openai_embedding( + self, + endpoint: str, + api_key: str, + model: str, + timeout: int = 30, # noqa: ASYNC109 + ) -> tuple[bool, str, dict[str, Any]]: + """验证 OpenAI 格式的 embedding 配置""" + try: + client = AsyncOpenAI(api_key=api_key, base_url=endpoint, timeout=timeout) + + # 测试 embedding 功能 + test_text = "这是一个测试文本" + response = await client.embeddings.create(input=test_text, model=model) + + await client.close() + except TimeoutError: + return False, f"连接超时 - 无法在 {timeout} 秒内连接到 {endpoint}", {} + except (AuthenticationError, APIError, OpenAIError) as e: + error_msg = f"OpenAI Embedding 配置验证失败: {e!s}" + self.logger.exception(error_msg) + return False, error_msg, {} + else: + if response.data and len(response.data) > 0: + embedding = response.data[0].embedding + dimension = len(embedding) + return ( + True, + f"OpenAI Embedding 配置验证成功 - 维度: {dimension}", + { + "type": "openai", + "dimension": dimension, + "sample_embedding_length": len(embedding), + }, + ) + + return False, "OpenAI Embedding 响应为空", {} + + async def _validate_mindie_embedding( + self, + endpoint: str, + api_key: str, + timeout: int = 30, # noqa: ASYNC109 + ) -> tuple[bool, str, dict[str, Any]]: + """验证 MindIE (TEI) 格式的 embedding 配置""" + try: + embed_endpoint = endpoint.rstrip("/") + "/embed" + headers = {"Content-Type": "application/json"} + if api_key: + headers["Authorization"] = f"Bearer {api_key}" + + data = {"inputs": "这是一个测试文本", "normalize": True} + + async with httpx.AsyncClient(timeout=timeout) as client: + response = await client.post(embed_endpoint, json=data, headers=headers) + + if response.status_code == HTTP_OK: + json_response = response.json() + if isinstance(json_response, list) and len(json_response) > 0: + embedding = json_response[0] + if isinstance(embedding, list) and len(embedding) > 0: + dimension = len(embedding) + return ( + True, + f"MindIE Embedding 配置验证成功 - 维度: {dimension}", + { + "type": "mindie", + "dimension": dimension, + "sample_embedding_length": len(embedding), + }, + ) + + return False, "MindIE Embedding 响应格式不正确", {} + + except httpx.TimeoutException: + return False, f"连接超时 - 无法在 {timeout} 秒内连接到 {endpoint}", {} + except (httpx.RequestError, httpx.HTTPStatusError) as e: + error_msg = f"MindIE Embedding 配置验证失败: {e!s}" + self.logger.exception(error_msg) + return False, error_msg, {} + async def validate_oi_connection(base_url: str, access_token: str) -> tuple[bool, str]: # noqa: PLR0911 """ -- Gitee From d80e082917580ed595d191fbbf55116ba05f6218 Mon Sep 17 00:00:00 2001 From: Hongyu Shi Date: Tue, 16 Sep 2025 17:47:29 +0800 Subject: [PATCH 4/4] =?UTF-8?q?feat:=20=E6=B7=BB=E5=8A=A0=E9=AA=8C?= =?UTF-8?q?=E8=AF=81=E7=8A=B6=E6=80=81=E8=B7=9F=E8=B8=AA=EF=BC=8C=E6=9B=B4?= =?UTF-8?q?=E6=96=B0=20LLM=20=E5=92=8C=20Embedding=20=E9=85=8D=E7=BD=AE?= =?UTF-8?q?=E9=AA=8C=E8=AF=81=E9=80=BB=E8=BE=91?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Hongyu Shi --- src/app/deployment/models.py | 2 +- src/app/deployment/ui.py | 288 ++++++++++++++++++++--------------- src/tool/oi_llm_config.py | 140 +++++++++++++++-- src/tool/validators.py | 10 +- 4 files changed, 300 insertions(+), 140 deletions(-) diff --git a/src/app/deployment/models.py b/src/app/deployment/models.py index 4fe94c3..29d8098 100644 --- a/src/app/deployment/models.py +++ b/src/app/deployment/models.py @@ -96,7 +96,7 @@ class DeploymentConfig: # 验证 LLM 字段 errors.extend(self._validate_llm_fields()) - # 验证 Embedding 字段(根据部署模式决定是否必须) + # 验证 Embedding 字段 errors.extend(self._validate_embedding_fields()) # 验证数值范围 diff --git a/src/app/deployment/ui.py b/src/app/deployment/ui.py index e764120..829c03e 100644 --- a/src/app/deployment/ui.py +++ b/src/app/deployment/ui.py @@ -8,6 +8,7 @@ from __future__ import annotations import asyncio import contextlib +from enum import Enum from typing import TYPE_CHECKING from rich.errors import MarkupError @@ -34,6 +35,16 @@ from .service import DeploymentService FULL_PROGRESS = 100 +class ValidationStatus(Enum): + """验证状态枚举""" + + PENDING = "pending" + VALIDATING = "validating" + VALID = "valid" + INVALID = "invalid" + NOT_REQUIRED = "not_required" + + class DeploymentConfigScreen(ModalScreen[bool]): """ 部署配置屏幕 @@ -111,9 +122,14 @@ class DeploymentConfigScreen(ModalScreen[bool]): """初始化部署配置屏幕""" super().__init__() self.config = DeploymentConfig() + self._llm_validation_task: asyncio.Task[None] | None = None self._embedding_validation_task: asyncio.Task[None] | None = None + # 验证状态跟踪 + self.llm_validation_status: ValidationStatus = ValidationStatus.PENDING + self.embedding_validation_status: ValidationStatus = ValidationStatus.PENDING + def compose(self) -> ComposeResult: """组合界面组件""" with Container(classes="config-container"): @@ -133,6 +149,48 @@ class DeploymentConfigScreen(ModalScreen[bool]): yield Button("开始部署", id="deploy", variant="success") yield Button("取消", id="cancel", variant="error") + async def on_mount(self) -> None: + """界面挂载时初始化状态""" + # 根据配置初始化 UI 状态 + self._sync_ui_from_config() + # 初始化验证状态 + self._initialize_validation_status() + + def _sync_ui_from_config(self) -> None: + """根据配置同步 UI 状态""" + # 根据配置更新部署模式按钮显示 + try: + btn = self.query_one("#deployment_mode_btn", Button) + desc = self.query_one("#deployment_mode_desc", Static) + + if self.config.deployment_mode == "full": + btn.label = "全量部署" + desc.update("全量部署:部署框架服务 + Web 界面 + RAG 组件,自动初始化 Agent。") + else: + btn.label = "轻量部署" + desc.update("轻量部署:仅部署框架服务,自动初始化 Agent。") + + except (ValueError, AttributeError): + # 如果 UI 组件还没初始化完成,忽略错误 + pass + + def _initialize_validation_status(self) -> None: + """初始化验证状态""" + # 初始化 Embedding 验证状态 + if self._is_embedding_required(): + self.embedding_validation_status = ValidationStatus.PENDING + else: + self.embedding_validation_status = ValidationStatus.NOT_REQUIRED + # 如果不需要验证 Embedding,显示相应状态 + try: + embedding_status = self.query_one("#embedding_validation_status", Static) + embedding_status.update("[dim]不需要验证[/dim]") + except (ValueError, AttributeError): + pass + + # 更新部署按钮状态 + self._update_deploy_button_state() + def _compose_basic_config(self) -> ComposeResult: """组合基础配置组件""" with Vertical(): @@ -195,7 +253,7 @@ class DeploymentConfigScreen(ModalScreen[bool]): yield Static("未验证", id="llm_validation_status", classes="form-input") with Horizontal(classes="form-row"): - yield Label("最大 Token 数:", classes="form-label") + yield Label("最大输出令牌数:", classes="form-label") yield Input( value="8192", id="llm_max_tokens", @@ -271,68 +329,6 @@ class DeploymentConfigScreen(ModalScreen[bool]): ) return - # LLM function call 能力验证 - llm_valid, llm_message, llm_info = await self.config.validate_llm_connectivity() - if not llm_valid: - await self.app.push_screen( - ErrorMessageScreen( - "LLM 配置验证失败", - [f"LLM 配置无效:{llm_message}"], - ), - ) - return - - # 检查 LLM 是否支持 function call - if not llm_info.get("supports_function_call", False): - await self.app.push_screen( - ErrorMessageScreen( - "LLM 功能不满足要求", - [ - "所选的 LLM 模型不支持 function call 功能,无法继续部署。", - "请选择支持 function call 的模型(如 OpenAI GPT 系列、通义千问、DeepSeek 等)。", - ], - ), - ) - return - - # 轻量部署模式下的 Embedding 验证 - if self.config.deployment_mode == "light": - # 检查是否填写了 Embedding 配置 - has_embedding = any( - [ - self.config.embedding.endpoint.strip(), - self.config.embedding.api_key.strip(), - self.config.embedding.model.strip(), - ], - ) - - if has_embedding: - # 如果填了 Embedding 配置,需要验证连通性 - embed_valid, embed_message, _ = await self.config.validate_embedding_connectivity() - if not embed_valid: - await self.app.push_screen( - ErrorMessageScreen( - "Embedding 配置验证失败", - [ - f"Embedding 配置无效:{embed_message}", - "轻量部署模式下,如果填写了 Embedding 配置,必须确保配置正确。", - "您可以选择清空 Embedding 配置字段来跳过此验证。", - ], - ), - ) - return - else: - # 全量部署模式下,必须验证 Embedding - embed_valid, embed_message, _ = await self.config.validate_embedding_connectivity() - if not embed_valid: - await self.app.push_screen( - ErrorMessageScreen( - "Embedding 配置验证失败", - [f"Embedding 配置无效:{embed_message}"], - ), - ) - return - # 所有验证通过,开始部署 await self.app.push_screen(DeploymentProgressScreen(self.config)) @@ -344,25 +340,23 @@ class DeploymentConfigScreen(ModalScreen[bool]): @on(Button.Pressed, "#deployment_mode_btn") def on_deployment_mode_btn_pressed(self) -> None: - """切换部署模式按钮:在轻量和全量之间切换,更新按钮文本和描述。""" - try: - btn = self.query_one("#deployment_mode_btn", Button) - desc = self.query_one("#deployment_mode_desc", Static) + """切换部署模式按钮:在轻量和全量之间切换""" + # 基于当前配置状态切换,而不是按钮文本 + if self.config.deployment_mode == "light": + # 切换到全量部署 + self.config.deployment_mode = "full" + else: + # 切换到轻量部署 + self.config.deployment_mode = "light" - # 如果当前为轻量,则切换到全量 - if btn.label and "轻量" in str(btn.label): - btn.label = "全量部署" - desc.update("全量部署:部署框架服务、Web 界面和 RAG 组件,需手动配置 Agent。") - # 更新 Embedding 配置提示 - self._update_embedding_hint(is_light_mode=False) - else: - btn.label = "轻量部署" - desc.update("轻量部署:仅部署框架服务,自动初始化 Agent。") - # 更新 Embedding 配置提示 - self._update_embedding_hint(is_light_mode=True) - except (AttributeError, ValueError): - # 查询失败或属性错误时忽略 - return + # 同步 UI 状态 + self._sync_ui_from_config() + + # 更新 Embedding 配置提示 + self._update_embedding_hint(is_light_mode=(self.config.deployment_mode == "light")) + + # 重新初始化验证状态(部署模式变化可能影响 Embedding 验证需求) + self._initialize_validation_status() def _update_embedding_hint(self, *, is_light_mode: bool) -> None: """更新 Embedding 配置提示信息""" @@ -383,11 +377,17 @@ class DeploymentConfigScreen(ModalScreen[bool]): @on(Input.Changed, "#llm_endpoint, #llm_api_key, #llm_model") async def on_llm_field_changed(self, event: Input.Changed) -> None: """处理 LLM 字段变化,检查是否需要自动验证""" + # 重置 LLM 验证状态 + self.llm_validation_status = ValidationStatus.PENDING + # 取消之前的验证任务 if self._llm_validation_task and not self._llm_validation_task.done(): self._llm_validation_task.cancel() - # 检查是否所有核心字段都已填写 + # 更新部署按钮状态 + self._update_deploy_button_state() + + # 检查是否需要验证 if self._should_validate_llm(): # 延迟 1 秒后进行验证,避免用户快速输入时频繁触发 self._llm_validation_task = asyncio.create_task(self._delayed_llm_validation()) @@ -395,11 +395,20 @@ class DeploymentConfigScreen(ModalScreen[bool]): @on(Input.Changed, "#embedding_endpoint, #embedding_api_key, #embedding_model") async def on_embedding_field_changed(self, event: Input.Changed) -> None: """处理 Embedding 字段变化,检查是否需要自动验证""" + # 重置 Embedding 验证状态 + if self._is_embedding_required(): + self.embedding_validation_status = ValidationStatus.PENDING + else: + self.embedding_validation_status = ValidationStatus.NOT_REQUIRED + # 取消之前的验证任务 if self._embedding_validation_task and not self._embedding_validation_task.done(): self._embedding_validation_task.cancel() - # 检查是否所有核心字段都已填写 + # 更新部署按钮状态 + self._update_deploy_button_state() + + # 检查是否需要验证 if self._should_validate_embedding(): # 延迟 1 秒后进行验证,避免用户快速输入时频繁触发 self._embedding_validation_task = asyncio.create_task(self._delayed_embedding_validation()) @@ -407,36 +416,14 @@ class DeploymentConfigScreen(ModalScreen[bool]): def _should_validate_llm(self) -> bool: """检查是否应该验证 LLM 配置""" try: - endpoint = self.query_one("#llm_endpoint", Input).value.strip() - api_key = self.query_one("#llm_api_key", Input).value.strip() - model = self.query_one("#llm_model", Input).value.strip() - return bool(endpoint and api_key and model) + return bool(self.query_one("#llm_endpoint", Input).value.strip()) except (AttributeError, ValueError): return False def _should_validate_embedding(self) -> bool: """检查是否应该验证 Embedding 配置""" try: - endpoint = self.query_one("#embedding_endpoint", Input).value.strip() - api_key = self.query_one("#embedding_api_key", Input).value.strip() - model = self.query_one("#embedding_model", Input).value.strip() - - # 检查部署模式 - try: - btn = self.query_one("#deployment_mode_btn", Button) - label = str(btn.label) if btn.label is not None else "" - is_light_mode = "轻量" in label - except (AttributeError, ValueError): - is_light_mode = True # 默认为轻量模式 - - # 轻量模式下,只有在用户填写了 Embedding 字段时才验证 - if is_light_mode: - has_embedding_config = bool(endpoint or api_key or model) - return has_embedding_config and bool(endpoint and api_key and model) - - # 全量模式下,必须验证 - return bool(endpoint and api_key and model) - + return bool(self.query_one("#embedding_endpoint", Input).value.strip()) except (AttributeError, ValueError): return False @@ -456,11 +443,62 @@ class DeploymentConfigScreen(ModalScreen[bool]): except asyncio.CancelledError: pass + def _is_embedding_required(self) -> bool: + """检查是否需要验证 Embedding 配置""" + try: + # 从配置模型获取部署模式,更准确可靠 + is_full_mode = self.config.deployment_mode == "full" + + # 全量部署模式下,Embedding 是必需的 + if is_full_mode: + return True + + # 轻量部署模式下,如果用户填写了 Embedding 配置,则需要验证 + endpoint = self.query_one("#embedding_endpoint", Input).value.strip() + api_key = self.query_one("#embedding_api_key", Input).value.strip() + model = self.query_one("#embedding_model", Input).value.strip() + return bool(endpoint or api_key or model) + + except (AttributeError, ValueError): + return False + + def _update_deploy_button_state(self) -> None: + """根据验证状态更新部署按钮状态""" + try: + deploy_button = self.query_one("#deploy", Button) + + # 检查 LLM 验证状态 + if self.llm_validation_status in ( + ValidationStatus.PENDING, + ValidationStatus.VALIDATING, + ValidationStatus.INVALID, + ): + deploy_button.disabled = True + return + + # 检查 Embedding 验证状态 + if self._is_embedding_required() and self.embedding_validation_status in ( + ValidationStatus.PENDING, + ValidationStatus.VALIDATING, + ValidationStatus.INVALID, + ): + deploy_button.disabled = True + return + + # 所有必要的验证都通过,启用部署按钮 + deploy_button.disabled = False + + except (ValueError, AttributeError): + # 如果出现异常,为安全起见禁用部署按钮 + pass + async def _validate_llm_config(self) -> None: """验证 LLM 配置""" # 更新状态为验证中 + self.llm_validation_status = ValidationStatus.VALIDATING status_widget = self.query_one("#llm_validation_status", Static) status_widget.update("[yellow]验证中...[/yellow]") + self._update_deploy_button_state() # 收集当前 LLM 配置 self._collect_llm_config() @@ -471,30 +509,39 @@ class DeploymentConfigScreen(ModalScreen[bool]): # 更新验证状态 if is_valid: - # 检查是否支持 function_call + # 检查是否支持工具调用 supports_function_call = info.get("supports_function_call", False) if supports_function_call: + self.llm_validation_status = ValidationStatus.VALID status_widget.update(f"[green]✓ {message}[/green]") - self.notify("LLM 验证成功,支持 function_call 功能", severity="information") + self.notify("LLM 验证成功,支持工具调用功能", severity="information") else: - status_widget.update("[red]✗ 不支持 function_call[/red]") + self.llm_validation_status = ValidationStatus.INVALID + status_widget.update("[red]✗ 不支持工具调用[/red]") self.notify( - "LLM 验证失败:模型不支持 function_call 功能,无法用于部署。请选择支持 function_call 的模型。", + "LLM 验证失败:模型不支持工具调用功能,无法用于部署。请选择支持工具调用的模型。", severity="error", ) else: + self.llm_validation_status = ValidationStatus.INVALID status_widget.update(f"[red]✗ {message}[/red]") self.notify(f"LLM 验证失败: {message}", severity="error") except (OSError, ValueError, TypeError) as e: + self.llm_validation_status = ValidationStatus.INVALID status_widget.update(f"[red]✗ 验证异常: {e}[/red]") self.notify(f"LLM 验证过程中出现异常: {e}", severity="error") + # 更新部署按钮状态 + self._update_deploy_button_state() + async def _validate_embedding_config(self) -> None: """验证 Embedding 配置""" # 更新状态为验证中 + self.embedding_validation_status = ValidationStatus.VALIDATING status_widget = self.query_one("#embedding_validation_status", Static) status_widget.update("[yellow]验证中...[/yellow]") + self._update_deploy_button_state() # 收集当前 Embedding 配置 self._collect_embedding_config() @@ -505,18 +552,24 @@ class DeploymentConfigScreen(ModalScreen[bool]): # 更新验证状态 if is_valid: + self.embedding_validation_status = ValidationStatus.VALID status_widget.update(f"[green]✓ {message}[/green]") # 显示维度信息 dimension = info.get("dimension", "未知") self.notify(f"Embedding 验证成功,向量维度: {dimension}", severity="information") else: + self.embedding_validation_status = ValidationStatus.INVALID status_widget.update(f"[red]✗ {message}[/red]") self.notify(f"Embedding 验证失败: {message}", severity="error") except (OSError, ValueError, TypeError) as e: + self.embedding_validation_status = ValidationStatus.INVALID status_widget.update(f"[red]✗ 验证异常: {e}[/red]") self.notify(f"Embedding 验证过程中出现异常: {e}", severity="error") + # 更新部署按钮状态 + self._update_deploy_button_state() + def _collect_llm_config(self) -> None: """收集 LLM 配置""" try: @@ -566,18 +619,11 @@ class DeploymentConfigScreen(ModalScreen[bool]): model=self.query_one("#embedding_model", Input).value.strip(), ) - # 部署选项 - 从切换按钮读取当前模式 - try: - btn = self.query_one("#deployment_mode_btn", Button) - label = str(btn.label) if btn.label is not None else "" - if "全量" in label: - self.config.deployment_mode = "full" - else: - self.config.deployment_mode = "light" - except (AttributeError, ValueError): + # 部署选项 + if not hasattr(self.config, "deployment_mode") or not self.config.deployment_mode: self.config.deployment_mode = "light" - # 根据部署模式自动设置组件启用状态 + # 根据部署模式最终设置组件启用状态 if self.config.deployment_mode == "full": self.config.enable_web = True self.config.enable_rag = True diff --git a/src/tool/oi_llm_config.py b/src/tool/oi_llm_config.py index d058d57..57766c1 100644 --- a/src/tool/oi_llm_config.py +++ b/src/tool/oi_llm_config.py @@ -12,6 +12,7 @@ import os import subprocess import sys from dataclasses import dataclass, field +from enum import Enum from pathlib import Path import toml @@ -28,6 +29,16 @@ from tool.validators import APIValidator logger = get_logger(__name__) +class ValidationStatus(Enum): + """验证状态枚举""" + + PENDING = "pending" + VALIDATING = "validating" + VALID = "valid" + INVALID = "invalid" + NOT_REQUIRED = "not_required" + + @dataclass class LLMSystemConfig: """ @@ -529,6 +540,10 @@ class LLMConfigScreen(ModalScreen[bool]): self._embedding_validation_task: asyncio.Task[None] | None = None self._background_tasks: set[asyncio.Task] = set() + # 验证状态跟踪 + self.llm_validation_status: ValidationStatus = ValidationStatus.PENDING + self.embedding_validation_status: ValidationStatus = ValidationStatus.PENDING + def compose(self) -> ComposeResult: """组合界面组件""" with Container(classes="config-container"): @@ -554,6 +569,9 @@ class LLMConfigScreen(ModalScreen[bool]): # 更新界面显示的值 self._update_form_values() + # 初始化验证状态和保存按钮状态 + self._initialize_validation_status() + except FileNotFoundError: logger.exception("核心配置文件缺失") self.notify("错误:核心配置文件不存在,请检查系统安装", severity="error") @@ -590,10 +608,16 @@ class LLMConfigScreen(ModalScreen[bool]): @on(Input.Changed, "#llm_endpoint, #llm_api_key, #llm_model, #llm_max_tokens, #llm_temperature") async def on_llm_field_changed(self, event: Input.Changed) -> None: """处理 LLM 字段变化,检查是否需要自动验证""" + # 重置 LLM 验证状态 + self.llm_validation_status = ValidationStatus.PENDING + # 取消之前的验证任务 if self._llm_validation_task and not self._llm_validation_task.done(): self._llm_validation_task.cancel() + # 更新保存按钮状态 + self._update_save_button_state() + # 检查是否所有核心字段都已填写 if self._should_validate_llm(): # 延迟验证,避免用户输入时频繁验证 @@ -602,11 +626,20 @@ class LLMConfigScreen(ModalScreen[bool]): @on(Input.Changed, "#embedding_endpoint, #embedding_api_key, #embedding_model") async def on_embedding_field_changed(self, event: Input.Changed) -> None: """处理 Embedding 字段变化,检查是否需要自动验证""" + # 重置 Embedding 验证状态 + if self._is_embedding_required(): + self.embedding_validation_status = ValidationStatus.PENDING + else: + self.embedding_validation_status = ValidationStatus.NOT_REQUIRED + # 取消之前的验证任务 if self._embedding_validation_task and not self._embedding_validation_task.done(): self._embedding_validation_task.cancel() - # 检查是否所有核心字段都已填写 + # 更新保存按钮状态 + self._update_save_button_state() + + # 检查是否需要验证 Embedding if self._should_validate_embedding(): # 延迟验证,避免用户输入时频繁验证 self._embedding_validation_task = asyncio.create_task(self._delayed_embedding_validation()) @@ -645,7 +678,7 @@ class LLMConfigScreen(ModalScreen[bool]): ) with Horizontal(classes="form-row"): - yield Label("最大令牌数:", classes="form-label") + yield Label("最大输出令牌数:", classes="form-label") yield Input( value=str(self.config.llm.max_tokens), placeholder="8192", @@ -723,25 +756,38 @@ class LLMConfigScreen(ModalScreen[bool]): self.query_one("#embedding_api_key", Input).value = self.config.embedding.api_key self.query_one("#embedding_model", Input).value = self.config.embedding.model - except (OSError, ValueError, AttributeError): + except (ValueError, AttributeError): + # 如果获取失败,记录警告并使用默认值 logger.warning("更新表单值时出现警告") + def _initialize_validation_status(self) -> None: + """初始化验证状态""" + # 初始化 Embedding 验证状态 + if self._is_embedding_required(): + self.embedding_validation_status = ValidationStatus.PENDING + else: + self.embedding_validation_status = ValidationStatus.NOT_REQUIRED + # 如果不需要验证 Embedding,显示相应状态 + try: + embedding_status = self.query_one("#embedding_validation_status", Static) + embedding_status.update("[dim]不需要验证[/dim]") + except (ValueError, AttributeError): + pass + + # 更新保存按钮状态 + self._update_save_button_state() + def _should_validate_llm(self) -> bool: """检查是否应该验证 LLM 配置""" try: - endpoint = self.query_one("#llm_endpoint", Input).value.strip() - # 只要有端点就可以验证,API Key 和模型名称可能是可选的 - return bool(endpoint) + return bool(self.query_one("#llm_endpoint", Input).value.strip()) except (ValueError, AttributeError): return False def _should_validate_embedding(self) -> bool: """检查是否应该验证 Embedding 配置""" try: - endpoint = self.query_one("#embedding_endpoint", Input).value.strip() - api_key = self.query_one("#embedding_api_key", Input).value.strip() - model = self.query_one("#embedding_model", Input).value.strip() - return bool(endpoint and api_key and model) + return bool(self.query_one("#embedding_endpoint", Input).value.strip()) except (ValueError, AttributeError): return False @@ -764,8 +810,10 @@ class LLMConfigScreen(ModalScreen[bool]): async def _validate_llm_config(self) -> None: """验证 LLM 配置""" # 更新状态为验证中 + self.llm_validation_status = ValidationStatus.VALIDATING status_widget = self.query_one("#llm_validation_status", Static) status_widget.update("[yellow]验证中...[/yellow]") + self._update_save_button_state() # 收集当前 LLM 配置 self._collect_llm_config() @@ -776,19 +824,27 @@ class LLMConfigScreen(ModalScreen[bool]): # 更新验证状态 if is_valid: + self.llm_validation_status = ValidationStatus.VALID status_widget.update(f"[green]✓ {message}[/green]") else: + self.llm_validation_status = ValidationStatus.INVALID status_widget.update(f"[red]✗ {message}[/red]") except (ValueError, AttributeError, OSError) as e: + self.llm_validation_status = ValidationStatus.INVALID status_widget.update(f"[red]✗ 验证异常: {e}[/red]") self.notify(f"LLM 验证过程中出现异常: {e}", severity="error") + # 更新保存按钮状态 + self._update_save_button_state() + async def _validate_embedding_config(self) -> None: """验证 Embedding 配置""" # 更新状态为验证中 + self.embedding_validation_status = ValidationStatus.VALIDATING status_widget = self.query_one("#embedding_validation_status", Static) status_widget.update("[yellow]验证中...[/yellow]") + self._update_save_button_state() # 收集当前 Embedding 配置 self._collect_embedding_config() @@ -799,14 +855,20 @@ class LLMConfigScreen(ModalScreen[bool]): # 更新验证状态 if is_valid: + self.embedding_validation_status = ValidationStatus.VALID status_widget.update(f"[green]✓ {message}[/green]") else: + self.embedding_validation_status = ValidationStatus.INVALID status_widget.update(f"[red]✗ {message}[/red]") except (ValueError, AttributeError, OSError) as e: + self.embedding_validation_status = ValidationStatus.INVALID status_widget.update(f"[red]✗ 验证异常: {e}[/red]") self.notify(f"Embedding 验证过程中出现异常: {e}", severity="error") + # 更新保存按钮状态 + self._update_save_button_state() + def _collect_llm_config(self) -> None: """收集 LLM 配置""" try: @@ -849,6 +911,51 @@ class LLMConfigScreen(ModalScreen[bool]): # 如果获取失败,记录警告并使用默认值 logger.warning("获取 Embedding 配置失败,使用默认值") + def _is_embedding_required(self) -> bool: + """检查是否需要验证 Embedding 配置""" + # 如果 RAG 环境文件存在,则需要验证 Embedding + if self.config.RAG_ENV_PATH.exists(): + return True + + # 如果用户填写了 Embedding 配置,则需要验证 + try: + endpoint = self.query_one("#embedding_endpoint", Input).value.strip() + api_key = self.query_one("#embedding_api_key", Input).value.strip() + model = self.query_one("#embedding_model", Input).value.strip() + return bool(endpoint or api_key or model) + except (ValueError, AttributeError): + return False + + def _update_save_button_state(self) -> None: + """根据验证状态更新保存按钮状态""" + try: + save_button = self.query_one("#save", Button) + + # 检查 LLM 验证状态 + if self.llm_validation_status in ( + ValidationStatus.PENDING, + ValidationStatus.VALIDATING, + ValidationStatus.INVALID, + ): + save_button.disabled = True + return + + # 检查 Embedding 验证状态 + if self._is_embedding_required() and self.embedding_validation_status in ( + ValidationStatus.PENDING, + ValidationStatus.VALIDATING, + ValidationStatus.INVALID, + ): + save_button.disabled = True + return + + # 所有必要的验证都通过,启用保存按钮 + save_button.disabled = False + + except (ValueError, AttributeError): + # 如果出现异常,为安全起见禁用保存按钮 + pass + async def _collect_and_save_config(self) -> bool: """收集用户配置并保存""" try: @@ -887,13 +994,19 @@ class LLMConfigApp(App[bool]): CSS_PATH = str(Path(__file__).parent.parent / "app" / "css" / "styles.tcss") TITLE = "openEuler Intelligence LLM 配置工具" + def __init__(self) -> None: + """初始化应用""" + super().__init__() + self.config_result: bool | None = None + def on_mount(self) -> None: """应用启动时显示配置屏幕""" self.push_screen(LLMConfigScreen(), self._handle_screen_result) def _handle_screen_result(self, result: bool | None) -> None: # noqa: FBT001 """处理配置屏幕结果""" - self.exit(return_code=0 if result else 1) + self.config_result = result + self.exit() def llm_config() -> None: @@ -915,9 +1028,10 @@ def llm_config() -> None: # 启动 TUI 应用 app = LLMConfigApp() - result = app.run() + app.run() - if result == 0: + # 检查应用内部存储的结果 + if app.config_result: sys.stdout.write("✓ LLM 配置更新完成\n") else: sys.stdout.write("配置更新已取消\n") diff --git a/src/tool/validators.py b/src/tool/validators.py index b261e42..4f1fe0a 100644 --- a/src/tool/validators.py +++ b/src/tool/validators.py @@ -81,9 +81,9 @@ class APIValidator: else: success_msg = "LLM 配置验证成功" if func_valid: - success_msg += f" - 支持 function_call,类型: {func_type}" + success_msg += f" - 支持工具调用,类型: {func_type}" else: - success_msg += f" - 不支持 function_call: {func_msg}" + success_msg += " - 不支持工具调用" return ( True, @@ -159,8 +159,8 @@ class APIValidator: call_kwargs["temperature"] = temperature response = await client.chat.completions.create(**call_kwargs) - except (AuthenticationError, APIError, OpenAIError) as e: - return False, f"基本对话测试失败: {e!s}" + except (AuthenticationError, APIError, OpenAIError): + return False, "基本对话测试失败" else: if response.choices and len(response.choices) > 0: return True, "基本对话功能正常" @@ -279,7 +279,7 @@ class APIValidator: if hasattr(choice.message, "tool_calls") and choice.message.tool_calls: return True, "支持 tools 格式的 function_call" - return False, "不支持 function_call 功能" + return False, "不支持工具调用功能" async def _test_structured_output( self, -- Gitee