From 0067b16d1eedea27b9825f0094ac8370afa43609 Mon Sep 17 00:00:00 2001 From: yihong Date: Thu, 21 Nov 2024 10:34:43 +0800 Subject: [PATCH] fix: refactor all 'or []' and 'or {}' logic to make code more clear (#10883) Signed-off-by: yihong0618 --- api/core/agent/base_agent_runner.py | 15 ++++----------- .../app/task_pipeline/workflow_cycle_manage.py | 4 ++-- .../model_providers/cohere/llm/llm.py | 4 ++-- .../model_providers/openai/llm/llm.py | 4 ++-- .../openllm/llm/openllm_generate.py | 5 +++-- api/core/tools/tool/tool.py | 2 +- api/core/tools/tool_engine.py | 2 +- api/core/tools/utils/configuration.py | 2 +- api/services/app_service.py | 2 +- api/services/tools/tools_transform_service.py | 2 +- api/services/website_service.py | 4 ++-- 11 files changed, 20 insertions(+), 26 deletions(-) diff --git a/api/core/agent/base_agent_runner.py b/api/core/agent/base_agent_runner.py index 860ec5de0..2f5e7c779 100644 --- a/api/core/agent/base_agent_runner.py +++ b/api/core/agent/base_agent_runner.py @@ -114,16 +114,9 @@ class BaseAgentRunner(AppRunner): # check if model supports stream tool call llm_model = cast(LargeLanguageModel, model_instance.model_type_instance) model_schema = llm_model.get_model_schema(model_instance.model, model_instance.credentials) - if model_schema and ModelFeature.STREAM_TOOL_CALL in (model_schema.features or []): - self.stream_tool_call = True - else: - self.stream_tool_call = False - - # check if model supports vision - if model_schema and ModelFeature.VISION in (model_schema.features or []): - self.files = application_generate_entity.files - else: - self.files = [] + features = model_schema.features if model_schema and model_schema.features else [] + self.stream_tool_call = ModelFeature.STREAM_TOOL_CALL in features + self.files = application_generate_entity.files if ModelFeature.VISION in features else [] self.query = None self._current_thoughts: list[PromptMessage] = [] @@ -250,7 +243,7 @@ class BaseAgentRunner(AppRunner): update prompt message tool """ # try to get tool runtime parameters - tool_runtime_parameters = tool.get_runtime_parameters() or [] + tool_runtime_parameters = tool.get_runtime_parameters() for parameter in tool_runtime_parameters: if parameter.form != ToolParameter.ToolParameterForm.LLM: diff --git a/api/core/app/task_pipeline/workflow_cycle_manage.py b/api/core/app/task_pipeline/workflow_cycle_manage.py index 042339969..46b860927 100644 --- a/api/core/app/task_pipeline/workflow_cycle_manage.py +++ b/api/core/app/task_pipeline/workflow_cycle_manage.py @@ -381,7 +381,7 @@ class WorkflowCycleManage: id=workflow_run.id, workflow_id=workflow_run.workflow_id, sequence_number=workflow_run.sequence_number, - inputs=workflow_run.inputs_dict or {}, + inputs=workflow_run.inputs_dict, created_at=int(workflow_run.created_at.timestamp()), ), ) @@ -428,7 +428,7 @@ class WorkflowCycleManage: created_by=created_by, created_at=int(workflow_run.created_at.timestamp()), finished_at=int(workflow_run.finished_at.timestamp()), - files=self._fetch_files_from_node_outputs(workflow_run.outputs_dict or {}), + files=self._fetch_files_from_node_outputs(workflow_run.outputs_dict), ), ) diff --git a/api/core/model_runtime/model_providers/cohere/llm/llm.py b/api/core/model_runtime/model_providers/cohere/llm/llm.py index 3863ad330..f230157a3 100644 --- a/api/core/model_runtime/model_providers/cohere/llm/llm.py +++ b/api/core/model_runtime/model_providers/cohere/llm/llm.py @@ -691,8 +691,8 @@ class CohereLargeLanguageModel(LargeLanguageModel): base_model_schema = cast(AIModelEntity, base_model_schema) base_model_schema_features = base_model_schema.features or [] - base_model_schema_model_properties = base_model_schema.model_properties or {} - base_model_schema_parameters_rules = base_model_schema.parameter_rules or [] + base_model_schema_model_properties = base_model_schema.model_properties + base_model_schema_parameters_rules = base_model_schema.parameter_rules entity = AIModelEntity( model=model, diff --git a/api/core/model_runtime/model_providers/openai/llm/llm.py b/api/core/model_runtime/model_providers/openai/llm/llm.py index f16f81c12..aea884e00 100644 --- a/api/core/model_runtime/model_providers/openai/llm/llm.py +++ b/api/core/model_runtime/model_providers/openai/llm/llm.py @@ -1130,8 +1130,8 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): base_model_schema = model_map[base_model] base_model_schema_features = base_model_schema.features or [] - base_model_schema_model_properties = base_model_schema.model_properties or {} - base_model_schema_parameters_rules = base_model_schema.parameter_rules or [] + base_model_schema_model_properties = base_model_schema.model_properties + base_model_schema_parameters_rules = base_model_schema.parameter_rules entity = AIModelEntity( model=model, diff --git a/api/core/model_runtime/model_providers/openllm/llm/openllm_generate.py b/api/core/model_runtime/model_providers/openllm/llm/openllm_generate.py index 351dcced1..2789a9250 100644 --- a/api/core/model_runtime/model_providers/openllm/llm/openllm_generate.py +++ b/api/core/model_runtime/model_providers/openllm/llm/openllm_generate.py @@ -37,13 +37,14 @@ class OpenLLMGenerateMessage: class OpenLLMGenerate: def generate( self, + *, server_url: str, model_name: str, stream: bool, model_parameters: dict[str, Any], - stop: list[str], + stop: list[str] | None = None, prompt_messages: list[OpenLLMGenerateMessage], - user: str, + user: str | None = None, ) -> Union[Generator[OpenLLMGenerateMessage, None, None], OpenLLMGenerateMessage]: if not server_url: raise InvalidAuthenticationError("Invalid server URL") diff --git a/api/core/tools/tool/tool.py b/api/core/tools/tool/tool.py index 6cb6e18b6..f17a26dfb 100644 --- a/api/core/tools/tool/tool.py +++ b/api/core/tools/tool/tool.py @@ -261,7 +261,7 @@ class Tool(BaseModel, ABC): """ parameters = self.parameters or [] parameters = parameters.copy() - user_parameters = self.get_runtime_parameters() or [] + user_parameters = self.get_runtime_parameters() user_parameters = user_parameters.copy() # override parameters diff --git a/api/core/tools/tool_engine.py b/api/core/tools/tool_engine.py index 9e290c365..01a1fe330 100644 --- a/api/core/tools/tool_engine.py +++ b/api/core/tools/tool_engine.py @@ -55,7 +55,7 @@ class ToolEngine: # check if this tool has only one parameter parameters = [ parameter - for parameter in tool.get_runtime_parameters() or [] + for parameter in tool.get_runtime_parameters() if parameter.form == ToolParameter.ToolParameterForm.LLM ] if parameters and len(parameters) == 1: diff --git a/api/core/tools/utils/configuration.py b/api/core/tools/utils/configuration.py index 83600d21c..8b5e27f53 100644 --- a/api/core/tools/utils/configuration.py +++ b/api/core/tools/utils/configuration.py @@ -127,7 +127,7 @@ class ToolParameterConfigurationManager(BaseModel): # get tool parameters tool_parameters = self.tool_runtime.parameters or [] # get tool runtime parameters - runtime_parameters = self.tool_runtime.get_runtime_parameters() or [] + runtime_parameters = self.tool_runtime.get_runtime_parameters() # override parameters current_parameters = tool_parameters.copy() for runtime_parameter in runtime_parameters: diff --git a/api/services/app_service.py b/api/services/app_service.py index 620d0ac27..af2b77d63 100644 --- a/api/services/app_service.py +++ b/api/services/app_service.py @@ -341,7 +341,7 @@ class AppService: if not app_model_config: return meta - agent_config = app_model_config.agent_mode_dict or {} + agent_config = app_model_config.agent_mode_dict # get all tools tools = agent_config.get("tools", []) diff --git a/api/services/tools/tools_transform_service.py b/api/services/tools/tools_transform_service.py index 1befa1153..a4aa870dc 100644 --- a/api/services/tools/tools_transform_service.py +++ b/api/services/tools/tools_transform_service.py @@ -242,7 +242,7 @@ class ToolTransformService: # get tool parameters parameters = tool.parameters or [] # get tool runtime parameters - runtime_parameters = tool.get_runtime_parameters() or [] + runtime_parameters = tool.get_runtime_parameters() # override parameters current_parameters = parameters.copy() for runtime_parameter in runtime_parameters: diff --git a/api/services/website_service.py b/api/services/website_service.py index 13cc9c679..230f5d781 100644 --- a/api/services/website_service.py +++ b/api/services/website_service.py @@ -51,8 +51,8 @@ class WebsiteService: excludes = options.get("excludes").split(",") if options.get("excludes") else [] params = { "crawlerOptions": { - "includes": includes or [], - "excludes": excludes or [], + "includes": includes, + "excludes": excludes, "generateImgAltText": True, "limit": options.get("limit", 1), "returnOnlyUrls": False,