diff --git a/api/core/entities/model_entities.py b/api/core/entities/model_entities.py index 501783556..e1c021a44 100644 --- a/api/core/entities/model_entities.py +++ b/api/core/entities/model_entities.py @@ -55,6 +55,25 @@ class ProviderModelWithStatusEntity(ProviderModel): status: ModelStatus load_balancing_enabled: bool = False + def raise_for_status(self) -> None: + """ + Check model status and raise ValueError if not active. + + :raises ValueError: When model status is not active, with a descriptive message + """ + if self.status == ModelStatus.ACTIVE: + return + + error_messages = { + ModelStatus.NO_CONFIGURE: "Model is not configured", + ModelStatus.QUOTA_EXCEEDED: "Model quota has been exceeded", + ModelStatus.NO_PERMISSION: "No permission to use this model", + ModelStatus.DISABLED: "Model is disabled", + } + + if self.status in error_messages: + raise ValueError(error_messages[self.status]) + class ModelWithProviderEntity(ProviderModelWithStatusEntity): """ diff --git a/api/core/extension/extensible.py b/api/core/extension/extensible.py index 231743bf2..06fdb089d 100644 --- a/api/core/extension/extensible.py +++ b/api/core/extension/extensible.py @@ -41,45 +41,53 @@ class Extensible: extensions = [] position_map: dict[str, int] = {} - # get the path of the current class - current_path = os.path.abspath(cls.__module__.replace(".", os.path.sep) + ".py") - current_dir_path = os.path.dirname(current_path) + # Get the package name from the module path + package_name = ".".join(cls.__module__.split(".")[:-1]) - # traverse subdirectories - for subdir_name in os.listdir(current_dir_path): - if subdir_name.startswith("__"): - continue + try: + # Get package directory path + package_spec = importlib.util.find_spec(package_name) + if not package_spec or not package_spec.origin: + raise ImportError(f"Could not find package {package_name}") - subdir_path = os.path.join(current_dir_path, subdir_name) - extension_name = subdir_name - if os.path.isdir(subdir_path): + package_dir = os.path.dirname(package_spec.origin) + + # Traverse subdirectories + for subdir_name in os.listdir(package_dir): + if subdir_name.startswith("__"): + continue + + subdir_path = os.path.join(package_dir, subdir_name) + if not os.path.isdir(subdir_path): + continue + + extension_name = subdir_name file_names = os.listdir(subdir_path) - # is builtin extension, builtin extension - # in the front-end page and business logic, there are special treatments. + # Check for extension module file + if (extension_name + ".py") not in file_names: + logging.warning(f"Missing {extension_name}.py file in {subdir_path}, Skip.") + continue + + # Check for builtin flag and position builtin = False - # default position is 0 can not be None for sort_to_dict_by_position_map position = 0 if "__builtin__" in file_names: builtin = True - builtin_file_path = os.path.join(subdir_path, "__builtin__") if os.path.exists(builtin_file_path): position = int(Path(builtin_file_path).read_text(encoding="utf-8").strip()) position_map[extension_name] = position - if (extension_name + ".py") not in file_names: - logging.warning(f"Missing {extension_name}.py file in {subdir_path}, Skip.") - continue - - # Dynamic loading {subdir_name}.py file and find the subclass of Extensible - py_path = os.path.join(subdir_path, extension_name + ".py") - spec = importlib.util.spec_from_file_location(extension_name, py_path) + # Import the extension module + module_name = f"{package_name}.{extension_name}.{extension_name}" + spec = importlib.util.find_spec(module_name) if not spec or not spec.loader: - raise Exception(f"Failed to load module {extension_name} from {py_path}") + raise ImportError(f"Failed to load module {module_name}") mod = importlib.util.module_from_spec(spec) spec.loader.exec_module(mod) + # Find extension class extension_class = None for name, obj in vars(mod).items(): if isinstance(obj, type) and issubclass(obj, cls) and obj != cls: @@ -87,21 +95,21 @@ class Extensible: break if not extension_class: - logging.warning(f"Missing subclass of {cls.__name__} in {py_path}, Skip.") + logging.warning(f"Missing subclass of {cls.__name__} in {module_name}, Skip.") continue + # Load schema if not builtin json_data: dict[str, Any] = {} if not builtin: - if "schema.json" not in file_names: + json_path = os.path.join(subdir_path, "schema.json") + if not os.path.exists(json_path): logging.warning(f"Missing schema.json file in {subdir_path}, Skip.") continue - json_path = os.path.join(subdir_path, "schema.json") - json_data = {} - if os.path.exists(json_path): - with open(json_path, encoding="utf-8") as f: - json_data = json.load(f) + with open(json_path, encoding="utf-8") as f: + json_data = json.load(f) + # Create extension extensions.append( ModuleExtension( extension_class=extension_class, @@ -113,6 +121,11 @@ class Extensible: ) ) + except Exception as e: + logging.exception("Error scanning extensions") + raise + + # Sort extensions by position sorted_extensions = sort_to_dict_by_position_map( position_map=position_map, data=extensions, name_func=lambda x: x.name ) diff --git a/api/core/model_runtime/entities/model_entities.py b/api/core/model_runtime/entities/model_entities.py index 373ef2bbe..568149cc3 100644 --- a/api/core/model_runtime/entities/model_entities.py +++ b/api/core/model_runtime/entities/model_entities.py @@ -160,6 +160,10 @@ class ProviderModel(BaseModel): deprecated: bool = False model_config = ConfigDict(protected_namespaces=()) + @property + def support_structure_output(self) -> bool: + return self.features is not None and ModelFeature.STRUCTURED_OUTPUT in self.features + class ParameterRule(BaseModel): """ diff --git a/api/core/provider_manager.py b/api/core/provider_manager.py index 757020017..488a39467 100644 --- a/api/core/provider_manager.py +++ b/api/core/provider_manager.py @@ -3,7 +3,9 @@ from collections import defaultdict from json import JSONDecodeError from typing import Any, Optional, cast +from sqlalchemy import select from sqlalchemy.exc import IntegrityError +from sqlalchemy.orm import Session from configs import dify_config from core.entities.model_entities import DefaultModelEntity, DefaultModelProviderEntity @@ -393,19 +395,13 @@ class ProviderManager: @staticmethod def _get_all_providers(tenant_id: str) -> dict[str, list[Provider]]: - """ - Get all provider records of the workspace. - - :param tenant_id: workspace id - :return: - """ - providers = db.session.query(Provider).filter(Provider.tenant_id == tenant_id, Provider.is_valid == True).all() - provider_name_to_provider_records_dict = defaultdict(list) - for provider in providers: - # TODO: Use provider name with prefix after the data migration - provider_name_to_provider_records_dict[str(ModelProviderID(provider.provider_name))].append(provider) - + with Session(db.engine, expire_on_commit=False) as session: + stmt = select(Provider).where(Provider.tenant_id == tenant_id, Provider.is_valid == True) + providers = session.scalars(stmt) + for provider in providers: + # Use provider name with prefix after the data migration + provider_name_to_provider_records_dict[str(ModelProviderID(provider.provider_name))].append(provider) return provider_name_to_provider_records_dict @staticmethod @@ -416,17 +412,12 @@ class ProviderManager: :param tenant_id: workspace id :return: """ - # Get all provider model records of the workspace - provider_models = ( - db.session.query(ProviderModel) - .filter(ProviderModel.tenant_id == tenant_id, ProviderModel.is_valid == True) - .all() - ) - provider_name_to_provider_model_records_dict = defaultdict(list) - for provider_model in provider_models: - provider_name_to_provider_model_records_dict[provider_model.provider_name].append(provider_model) - + with Session(db.engine, expire_on_commit=False) as session: + stmt = select(ProviderModel).where(ProviderModel.tenant_id == tenant_id, ProviderModel.is_valid == True) + provider_models = session.scalars(stmt) + for provider_model in provider_models: + provider_name_to_provider_model_records_dict[provider_model.provider_name].append(provider_model) return provider_name_to_provider_model_records_dict @staticmethod @@ -437,17 +428,14 @@ class ProviderManager: :param tenant_id: workspace id :return: """ - preferred_provider_types = ( - db.session.query(TenantPreferredModelProvider) - .filter(TenantPreferredModelProvider.tenant_id == tenant_id) - .all() - ) - - provider_name_to_preferred_provider_type_records_dict = { - preferred_provider_type.provider_name: preferred_provider_type - for preferred_provider_type in preferred_provider_types - } - + provider_name_to_preferred_provider_type_records_dict = {} + with Session(db.engine, expire_on_commit=False) as session: + stmt = select(TenantPreferredModelProvider).where(TenantPreferredModelProvider.tenant_id == tenant_id) + preferred_provider_types = session.scalars(stmt) + provider_name_to_preferred_provider_type_records_dict = { + preferred_provider_type.provider_name: preferred_provider_type + for preferred_provider_type in preferred_provider_types + } return provider_name_to_preferred_provider_type_records_dict @staticmethod @@ -458,18 +446,14 @@ class ProviderManager: :param tenant_id: workspace id :return: """ - provider_model_settings = ( - db.session.query(ProviderModelSetting).filter(ProviderModelSetting.tenant_id == tenant_id).all() - ) - provider_name_to_provider_model_settings_dict = defaultdict(list) - for provider_model_setting in provider_model_settings: - ( + with Session(db.engine, expire_on_commit=False) as session: + stmt = select(ProviderModelSetting).where(ProviderModelSetting.tenant_id == tenant_id) + provider_model_settings = session.scalars(stmt) + for provider_model_setting in provider_model_settings: provider_name_to_provider_model_settings_dict[provider_model_setting.provider_name].append( provider_model_setting ) - ) - return provider_name_to_provider_model_settings_dict @staticmethod @@ -492,15 +476,14 @@ class ProviderManager: if not model_load_balancing_enabled: return {} - provider_load_balancing_configs = ( - db.session.query(LoadBalancingModelConfig).filter(LoadBalancingModelConfig.tenant_id == tenant_id).all() - ) - provider_name_to_provider_load_balancing_model_configs_dict = defaultdict(list) - for provider_load_balancing_config in provider_load_balancing_configs: - provider_name_to_provider_load_balancing_model_configs_dict[ - provider_load_balancing_config.provider_name - ].append(provider_load_balancing_config) + with Session(db.engine, expire_on_commit=False) as session: + stmt = select(LoadBalancingModelConfig).where(LoadBalancingModelConfig.tenant_id == tenant_id) + provider_load_balancing_configs = session.scalars(stmt) + for provider_load_balancing_config in provider_load_balancing_configs: + provider_name_to_provider_load_balancing_model_configs_dict[ + provider_load_balancing_config.provider_name + ].append(provider_load_balancing_config) return provider_name_to_provider_load_balancing_model_configs_dict @@ -626,10 +609,9 @@ class ProviderManager: if not cached_provider_credentials: try: # fix origin data - if ( - custom_provider_record.encrypted_config - and not custom_provider_record.encrypted_config.startswith("{") - ): + if custom_provider_record.encrypted_config is None: + raise ValueError("No credentials found") + if not custom_provider_record.encrypted_config.startswith("{"): provider_credentials = {"openai_api_key": custom_provider_record.encrypted_config} else: provider_credentials = json.loads(custom_provider_record.encrypted_config) @@ -733,7 +715,7 @@ class ProviderManager: return SystemConfiguration(enabled=False) # Convert provider_records to dict - quota_type_to_provider_records_dict = {} + quota_type_to_provider_records_dict: dict[ProviderQuotaType, Provider] = {} for provider_record in provider_records: if provider_record.provider_type != ProviderType.SYSTEM.value: continue @@ -758,6 +740,11 @@ class ProviderManager: else: provider_record = quota_type_to_provider_records_dict[provider_quota.quota_type] + if provider_record.quota_used is None: + raise ValueError("quota_used is None") + if provider_record.quota_limit is None: + raise ValueError("quota_limit is None") + quota_configuration = QuotaConfiguration( quota_type=provider_quota.quota_type, quota_unit=provider_hosting_configuration.quota_unit or QuotaUnit.TOKENS, @@ -791,10 +778,9 @@ class ProviderManager: cached_provider_credentials = provider_credentials_cache.get() if not cached_provider_credentials: - try: - provider_credentials: dict[str, Any] = json.loads(provider_record.encrypted_config) - except JSONDecodeError: - provider_credentials = {} + provider_credentials: dict[str, Any] = {} + if provider_records and provider_records[0].encrypted_config: + provider_credentials = json.loads(provider_records[0].encrypted_config) # Get provider credential secret variables provider_credential_secret_variables = self._extract_secret_variables( diff --git a/api/core/workflow/nodes/llm/entities.py b/api/core/workflow/nodes/llm/entities.py index 486b4b01a..36d068880 100644 --- a/api/core/workflow/nodes/llm/entities.py +++ b/api/core/workflow/nodes/llm/entities.py @@ -66,7 +66,8 @@ class LLMNodeData(BaseNodeData): context: ContextConfig vision: VisionConfig = Field(default_factory=VisionConfig) structured_output: dict | None = None - structured_output_enabled: bool = False + # We used 'structured_output_enabled' in the past, but it's not a good name. + structured_output_switch_on: bool = Field(False, alias="structured_output_enabled") @field_validator("prompt_config", mode="before") @classmethod @@ -74,3 +75,7 @@ class LLMNodeData(BaseNodeData): if v is None: return PromptConfig() return v + + @property + def structured_output_enabled(self) -> bool: + return self.structured_output_switch_on and self.structured_output is not None diff --git a/api/core/workflow/nodes/llm/node.py b/api/core/workflow/nodes/llm/node.py index df8f614db..2795fe05a 100644 --- a/api/core/workflow/nodes/llm/node.py +++ b/api/core/workflow/nodes/llm/node.py @@ -12,9 +12,7 @@ from sqlalchemy.orm import Session from configs import dify_config from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity -from core.entities.model_entities import ModelStatus from core.entities.provider_entities import QuotaUnit -from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.file import FileType, file_manager from core.helper.code_executor import CodeExecutor, CodeLanguage from core.memory.token_buffer_memory import TokenBufferMemory @@ -74,7 +72,6 @@ from core.workflow.nodes.event import ( from core.workflow.utils.structured_output.entities import ( ResponseFormat, SpecialModelType, - SupportStructuredOutputStatus, ) from core.workflow.utils.structured_output.prompt import STRUCTURED_OUTPUT_PROMPT from core.workflow.utils.variable_template_parser import VariableTemplateParser @@ -277,7 +274,7 @@ class LLMNode(BaseNode[LLMNodeData]): llm_usage=usage, ) ) - except LLMNodeError as e: + except ValueError as e: yield RunCompletedEvent( run_result=NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, @@ -527,65 +524,53 @@ class LLMNode(BaseNode[LLMNodeData]): def _fetch_model_config( self, node_data_model: ModelConfig ) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]: - model_name = node_data_model.name - provider_name = node_data_model.provider + if not node_data_model.mode: + raise LLMModeRequiredError("LLM mode is required.") - model_manager = ModelManager() - model_instance = model_manager.get_model_instance( - tenant_id=self.tenant_id, model_type=ModelType.LLM, provider=provider_name, model=model_name + model = ModelManager().get_model_instance( + tenant_id=self.tenant_id, + model_type=ModelType.LLM, + provider=node_data_model.provider, + model=node_data_model.name, ) - provider_model_bundle = model_instance.provider_model_bundle - model_type_instance = model_instance.model_type_instance - model_type_instance = cast(LargeLanguageModel, model_type_instance) - - model_credentials = model_instance.credentials + model.model_type_instance = cast(LargeLanguageModel, model.model_type_instance) # check model - provider_model = provider_model_bundle.configuration.get_provider_model( - model=model_name, model_type=ModelType.LLM + provider_model = model.provider_model_bundle.configuration.get_provider_model( + model=node_data_model.name, model_type=ModelType.LLM ) if provider_model is None: - raise ModelNotExistError(f"Model {model_name} not exist.") - - if provider_model.status == ModelStatus.NO_CONFIGURE: - raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.") - elif provider_model.status == ModelStatus.NO_PERMISSION: - raise ModelCurrentlyNotSupportError(f"Dify Hosted OpenAI {model_name} currently not support.") - elif provider_model.status == ModelStatus.QUOTA_EXCEEDED: - raise QuotaExceededError(f"Model provider {provider_name} quota exceeded.") + raise ModelNotExistError(f"Model {node_data_model.name} not exist.") + provider_model.raise_for_status() # model config - completion_params = node_data_model.completion_params - stop = [] - if "stop" in completion_params: - stop = completion_params["stop"] - del completion_params["stop"] - - # get model mode - model_mode = node_data_model.mode - if not model_mode: - raise LLMModeRequiredError("LLM mode is required.") - - model_schema = model_type_instance.get_model_schema(model_name, model_credentials) + stop: list[str] = [] + if "stop" in node_data_model.completion_params: + stop = node_data_model.completion_params.pop("stop") + model_schema = model.model_type_instance.get_model_schema(node_data_model.name, model.credentials) if not model_schema: - raise ModelNotExistError(f"Model {model_name} not exist.") - support_structured_output = self._check_model_structured_output_support() - if support_structured_output == SupportStructuredOutputStatus.SUPPORTED: - completion_params = self._handle_native_json_schema(completion_params, model_schema.parameter_rules) - elif support_structured_output == SupportStructuredOutputStatus.UNSUPPORTED: - # Set appropriate response format based on model capabilities - self._set_response_format(completion_params, model_schema.parameter_rules) - return model_instance, ModelConfigWithCredentialsEntity( - provider=provider_name, - model=model_name, + raise ModelNotExistError(f"Model {node_data_model.name} not exist.") + + if self.node_data.structured_output_enabled: + if model_schema.support_structure_output: + node_data_model.completion_params = self._handle_native_json_schema( + node_data_model.completion_params, model_schema.parameter_rules + ) + else: + # Set appropriate response format based on model capabilities + self._set_response_format(node_data_model.completion_params, model_schema.parameter_rules) + + return model, ModelConfigWithCredentialsEntity( + provider=node_data_model.provider, + model=node_data_model.name, model_schema=model_schema, - mode=model_mode, - provider_model_bundle=provider_model_bundle, - credentials=model_credentials, - parameters=completion_params, + mode=node_data_model.mode, + provider_model_bundle=model.provider_model_bundle, + credentials=model.credentials, + parameters=node_data_model.completion_params, stop=stop, ) @@ -786,13 +771,25 @@ class LLMNode(BaseNode[LLMNodeData]): "No prompt found in the LLM configuration. " "Please ensure a prompt is properly configured before proceeding." ) - support_structured_output = self._check_model_structured_output_support() - if support_structured_output == SupportStructuredOutputStatus.UNSUPPORTED: - filtered_prompt_messages = self._handle_prompt_based_schema( - prompt_messages=filtered_prompt_messages, - ) - stop = model_config.stop - return filtered_prompt_messages, stop + + model = ModelManager().get_model_instance( + tenant_id=self.tenant_id, + model_type=ModelType.LLM, + provider=self.node_data.model.provider, + model=self.node_data.model.name, + ) + model_schema = model.model_type_instance.get_model_schema( + model=self.node_data.model.name, + credentials=model.credentials, + ) + if not model_schema: + raise ModelNotExistError(f"Model {self.node_data.model.name} not exist.") + if self.node_data.structured_output_enabled: + if not model_schema.support_structure_output: + filtered_prompt_messages = self._handle_prompt_based_schema( + prompt_messages=filtered_prompt_messages, + ) + return filtered_prompt_messages, model_config.stop def _parse_structured_output(self, result_text: str) -> dict[str, Any]: structured_output: dict[str, Any] = {} @@ -1185,32 +1182,6 @@ class LLMNode(BaseNode[LLMNodeData]): except json.JSONDecodeError: raise LLMNodeError("structured_output_schema is not valid JSON format") - def _check_model_structured_output_support(self) -> SupportStructuredOutputStatus: - """ - Check if the current model supports structured output. - - Returns: - SupportStructuredOutput: The support status of structured output - """ - # Early return if structured output is disabled - if ( - not isinstance(self.node_data, LLMNodeData) - or not self.node_data.structured_output_enabled - or not self.node_data.structured_output - ): - return SupportStructuredOutputStatus.DISABLED - # Get model schema and check if it exists - model_schema = self._fetch_model_schema(self.node_data.model.provider) - if not model_schema: - return SupportStructuredOutputStatus.DISABLED - - # Check if model supports structured output feature - return ( - SupportStructuredOutputStatus.SUPPORTED - if bool(model_schema.features and ModelFeature.STRUCTURED_OUTPUT in model_schema.features) - else SupportStructuredOutputStatus.UNSUPPORTED - ) - def _save_multimodal_output_and_convert_result_to_markdown( self, contents: str | list[PromptMessageContentUnionTypes] | None, diff --git a/api/core/workflow/utils/structured_output/entities.py b/api/core/workflow/utils/structured_output/entities.py index 7954acbae..6491042bf 100644 --- a/api/core/workflow/utils/structured_output/entities.py +++ b/api/core/workflow/utils/structured_output/entities.py @@ -14,11 +14,3 @@ class SpecialModelType(StrEnum): GEMINI = "gemini" OLLAMA = "ollama" - - -class SupportStructuredOutputStatus(StrEnum): - """Constants for structured output support status""" - - SUPPORTED = "supported" - UNSUPPORTED = "unsupported" - DISABLED = "disabled" diff --git a/api/models/provider.py b/api/models/provider.py index 497cbefc6..1e25f0c90 100644 --- a/api/models/provider.py +++ b/api/models/provider.py @@ -1,6 +1,9 @@ +from datetime import datetime from enum import Enum +from typing import Optional -from sqlalchemy import func +from sqlalchemy import func, text +from sqlalchemy.orm import Mapped, mapped_column from .base import Base from .engine import db @@ -51,20 +54,24 @@ class Provider(Base): ), ) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) - tenant_id = db.Column(StringUUID, nullable=False) - provider_name = db.Column(db.String(255), nullable=False) - provider_type = db.Column(db.String(40), nullable=False, server_default=db.text("'custom'::character varying")) - encrypted_config = db.Column(db.Text, nullable=True) - is_valid = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) - last_used = db.Column(db.DateTime, nullable=True) + id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()")) + tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + provider_name: Mapped[str] = mapped_column(db.String(255), nullable=False) + provider_type: Mapped[str] = mapped_column( + db.String(40), nullable=False, server_default=text("'custom'::character varying") + ) + encrypted_config: Mapped[Optional[str]] = mapped_column(db.Text, nullable=True) + is_valid: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=text("false")) + last_used: Mapped[Optional[datetime]] = mapped_column(db.DateTime, nullable=True) - quota_type = db.Column(db.String(40), nullable=True, server_default=db.text("''::character varying")) - quota_limit = db.Column(db.BigInteger, nullable=True) - quota_used = db.Column(db.BigInteger, default=0) + quota_type: Mapped[Optional[str]] = mapped_column( + db.String(40), nullable=True, server_default=text("''::character varying") + ) + quota_limit: Mapped[Optional[int]] = mapped_column(db.BigInteger, nullable=True) + quota_used: Mapped[Optional[int]] = mapped_column(db.BigInteger, default=0) - created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) def __repr__(self): return ( @@ -104,15 +111,15 @@ class ProviderModel(Base): ), ) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) - tenant_id = db.Column(StringUUID, nullable=False) - provider_name = db.Column(db.String(255), nullable=False) - model_name = db.Column(db.String(255), nullable=False) - model_type = db.Column(db.String(40), nullable=False) - encrypted_config = db.Column(db.Text, nullable=True) - is_valid = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) - created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()")) + tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + provider_name: Mapped[str] = mapped_column(db.String(255), nullable=False) + model_name: Mapped[str] = mapped_column(db.String(255), nullable=False) + model_type: Mapped[str] = mapped_column(db.String(40), nullable=False) + encrypted_config: Mapped[Optional[str]] = mapped_column(db.Text, nullable=True) + is_valid: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=text("false")) + created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) class TenantDefaultModel(Base): @@ -122,13 +129,13 @@ class TenantDefaultModel(Base): db.Index("tenant_default_model_tenant_id_provider_type_idx", "tenant_id", "provider_name", "model_type"), ) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) - tenant_id = db.Column(StringUUID, nullable=False) - provider_name = db.Column(db.String(255), nullable=False) - model_name = db.Column(db.String(255), nullable=False) - model_type = db.Column(db.String(40), nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()")) + tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + provider_name: Mapped[str] = mapped_column(db.String(255), nullable=False) + model_name: Mapped[str] = mapped_column(db.String(255), nullable=False) + model_type: Mapped[str] = mapped_column(db.String(40), nullable=False) + created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) class TenantPreferredModelProvider(Base): @@ -138,12 +145,12 @@ class TenantPreferredModelProvider(Base): db.Index("tenant_preferred_model_provider_tenant_provider_idx", "tenant_id", "provider_name"), ) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) - tenant_id = db.Column(StringUUID, nullable=False) - provider_name = db.Column(db.String(255), nullable=False) - preferred_provider_type = db.Column(db.String(40), nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()")) + tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + provider_name: Mapped[str] = mapped_column(db.String(255), nullable=False) + preferred_provider_type: Mapped[str] = mapped_column(db.String(40), nullable=False) + created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) class ProviderOrder(Base): @@ -153,22 +160,24 @@ class ProviderOrder(Base): db.Index("provider_order_tenant_provider_idx", "tenant_id", "provider_name"), ) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) - tenant_id = db.Column(StringUUID, nullable=False) - provider_name = db.Column(db.String(255), nullable=False) - account_id = db.Column(StringUUID, nullable=False) - payment_product_id = db.Column(db.String(191), nullable=False) - payment_id = db.Column(db.String(191)) - transaction_id = db.Column(db.String(191)) - quantity = db.Column(db.Integer, nullable=False, server_default=db.text("1")) - currency = db.Column(db.String(40)) - total_amount = db.Column(db.Integer) - payment_status = db.Column(db.String(40), nullable=False, server_default=db.text("'wait_pay'::character varying")) - paid_at = db.Column(db.DateTime) - pay_failed_at = db.Column(db.DateTime) - refunded_at = db.Column(db.DateTime) - created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()")) + tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + provider_name: Mapped[str] = mapped_column(db.String(255), nullable=False) + account_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + payment_product_id: Mapped[str] = mapped_column(db.String(191), nullable=False) + payment_id: Mapped[Optional[str]] = mapped_column(db.String(191)) + transaction_id: Mapped[Optional[str]] = mapped_column(db.String(191)) + quantity: Mapped[int] = mapped_column(db.Integer, nullable=False, server_default=text("1")) + currency: Mapped[Optional[str]] = mapped_column(db.String(40)) + total_amount: Mapped[Optional[int]] = mapped_column(db.Integer) + payment_status: Mapped[str] = mapped_column( + db.String(40), nullable=False, server_default=text("'wait_pay'::character varying") + ) + paid_at: Mapped[Optional[datetime]] = mapped_column(db.DateTime) + pay_failed_at: Mapped[Optional[datetime]] = mapped_column(db.DateTime) + refunded_at: Mapped[Optional[datetime]] = mapped_column(db.DateTime) + created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) class ProviderModelSetting(Base): @@ -182,15 +191,15 @@ class ProviderModelSetting(Base): db.Index("provider_model_setting_tenant_provider_model_idx", "tenant_id", "provider_name", "model_type"), ) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) - tenant_id = db.Column(StringUUID, nullable=False) - provider_name = db.Column(db.String(255), nullable=False) - model_name = db.Column(db.String(255), nullable=False) - model_type = db.Column(db.String(40), nullable=False) - enabled = db.Column(db.Boolean, nullable=False, server_default=db.text("true")) - load_balancing_enabled = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) - created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()")) + tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + provider_name: Mapped[str] = mapped_column(db.String(255), nullable=False) + model_name: Mapped[str] = mapped_column(db.String(255), nullable=False) + model_type: Mapped[str] = mapped_column(db.String(40), nullable=False) + enabled: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=text("true")) + load_balancing_enabled: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=text("false")) + created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) class LoadBalancingModelConfig(Base): @@ -204,13 +213,13 @@ class LoadBalancingModelConfig(Base): db.Index("load_balancing_model_config_tenant_provider_model_idx", "tenant_id", "provider_name", "model_type"), ) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) - tenant_id = db.Column(StringUUID, nullable=False) - provider_name = db.Column(db.String(255), nullable=False) - model_name = db.Column(db.String(255), nullable=False) - model_type = db.Column(db.String(40), nullable=False) - name = db.Column(db.String(255), nullable=False) - encrypted_config = db.Column(db.Text, nullable=True) - enabled = db.Column(db.Boolean, nullable=False, server_default=db.text("true")) - created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()")) + tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + provider_name: Mapped[str] = mapped_column(db.String(255), nullable=False) + model_name: Mapped[str] = mapped_column(db.String(255), nullable=False) + model_type: Mapped[str] = mapped_column(db.String(40), nullable=False) + name: Mapped[str] = mapped_column(db.String(255), nullable=False) + encrypted_config: Mapped[Optional[str]] = mapped_column(db.Text, nullable=True) + enabled: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=text("true")) + created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) diff --git a/api/tests/integration_tests/workflow/nodes/test_llm.py b/api/tests/integration_tests/workflow/nodes/test_llm.py index 5fbee266b..6aa48b1cb 100644 --- a/api/tests/integration_tests/workflow/nodes/test_llm.py +++ b/api/tests/integration_tests/workflow/nodes/test_llm.py @@ -3,11 +3,16 @@ import os import time import uuid from collections.abc import Generator -from unittest.mock import MagicMock +from decimal import Decimal +from unittest.mock import MagicMock, patch import pytest +from app_factory import create_app +from configs import dify_config from core.app.entities.app_invoke_entities import InvokeFrom +from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage +from core.model_runtime.entities.message_entities import AssistantPromptMessage from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus from core.workflow.enums import SystemVariableKey @@ -19,13 +24,27 @@ from core.workflow.nodes.llm.node import LLMNode from extensions.ext_database import db from models.enums import UserFrom from models.workflow import WorkflowType -from tests.integration_tests.workflow.nodes.__mock.model import get_mocked_fetch_model_config """FOR MOCK FIXTURES, DO NOT REMOVE""" from tests.integration_tests.model_runtime.__mock.plugin_daemon import setup_model_mock from tests.integration_tests.workflow.nodes.__mock.code_executor import setup_code_executor_mock +@pytest.fixture(scope="session") +def app(): + # Set up storage configuration + os.environ["STORAGE_TYPE"] = "opendal" + os.environ["OPENDAL_SCHEME"] = "fs" + os.environ["OPENDAL_FS_ROOT"] = "storage" + + # Ensure storage directory exists + os.makedirs("storage", exist_ok=True) + + app = create_app() + dify_config.LOGIN_DISABLED = True + return app + + def init_llm_node(config: dict) -> LLMNode: graph_config = { "edges": [ @@ -40,13 +59,19 @@ def init_llm_node(config: dict) -> LLMNode: graph = Graph.init(graph_config=graph_config) + # Use proper UUIDs for database compatibility + tenant_id = "9d2074fc-6f86-45a9-b09d-6ecc63b9056b" + app_id = "9d2074fc-6f86-45a9-b09d-6ecc63b9056c" + workflow_id = "9d2074fc-6f86-45a9-b09d-6ecc63b9056d" + user_id = "9d2074fc-6f86-45a9-b09d-6ecc63b9056e" + init_params = GraphInitParams( - tenant_id="1", - app_id="1", + tenant_id=tenant_id, + app_id=app_id, workflow_type=WorkflowType.WORKFLOW, - workflow_id="1", + workflow_id=workflow_id, graph_config=graph_config, - user_id="1", + user_id=user_id, user_from=UserFrom.ACCOUNT, invoke_from=InvokeFrom.DEBUGGER, call_depth=0, @@ -77,115 +102,197 @@ def init_llm_node(config: dict) -> LLMNode: return node -def test_execute_llm(setup_model_mock): - node = init_llm_node( - config={ - "id": "llm", - "data": { - "title": "123", - "type": "llm", - "model": { - "provider": "langgenius/openai/openai", - "name": "gpt-3.5-turbo", - "mode": "chat", - "completion_params": {}, +def test_execute_llm(app): + with app.app_context(): + node = init_llm_node( + config={ + "id": "llm", + "data": { + "title": "123", + "type": "llm", + "model": { + "provider": "langgenius/openai/openai", + "name": "gpt-3.5-turbo", + "mode": "chat", + "completion_params": {}, + }, + "prompt_template": [ + { + "role": "system", + "text": "you are a helpful assistant.\ntoday's weather is {{#abc.output#}}.", + }, + {"role": "user", "text": "{{#sys.query#}}"}, + ], + "memory": None, + "context": {"enabled": False}, + "vision": {"enabled": False}, }, - "prompt_template": [ - {"role": "system", "text": "you are a helpful assistant.\ntoday's weather is {{#abc.output#}}."}, - {"role": "user", "text": "{{#sys.query#}}"}, - ], - "memory": None, - "context": {"enabled": False}, - "vision": {"enabled": False}, }, - }, - ) + ) - credentials = {"openai_api_key": os.environ.get("OPENAI_API_KEY")} + credentials = {"openai_api_key": os.environ.get("OPENAI_API_KEY")} - # Mock db.session.close() - db.session.close = MagicMock() + # Create a proper LLM result with real entities + mock_usage = LLMUsage( + prompt_tokens=30, + prompt_unit_price=Decimal("0.001"), + prompt_price_unit=Decimal("1000"), + prompt_price=Decimal("0.00003"), + completion_tokens=20, + completion_unit_price=Decimal("0.002"), + completion_price_unit=Decimal("1000"), + completion_price=Decimal("0.00004"), + total_tokens=50, + total_price=Decimal("0.00007"), + currency="USD", + latency=0.5, + ) - node._fetch_model_config = get_mocked_fetch_model_config( - provider="langgenius/openai/openai", - model="gpt-3.5-turbo", - mode="chat", - credentials=credentials, - ) + mock_message = AssistantPromptMessage(content="This is a test response from the mocked LLM.") - # execute node - result = node._run() - assert isinstance(result, Generator) + mock_llm_result = LLMResult( + model="gpt-3.5-turbo", + prompt_messages=[], + message=mock_message, + usage=mock_usage, + ) - for item in result: - if isinstance(item, RunCompletedEvent): - assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED - assert item.run_result.process_data is not None - assert item.run_result.outputs is not None - assert item.run_result.outputs.get("text") is not None - assert item.run_result.outputs.get("usage", {})["total_tokens"] > 0 + # Create a simple mock model instance that doesn't call real providers + mock_model_instance = MagicMock() + mock_model_instance.invoke_llm.return_value = mock_llm_result + + # Create a simple mock model config with required attributes + mock_model_config = MagicMock() + mock_model_config.mode = "chat" + mock_model_config.provider = "langgenius/openai/openai" + mock_model_config.model = "gpt-3.5-turbo" + mock_model_config.provider_model_bundle.configuration.tenant_id = "9d2074fc-6f86-45a9-b09d-6ecc63b9056b" + + # Mock the _fetch_model_config method + def mock_fetch_model_config_func(_node_data_model): + return mock_model_instance, mock_model_config + + # Also mock ModelManager.get_model_instance to avoid database calls + def mock_get_model_instance(_self, **kwargs): + return mock_model_instance + + with ( + patch.object(node, "_fetch_model_config", mock_fetch_model_config_func), + patch("core.model_manager.ModelManager.get_model_instance", mock_get_model_instance), + ): + # execute node + result = node._run() + assert isinstance(result, Generator) + + for item in result: + if isinstance(item, RunCompletedEvent): + assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED + assert item.run_result.process_data is not None + assert item.run_result.outputs is not None + assert item.run_result.outputs.get("text") is not None + assert item.run_result.outputs.get("usage", {})["total_tokens"] > 0 @pytest.mark.parametrize("setup_code_executor_mock", [["none"]], indirect=True) -def test_execute_llm_with_jinja2(setup_code_executor_mock, setup_model_mock): +def test_execute_llm_with_jinja2(app, setup_code_executor_mock): """ Test execute LLM node with jinja2 """ - node = init_llm_node( - config={ - "id": "llm", - "data": { - "title": "123", - "type": "llm", - "model": {"provider": "openai", "name": "gpt-3.5-turbo", "mode": "chat", "completion_params": {}}, - "prompt_config": { - "jinja2_variables": [ - {"variable": "sys_query", "value_selector": ["sys", "query"]}, - {"variable": "output", "value_selector": ["abc", "output"]}, - ] + with app.app_context(): + node = init_llm_node( + config={ + "id": "llm", + "data": { + "title": "123", + "type": "llm", + "model": {"provider": "openai", "name": "gpt-3.5-turbo", "mode": "chat", "completion_params": {}}, + "prompt_config": { + "jinja2_variables": [ + {"variable": "sys_query", "value_selector": ["sys", "query"]}, + {"variable": "output", "value_selector": ["abc", "output"]}, + ] + }, + "prompt_template": [ + { + "role": "system", + "text": "you are a helpful assistant.\ntoday's weather is {{#abc.output#}}", + "jinja2_text": "you are a helpful assistant.\ntoday's weather is {{output}}.", + "edition_type": "jinja2", + }, + { + "role": "user", + "text": "{{#sys.query#}}", + "jinja2_text": "{{sys_query}}", + "edition_type": "basic", + }, + ], + "memory": None, + "context": {"enabled": False}, + "vision": {"enabled": False}, }, - "prompt_template": [ - { - "role": "system", - "text": "you are a helpful assistant.\ntoday's weather is {{#abc.output#}}", - "jinja2_text": "you are a helpful assistant.\ntoday's weather is {{output}}.", - "edition_type": "jinja2", - }, - { - "role": "user", - "text": "{{#sys.query#}}", - "jinja2_text": "{{sys_query}}", - "edition_type": "basic", - }, - ], - "memory": None, - "context": {"enabled": False}, - "vision": {"enabled": False}, }, - }, - ) + ) - credentials = {"openai_api_key": os.environ.get("OPENAI_API_KEY")} + # Mock db.session.close() + db.session.close = MagicMock() - # Mock db.session.close() - db.session.close = MagicMock() + # Create a proper LLM result with real entities + mock_usage = LLMUsage( + prompt_tokens=30, + prompt_unit_price=Decimal("0.001"), + prompt_price_unit=Decimal("1000"), + prompt_price=Decimal("0.00003"), + completion_tokens=20, + completion_unit_price=Decimal("0.002"), + completion_price_unit=Decimal("1000"), + completion_price=Decimal("0.00004"), + total_tokens=50, + total_price=Decimal("0.00007"), + currency="USD", + latency=0.5, + ) - node._fetch_model_config = get_mocked_fetch_model_config( - provider="langgenius/openai/openai", - model="gpt-3.5-turbo", - mode="chat", - credentials=credentials, - ) + mock_message = AssistantPromptMessage(content="Test response: sunny weather and what's the weather today?") - # execute node - result = node._run() + mock_llm_result = LLMResult( + model="gpt-3.5-turbo", + prompt_messages=[], + message=mock_message, + usage=mock_usage, + ) - for item in result: - if isinstance(item, RunCompletedEvent): - assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED - assert item.run_result.process_data is not None - assert "sunny" in json.dumps(item.run_result.process_data) - assert "what's the weather today?" in json.dumps(item.run_result.process_data) + # Create a simple mock model instance that doesn't call real providers + mock_model_instance = MagicMock() + mock_model_instance.invoke_llm.return_value = mock_llm_result + + # Create a simple mock model config with required attributes + mock_model_config = MagicMock() + mock_model_config.mode = "chat" + mock_model_config.provider = "openai" + mock_model_config.model = "gpt-3.5-turbo" + mock_model_config.provider_model_bundle.configuration.tenant_id = "9d2074fc-6f86-45a9-b09d-6ecc63b9056b" + + # Mock the _fetch_model_config method + def mock_fetch_model_config_func(_node_data_model): + return mock_model_instance, mock_model_config + + # Also mock ModelManager.get_model_instance to avoid database calls + def mock_get_model_instance(_self, **kwargs): + return mock_model_instance + + with ( + patch.object(node, "_fetch_model_config", mock_fetch_model_config_func), + patch("core.model_manager.ModelManager.get_model_instance", mock_get_model_instance), + ): + # execute node + result = node._run() + + for item in result: + if isinstance(item, RunCompletedEvent): + assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED + assert item.run_result.process_data is not None + assert "sunny" in json.dumps(item.run_result.process_data) + assert "what's the weather today?" in json.dumps(item.run_result.process_data) def test_extract_json():