refactor: Improve model status handling and structured output (#20586)
Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
@@ -55,6 +55,25 @@ class ProviderModelWithStatusEntity(ProviderModel):
|
|||||||
status: ModelStatus
|
status: ModelStatus
|
||||||
load_balancing_enabled: bool = False
|
load_balancing_enabled: bool = False
|
||||||
|
|
||||||
|
def raise_for_status(self) -> None:
|
||||||
|
"""
|
||||||
|
Check model status and raise ValueError if not active.
|
||||||
|
|
||||||
|
:raises ValueError: When model status is not active, with a descriptive message
|
||||||
|
"""
|
||||||
|
if self.status == ModelStatus.ACTIVE:
|
||||||
|
return
|
||||||
|
|
||||||
|
error_messages = {
|
||||||
|
ModelStatus.NO_CONFIGURE: "Model is not configured",
|
||||||
|
ModelStatus.QUOTA_EXCEEDED: "Model quota has been exceeded",
|
||||||
|
ModelStatus.NO_PERMISSION: "No permission to use this model",
|
||||||
|
ModelStatus.DISABLED: "Model is disabled",
|
||||||
|
}
|
||||||
|
|
||||||
|
if self.status in error_messages:
|
||||||
|
raise ValueError(error_messages[self.status])
|
||||||
|
|
||||||
|
|
||||||
class ModelWithProviderEntity(ProviderModelWithStatusEntity):
|
class ModelWithProviderEntity(ProviderModelWithStatusEntity):
|
||||||
"""
|
"""
|
||||||
|
@@ -41,45 +41,53 @@ class Extensible:
|
|||||||
extensions = []
|
extensions = []
|
||||||
position_map: dict[str, int] = {}
|
position_map: dict[str, int] = {}
|
||||||
|
|
||||||
# get the path of the current class
|
# Get the package name from the module path
|
||||||
current_path = os.path.abspath(cls.__module__.replace(".", os.path.sep) + ".py")
|
package_name = ".".join(cls.__module__.split(".")[:-1])
|
||||||
current_dir_path = os.path.dirname(current_path)
|
|
||||||
|
|
||||||
# traverse subdirectories
|
try:
|
||||||
for subdir_name in os.listdir(current_dir_path):
|
# Get package directory path
|
||||||
if subdir_name.startswith("__"):
|
package_spec = importlib.util.find_spec(package_name)
|
||||||
continue
|
if not package_spec or not package_spec.origin:
|
||||||
|
raise ImportError(f"Could not find package {package_name}")
|
||||||
|
|
||||||
subdir_path = os.path.join(current_dir_path, subdir_name)
|
package_dir = os.path.dirname(package_spec.origin)
|
||||||
extension_name = subdir_name
|
|
||||||
if os.path.isdir(subdir_path):
|
# Traverse subdirectories
|
||||||
|
for subdir_name in os.listdir(package_dir):
|
||||||
|
if subdir_name.startswith("__"):
|
||||||
|
continue
|
||||||
|
|
||||||
|
subdir_path = os.path.join(package_dir, subdir_name)
|
||||||
|
if not os.path.isdir(subdir_path):
|
||||||
|
continue
|
||||||
|
|
||||||
|
extension_name = subdir_name
|
||||||
file_names = os.listdir(subdir_path)
|
file_names = os.listdir(subdir_path)
|
||||||
|
|
||||||
# is builtin extension, builtin extension
|
# Check for extension module file
|
||||||
# in the front-end page and business logic, there are special treatments.
|
if (extension_name + ".py") not in file_names:
|
||||||
|
logging.warning(f"Missing {extension_name}.py file in {subdir_path}, Skip.")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Check for builtin flag and position
|
||||||
builtin = False
|
builtin = False
|
||||||
# default position is 0 can not be None for sort_to_dict_by_position_map
|
|
||||||
position = 0
|
position = 0
|
||||||
if "__builtin__" in file_names:
|
if "__builtin__" in file_names:
|
||||||
builtin = True
|
builtin = True
|
||||||
|
|
||||||
builtin_file_path = os.path.join(subdir_path, "__builtin__")
|
builtin_file_path = os.path.join(subdir_path, "__builtin__")
|
||||||
if os.path.exists(builtin_file_path):
|
if os.path.exists(builtin_file_path):
|
||||||
position = int(Path(builtin_file_path).read_text(encoding="utf-8").strip())
|
position = int(Path(builtin_file_path).read_text(encoding="utf-8").strip())
|
||||||
position_map[extension_name] = position
|
position_map[extension_name] = position
|
||||||
|
|
||||||
if (extension_name + ".py") not in file_names:
|
# Import the extension module
|
||||||
logging.warning(f"Missing {extension_name}.py file in {subdir_path}, Skip.")
|
module_name = f"{package_name}.{extension_name}.{extension_name}"
|
||||||
continue
|
spec = importlib.util.find_spec(module_name)
|
||||||
|
|
||||||
# Dynamic loading {subdir_name}.py file and find the subclass of Extensible
|
|
||||||
py_path = os.path.join(subdir_path, extension_name + ".py")
|
|
||||||
spec = importlib.util.spec_from_file_location(extension_name, py_path)
|
|
||||||
if not spec or not spec.loader:
|
if not spec or not spec.loader:
|
||||||
raise Exception(f"Failed to load module {extension_name} from {py_path}")
|
raise ImportError(f"Failed to load module {module_name}")
|
||||||
mod = importlib.util.module_from_spec(spec)
|
mod = importlib.util.module_from_spec(spec)
|
||||||
spec.loader.exec_module(mod)
|
spec.loader.exec_module(mod)
|
||||||
|
|
||||||
|
# Find extension class
|
||||||
extension_class = None
|
extension_class = None
|
||||||
for name, obj in vars(mod).items():
|
for name, obj in vars(mod).items():
|
||||||
if isinstance(obj, type) and issubclass(obj, cls) and obj != cls:
|
if isinstance(obj, type) and issubclass(obj, cls) and obj != cls:
|
||||||
@@ -87,21 +95,21 @@ class Extensible:
|
|||||||
break
|
break
|
||||||
|
|
||||||
if not extension_class:
|
if not extension_class:
|
||||||
logging.warning(f"Missing subclass of {cls.__name__} in {py_path}, Skip.")
|
logging.warning(f"Missing subclass of {cls.__name__} in {module_name}, Skip.")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
# Load schema if not builtin
|
||||||
json_data: dict[str, Any] = {}
|
json_data: dict[str, Any] = {}
|
||||||
if not builtin:
|
if not builtin:
|
||||||
if "schema.json" not in file_names:
|
json_path = os.path.join(subdir_path, "schema.json")
|
||||||
|
if not os.path.exists(json_path):
|
||||||
logging.warning(f"Missing schema.json file in {subdir_path}, Skip.")
|
logging.warning(f"Missing schema.json file in {subdir_path}, Skip.")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
json_path = os.path.join(subdir_path, "schema.json")
|
with open(json_path, encoding="utf-8") as f:
|
||||||
json_data = {}
|
json_data = json.load(f)
|
||||||
if os.path.exists(json_path):
|
|
||||||
with open(json_path, encoding="utf-8") as f:
|
|
||||||
json_data = json.load(f)
|
|
||||||
|
|
||||||
|
# Create extension
|
||||||
extensions.append(
|
extensions.append(
|
||||||
ModuleExtension(
|
ModuleExtension(
|
||||||
extension_class=extension_class,
|
extension_class=extension_class,
|
||||||
@@ -113,6 +121,11 @@ class Extensible:
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logging.exception("Error scanning extensions")
|
||||||
|
raise
|
||||||
|
|
||||||
|
# Sort extensions by position
|
||||||
sorted_extensions = sort_to_dict_by_position_map(
|
sorted_extensions = sort_to_dict_by_position_map(
|
||||||
position_map=position_map, data=extensions, name_func=lambda x: x.name
|
position_map=position_map, data=extensions, name_func=lambda x: x.name
|
||||||
)
|
)
|
||||||
|
@@ -160,6 +160,10 @@ class ProviderModel(BaseModel):
|
|||||||
deprecated: bool = False
|
deprecated: bool = False
|
||||||
model_config = ConfigDict(protected_namespaces=())
|
model_config = ConfigDict(protected_namespaces=())
|
||||||
|
|
||||||
|
@property
|
||||||
|
def support_structure_output(self) -> bool:
|
||||||
|
return self.features is not None and ModelFeature.STRUCTURED_OUTPUT in self.features
|
||||||
|
|
||||||
|
|
||||||
class ParameterRule(BaseModel):
|
class ParameterRule(BaseModel):
|
||||||
"""
|
"""
|
||||||
|
@@ -3,7 +3,9 @@ from collections import defaultdict
|
|||||||
from json import JSONDecodeError
|
from json import JSONDecodeError
|
||||||
from typing import Any, Optional, cast
|
from typing import Any, Optional, cast
|
||||||
|
|
||||||
|
from sqlalchemy import select
|
||||||
from sqlalchemy.exc import IntegrityError
|
from sqlalchemy.exc import IntegrityError
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
from core.entities.model_entities import DefaultModelEntity, DefaultModelProviderEntity
|
from core.entities.model_entities import DefaultModelEntity, DefaultModelProviderEntity
|
||||||
@@ -393,19 +395,13 @@ class ProviderManager:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _get_all_providers(tenant_id: str) -> dict[str, list[Provider]]:
|
def _get_all_providers(tenant_id: str) -> dict[str, list[Provider]]:
|
||||||
"""
|
|
||||||
Get all provider records of the workspace.
|
|
||||||
|
|
||||||
:param tenant_id: workspace id
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
providers = db.session.query(Provider).filter(Provider.tenant_id == tenant_id, Provider.is_valid == True).all()
|
|
||||||
|
|
||||||
provider_name_to_provider_records_dict = defaultdict(list)
|
provider_name_to_provider_records_dict = defaultdict(list)
|
||||||
for provider in providers:
|
with Session(db.engine, expire_on_commit=False) as session:
|
||||||
# TODO: Use provider name with prefix after the data migration
|
stmt = select(Provider).where(Provider.tenant_id == tenant_id, Provider.is_valid == True)
|
||||||
provider_name_to_provider_records_dict[str(ModelProviderID(provider.provider_name))].append(provider)
|
providers = session.scalars(stmt)
|
||||||
|
for provider in providers:
|
||||||
|
# Use provider name with prefix after the data migration
|
||||||
|
provider_name_to_provider_records_dict[str(ModelProviderID(provider.provider_name))].append(provider)
|
||||||
return provider_name_to_provider_records_dict
|
return provider_name_to_provider_records_dict
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -416,17 +412,12 @@ class ProviderManager:
|
|||||||
:param tenant_id: workspace id
|
:param tenant_id: workspace id
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
# Get all provider model records of the workspace
|
|
||||||
provider_models = (
|
|
||||||
db.session.query(ProviderModel)
|
|
||||||
.filter(ProviderModel.tenant_id == tenant_id, ProviderModel.is_valid == True)
|
|
||||||
.all()
|
|
||||||
)
|
|
||||||
|
|
||||||
provider_name_to_provider_model_records_dict = defaultdict(list)
|
provider_name_to_provider_model_records_dict = defaultdict(list)
|
||||||
for provider_model in provider_models:
|
with Session(db.engine, expire_on_commit=False) as session:
|
||||||
provider_name_to_provider_model_records_dict[provider_model.provider_name].append(provider_model)
|
stmt = select(ProviderModel).where(ProviderModel.tenant_id == tenant_id, ProviderModel.is_valid == True)
|
||||||
|
provider_models = session.scalars(stmt)
|
||||||
|
for provider_model in provider_models:
|
||||||
|
provider_name_to_provider_model_records_dict[provider_model.provider_name].append(provider_model)
|
||||||
return provider_name_to_provider_model_records_dict
|
return provider_name_to_provider_model_records_dict
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -437,17 +428,14 @@ class ProviderManager:
|
|||||||
:param tenant_id: workspace id
|
:param tenant_id: workspace id
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
preferred_provider_types = (
|
provider_name_to_preferred_provider_type_records_dict = {}
|
||||||
db.session.query(TenantPreferredModelProvider)
|
with Session(db.engine, expire_on_commit=False) as session:
|
||||||
.filter(TenantPreferredModelProvider.tenant_id == tenant_id)
|
stmt = select(TenantPreferredModelProvider).where(TenantPreferredModelProvider.tenant_id == tenant_id)
|
||||||
.all()
|
preferred_provider_types = session.scalars(stmt)
|
||||||
)
|
provider_name_to_preferred_provider_type_records_dict = {
|
||||||
|
preferred_provider_type.provider_name: preferred_provider_type
|
||||||
provider_name_to_preferred_provider_type_records_dict = {
|
for preferred_provider_type in preferred_provider_types
|
||||||
preferred_provider_type.provider_name: preferred_provider_type
|
}
|
||||||
for preferred_provider_type in preferred_provider_types
|
|
||||||
}
|
|
||||||
|
|
||||||
return provider_name_to_preferred_provider_type_records_dict
|
return provider_name_to_preferred_provider_type_records_dict
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -458,18 +446,14 @@ class ProviderManager:
|
|||||||
:param tenant_id: workspace id
|
:param tenant_id: workspace id
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
provider_model_settings = (
|
|
||||||
db.session.query(ProviderModelSetting).filter(ProviderModelSetting.tenant_id == tenant_id).all()
|
|
||||||
)
|
|
||||||
|
|
||||||
provider_name_to_provider_model_settings_dict = defaultdict(list)
|
provider_name_to_provider_model_settings_dict = defaultdict(list)
|
||||||
for provider_model_setting in provider_model_settings:
|
with Session(db.engine, expire_on_commit=False) as session:
|
||||||
(
|
stmt = select(ProviderModelSetting).where(ProviderModelSetting.tenant_id == tenant_id)
|
||||||
|
provider_model_settings = session.scalars(stmt)
|
||||||
|
for provider_model_setting in provider_model_settings:
|
||||||
provider_name_to_provider_model_settings_dict[provider_model_setting.provider_name].append(
|
provider_name_to_provider_model_settings_dict[provider_model_setting.provider_name].append(
|
||||||
provider_model_setting
|
provider_model_setting
|
||||||
)
|
)
|
||||||
)
|
|
||||||
|
|
||||||
return provider_name_to_provider_model_settings_dict
|
return provider_name_to_provider_model_settings_dict
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -492,15 +476,14 @@ class ProviderManager:
|
|||||||
if not model_load_balancing_enabled:
|
if not model_load_balancing_enabled:
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
provider_load_balancing_configs = (
|
|
||||||
db.session.query(LoadBalancingModelConfig).filter(LoadBalancingModelConfig.tenant_id == tenant_id).all()
|
|
||||||
)
|
|
||||||
|
|
||||||
provider_name_to_provider_load_balancing_model_configs_dict = defaultdict(list)
|
provider_name_to_provider_load_balancing_model_configs_dict = defaultdict(list)
|
||||||
for provider_load_balancing_config in provider_load_balancing_configs:
|
with Session(db.engine, expire_on_commit=False) as session:
|
||||||
provider_name_to_provider_load_balancing_model_configs_dict[
|
stmt = select(LoadBalancingModelConfig).where(LoadBalancingModelConfig.tenant_id == tenant_id)
|
||||||
provider_load_balancing_config.provider_name
|
provider_load_balancing_configs = session.scalars(stmt)
|
||||||
].append(provider_load_balancing_config)
|
for provider_load_balancing_config in provider_load_balancing_configs:
|
||||||
|
provider_name_to_provider_load_balancing_model_configs_dict[
|
||||||
|
provider_load_balancing_config.provider_name
|
||||||
|
].append(provider_load_balancing_config)
|
||||||
|
|
||||||
return provider_name_to_provider_load_balancing_model_configs_dict
|
return provider_name_to_provider_load_balancing_model_configs_dict
|
||||||
|
|
||||||
@@ -626,10 +609,9 @@ class ProviderManager:
|
|||||||
if not cached_provider_credentials:
|
if not cached_provider_credentials:
|
||||||
try:
|
try:
|
||||||
# fix origin data
|
# fix origin data
|
||||||
if (
|
if custom_provider_record.encrypted_config is None:
|
||||||
custom_provider_record.encrypted_config
|
raise ValueError("No credentials found")
|
||||||
and not custom_provider_record.encrypted_config.startswith("{")
|
if not custom_provider_record.encrypted_config.startswith("{"):
|
||||||
):
|
|
||||||
provider_credentials = {"openai_api_key": custom_provider_record.encrypted_config}
|
provider_credentials = {"openai_api_key": custom_provider_record.encrypted_config}
|
||||||
else:
|
else:
|
||||||
provider_credentials = json.loads(custom_provider_record.encrypted_config)
|
provider_credentials = json.loads(custom_provider_record.encrypted_config)
|
||||||
@@ -733,7 +715,7 @@ class ProviderManager:
|
|||||||
return SystemConfiguration(enabled=False)
|
return SystemConfiguration(enabled=False)
|
||||||
|
|
||||||
# Convert provider_records to dict
|
# Convert provider_records to dict
|
||||||
quota_type_to_provider_records_dict = {}
|
quota_type_to_provider_records_dict: dict[ProviderQuotaType, Provider] = {}
|
||||||
for provider_record in provider_records:
|
for provider_record in provider_records:
|
||||||
if provider_record.provider_type != ProviderType.SYSTEM.value:
|
if provider_record.provider_type != ProviderType.SYSTEM.value:
|
||||||
continue
|
continue
|
||||||
@@ -758,6 +740,11 @@ class ProviderManager:
|
|||||||
else:
|
else:
|
||||||
provider_record = quota_type_to_provider_records_dict[provider_quota.quota_type]
|
provider_record = quota_type_to_provider_records_dict[provider_quota.quota_type]
|
||||||
|
|
||||||
|
if provider_record.quota_used is None:
|
||||||
|
raise ValueError("quota_used is None")
|
||||||
|
if provider_record.quota_limit is None:
|
||||||
|
raise ValueError("quota_limit is None")
|
||||||
|
|
||||||
quota_configuration = QuotaConfiguration(
|
quota_configuration = QuotaConfiguration(
|
||||||
quota_type=provider_quota.quota_type,
|
quota_type=provider_quota.quota_type,
|
||||||
quota_unit=provider_hosting_configuration.quota_unit or QuotaUnit.TOKENS,
|
quota_unit=provider_hosting_configuration.quota_unit or QuotaUnit.TOKENS,
|
||||||
@@ -791,10 +778,9 @@ class ProviderManager:
|
|||||||
cached_provider_credentials = provider_credentials_cache.get()
|
cached_provider_credentials = provider_credentials_cache.get()
|
||||||
|
|
||||||
if not cached_provider_credentials:
|
if not cached_provider_credentials:
|
||||||
try:
|
provider_credentials: dict[str, Any] = {}
|
||||||
provider_credentials: dict[str, Any] = json.loads(provider_record.encrypted_config)
|
if provider_records and provider_records[0].encrypted_config:
|
||||||
except JSONDecodeError:
|
provider_credentials = json.loads(provider_records[0].encrypted_config)
|
||||||
provider_credentials = {}
|
|
||||||
|
|
||||||
# Get provider credential secret variables
|
# Get provider credential secret variables
|
||||||
provider_credential_secret_variables = self._extract_secret_variables(
|
provider_credential_secret_variables = self._extract_secret_variables(
|
||||||
|
@@ -66,7 +66,8 @@ class LLMNodeData(BaseNodeData):
|
|||||||
context: ContextConfig
|
context: ContextConfig
|
||||||
vision: VisionConfig = Field(default_factory=VisionConfig)
|
vision: VisionConfig = Field(default_factory=VisionConfig)
|
||||||
structured_output: dict | None = None
|
structured_output: dict | None = None
|
||||||
structured_output_enabled: bool = False
|
# We used 'structured_output_enabled' in the past, but it's not a good name.
|
||||||
|
structured_output_switch_on: bool = Field(False, alias="structured_output_enabled")
|
||||||
|
|
||||||
@field_validator("prompt_config", mode="before")
|
@field_validator("prompt_config", mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -74,3 +75,7 @@ class LLMNodeData(BaseNodeData):
|
|||||||
if v is None:
|
if v is None:
|
||||||
return PromptConfig()
|
return PromptConfig()
|
||||||
return v
|
return v
|
||||||
|
|
||||||
|
@property
|
||||||
|
def structured_output_enabled(self) -> bool:
|
||||||
|
return self.structured_output_switch_on and self.structured_output is not None
|
||||||
|
@@ -12,9 +12,7 @@ from sqlalchemy.orm import Session
|
|||||||
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||||
from core.entities.model_entities import ModelStatus
|
|
||||||
from core.entities.provider_entities import QuotaUnit
|
from core.entities.provider_entities import QuotaUnit
|
||||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
|
||||||
from core.file import FileType, file_manager
|
from core.file import FileType, file_manager
|
||||||
from core.helper.code_executor import CodeExecutor, CodeLanguage
|
from core.helper.code_executor import CodeExecutor, CodeLanguage
|
||||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||||
@@ -74,7 +72,6 @@ from core.workflow.nodes.event import (
|
|||||||
from core.workflow.utils.structured_output.entities import (
|
from core.workflow.utils.structured_output.entities import (
|
||||||
ResponseFormat,
|
ResponseFormat,
|
||||||
SpecialModelType,
|
SpecialModelType,
|
||||||
SupportStructuredOutputStatus,
|
|
||||||
)
|
)
|
||||||
from core.workflow.utils.structured_output.prompt import STRUCTURED_OUTPUT_PROMPT
|
from core.workflow.utils.structured_output.prompt import STRUCTURED_OUTPUT_PROMPT
|
||||||
from core.workflow.utils.variable_template_parser import VariableTemplateParser
|
from core.workflow.utils.variable_template_parser import VariableTemplateParser
|
||||||
@@ -277,7 +274,7 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|||||||
llm_usage=usage,
|
llm_usage=usage,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
except LLMNodeError as e:
|
except ValueError as e:
|
||||||
yield RunCompletedEvent(
|
yield RunCompletedEvent(
|
||||||
run_result=NodeRunResult(
|
run_result=NodeRunResult(
|
||||||
status=WorkflowNodeExecutionStatus.FAILED,
|
status=WorkflowNodeExecutionStatus.FAILED,
|
||||||
@@ -527,65 +524,53 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|||||||
def _fetch_model_config(
|
def _fetch_model_config(
|
||||||
self, node_data_model: ModelConfig
|
self, node_data_model: ModelConfig
|
||||||
) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]:
|
) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]:
|
||||||
model_name = node_data_model.name
|
if not node_data_model.mode:
|
||||||
provider_name = node_data_model.provider
|
raise LLMModeRequiredError("LLM mode is required.")
|
||||||
|
|
||||||
model_manager = ModelManager()
|
model = ModelManager().get_model_instance(
|
||||||
model_instance = model_manager.get_model_instance(
|
tenant_id=self.tenant_id,
|
||||||
tenant_id=self.tenant_id, model_type=ModelType.LLM, provider=provider_name, model=model_name
|
model_type=ModelType.LLM,
|
||||||
|
provider=node_data_model.provider,
|
||||||
|
model=node_data_model.name,
|
||||||
)
|
)
|
||||||
|
|
||||||
provider_model_bundle = model_instance.provider_model_bundle
|
model.model_type_instance = cast(LargeLanguageModel, model.model_type_instance)
|
||||||
model_type_instance = model_instance.model_type_instance
|
|
||||||
model_type_instance = cast(LargeLanguageModel, model_type_instance)
|
|
||||||
|
|
||||||
model_credentials = model_instance.credentials
|
|
||||||
|
|
||||||
# check model
|
# check model
|
||||||
provider_model = provider_model_bundle.configuration.get_provider_model(
|
provider_model = model.provider_model_bundle.configuration.get_provider_model(
|
||||||
model=model_name, model_type=ModelType.LLM
|
model=node_data_model.name, model_type=ModelType.LLM
|
||||||
)
|
)
|
||||||
|
|
||||||
if provider_model is None:
|
if provider_model is None:
|
||||||
raise ModelNotExistError(f"Model {model_name} not exist.")
|
raise ModelNotExistError(f"Model {node_data_model.name} not exist.")
|
||||||
|
provider_model.raise_for_status()
|
||||||
if provider_model.status == ModelStatus.NO_CONFIGURE:
|
|
||||||
raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.")
|
|
||||||
elif provider_model.status == ModelStatus.NO_PERMISSION:
|
|
||||||
raise ModelCurrentlyNotSupportError(f"Dify Hosted OpenAI {model_name} currently not support.")
|
|
||||||
elif provider_model.status == ModelStatus.QUOTA_EXCEEDED:
|
|
||||||
raise QuotaExceededError(f"Model provider {provider_name} quota exceeded.")
|
|
||||||
|
|
||||||
# model config
|
# model config
|
||||||
completion_params = node_data_model.completion_params
|
stop: list[str] = []
|
||||||
stop = []
|
if "stop" in node_data_model.completion_params:
|
||||||
if "stop" in completion_params:
|
stop = node_data_model.completion_params.pop("stop")
|
||||||
stop = completion_params["stop"]
|
|
||||||
del completion_params["stop"]
|
|
||||||
|
|
||||||
# get model mode
|
|
||||||
model_mode = node_data_model.mode
|
|
||||||
if not model_mode:
|
|
||||||
raise LLMModeRequiredError("LLM mode is required.")
|
|
||||||
|
|
||||||
model_schema = model_type_instance.get_model_schema(model_name, model_credentials)
|
|
||||||
|
|
||||||
|
model_schema = model.model_type_instance.get_model_schema(node_data_model.name, model.credentials)
|
||||||
if not model_schema:
|
if not model_schema:
|
||||||
raise ModelNotExistError(f"Model {model_name} not exist.")
|
raise ModelNotExistError(f"Model {node_data_model.name} not exist.")
|
||||||
support_structured_output = self._check_model_structured_output_support()
|
|
||||||
if support_structured_output == SupportStructuredOutputStatus.SUPPORTED:
|
if self.node_data.structured_output_enabled:
|
||||||
completion_params = self._handle_native_json_schema(completion_params, model_schema.parameter_rules)
|
if model_schema.support_structure_output:
|
||||||
elif support_structured_output == SupportStructuredOutputStatus.UNSUPPORTED:
|
node_data_model.completion_params = self._handle_native_json_schema(
|
||||||
# Set appropriate response format based on model capabilities
|
node_data_model.completion_params, model_schema.parameter_rules
|
||||||
self._set_response_format(completion_params, model_schema.parameter_rules)
|
)
|
||||||
return model_instance, ModelConfigWithCredentialsEntity(
|
else:
|
||||||
provider=provider_name,
|
# Set appropriate response format based on model capabilities
|
||||||
model=model_name,
|
self._set_response_format(node_data_model.completion_params, model_schema.parameter_rules)
|
||||||
|
|
||||||
|
return model, ModelConfigWithCredentialsEntity(
|
||||||
|
provider=node_data_model.provider,
|
||||||
|
model=node_data_model.name,
|
||||||
model_schema=model_schema,
|
model_schema=model_schema,
|
||||||
mode=model_mode,
|
mode=node_data_model.mode,
|
||||||
provider_model_bundle=provider_model_bundle,
|
provider_model_bundle=model.provider_model_bundle,
|
||||||
credentials=model_credentials,
|
credentials=model.credentials,
|
||||||
parameters=completion_params,
|
parameters=node_data_model.completion_params,
|
||||||
stop=stop,
|
stop=stop,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -786,13 +771,25 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|||||||
"No prompt found in the LLM configuration. "
|
"No prompt found in the LLM configuration. "
|
||||||
"Please ensure a prompt is properly configured before proceeding."
|
"Please ensure a prompt is properly configured before proceeding."
|
||||||
)
|
)
|
||||||
support_structured_output = self._check_model_structured_output_support()
|
|
||||||
if support_structured_output == SupportStructuredOutputStatus.UNSUPPORTED:
|
model = ModelManager().get_model_instance(
|
||||||
filtered_prompt_messages = self._handle_prompt_based_schema(
|
tenant_id=self.tenant_id,
|
||||||
prompt_messages=filtered_prompt_messages,
|
model_type=ModelType.LLM,
|
||||||
)
|
provider=self.node_data.model.provider,
|
||||||
stop = model_config.stop
|
model=self.node_data.model.name,
|
||||||
return filtered_prompt_messages, stop
|
)
|
||||||
|
model_schema = model.model_type_instance.get_model_schema(
|
||||||
|
model=self.node_data.model.name,
|
||||||
|
credentials=model.credentials,
|
||||||
|
)
|
||||||
|
if not model_schema:
|
||||||
|
raise ModelNotExistError(f"Model {self.node_data.model.name} not exist.")
|
||||||
|
if self.node_data.structured_output_enabled:
|
||||||
|
if not model_schema.support_structure_output:
|
||||||
|
filtered_prompt_messages = self._handle_prompt_based_schema(
|
||||||
|
prompt_messages=filtered_prompt_messages,
|
||||||
|
)
|
||||||
|
return filtered_prompt_messages, model_config.stop
|
||||||
|
|
||||||
def _parse_structured_output(self, result_text: str) -> dict[str, Any]:
|
def _parse_structured_output(self, result_text: str) -> dict[str, Any]:
|
||||||
structured_output: dict[str, Any] = {}
|
structured_output: dict[str, Any] = {}
|
||||||
@@ -1185,32 +1182,6 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
raise LLMNodeError("structured_output_schema is not valid JSON format")
|
raise LLMNodeError("structured_output_schema is not valid JSON format")
|
||||||
|
|
||||||
def _check_model_structured_output_support(self) -> SupportStructuredOutputStatus:
|
|
||||||
"""
|
|
||||||
Check if the current model supports structured output.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
SupportStructuredOutput: The support status of structured output
|
|
||||||
"""
|
|
||||||
# Early return if structured output is disabled
|
|
||||||
if (
|
|
||||||
not isinstance(self.node_data, LLMNodeData)
|
|
||||||
or not self.node_data.structured_output_enabled
|
|
||||||
or not self.node_data.structured_output
|
|
||||||
):
|
|
||||||
return SupportStructuredOutputStatus.DISABLED
|
|
||||||
# Get model schema and check if it exists
|
|
||||||
model_schema = self._fetch_model_schema(self.node_data.model.provider)
|
|
||||||
if not model_schema:
|
|
||||||
return SupportStructuredOutputStatus.DISABLED
|
|
||||||
|
|
||||||
# Check if model supports structured output feature
|
|
||||||
return (
|
|
||||||
SupportStructuredOutputStatus.SUPPORTED
|
|
||||||
if bool(model_schema.features and ModelFeature.STRUCTURED_OUTPUT in model_schema.features)
|
|
||||||
else SupportStructuredOutputStatus.UNSUPPORTED
|
|
||||||
)
|
|
||||||
|
|
||||||
def _save_multimodal_output_and_convert_result_to_markdown(
|
def _save_multimodal_output_and_convert_result_to_markdown(
|
||||||
self,
|
self,
|
||||||
contents: str | list[PromptMessageContentUnionTypes] | None,
|
contents: str | list[PromptMessageContentUnionTypes] | None,
|
||||||
|
@@ -14,11 +14,3 @@ class SpecialModelType(StrEnum):
|
|||||||
|
|
||||||
GEMINI = "gemini"
|
GEMINI = "gemini"
|
||||||
OLLAMA = "ollama"
|
OLLAMA = "ollama"
|
||||||
|
|
||||||
|
|
||||||
class SupportStructuredOutputStatus(StrEnum):
|
|
||||||
"""Constants for structured output support status"""
|
|
||||||
|
|
||||||
SUPPORTED = "supported"
|
|
||||||
UNSUPPORTED = "unsupported"
|
|
||||||
DISABLED = "disabled"
|
|
||||||
|
@@ -1,6 +1,9 @@
|
|||||||
|
from datetime import datetime
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
from sqlalchemy import func
|
from sqlalchemy import func, text
|
||||||
|
from sqlalchemy.orm import Mapped, mapped_column
|
||||||
|
|
||||||
from .base import Base
|
from .base import Base
|
||||||
from .engine import db
|
from .engine import db
|
||||||
@@ -51,20 +54,24 @@ class Provider(Base):
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
|
id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()"))
|
||||||
tenant_id = db.Column(StringUUID, nullable=False)
|
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||||
provider_name = db.Column(db.String(255), nullable=False)
|
provider_name: Mapped[str] = mapped_column(db.String(255), nullable=False)
|
||||||
provider_type = db.Column(db.String(40), nullable=False, server_default=db.text("'custom'::character varying"))
|
provider_type: Mapped[str] = mapped_column(
|
||||||
encrypted_config = db.Column(db.Text, nullable=True)
|
db.String(40), nullable=False, server_default=text("'custom'::character varying")
|
||||||
is_valid = db.Column(db.Boolean, nullable=False, server_default=db.text("false"))
|
)
|
||||||
last_used = db.Column(db.DateTime, nullable=True)
|
encrypted_config: Mapped[Optional[str]] = mapped_column(db.Text, nullable=True)
|
||||||
|
is_valid: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=text("false"))
|
||||||
|
last_used: Mapped[Optional[datetime]] = mapped_column(db.DateTime, nullable=True)
|
||||||
|
|
||||||
quota_type = db.Column(db.String(40), nullable=True, server_default=db.text("''::character varying"))
|
quota_type: Mapped[Optional[str]] = mapped_column(
|
||||||
quota_limit = db.Column(db.BigInteger, nullable=True)
|
db.String(40), nullable=True, server_default=text("''::character varying")
|
||||||
quota_used = db.Column(db.BigInteger, default=0)
|
)
|
||||||
|
quota_limit: Mapped[Optional[int]] = mapped_column(db.BigInteger, nullable=True)
|
||||||
|
quota_used: Mapped[Optional[int]] = mapped_column(db.BigInteger, default=0)
|
||||||
|
|
||||||
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
updated_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return (
|
return (
|
||||||
@@ -104,15 +111,15 @@ class ProviderModel(Base):
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
|
id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()"))
|
||||||
tenant_id = db.Column(StringUUID, nullable=False)
|
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||||
provider_name = db.Column(db.String(255), nullable=False)
|
provider_name: Mapped[str] = mapped_column(db.String(255), nullable=False)
|
||||||
model_name = db.Column(db.String(255), nullable=False)
|
model_name: Mapped[str] = mapped_column(db.String(255), nullable=False)
|
||||||
model_type = db.Column(db.String(40), nullable=False)
|
model_type: Mapped[str] = mapped_column(db.String(40), nullable=False)
|
||||||
encrypted_config = db.Column(db.Text, nullable=True)
|
encrypted_config: Mapped[Optional[str]] = mapped_column(db.Text, nullable=True)
|
||||||
is_valid = db.Column(db.Boolean, nullable=False, server_default=db.text("false"))
|
is_valid: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=text("false"))
|
||||||
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
updated_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||||
|
|
||||||
|
|
||||||
class TenantDefaultModel(Base):
|
class TenantDefaultModel(Base):
|
||||||
@@ -122,13 +129,13 @@ class TenantDefaultModel(Base):
|
|||||||
db.Index("tenant_default_model_tenant_id_provider_type_idx", "tenant_id", "provider_name", "model_type"),
|
db.Index("tenant_default_model_tenant_id_provider_type_idx", "tenant_id", "provider_name", "model_type"),
|
||||||
)
|
)
|
||||||
|
|
||||||
id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
|
id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()"))
|
||||||
tenant_id = db.Column(StringUUID, nullable=False)
|
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||||
provider_name = db.Column(db.String(255), nullable=False)
|
provider_name: Mapped[str] = mapped_column(db.String(255), nullable=False)
|
||||||
model_name = db.Column(db.String(255), nullable=False)
|
model_name: Mapped[str] = mapped_column(db.String(255), nullable=False)
|
||||||
model_type = db.Column(db.String(40), nullable=False)
|
model_type: Mapped[str] = mapped_column(db.String(40), nullable=False)
|
||||||
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
updated_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||||
|
|
||||||
|
|
||||||
class TenantPreferredModelProvider(Base):
|
class TenantPreferredModelProvider(Base):
|
||||||
@@ -138,12 +145,12 @@ class TenantPreferredModelProvider(Base):
|
|||||||
db.Index("tenant_preferred_model_provider_tenant_provider_idx", "tenant_id", "provider_name"),
|
db.Index("tenant_preferred_model_provider_tenant_provider_idx", "tenant_id", "provider_name"),
|
||||||
)
|
)
|
||||||
|
|
||||||
id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
|
id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()"))
|
||||||
tenant_id = db.Column(StringUUID, nullable=False)
|
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||||
provider_name = db.Column(db.String(255), nullable=False)
|
provider_name: Mapped[str] = mapped_column(db.String(255), nullable=False)
|
||||||
preferred_provider_type = db.Column(db.String(40), nullable=False)
|
preferred_provider_type: Mapped[str] = mapped_column(db.String(40), nullable=False)
|
||||||
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
updated_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||||
|
|
||||||
|
|
||||||
class ProviderOrder(Base):
|
class ProviderOrder(Base):
|
||||||
@@ -153,22 +160,24 @@ class ProviderOrder(Base):
|
|||||||
db.Index("provider_order_tenant_provider_idx", "tenant_id", "provider_name"),
|
db.Index("provider_order_tenant_provider_idx", "tenant_id", "provider_name"),
|
||||||
)
|
)
|
||||||
|
|
||||||
id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
|
id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()"))
|
||||||
tenant_id = db.Column(StringUUID, nullable=False)
|
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||||
provider_name = db.Column(db.String(255), nullable=False)
|
provider_name: Mapped[str] = mapped_column(db.String(255), nullable=False)
|
||||||
account_id = db.Column(StringUUID, nullable=False)
|
account_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||||
payment_product_id = db.Column(db.String(191), nullable=False)
|
payment_product_id: Mapped[str] = mapped_column(db.String(191), nullable=False)
|
||||||
payment_id = db.Column(db.String(191))
|
payment_id: Mapped[Optional[str]] = mapped_column(db.String(191))
|
||||||
transaction_id = db.Column(db.String(191))
|
transaction_id: Mapped[Optional[str]] = mapped_column(db.String(191))
|
||||||
quantity = db.Column(db.Integer, nullable=False, server_default=db.text("1"))
|
quantity: Mapped[int] = mapped_column(db.Integer, nullable=False, server_default=text("1"))
|
||||||
currency = db.Column(db.String(40))
|
currency: Mapped[Optional[str]] = mapped_column(db.String(40))
|
||||||
total_amount = db.Column(db.Integer)
|
total_amount: Mapped[Optional[int]] = mapped_column(db.Integer)
|
||||||
payment_status = db.Column(db.String(40), nullable=False, server_default=db.text("'wait_pay'::character varying"))
|
payment_status: Mapped[str] = mapped_column(
|
||||||
paid_at = db.Column(db.DateTime)
|
db.String(40), nullable=False, server_default=text("'wait_pay'::character varying")
|
||||||
pay_failed_at = db.Column(db.DateTime)
|
)
|
||||||
refunded_at = db.Column(db.DateTime)
|
paid_at: Mapped[Optional[datetime]] = mapped_column(db.DateTime)
|
||||||
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
pay_failed_at: Mapped[Optional[datetime]] = mapped_column(db.DateTime)
|
||||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
refunded_at: Mapped[Optional[datetime]] = mapped_column(db.DateTime)
|
||||||
|
created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||||
|
updated_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||||
|
|
||||||
|
|
||||||
class ProviderModelSetting(Base):
|
class ProviderModelSetting(Base):
|
||||||
@@ -182,15 +191,15 @@ class ProviderModelSetting(Base):
|
|||||||
db.Index("provider_model_setting_tenant_provider_model_idx", "tenant_id", "provider_name", "model_type"),
|
db.Index("provider_model_setting_tenant_provider_model_idx", "tenant_id", "provider_name", "model_type"),
|
||||||
)
|
)
|
||||||
|
|
||||||
id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
|
id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()"))
|
||||||
tenant_id = db.Column(StringUUID, nullable=False)
|
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||||
provider_name = db.Column(db.String(255), nullable=False)
|
provider_name: Mapped[str] = mapped_column(db.String(255), nullable=False)
|
||||||
model_name = db.Column(db.String(255), nullable=False)
|
model_name: Mapped[str] = mapped_column(db.String(255), nullable=False)
|
||||||
model_type = db.Column(db.String(40), nullable=False)
|
model_type: Mapped[str] = mapped_column(db.String(40), nullable=False)
|
||||||
enabled = db.Column(db.Boolean, nullable=False, server_default=db.text("true"))
|
enabled: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=text("true"))
|
||||||
load_balancing_enabled = db.Column(db.Boolean, nullable=False, server_default=db.text("false"))
|
load_balancing_enabled: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=text("false"))
|
||||||
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
updated_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||||
|
|
||||||
|
|
||||||
class LoadBalancingModelConfig(Base):
|
class LoadBalancingModelConfig(Base):
|
||||||
@@ -204,13 +213,13 @@ class LoadBalancingModelConfig(Base):
|
|||||||
db.Index("load_balancing_model_config_tenant_provider_model_idx", "tenant_id", "provider_name", "model_type"),
|
db.Index("load_balancing_model_config_tenant_provider_model_idx", "tenant_id", "provider_name", "model_type"),
|
||||||
)
|
)
|
||||||
|
|
||||||
id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
|
id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()"))
|
||||||
tenant_id = db.Column(StringUUID, nullable=False)
|
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||||
provider_name = db.Column(db.String(255), nullable=False)
|
provider_name: Mapped[str] = mapped_column(db.String(255), nullable=False)
|
||||||
model_name = db.Column(db.String(255), nullable=False)
|
model_name: Mapped[str] = mapped_column(db.String(255), nullable=False)
|
||||||
model_type = db.Column(db.String(40), nullable=False)
|
model_type: Mapped[str] = mapped_column(db.String(40), nullable=False)
|
||||||
name = db.Column(db.String(255), nullable=False)
|
name: Mapped[str] = mapped_column(db.String(255), nullable=False)
|
||||||
encrypted_config = db.Column(db.Text, nullable=True)
|
encrypted_config: Mapped[Optional[str]] = mapped_column(db.Text, nullable=True)
|
||||||
enabled = db.Column(db.Boolean, nullable=False, server_default=db.text("true"))
|
enabled: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=text("true"))
|
||||||
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
updated_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||||
|
@@ -3,11 +3,16 @@ import os
|
|||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
from collections.abc import Generator
|
from collections.abc import Generator
|
||||||
from unittest.mock import MagicMock
|
from decimal import Decimal
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
from app_factory import create_app
|
||||||
|
from configs import dify_config
|
||||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||||
|
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
|
||||||
|
from core.model_runtime.entities.message_entities import AssistantPromptMessage
|
||||||
from core.workflow.entities.variable_pool import VariablePool
|
from core.workflow.entities.variable_pool import VariablePool
|
||||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||||
from core.workflow.enums import SystemVariableKey
|
from core.workflow.enums import SystemVariableKey
|
||||||
@@ -19,13 +24,27 @@ from core.workflow.nodes.llm.node import LLMNode
|
|||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from models.enums import UserFrom
|
from models.enums import UserFrom
|
||||||
from models.workflow import WorkflowType
|
from models.workflow import WorkflowType
|
||||||
from tests.integration_tests.workflow.nodes.__mock.model import get_mocked_fetch_model_config
|
|
||||||
|
|
||||||
"""FOR MOCK FIXTURES, DO NOT REMOVE"""
|
"""FOR MOCK FIXTURES, DO NOT REMOVE"""
|
||||||
from tests.integration_tests.model_runtime.__mock.plugin_daemon import setup_model_mock
|
from tests.integration_tests.model_runtime.__mock.plugin_daemon import setup_model_mock
|
||||||
from tests.integration_tests.workflow.nodes.__mock.code_executor import setup_code_executor_mock
|
from tests.integration_tests.workflow.nodes.__mock.code_executor import setup_code_executor_mock
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def app():
|
||||||
|
# Set up storage configuration
|
||||||
|
os.environ["STORAGE_TYPE"] = "opendal"
|
||||||
|
os.environ["OPENDAL_SCHEME"] = "fs"
|
||||||
|
os.environ["OPENDAL_FS_ROOT"] = "storage"
|
||||||
|
|
||||||
|
# Ensure storage directory exists
|
||||||
|
os.makedirs("storage", exist_ok=True)
|
||||||
|
|
||||||
|
app = create_app()
|
||||||
|
dify_config.LOGIN_DISABLED = True
|
||||||
|
return app
|
||||||
|
|
||||||
|
|
||||||
def init_llm_node(config: dict) -> LLMNode:
|
def init_llm_node(config: dict) -> LLMNode:
|
||||||
graph_config = {
|
graph_config = {
|
||||||
"edges": [
|
"edges": [
|
||||||
@@ -40,13 +59,19 @@ def init_llm_node(config: dict) -> LLMNode:
|
|||||||
|
|
||||||
graph = Graph.init(graph_config=graph_config)
|
graph = Graph.init(graph_config=graph_config)
|
||||||
|
|
||||||
|
# Use proper UUIDs for database compatibility
|
||||||
|
tenant_id = "9d2074fc-6f86-45a9-b09d-6ecc63b9056b"
|
||||||
|
app_id = "9d2074fc-6f86-45a9-b09d-6ecc63b9056c"
|
||||||
|
workflow_id = "9d2074fc-6f86-45a9-b09d-6ecc63b9056d"
|
||||||
|
user_id = "9d2074fc-6f86-45a9-b09d-6ecc63b9056e"
|
||||||
|
|
||||||
init_params = GraphInitParams(
|
init_params = GraphInitParams(
|
||||||
tenant_id="1",
|
tenant_id=tenant_id,
|
||||||
app_id="1",
|
app_id=app_id,
|
||||||
workflow_type=WorkflowType.WORKFLOW,
|
workflow_type=WorkflowType.WORKFLOW,
|
||||||
workflow_id="1",
|
workflow_id=workflow_id,
|
||||||
graph_config=graph_config,
|
graph_config=graph_config,
|
||||||
user_id="1",
|
user_id=user_id,
|
||||||
user_from=UserFrom.ACCOUNT,
|
user_from=UserFrom.ACCOUNT,
|
||||||
invoke_from=InvokeFrom.DEBUGGER,
|
invoke_from=InvokeFrom.DEBUGGER,
|
||||||
call_depth=0,
|
call_depth=0,
|
||||||
@@ -77,115 +102,197 @@ def init_llm_node(config: dict) -> LLMNode:
|
|||||||
return node
|
return node
|
||||||
|
|
||||||
|
|
||||||
def test_execute_llm(setup_model_mock):
|
def test_execute_llm(app):
|
||||||
node = init_llm_node(
|
with app.app_context():
|
||||||
config={
|
node = init_llm_node(
|
||||||
"id": "llm",
|
config={
|
||||||
"data": {
|
"id": "llm",
|
||||||
"title": "123",
|
"data": {
|
||||||
"type": "llm",
|
"title": "123",
|
||||||
"model": {
|
"type": "llm",
|
||||||
"provider": "langgenius/openai/openai",
|
"model": {
|
||||||
"name": "gpt-3.5-turbo",
|
"provider": "langgenius/openai/openai",
|
||||||
"mode": "chat",
|
"name": "gpt-3.5-turbo",
|
||||||
"completion_params": {},
|
"mode": "chat",
|
||||||
|
"completion_params": {},
|
||||||
|
},
|
||||||
|
"prompt_template": [
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"text": "you are a helpful assistant.\ntoday's weather is {{#abc.output#}}.",
|
||||||
|
},
|
||||||
|
{"role": "user", "text": "{{#sys.query#}}"},
|
||||||
|
],
|
||||||
|
"memory": None,
|
||||||
|
"context": {"enabled": False},
|
||||||
|
"vision": {"enabled": False},
|
||||||
},
|
},
|
||||||
"prompt_template": [
|
|
||||||
{"role": "system", "text": "you are a helpful assistant.\ntoday's weather is {{#abc.output#}}."},
|
|
||||||
{"role": "user", "text": "{{#sys.query#}}"},
|
|
||||||
],
|
|
||||||
"memory": None,
|
|
||||||
"context": {"enabled": False},
|
|
||||||
"vision": {"enabled": False},
|
|
||||||
},
|
},
|
||||||
},
|
)
|
||||||
)
|
|
||||||
|
|
||||||
credentials = {"openai_api_key": os.environ.get("OPENAI_API_KEY")}
|
credentials = {"openai_api_key": os.environ.get("OPENAI_API_KEY")}
|
||||||
|
|
||||||
# Mock db.session.close()
|
# Create a proper LLM result with real entities
|
||||||
db.session.close = MagicMock()
|
mock_usage = LLMUsage(
|
||||||
|
prompt_tokens=30,
|
||||||
|
prompt_unit_price=Decimal("0.001"),
|
||||||
|
prompt_price_unit=Decimal("1000"),
|
||||||
|
prompt_price=Decimal("0.00003"),
|
||||||
|
completion_tokens=20,
|
||||||
|
completion_unit_price=Decimal("0.002"),
|
||||||
|
completion_price_unit=Decimal("1000"),
|
||||||
|
completion_price=Decimal("0.00004"),
|
||||||
|
total_tokens=50,
|
||||||
|
total_price=Decimal("0.00007"),
|
||||||
|
currency="USD",
|
||||||
|
latency=0.5,
|
||||||
|
)
|
||||||
|
|
||||||
node._fetch_model_config = get_mocked_fetch_model_config(
|
mock_message = AssistantPromptMessage(content="This is a test response from the mocked LLM.")
|
||||||
provider="langgenius/openai/openai",
|
|
||||||
model="gpt-3.5-turbo",
|
|
||||||
mode="chat",
|
|
||||||
credentials=credentials,
|
|
||||||
)
|
|
||||||
|
|
||||||
# execute node
|
mock_llm_result = LLMResult(
|
||||||
result = node._run()
|
model="gpt-3.5-turbo",
|
||||||
assert isinstance(result, Generator)
|
prompt_messages=[],
|
||||||
|
message=mock_message,
|
||||||
|
usage=mock_usage,
|
||||||
|
)
|
||||||
|
|
||||||
for item in result:
|
# Create a simple mock model instance that doesn't call real providers
|
||||||
if isinstance(item, RunCompletedEvent):
|
mock_model_instance = MagicMock()
|
||||||
assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
mock_model_instance.invoke_llm.return_value = mock_llm_result
|
||||||
assert item.run_result.process_data is not None
|
|
||||||
assert item.run_result.outputs is not None
|
# Create a simple mock model config with required attributes
|
||||||
assert item.run_result.outputs.get("text") is not None
|
mock_model_config = MagicMock()
|
||||||
assert item.run_result.outputs.get("usage", {})["total_tokens"] > 0
|
mock_model_config.mode = "chat"
|
||||||
|
mock_model_config.provider = "langgenius/openai/openai"
|
||||||
|
mock_model_config.model = "gpt-3.5-turbo"
|
||||||
|
mock_model_config.provider_model_bundle.configuration.tenant_id = "9d2074fc-6f86-45a9-b09d-6ecc63b9056b"
|
||||||
|
|
||||||
|
# Mock the _fetch_model_config method
|
||||||
|
def mock_fetch_model_config_func(_node_data_model):
|
||||||
|
return mock_model_instance, mock_model_config
|
||||||
|
|
||||||
|
# Also mock ModelManager.get_model_instance to avoid database calls
|
||||||
|
def mock_get_model_instance(_self, **kwargs):
|
||||||
|
return mock_model_instance
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch.object(node, "_fetch_model_config", mock_fetch_model_config_func),
|
||||||
|
patch("core.model_manager.ModelManager.get_model_instance", mock_get_model_instance),
|
||||||
|
):
|
||||||
|
# execute node
|
||||||
|
result = node._run()
|
||||||
|
assert isinstance(result, Generator)
|
||||||
|
|
||||||
|
for item in result:
|
||||||
|
if isinstance(item, RunCompletedEvent):
|
||||||
|
assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||||
|
assert item.run_result.process_data is not None
|
||||||
|
assert item.run_result.outputs is not None
|
||||||
|
assert item.run_result.outputs.get("text") is not None
|
||||||
|
assert item.run_result.outputs.get("usage", {})["total_tokens"] > 0
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("setup_code_executor_mock", [["none"]], indirect=True)
|
@pytest.mark.parametrize("setup_code_executor_mock", [["none"]], indirect=True)
|
||||||
def test_execute_llm_with_jinja2(setup_code_executor_mock, setup_model_mock):
|
def test_execute_llm_with_jinja2(app, setup_code_executor_mock):
|
||||||
"""
|
"""
|
||||||
Test execute LLM node with jinja2
|
Test execute LLM node with jinja2
|
||||||
"""
|
"""
|
||||||
node = init_llm_node(
|
with app.app_context():
|
||||||
config={
|
node = init_llm_node(
|
||||||
"id": "llm",
|
config={
|
||||||
"data": {
|
"id": "llm",
|
||||||
"title": "123",
|
"data": {
|
||||||
"type": "llm",
|
"title": "123",
|
||||||
"model": {"provider": "openai", "name": "gpt-3.5-turbo", "mode": "chat", "completion_params": {}},
|
"type": "llm",
|
||||||
"prompt_config": {
|
"model": {"provider": "openai", "name": "gpt-3.5-turbo", "mode": "chat", "completion_params": {}},
|
||||||
"jinja2_variables": [
|
"prompt_config": {
|
||||||
{"variable": "sys_query", "value_selector": ["sys", "query"]},
|
"jinja2_variables": [
|
||||||
{"variable": "output", "value_selector": ["abc", "output"]},
|
{"variable": "sys_query", "value_selector": ["sys", "query"]},
|
||||||
]
|
{"variable": "output", "value_selector": ["abc", "output"]},
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"prompt_template": [
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"text": "you are a helpful assistant.\ntoday's weather is {{#abc.output#}}",
|
||||||
|
"jinja2_text": "you are a helpful assistant.\ntoday's weather is {{output}}.",
|
||||||
|
"edition_type": "jinja2",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"text": "{{#sys.query#}}",
|
||||||
|
"jinja2_text": "{{sys_query}}",
|
||||||
|
"edition_type": "basic",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"memory": None,
|
||||||
|
"context": {"enabled": False},
|
||||||
|
"vision": {"enabled": False},
|
||||||
},
|
},
|
||||||
"prompt_template": [
|
|
||||||
{
|
|
||||||
"role": "system",
|
|
||||||
"text": "you are a helpful assistant.\ntoday's weather is {{#abc.output#}}",
|
|
||||||
"jinja2_text": "you are a helpful assistant.\ntoday's weather is {{output}}.",
|
|
||||||
"edition_type": "jinja2",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"text": "{{#sys.query#}}",
|
|
||||||
"jinja2_text": "{{sys_query}}",
|
|
||||||
"edition_type": "basic",
|
|
||||||
},
|
|
||||||
],
|
|
||||||
"memory": None,
|
|
||||||
"context": {"enabled": False},
|
|
||||||
"vision": {"enabled": False},
|
|
||||||
},
|
},
|
||||||
},
|
)
|
||||||
)
|
|
||||||
|
|
||||||
credentials = {"openai_api_key": os.environ.get("OPENAI_API_KEY")}
|
# Mock db.session.close()
|
||||||
|
db.session.close = MagicMock()
|
||||||
|
|
||||||
# Mock db.session.close()
|
# Create a proper LLM result with real entities
|
||||||
db.session.close = MagicMock()
|
mock_usage = LLMUsage(
|
||||||
|
prompt_tokens=30,
|
||||||
|
prompt_unit_price=Decimal("0.001"),
|
||||||
|
prompt_price_unit=Decimal("1000"),
|
||||||
|
prompt_price=Decimal("0.00003"),
|
||||||
|
completion_tokens=20,
|
||||||
|
completion_unit_price=Decimal("0.002"),
|
||||||
|
completion_price_unit=Decimal("1000"),
|
||||||
|
completion_price=Decimal("0.00004"),
|
||||||
|
total_tokens=50,
|
||||||
|
total_price=Decimal("0.00007"),
|
||||||
|
currency="USD",
|
||||||
|
latency=0.5,
|
||||||
|
)
|
||||||
|
|
||||||
node._fetch_model_config = get_mocked_fetch_model_config(
|
mock_message = AssistantPromptMessage(content="Test response: sunny weather and what's the weather today?")
|
||||||
provider="langgenius/openai/openai",
|
|
||||||
model="gpt-3.5-turbo",
|
|
||||||
mode="chat",
|
|
||||||
credentials=credentials,
|
|
||||||
)
|
|
||||||
|
|
||||||
# execute node
|
mock_llm_result = LLMResult(
|
||||||
result = node._run()
|
model="gpt-3.5-turbo",
|
||||||
|
prompt_messages=[],
|
||||||
|
message=mock_message,
|
||||||
|
usage=mock_usage,
|
||||||
|
)
|
||||||
|
|
||||||
for item in result:
|
# Create a simple mock model instance that doesn't call real providers
|
||||||
if isinstance(item, RunCompletedEvent):
|
mock_model_instance = MagicMock()
|
||||||
assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
mock_model_instance.invoke_llm.return_value = mock_llm_result
|
||||||
assert item.run_result.process_data is not None
|
|
||||||
assert "sunny" in json.dumps(item.run_result.process_data)
|
# Create a simple mock model config with required attributes
|
||||||
assert "what's the weather today?" in json.dumps(item.run_result.process_data)
|
mock_model_config = MagicMock()
|
||||||
|
mock_model_config.mode = "chat"
|
||||||
|
mock_model_config.provider = "openai"
|
||||||
|
mock_model_config.model = "gpt-3.5-turbo"
|
||||||
|
mock_model_config.provider_model_bundle.configuration.tenant_id = "9d2074fc-6f86-45a9-b09d-6ecc63b9056b"
|
||||||
|
|
||||||
|
# Mock the _fetch_model_config method
|
||||||
|
def mock_fetch_model_config_func(_node_data_model):
|
||||||
|
return mock_model_instance, mock_model_config
|
||||||
|
|
||||||
|
# Also mock ModelManager.get_model_instance to avoid database calls
|
||||||
|
def mock_get_model_instance(_self, **kwargs):
|
||||||
|
return mock_model_instance
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch.object(node, "_fetch_model_config", mock_fetch_model_config_func),
|
||||||
|
patch("core.model_manager.ModelManager.get_model_instance", mock_get_model_instance),
|
||||||
|
):
|
||||||
|
# execute node
|
||||||
|
result = node._run()
|
||||||
|
|
||||||
|
for item in result:
|
||||||
|
if isinstance(item, RunCompletedEvent):
|
||||||
|
assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||||
|
assert item.run_result.process_data is not None
|
||||||
|
assert "sunny" in json.dumps(item.run_result.process_data)
|
||||||
|
assert "what's the weather today?" in json.dumps(item.run_result.process_data)
|
||||||
|
|
||||||
|
|
||||||
def test_extract_json():
|
def test_extract_json():
|
||||||
|
Reference in New Issue
Block a user