diff --git a/data_chain/apps/service/model_service.py b/data_chain/apps/service/model_service.py index 920006200cc7982d45f9a20f0bf2f261c1e415ec..816940fd88aa1afffe883f682523e9db56c22ac1 100644 --- a/data_chain/apps/service/model_service.py +++ b/data_chain/apps/service/model_service.py @@ -42,6 +42,8 @@ async def get_model_by_kb_id(kb_id): kb_entity = await KnowledgeBaseManager.select_by_id(kb_id) if kb_entity is not None: model_entity = await ModelManager.select_by_user_id(kb_entity.user_id) + else: + return None return ModelConvertor.convert_entity_to_dto(model_entity) @@ -89,11 +91,11 @@ async def update_model(user_id, update_dict): encrypted_config=json.dumps(encrypted_config), max_tokens=update_dict['max_tokens'] ) - model_entity=await ModelManager.insert(model_entity) + model_entity = await ModelManager.insert(model_entity) else: update_dict['encrypted_openai_api_key'] = encrypted_openai_api_key update_dict['encrypted_config'] = json.dumps(encrypted_config) - model_entity=await ModelManager.update_by_user_id(user_id, update_dict) + model_entity = await ModelManager.update_by_user_id(user_id, update_dict) if model_entity is None: raise ModelException("Model update failed") return ModelConvertor.convert_entity_to_dto(model_entity)