From 10b738a296cc7d27b39540f68b79e866e1dbd865 Mon Sep 17 00:00:00 2001 From: QuantumGhost Date: Tue, 24 Jun 2025 09:05:29 +0800 Subject: [PATCH] feat: Persist Variables for Enhanced Debugging Workflow (#20699) This pull request introduces a feature aimed at improving the debugging experience during workflow editing. With the addition of variable persistence, the system will automatically retain the output variables from previously executed nodes. These persisted variables can then be reused when debugging subsequent nodes, eliminating the need for repetitive manual input. By streamlining this aspect of the workflow, the feature minimizes user errors and significantly reduces debugging effort, offering a smoother and more efficient experience. Key highlights of this change: - Automatic persistence of output variables for executed nodes. - Reuse of persisted variables to simplify input steps for nodes requiring them (e.g., `code`, `template`, `variable_assigner`). - Enhanced debugging experience with reduced friction. Closes #19735. --- .github/workflows/api-tests.yml | 6 + .github/workflows/vdb-tests.yml | 6 + api/.ruff.toml | 103 ++- api/controllers/console/__init__.py | 1 + api/controllers/console/app/workflow.py | 70 +- .../console/app/workflow_draft_variable.py | 421 +++++++++ api/controllers/console/app/wraps.py | 15 +- api/controllers/web/error.py | 10 + api/core/app/app_config/entities.py | 1 + .../app/apps/advanced_chat/app_generator.py | 27 + api/core/app/apps/advanced_chat/app_runner.py | 8 +- api/core/app/apps/agent_chat/app_generator.py | 5 + api/core/app/apps/chat/app_generator.py | 5 + .../common/workflow_response_converter.py | 16 +- api/core/app/apps/completion/app_generator.py | 5 + api/core/app/apps/workflow/app_generator.py | 28 +- api/core/app/apps/workflow/app_runner.py | 7 +- api/core/app/apps/workflow_app_runner.py | 57 +- api/core/app/entities/app_invoke_entities.py | 15 + api/core/file/constants.py | 10 + ...qlalchemy_workflow_execution_repository.py | 7 +- ...hemy_workflow_node_execution_repository.py | 16 +- api/core/variables/segments.py | 14 + api/core/variables/types.py | 14 + api/core/variables/utils.py | 18 + .../workflow/conversation_variable_updater.py | 39 + api/core/workflow/entities/variable_pool.py | 37 +- .../workflow/graph_engine/entities/event.py | 2 + .../workflow/graph_engine/graph_engine.py | 9 + api/core/workflow/nodes/answer/answer_node.py | 11 +- .../nodes/answer/answer_stream_processor.py | 2 + api/core/workflow/nodes/base/node.py | 51 +- api/core/workflow/nodes/code/code_node.py | 7 + .../workflow/nodes/document_extractor/node.py | 8 +- api/core/workflow/nodes/end/end_node.py | 4 + .../nodes/end/end_stream_processor.py | 1 + api/core/workflow/nodes/http_request/node.py | 13 +- .../workflow/nodes/if_else/if_else_node.py | 23 +- .../nodes/iteration/iteration_node.py | 18 +- .../nodes/iteration/iteration_start_node.py | 4 + .../knowledge_retrieval_node.py | 8 +- api/core/workflow/nodes/list_operator/node.py | 13 +- api/core/workflow/nodes/llm/file_saver.py | 3 - api/core/workflow/nodes/llm/node.py | 6 +- api/core/workflow/nodes/loop/loop_end_node.py | 4 + api/core/workflow/nodes/loop/loop_node.py | 11 + .../workflow/nodes/loop/loop_start_node.py | 4 + api/core/workflow/nodes/node_mapping.py | 5 + .../nodes/parameter_extractor/entities.py | 17 + .../parameter_extractor_node.py | 28 +- api/core/workflow/nodes/start/start_node.py | 7 +- .../template_transform_node.py | 4 + api/core/workflow/nodes/tool/tool_node.py | 9 +- .../variable_aggregator_node.py | 13 +- .../nodes/variable_assigner/common/helpers.py | 66 +- .../nodes/variable_assigner/common/impl.py | 38 + .../nodes/variable_assigner/v1/node.py | 74 +- .../nodes/variable_assigner/v2/entities.py | 6 + .../nodes/variable_assigner/v2/exc.py | 5 + .../nodes/variable_assigner/v2/node.py | 67 +- api/core/workflow/variable_loader.py | 79 ++ api/core/workflow/workflow_cycle_manager.py | 14 +- api/core/workflow/workflow_entry.py | 64 +- api/core/workflow/workflow_type_encoder.py | 49 + api/factories/file_factory.py | 75 ++ api/factories/variable_factory.py | 78 ++ api/libs/datetime_utils.py | 22 + api/libs/jsonutil.py | 11 + api/models/_workflow_exc.py | 20 + api/models/model.py | 15 + api/models/workflow.py | 238 ++++- api/pyproject.toml | 1 + api/services/app_dsl_service.py | 3 + api/services/errors/app.py | 4 + .../workflow_draft_variable_service.py | 721 +++++++++++++++ api/services/workflow_service.py | 242 ++++- api/tests/integration_tests/.env.example | 264 ++++-- api/tests/integration_tests/conftest.py | 88 +- .../controllers/app_fixture.py | 25 - .../controllers/console/__init__.py | 0 .../controllers/console/app/__init__.py | 0 .../app/test_workflow_draft_variable.py | 47 + .../controllers/test_controllers.py | 9 - .../integration_tests/factories/__init__.py | 0 .../factories/test_storage_key_loader.py | 371 ++++++++ .../integration_tests/services/__init__.py | 0 .../test_workflow_draft_variable_service.py | 501 ++++++++++ .../workflow/nodes/test_llm.py | 337 ++++--- .../app/workflow_draft_variables_test.py | 302 ++++++ .../core/app/segments/test_factory.py | 165 ---- api/tests/unit_tests/core/file/test_models.py | 25 + .../segments => variables}/test_segment.py | 0 .../segments => variables}/test_variables.py | 0 .../nodes/iteration/test_iteration.py | 13 +- .../nodes/test_document_extractor_node.py | 19 +- .../core/workflow/nodes/test_list_operator.py | 2 +- .../v1/test_variable_assigner_v1.py | 66 +- .../core/workflow/test_variable_pool.py | 39 + .../utils/test_variable_template_parser.py | 48 +- .../factories/test_variable_factory.py | 865 ++++++++++++++++++ .../unit_tests/libs/test_datetime_utils.py | 20 + api/tests/unit_tests/models/__init__.py | 0 api/tests/unit_tests/models/test_workflow.py | 151 ++- .../test_workflow_draft_variable_service.py | 222 +++++ api/uv.lock | 24 + .../base/markdown-blocks/code-block.tsx | 2 +- 106 files changed, 6025 insertions(+), 718 deletions(-) create mode 100644 api/controllers/console/app/workflow_draft_variable.py create mode 100644 api/core/workflow/conversation_variable_updater.py create mode 100644 api/core/workflow/nodes/variable_assigner/common/impl.py create mode 100644 api/core/workflow/variable_loader.py create mode 100644 api/core/workflow/workflow_type_encoder.py create mode 100644 api/libs/datetime_utils.py create mode 100644 api/libs/jsonutil.py create mode 100644 api/models/_workflow_exc.py create mode 100644 api/services/workflow_draft_variable_service.py delete mode 100644 api/tests/integration_tests/controllers/app_fixture.py create mode 100644 api/tests/integration_tests/controllers/console/__init__.py create mode 100644 api/tests/integration_tests/controllers/console/app/__init__.py create mode 100644 api/tests/integration_tests/controllers/console/app/test_workflow_draft_variable.py delete mode 100644 api/tests/integration_tests/controllers/test_controllers.py create mode 100644 api/tests/integration_tests/factories/__init__.py create mode 100644 api/tests/integration_tests/factories/test_storage_key_loader.py create mode 100644 api/tests/integration_tests/services/__init__.py create mode 100644 api/tests/integration_tests/services/test_workflow_draft_variable_service.py create mode 100644 api/tests/unit_tests/controllers/console/app/workflow_draft_variables_test.py delete mode 100644 api/tests/unit_tests/core/app/segments/test_factory.py create mode 100644 api/tests/unit_tests/core/file/test_models.py rename api/tests/unit_tests/core/{app/segments => variables}/test_segment.py (100%) rename api/tests/unit_tests/core/{app/segments => variables}/test_variables.py (100%) create mode 100644 api/tests/unit_tests/factories/test_variable_factory.py create mode 100644 api/tests/unit_tests/libs/test_datetime_utils.py create mode 100644 api/tests/unit_tests/models/__init__.py create mode 100644 api/tests/unit_tests/services/workflow/test_workflow_draft_variable_service.py diff --git a/.github/workflows/api-tests.yml b/.github/workflows/api-tests.yml index f08befefb..76e5c04de 100644 --- a/.github/workflows/api-tests.yml +++ b/.github/workflows/api-tests.yml @@ -83,9 +83,15 @@ jobs: compose-file: | docker/docker-compose.middleware.yaml services: | + db + redis sandbox ssrf_proxy + - name: setup test config + run: | + cp api/tests/integration_tests/.env.example api/tests/integration_tests/.env + - name: Run Workflow run: uv run --project api bash dev/pytest/pytest_workflow.sh diff --git a/.github/workflows/vdb-tests.yml b/.github/workflows/vdb-tests.yml index 7d0a873eb..912267094 100644 --- a/.github/workflows/vdb-tests.yml +++ b/.github/workflows/vdb-tests.yml @@ -84,6 +84,12 @@ jobs: elasticsearch oceanbase + - name: setup test config + run: | + echo $(pwd) + ls -lah . + cp api/tests/integration_tests/.env.example api/tests/integration_tests/.env + - name: Check VDB Ready (TiDB) run: uv run --project api python api/tests/integration_tests/vdb/tidb_vector/check_tiflash_ready.py diff --git a/api/.ruff.toml b/api/.ruff.toml index facb0d541..0169613bf 100644 --- a/api/.ruff.toml +++ b/api/.ruff.toml @@ -1,6 +1,4 @@ -exclude = [ - "migrations/*", -] +exclude = ["migrations/*"] line-length = 120 [format] @@ -9,14 +7,14 @@ quote-style = "double" [lint] preview = false select = [ - "B", # flake8-bugbear rules - "C4", # flake8-comprehensions - "E", # pycodestyle E rules - "F", # pyflakes rules - "FURB", # refurb rules - "I", # isort rules - "N", # pep8-naming - "PT", # flake8-pytest-style rules + "B", # flake8-bugbear rules + "C4", # flake8-comprehensions + "E", # pycodestyle E rules + "F", # pyflakes rules + "FURB", # refurb rules + "I", # isort rules + "N", # pep8-naming + "PT", # flake8-pytest-style rules "PLC0208", # iteration-over-set "PLC0414", # useless-import-alias "PLE0604", # invalid-all-object @@ -24,19 +22,19 @@ select = [ "PLR0402", # manual-from-import "PLR1711", # useless-return "PLR1714", # repeated-equality-comparison - "RUF013", # implicit-optional - "RUF019", # unnecessary-key-check - "RUF100", # unused-noqa - "RUF101", # redirected-noqa - "RUF200", # invalid-pyproject-toml - "RUF022", # unsorted-dunder-all - "S506", # unsafe-yaml-load - "SIM", # flake8-simplify rules - "TRY400", # error-instead-of-exception - "TRY401", # verbose-log-message - "UP", # pyupgrade rules - "W191", # tab-indentation - "W605", # invalid-escape-sequence + "RUF013", # implicit-optional + "RUF019", # unnecessary-key-check + "RUF100", # unused-noqa + "RUF101", # redirected-noqa + "RUF200", # invalid-pyproject-toml + "RUF022", # unsorted-dunder-all + "S506", # unsafe-yaml-load + "SIM", # flake8-simplify rules + "TRY400", # error-instead-of-exception + "TRY401", # verbose-log-message + "UP", # pyupgrade rules + "W191", # tab-indentation + "W605", # invalid-escape-sequence # security related linting rules # RCE proctection (sort of) "S102", # exec-builtin, disallow use of `exec` @@ -47,36 +45,37 @@ select = [ ] ignore = [ - "E402", # module-import-not-at-top-of-file - "E711", # none-comparison - "E712", # true-false-comparison - "E721", # type-comparison - "E722", # bare-except - "F821", # undefined-name - "F841", # unused-variable + "E402", # module-import-not-at-top-of-file + "E711", # none-comparison + "E712", # true-false-comparison + "E721", # type-comparison + "E722", # bare-except + "F821", # undefined-name + "F841", # unused-variable "FURB113", # repeated-append "FURB152", # math-constant - "UP007", # non-pep604-annotation - "UP032", # f-string - "UP045", # non-pep604-annotation-optional - "B005", # strip-with-multi-characters - "B006", # mutable-argument-default - "B007", # unused-loop-control-variable - "B026", # star-arg-unpacking-after-keyword-arg - "B903", # class-as-data-structure - "B904", # raise-without-from-inside-except - "B905", # zip-without-explicit-strict - "N806", # non-lowercase-variable-in-function - "N815", # mixed-case-variable-in-class-scope - "PT011", # pytest-raises-too-broad - "SIM102", # collapsible-if - "SIM103", # needless-bool - "SIM105", # suppressible-exception - "SIM107", # return-in-try-except-finally - "SIM108", # if-else-block-instead-of-if-exp - "SIM113", # enumerate-for-loop - "SIM117", # multiple-with-statements - "SIM210", # if-expr-with-true-false + "UP007", # non-pep604-annotation + "UP032", # f-string + "UP045", # non-pep604-annotation-optional + "B005", # strip-with-multi-characters + "B006", # mutable-argument-default + "B007", # unused-loop-control-variable + "B026", # star-arg-unpacking-after-keyword-arg + "B903", # class-as-data-structure + "B904", # raise-without-from-inside-except + "B905", # zip-without-explicit-strict + "N806", # non-lowercase-variable-in-function + "N815", # mixed-case-variable-in-class-scope + "PT011", # pytest-raises-too-broad + "SIM102", # collapsible-if + "SIM103", # needless-bool + "SIM105", # suppressible-exception + "SIM107", # return-in-try-except-finally + "SIM108", # if-else-block-instead-of-if-exp + "SIM113", # enumerate-for-loop + "SIM117", # multiple-with-statements + "SIM210", # if-expr-with-true-false + "UP038", # deprecated and not recommended by Ruff, https://docs.astral.sh/ruff/rules/non-pep604-isinstance/ ] [lint.per-file-ignores] diff --git a/api/controllers/console/__init__.py b/api/controllers/console/__init__.py index a974c63e3..dbdcdc46c 100644 --- a/api/controllers/console/__init__.py +++ b/api/controllers/console/__init__.py @@ -63,6 +63,7 @@ from .app import ( statistic, workflow, workflow_app_log, + workflow_draft_variable, workflow_run, workflow_statistic, ) diff --git a/api/controllers/console/app/workflow.py b/api/controllers/console/app/workflow.py index cbbdd324b..a9f088a27 100644 --- a/api/controllers/console/app/workflow.py +++ b/api/controllers/console/app/workflow.py @@ -1,5 +1,6 @@ import json import logging +from collections.abc import Sequence from typing import cast from flask import abort, request @@ -18,10 +19,12 @@ from controllers.console.app.error import ( from controllers.console.app.wraps import get_app_model from controllers.console.wraps import account_initialization_required, setup_required from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError +from core.app.app_config.features.file_upload.manager import FileUploadConfigManager from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.entities.app_invoke_entities import InvokeFrom +from core.file.models import File from extensions.ext_database import db -from factories import variable_factory +from factories import file_factory, variable_factory from fields.workflow_fields import workflow_fields, workflow_pagination_fields from fields.workflow_run_fields import workflow_run_node_execution_fields from libs import helper @@ -30,6 +33,7 @@ from libs.login import current_user, login_required from models import App from models.account import Account from models.model import AppMode +from models.workflow import Workflow from services.app_generate_service import AppGenerateService from services.errors.app import WorkflowHashNotEqualError from services.errors.llm import InvokeRateLimitError @@ -38,6 +42,24 @@ from services.workflow_service import DraftWorkflowDeletionError, WorkflowInUseE logger = logging.getLogger(__name__) +# TODO(QuantumGhost): Refactor existing node run API to handle file parameter parsing +# at the controller level rather than in the workflow logic. This would improve separation +# of concerns and make the code more maintainable. +def _parse_file(workflow: Workflow, files: list[dict] | None = None) -> Sequence[File]: + files = files or [] + + file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False) + file_objs: Sequence[File] = [] + if file_extra_config is None: + return file_objs + file_objs = file_factory.build_from_mappings( + mappings=files, + tenant_id=workflow.tenant_id, + config=file_extra_config, + ) + return file_objs + + class DraftWorkflowApi(Resource): @setup_required @login_required @@ -402,15 +424,30 @@ class DraftWorkflowNodeRunApi(Resource): parser = reqparse.RequestParser() parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json") + parser.add_argument("query", type=str, required=False, location="json", default="") + parser.add_argument("files", type=list, location="json", default=[]) args = parser.parse_args() - inputs = args.get("inputs") - if inputs == None: + user_inputs = args.get("inputs") + if user_inputs is None: raise ValueError("missing inputs") + workflow_srv = WorkflowService() + # fetch draft workflow by app_model + draft_workflow = workflow_srv.get_draft_workflow(app_model=app_model) + if not draft_workflow: + raise ValueError("Workflow not initialized") + files = _parse_file(draft_workflow, args.get("files")) workflow_service = WorkflowService() + workflow_node_execution = workflow_service.run_draft_workflow_node( - app_model=app_model, node_id=node_id, user_inputs=inputs, account=current_user + app_model=app_model, + draft_workflow=draft_workflow, + node_id=node_id, + user_inputs=user_inputs, + account=current_user, + query=args.get("query", ""), + files=files, ) return workflow_node_execution @@ -731,6 +768,27 @@ class WorkflowByIdApi(Resource): return None, 204 +class DraftWorkflowNodeLastRunApi(Resource): + @setup_required + @login_required + @account_initialization_required + @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) + @marshal_with(workflow_run_node_execution_fields) + def get(self, app_model: App, node_id: str): + srv = WorkflowService() + workflow = srv.get_draft_workflow(app_model) + if not workflow: + raise NotFound("Workflow not found") + node_exec = srv.get_node_last_run( + app_model=app_model, + workflow=workflow, + node_id=node_id, + ) + if node_exec is None: + raise NotFound("last run not found") + return node_exec + + api.add_resource( DraftWorkflowApi, "/apps//workflows/draft", @@ -795,3 +853,7 @@ api.add_resource( WorkflowByIdApi, "/apps//workflows/", ) +api.add_resource( + DraftWorkflowNodeLastRunApi, + "/apps//workflows/draft/nodes//last-run", +) diff --git a/api/controllers/console/app/workflow_draft_variable.py b/api/controllers/console/app/workflow_draft_variable.py new file mode 100644 index 000000000..00d6fa3cb --- /dev/null +++ b/api/controllers/console/app/workflow_draft_variable.py @@ -0,0 +1,421 @@ +import logging +from typing import Any, NoReturn + +from flask import Response +from flask_restful import Resource, fields, inputs, marshal, marshal_with, reqparse +from sqlalchemy.orm import Session +from werkzeug.exceptions import Forbidden + +from controllers.console import api +from controllers.console.app.error import ( + DraftWorkflowNotExist, +) +from controllers.console.app.wraps import get_app_model +from controllers.console.wraps import account_initialization_required, setup_required +from controllers.web.error import InvalidArgumentError, NotFoundError +from core.variables.segment_group import SegmentGroup +from core.variables.segments import ArrayFileSegment, FileSegment, Segment +from core.variables.types import SegmentType +from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID +from factories.file_factory import build_from_mapping, build_from_mappings +from factories.variable_factory import build_segment_with_type +from libs.login import current_user, login_required +from models import App, AppMode, db +from models.workflow import WorkflowDraftVariable +from services.workflow_draft_variable_service import WorkflowDraftVariableList, WorkflowDraftVariableService +from services.workflow_service import WorkflowService + +logger = logging.getLogger(__name__) + + +def _convert_values_to_json_serializable_object(value: Segment) -> Any: + if isinstance(value, FileSegment): + return value.value.model_dump() + elif isinstance(value, ArrayFileSegment): + return [i.model_dump() for i in value.value] + elif isinstance(value, SegmentGroup): + return [_convert_values_to_json_serializable_object(i) for i in value.value] + else: + return value.value + + +def _serialize_var_value(variable: WorkflowDraftVariable) -> Any: + value = variable.get_value() + # create a copy of the value to avoid affecting the model cache. + value = value.model_copy(deep=True) + # Refresh the url signature before returning it to client. + if isinstance(value, FileSegment): + file = value.value + file.remote_url = file.generate_url() + elif isinstance(value, ArrayFileSegment): + files = value.value + for file in files: + file.remote_url = file.generate_url() + return _convert_values_to_json_serializable_object(value) + + +def _create_pagination_parser(): + parser = reqparse.RequestParser() + parser.add_argument( + "page", + type=inputs.int_range(1, 100_000), + required=False, + default=1, + location="args", + help="the page of data requested", + ) + parser.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args") + return parser + + +_WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS = { + "id": fields.String, + "type": fields.String(attribute=lambda model: model.get_variable_type()), + "name": fields.String, + "description": fields.String, + "selector": fields.List(fields.String, attribute=lambda model: model.get_selector()), + "value_type": fields.String, + "edited": fields.Boolean(attribute=lambda model: model.edited), + "visible": fields.Boolean, +} + +_WORKFLOW_DRAFT_VARIABLE_FIELDS = dict( + _WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS, + value=fields.Raw(attribute=_serialize_var_value), +) + +_WORKFLOW_DRAFT_ENV_VARIABLE_FIELDS = { + "id": fields.String, + "type": fields.String(attribute=lambda _: "env"), + "name": fields.String, + "description": fields.String, + "selector": fields.List(fields.String, attribute=lambda model: model.get_selector()), + "value_type": fields.String, + "edited": fields.Boolean(attribute=lambda model: model.edited), + "visible": fields.Boolean, +} + +_WORKFLOW_DRAFT_ENV_VARIABLE_LIST_FIELDS = { + "items": fields.List(fields.Nested(_WORKFLOW_DRAFT_ENV_VARIABLE_FIELDS)), +} + + +def _get_items(var_list: WorkflowDraftVariableList) -> list[WorkflowDraftVariable]: + return var_list.variables + + +_WORKFLOW_DRAFT_VARIABLE_LIST_WITHOUT_VALUE_FIELDS = { + "items": fields.List(fields.Nested(_WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS), attribute=_get_items), + "total": fields.Raw(), +} + +_WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS = { + "items": fields.List(fields.Nested(_WORKFLOW_DRAFT_VARIABLE_FIELDS), attribute=_get_items), +} + + +def _api_prerequisite(f): + """Common prerequisites for all draft workflow variable APIs. + + It ensures the following conditions are satisfied: + + - Dify has been property setup. + - The request user has logged in and initialized. + - The requested app is a workflow or a chat flow. + - The request user has the edit permission for the app. + """ + + @setup_required + @login_required + @account_initialization_required + @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) + def wrapper(*args, **kwargs): + if not current_user.is_editor: + raise Forbidden() + return f(*args, **kwargs) + + return wrapper + + +class WorkflowVariableCollectionApi(Resource): + @_api_prerequisite + @marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_WITHOUT_VALUE_FIELDS) + def get(self, app_model: App): + """ + Get draft workflow + """ + parser = _create_pagination_parser() + args = parser.parse_args() + + # fetch draft workflow by app_model + workflow_service = WorkflowService() + workflow_exist = workflow_service.is_workflow_exist(app_model=app_model) + if not workflow_exist: + raise DraftWorkflowNotExist() + + # fetch draft workflow by app_model + with Session(bind=db.engine, expire_on_commit=False) as session: + draft_var_srv = WorkflowDraftVariableService( + session=session, + ) + workflow_vars = draft_var_srv.list_variables_without_values( + app_id=app_model.id, + page=args.page, + limit=args.limit, + ) + + return workflow_vars + + @_api_prerequisite + def delete(self, app_model: App): + draft_var_srv = WorkflowDraftVariableService( + session=db.session(), + ) + draft_var_srv.delete_workflow_variables(app_model.id) + db.session.commit() + return Response("", 204) + + +def validate_node_id(node_id: str) -> NoReturn | None: + if node_id in [ + CONVERSATION_VARIABLE_NODE_ID, + SYSTEM_VARIABLE_NODE_ID, + ]: + # NOTE(QuantumGhost): While we store the system and conversation variables as node variables + # with specific `node_id` in database, we still want to make the API separated. By disallowing + # accessing system and conversation variables in `WorkflowDraftNodeVariableListApi`, + # we mitigate the risk that user of the API depending on the implementation detail of the API. + # + # ref: [Hyrum's Law](https://www.hyrumslaw.com/) + + raise InvalidArgumentError( + f"invalid node_id, please use correspond api for conversation and system variables, node_id={node_id}", + ) + return None + + +class NodeVariableCollectionApi(Resource): + @_api_prerequisite + @marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS) + def get(self, app_model: App, node_id: str): + validate_node_id(node_id) + with Session(bind=db.engine, expire_on_commit=False) as session: + draft_var_srv = WorkflowDraftVariableService( + session=session, + ) + node_vars = draft_var_srv.list_node_variables(app_model.id, node_id) + + return node_vars + + @_api_prerequisite + def delete(self, app_model: App, node_id: str): + validate_node_id(node_id) + srv = WorkflowDraftVariableService(db.session()) + srv.delete_node_variables(app_model.id, node_id) + db.session.commit() + return Response("", 204) + + +class VariableApi(Resource): + _PATCH_NAME_FIELD = "name" + _PATCH_VALUE_FIELD = "value" + + @_api_prerequisite + @marshal_with(_WORKFLOW_DRAFT_VARIABLE_FIELDS) + def get(self, app_model: App, variable_id: str): + draft_var_srv = WorkflowDraftVariableService( + session=db.session(), + ) + variable = draft_var_srv.get_variable(variable_id=variable_id) + if variable is None: + raise NotFoundError(description=f"variable not found, id={variable_id}") + if variable.app_id != app_model.id: + raise NotFoundError(description=f"variable not found, id={variable_id}") + return variable + + @_api_prerequisite + @marshal_with(_WORKFLOW_DRAFT_VARIABLE_FIELDS) + def patch(self, app_model: App, variable_id: str): + # Request payload for file types: + # + # Local File: + # + # { + # "type": "image", + # "transfer_method": "local_file", + # "url": "", + # "upload_file_id": "daded54f-72c7-4f8e-9d18-9b0abdd9f190" + # } + # + # Remote File: + # + # + # { + # "type": "image", + # "transfer_method": "remote_url", + # "url": "http://127.0.0.1:5001/files/1602650a-4fe4-423c-85a2-af76c083e3c4/file-preview?timestamp=1750041099&nonce=...&sign=...=", + # "upload_file_id": "1602650a-4fe4-423c-85a2-af76c083e3c4" + # } + + parser = reqparse.RequestParser() + parser.add_argument(self._PATCH_NAME_FIELD, type=str, required=False, nullable=True, location="json") + # Parse 'value' field as-is to maintain its original data structure + parser.add_argument(self._PATCH_VALUE_FIELD, type=lambda x: x, required=False, nullable=True, location="json") + + draft_var_srv = WorkflowDraftVariableService( + session=db.session(), + ) + args = parser.parse_args(strict=True) + + variable = draft_var_srv.get_variable(variable_id=variable_id) + if variable is None: + raise NotFoundError(description=f"variable not found, id={variable_id}") + if variable.app_id != app_model.id: + raise NotFoundError(description=f"variable not found, id={variable_id}") + + new_name = args.get(self._PATCH_NAME_FIELD, None) + raw_value = args.get(self._PATCH_VALUE_FIELD, None) + if new_name is None and raw_value is None: + return variable + + new_value = None + if raw_value is not None: + if variable.value_type == SegmentType.FILE: + if not isinstance(raw_value, dict): + raise InvalidArgumentError(description=f"expected dict for file, got {type(raw_value)}") + raw_value = build_from_mapping(mapping=raw_value, tenant_id=app_model.tenant_id) + elif variable.value_type == SegmentType.ARRAY_FILE: + if not isinstance(raw_value, list): + raise InvalidArgumentError(description=f"expected list for files, got {type(raw_value)}") + if len(raw_value) > 0 and not isinstance(raw_value[0], dict): + raise InvalidArgumentError(description=f"expected dict for files[0], got {type(raw_value)}") + raw_value = build_from_mappings(mappings=raw_value, tenant_id=app_model.tenant_id) + new_value = build_segment_with_type(variable.value_type, raw_value) + draft_var_srv.update_variable(variable, name=new_name, value=new_value) + db.session.commit() + return variable + + @_api_prerequisite + def delete(self, app_model: App, variable_id: str): + draft_var_srv = WorkflowDraftVariableService( + session=db.session(), + ) + variable = draft_var_srv.get_variable(variable_id=variable_id) + if variable is None: + raise NotFoundError(description=f"variable not found, id={variable_id}") + if variable.app_id != app_model.id: + raise NotFoundError(description=f"variable not found, id={variable_id}") + draft_var_srv.delete_variable(variable) + db.session.commit() + return Response("", 204) + + +class VariableResetApi(Resource): + @_api_prerequisite + def put(self, app_model: App, variable_id: str): + draft_var_srv = WorkflowDraftVariableService( + session=db.session(), + ) + + workflow_srv = WorkflowService() + draft_workflow = workflow_srv.get_draft_workflow(app_model) + if draft_workflow is None: + raise NotFoundError( + f"Draft workflow not found, app_id={app_model.id}", + ) + variable = draft_var_srv.get_variable(variable_id=variable_id) + if variable is None: + raise NotFoundError(description=f"variable not found, id={variable_id}") + if variable.app_id != app_model.id: + raise NotFoundError(description=f"variable not found, id={variable_id}") + + resetted = draft_var_srv.reset_variable(draft_workflow, variable) + db.session.commit() + if resetted is None: + return Response("", 204) + else: + return marshal(resetted, _WORKFLOW_DRAFT_VARIABLE_FIELDS) + + +def _get_variable_list(app_model: App, node_id) -> WorkflowDraftVariableList: + with Session(bind=db.engine, expire_on_commit=False) as session: + draft_var_srv = WorkflowDraftVariableService( + session=session, + ) + if node_id == CONVERSATION_VARIABLE_NODE_ID: + draft_vars = draft_var_srv.list_conversation_variables(app_model.id) + elif node_id == SYSTEM_VARIABLE_NODE_ID: + draft_vars = draft_var_srv.list_system_variables(app_model.id) + else: + draft_vars = draft_var_srv.list_node_variables(app_id=app_model.id, node_id=node_id) + return draft_vars + + +class ConversationVariableCollectionApi(Resource): + @_api_prerequisite + @marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS) + def get(self, app_model: App): + # NOTE(QuantumGhost): Prefill conversation variables into the draft variables table + # so their IDs can be returned to the caller. + workflow_srv = WorkflowService() + draft_workflow = workflow_srv.get_draft_workflow(app_model) + if draft_workflow is None: + raise NotFoundError(description=f"draft workflow not found, id={app_model.id}") + draft_var_srv = WorkflowDraftVariableService(db.session()) + draft_var_srv.prefill_conversation_variable_default_values(draft_workflow) + db.session.commit() + return _get_variable_list(app_model, CONVERSATION_VARIABLE_NODE_ID) + + +class SystemVariableCollectionApi(Resource): + @_api_prerequisite + @marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS) + def get(self, app_model: App): + return _get_variable_list(app_model, SYSTEM_VARIABLE_NODE_ID) + + +class EnvironmentVariableCollectionApi(Resource): + @_api_prerequisite + def get(self, app_model: App): + """ + Get draft workflow + """ + # fetch draft workflow by app_model + workflow_service = WorkflowService() + workflow = workflow_service.get_draft_workflow(app_model=app_model) + if workflow is None: + raise DraftWorkflowNotExist() + + env_vars = workflow.environment_variables + env_vars_list = [] + for v in env_vars: + env_vars_list.append( + { + "id": v.id, + "type": "env", + "name": v.name, + "description": v.description, + "selector": v.selector, + "value_type": v.value_type.value, + "value": v.value, + # Do not track edited for env vars. + "edited": False, + "visible": True, + "editable": True, + } + ) + + return {"items": env_vars_list} + + +api.add_resource( + WorkflowVariableCollectionApi, + "/apps//workflows/draft/variables", +) +api.add_resource(NodeVariableCollectionApi, "/apps//workflows/draft/nodes//variables") +api.add_resource(VariableApi, "/apps//workflows/draft/variables/") +api.add_resource(VariableResetApi, "/apps//workflows/draft/variables//reset") + +api.add_resource(ConversationVariableCollectionApi, "/apps//workflows/draft/conversation-variables") +api.add_resource(SystemVariableCollectionApi, "/apps//workflows/draft/system-variables") +api.add_resource(EnvironmentVariableCollectionApi, "/apps//workflows/draft/environment-variables") diff --git a/api/controllers/console/app/wraps.py b/api/controllers/console/app/wraps.py index 9ad8c1584..03b60610a 100644 --- a/api/controllers/console/app/wraps.py +++ b/api/controllers/console/app/wraps.py @@ -8,6 +8,15 @@ from libs.login import current_user from models import App, AppMode +def _load_app_model(app_id: str) -> Optional[App]: + app_model = ( + db.session.query(App) + .filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") + .first() + ) + return app_model + + def get_app_model(view: Optional[Callable] = None, *, mode: Union[AppMode, list[AppMode], None] = None): def decorator(view_func): @wraps(view_func) @@ -20,11 +29,7 @@ def get_app_model(view: Optional[Callable] = None, *, mode: Union[AppMode, list[ del kwargs["app_id"] - app_model = ( - db.session.query(App) - .filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") - .first() - ) + app_model = _load_app_model(app_id) if not app_model: raise AppNotFoundError() diff --git a/api/controllers/web/error.py b/api/controllers/web/error.py index 4371e679d..036e11d5c 100644 --- a/api/controllers/web/error.py +++ b/api/controllers/web/error.py @@ -139,3 +139,13 @@ class InvokeRateLimitError(BaseHTTPException): error_code = "rate_limit_error" description = "Rate Limit Error" code = 429 + + +class NotFoundError(BaseHTTPException): + error_code = "not_found" + code = 404 + + +class InvalidArgumentError(BaseHTTPException): + error_code = "invalid_param" + code = 400 diff --git a/api/core/app/app_config/entities.py b/api/core/app/app_config/entities.py index 3f31b1c3d..75bd2f677 100644 --- a/api/core/app/app_config/entities.py +++ b/api/core/app/app_config/entities.py @@ -104,6 +104,7 @@ class VariableEntity(BaseModel): Variable Entity. """ + # `variable` records the name of the variable in user inputs. variable: str label: str description: str = "" diff --git a/api/core/app/apps/advanced_chat/app_generator.py b/api/core/app/apps/advanced_chat/app_generator.py index a8848b953..afecd9997 100644 --- a/api/core/app/apps/advanced_chat/app_generator.py +++ b/api/core/app/apps/advanced_chat/app_generator.py @@ -29,6 +29,7 @@ from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository +from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader from extensions.ext_database import db from factories import file_factory from libs.flask_utils import preserve_flask_contexts @@ -36,6 +37,7 @@ from models import Account, App, Conversation, EndUser, Message, Workflow, Workf from models.enums import WorkflowRunTriggeredFrom from services.conversation_service import ConversationService from services.errors.message import MessageNotExistsError +from services.workflow_draft_variable_service import DraftVarLoader, WorkflowDraftVariableService logger = logging.getLogger(__name__) @@ -116,6 +118,11 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): ) # parse files + # TODO(QuantumGhost): Move file parsing logic to the API controller layer + # for better separation of concerns. + # + # For implementation reference, see the `_parse_file` function and + # `DraftWorkflowNodeRunApi` class which handle this properly. files = args["files"] if args.get("files") else [] file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False) if file_extra_config: @@ -261,6 +268,13 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): app_id=application_generate_entity.app_config.app_id, triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP, ) + var_loader = DraftVarLoader( + engine=db.engine, + app_id=application_generate_entity.app_config.app_id, + tenant_id=application_generate_entity.app_config.tenant_id, + ) + draft_var_srv = WorkflowDraftVariableService(db.session()) + draft_var_srv.prefill_conversation_variable_default_values(workflow) return self._generate( workflow=workflow, @@ -271,6 +285,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): workflow_node_execution_repository=workflow_node_execution_repository, conversation=None, stream=streaming, + variable_loader=var_loader, ) def single_loop_generate( @@ -336,6 +351,13 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): app_id=application_generate_entity.app_config.app_id, triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP, ) + var_loader = DraftVarLoader( + engine=db.engine, + app_id=application_generate_entity.app_config.app_id, + tenant_id=application_generate_entity.app_config.tenant_id, + ) + draft_var_srv = WorkflowDraftVariableService(db.session()) + draft_var_srv.prefill_conversation_variable_default_values(workflow) return self._generate( workflow=workflow, @@ -346,6 +368,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): workflow_node_execution_repository=workflow_node_execution_repository, conversation=None, stream=streaming, + variable_loader=var_loader, ) def _generate( @@ -359,6 +382,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): workflow_node_execution_repository: WorkflowNodeExecutionRepository, conversation: Optional[Conversation] = None, stream: bool = True, + variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER, ) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], Any, None]: """ Generate App response. @@ -410,6 +434,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): "conversation_id": conversation.id, "message_id": message.id, "context": context, + "variable_loader": variable_loader, }, ) @@ -438,6 +463,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): conversation_id: str, message_id: str, context: contextvars.Context, + variable_loader: VariableLoader, ) -> None: """ Generate worker in a new thread. @@ -464,6 +490,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): conversation=conversation, message=message, dialogue_count=self._dialogue_count, + variable_loader=variable_loader, ) runner.run() diff --git a/api/core/app/apps/advanced_chat/app_runner.py b/api/core/app/apps/advanced_chat/app_runner.py index d9b383386..840a3c9d3 100644 --- a/api/core/app/apps/advanced_chat/app_runner.py +++ b/api/core/app/apps/advanced_chat/app_runner.py @@ -19,6 +19,7 @@ from core.moderation.base import ModerationError from core.workflow.callbacks import WorkflowCallback, WorkflowLoggingCallback from core.workflow.entities.variable_pool import VariablePool from core.workflow.enums import SystemVariableKey +from core.workflow.variable_loader import VariableLoader from core.workflow.workflow_entry import WorkflowEntry from extensions.ext_database import db from models.enums import UserFrom @@ -40,14 +41,17 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): conversation: Conversation, message: Message, dialogue_count: int, + variable_loader: VariableLoader, ) -> None: - super().__init__(queue_manager) - + super().__init__(queue_manager, variable_loader) self.application_generate_entity = application_generate_entity self.conversation = conversation self.message = message self._dialogue_count = dialogue_count + def _get_app_id(self) -> str: + return self.application_generate_entity.app_config.app_id + def run(self) -> None: app_config = self.application_generate_entity.app_config app_config = cast(AdvancedChatAppConfig, app_config) diff --git a/api/core/app/apps/agent_chat/app_generator.py b/api/core/app/apps/agent_chat/app_generator.py index a448bf8a9..75a0b0042 100644 --- a/api/core/app/apps/agent_chat/app_generator.py +++ b/api/core/app/apps/agent_chat/app_generator.py @@ -124,6 +124,11 @@ class AgentChatAppGenerator(MessageBasedAppGenerator): override_model_config_dict["retriever_resource"] = {"enabled": True} # parse files + # TODO(QuantumGhost): Move file parsing logic to the API controller layer + # for better separation of concerns. + # + # For implementation reference, see the `_parse_file` function and + # `DraftWorkflowNodeRunApi` class which handle this properly. files = args.get("files") or [] file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict()) if file_extra_config: diff --git a/api/core/app/apps/chat/app_generator.py b/api/core/app/apps/chat/app_generator.py index a1329cb93..76fae879f 100644 --- a/api/core/app/apps/chat/app_generator.py +++ b/api/core/app/apps/chat/app_generator.py @@ -115,6 +115,11 @@ class ChatAppGenerator(MessageBasedAppGenerator): override_model_config_dict["retriever_resource"] = {"enabled": True} # parse files + # TODO(QuantumGhost): Move file parsing logic to the API controller layer + # for better separation of concerns. + # + # For implementation reference, see the `_parse_file` function and + # `DraftWorkflowNodeRunApi` class which handle this properly. files = args["files"] if args.get("files") else [] file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict()) if file_extra_config: diff --git a/api/core/app/apps/common/workflow_response_converter.py b/api/core/app/apps/common/workflow_response_converter.py index 6f524a587..cd1d298ca 100644 --- a/api/core/app/apps/common/workflow_response_converter.py +++ b/api/core/app/apps/common/workflow_response_converter.py @@ -48,6 +48,7 @@ from core.workflow.entities.workflow_execution import WorkflowExecution from core.workflow.entities.workflow_node_execution import WorkflowNodeExecution, WorkflowNodeExecutionStatus from core.workflow.nodes import NodeType from core.workflow.nodes.tool.entities import ToolNodeData +from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter from models import ( Account, CreatorUserRole, @@ -125,7 +126,7 @@ class WorkflowResponseConverter: id=workflow_execution.id_, workflow_id=workflow_execution.workflow_id, status=workflow_execution.status, - outputs=workflow_execution.outputs, + outputs=WorkflowRuntimeTypeConverter().to_json_encodable(workflow_execution.outputs), error=workflow_execution.error_message, elapsed_time=workflow_execution.elapsed_time, total_tokens=workflow_execution.total_tokens, @@ -202,6 +203,8 @@ class WorkflowResponseConverter: if not workflow_node_execution.finished_at: return None + json_converter = WorkflowRuntimeTypeConverter() + return NodeFinishStreamResponse( task_id=task_id, workflow_run_id=workflow_node_execution.workflow_execution_id, @@ -214,7 +217,7 @@ class WorkflowResponseConverter: predecessor_node_id=workflow_node_execution.predecessor_node_id, inputs=workflow_node_execution.inputs, process_data=workflow_node_execution.process_data, - outputs=workflow_node_execution.outputs, + outputs=json_converter.to_json_encodable(workflow_node_execution.outputs), status=workflow_node_execution.status, error=workflow_node_execution.error, elapsed_time=workflow_node_execution.elapsed_time, @@ -245,6 +248,8 @@ class WorkflowResponseConverter: if not workflow_node_execution.finished_at: return None + json_converter = WorkflowRuntimeTypeConverter() + return NodeRetryStreamResponse( task_id=task_id, workflow_run_id=workflow_node_execution.workflow_execution_id, @@ -257,7 +262,7 @@ class WorkflowResponseConverter: predecessor_node_id=workflow_node_execution.predecessor_node_id, inputs=workflow_node_execution.inputs, process_data=workflow_node_execution.process_data, - outputs=workflow_node_execution.outputs, + outputs=json_converter.to_json_encodable(workflow_node_execution.outputs), status=workflow_node_execution.status, error=workflow_node_execution.error, elapsed_time=workflow_node_execution.elapsed_time, @@ -376,6 +381,7 @@ class WorkflowResponseConverter: workflow_execution_id: str, event: QueueIterationCompletedEvent, ) -> IterationNodeCompletedStreamResponse: + json_converter = WorkflowRuntimeTypeConverter() return IterationNodeCompletedStreamResponse( task_id=task_id, workflow_run_id=workflow_execution_id, @@ -384,7 +390,7 @@ class WorkflowResponseConverter: node_id=event.node_id, node_type=event.node_type.value, title=event.node_data.title, - outputs=event.outputs, + outputs=json_converter.to_json_encodable(event.outputs), created_at=int(time.time()), extras={}, inputs=event.inputs or {}, @@ -463,7 +469,7 @@ class WorkflowResponseConverter: node_id=event.node_id, node_type=event.node_type.value, title=event.node_data.title, - outputs=event.outputs, + outputs=WorkflowRuntimeTypeConverter().to_json_encodable(event.outputs), created_at=int(time.time()), extras={}, inputs=event.inputs or {}, diff --git a/api/core/app/apps/completion/app_generator.py b/api/core/app/apps/completion/app_generator.py index adcbaad3e..7bc4a0a5c 100644 --- a/api/core/app/apps/completion/app_generator.py +++ b/api/core/app/apps/completion/app_generator.py @@ -101,6 +101,11 @@ class CompletionAppGenerator(MessageBasedAppGenerator): ) # parse files + # TODO(QuantumGhost): Move file parsing logic to the API controller layer + # for better separation of concerns. + # + # For implementation reference, see the `_parse_file` function and + # `DraftWorkflowNodeRunApi` class which handle this properly. files = args["files"] if args.get("files") else [] file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict()) if file_extra_config: diff --git a/api/core/app/apps/workflow/app_generator.py b/api/core/app/apps/workflow/app_generator.py index fd15bd9f5..369fa0e48 100644 --- a/api/core/app/apps/workflow/app_generator.py +++ b/api/core/app/apps/workflow/app_generator.py @@ -27,11 +27,13 @@ from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository +from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader from extensions.ext_database import db from factories import file_factory from libs.flask_utils import preserve_flask_contexts from models import Account, App, EndUser, Workflow, WorkflowNodeExecutionTriggeredFrom from models.enums import WorkflowRunTriggeredFrom +from services.workflow_draft_variable_service import DraftVarLoader, WorkflowDraftVariableService logger = logging.getLogger(__name__) @@ -94,6 +96,11 @@ class WorkflowAppGenerator(BaseAppGenerator): files: Sequence[Mapping[str, Any]] = args.get("files") or [] # parse files + # TODO(QuantumGhost): Move file parsing logic to the API controller layer + # for better separation of concerns. + # + # For implementation reference, see the `_parse_file` function and + # `DraftWorkflowNodeRunApi` class which handle this properly. file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False) system_files = file_factory.build_from_mappings( mappings=files, @@ -186,6 +193,7 @@ class WorkflowAppGenerator(BaseAppGenerator): workflow_node_execution_repository: WorkflowNodeExecutionRepository, streaming: bool = True, workflow_thread_pool_id: Optional[str] = None, + variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER, ) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]: """ Generate App response. @@ -219,6 +227,7 @@ class WorkflowAppGenerator(BaseAppGenerator): "queue_manager": queue_manager, "context": context, "workflow_thread_pool_id": workflow_thread_pool_id, + "variable_loader": variable_loader, }, ) @@ -303,6 +312,13 @@ class WorkflowAppGenerator(BaseAppGenerator): app_id=application_generate_entity.app_config.app_id, triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP, ) + draft_var_srv = WorkflowDraftVariableService(db.session()) + draft_var_srv.prefill_conversation_variable_default_values(workflow) + var_loader = DraftVarLoader( + engine=db.engine, + app_id=application_generate_entity.app_config.app_id, + tenant_id=application_generate_entity.app_config.tenant_id, + ) return self._generate( app_model=app_model, @@ -313,6 +329,7 @@ class WorkflowAppGenerator(BaseAppGenerator): workflow_execution_repository=workflow_execution_repository, workflow_node_execution_repository=workflow_node_execution_repository, streaming=streaming, + variable_loader=var_loader, ) def single_loop_generate( @@ -379,7 +396,13 @@ class WorkflowAppGenerator(BaseAppGenerator): app_id=application_generate_entity.app_config.app_id, triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP, ) - + draft_var_srv = WorkflowDraftVariableService(db.session()) + draft_var_srv.prefill_conversation_variable_default_values(workflow) + var_loader = DraftVarLoader( + engine=db.engine, + app_id=application_generate_entity.app_config.app_id, + tenant_id=application_generate_entity.app_config.tenant_id, + ) return self._generate( app_model=app_model, workflow=workflow, @@ -389,6 +412,7 @@ class WorkflowAppGenerator(BaseAppGenerator): workflow_execution_repository=workflow_execution_repository, workflow_node_execution_repository=workflow_node_execution_repository, streaming=streaming, + variable_loader=var_loader, ) def _generate_worker( @@ -397,6 +421,7 @@ class WorkflowAppGenerator(BaseAppGenerator): application_generate_entity: WorkflowAppGenerateEntity, queue_manager: AppQueueManager, context: contextvars.Context, + variable_loader: VariableLoader, workflow_thread_pool_id: Optional[str] = None, ) -> None: """ @@ -415,6 +440,7 @@ class WorkflowAppGenerator(BaseAppGenerator): application_generate_entity=application_generate_entity, queue_manager=queue_manager, workflow_thread_pool_id=workflow_thread_pool_id, + variable_loader=variable_loader, ) runner.run() diff --git a/api/core/app/apps/workflow/app_runner.py b/api/core/app/apps/workflow/app_runner.py index b59e34e22..07aeb57fa 100644 --- a/api/core/app/apps/workflow/app_runner.py +++ b/api/core/app/apps/workflow/app_runner.py @@ -12,6 +12,7 @@ from core.app.entities.app_invoke_entities import ( from core.workflow.callbacks import WorkflowCallback, WorkflowLoggingCallback from core.workflow.entities.variable_pool import VariablePool from core.workflow.enums import SystemVariableKey +from core.workflow.variable_loader import VariableLoader from core.workflow.workflow_entry import WorkflowEntry from extensions.ext_database import db from models.enums import UserFrom @@ -30,6 +31,7 @@ class WorkflowAppRunner(WorkflowBasedAppRunner): self, application_generate_entity: WorkflowAppGenerateEntity, queue_manager: AppQueueManager, + variable_loader: VariableLoader, workflow_thread_pool_id: Optional[str] = None, ) -> None: """ @@ -37,10 +39,13 @@ class WorkflowAppRunner(WorkflowBasedAppRunner): :param queue_manager: application queue manager :param workflow_thread_pool_id: workflow thread pool id """ + super().__init__(queue_manager, variable_loader) self.application_generate_entity = application_generate_entity - self.queue_manager = queue_manager self.workflow_thread_pool_id = workflow_thread_pool_id + def _get_app_id(self) -> str: + return self.application_generate_entity.app_config.app_id + def run(self) -> None: """ Run application diff --git a/api/core/app/apps/workflow_app_runner.py b/api/core/app/apps/workflow_app_runner.py index facc24b4c..dc6c381e8 100644 --- a/api/core/app/apps/workflow_app_runner.py +++ b/api/core/app/apps/workflow_app_runner.py @@ -1,6 +1,8 @@ from collections.abc import Mapping from typing import Any, Optional, cast +from sqlalchemy.orm import Session + from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom from core.app.apps.base_app_runner import AppRunner from core.app.entities.queue_entities import ( @@ -33,6 +35,7 @@ from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey from core.workflow.graph_engine.entities.event import ( AgentLogEvent, + BaseNodeEvent, GraphEngineEvent, GraphRunFailedEvent, GraphRunPartialSucceededEvent, @@ -62,15 +65,23 @@ from core.workflow.graph_engine.entities.event import ( from core.workflow.graph_engine.entities.graph import Graph from core.workflow.nodes import NodeType from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING +from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader, load_into_variable_pool from core.workflow.workflow_entry import WorkflowEntry from extensions.ext_database import db from models.model import App from models.workflow import Workflow +from services.workflow_draft_variable_service import ( + DraftVariableSaver, +) class WorkflowBasedAppRunner(AppRunner): - def __init__(self, queue_manager: AppQueueManager): + def __init__(self, queue_manager: AppQueueManager, variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER) -> None: self.queue_manager = queue_manager + self._variable_loader = variable_loader + + def _get_app_id(self) -> str: + raise NotImplementedError("not implemented") def _init_graph(self, graph_config: Mapping[str, Any]) -> Graph: """ @@ -173,6 +184,13 @@ class WorkflowBasedAppRunner(AppRunner): except NotImplementedError: variable_mapping = {} + load_into_variable_pool( + variable_loader=self._variable_loader, + variable_pool=variable_pool, + variable_mapping=variable_mapping, + user_inputs=user_inputs, + ) + WorkflowEntry.mapping_user_inputs_to_variable_pool( variable_mapping=variable_mapping, user_inputs=user_inputs, @@ -262,6 +280,12 @@ class WorkflowBasedAppRunner(AppRunner): ) except NotImplementedError: variable_mapping = {} + load_into_variable_pool( + self._variable_loader, + variable_pool=variable_pool, + variable_mapping=variable_mapping, + user_inputs=user_inputs, + ) WorkflowEntry.mapping_user_inputs_to_variable_pool( variable_mapping=variable_mapping, @@ -376,6 +400,8 @@ class WorkflowBasedAppRunner(AppRunner): in_loop_id=event.in_loop_id, ) ) + self._save_draft_var_for_event(event) + elif isinstance(event, NodeRunFailedEvent): self._publish_event( QueueNodeFailedEvent( @@ -438,6 +464,8 @@ class WorkflowBasedAppRunner(AppRunner): in_loop_id=event.in_loop_id, ) ) + self._save_draft_var_for_event(event) + elif isinstance(event, NodeInIterationFailedEvent): self._publish_event( QueueNodeInIterationFailedEvent( @@ -690,3 +718,30 @@ class WorkflowBasedAppRunner(AppRunner): def _publish_event(self, event: AppQueueEvent) -> None: self.queue_manager.publish(event, PublishFrom.APPLICATION_MANAGER) + + def _save_draft_var_for_event(self, event: BaseNodeEvent): + run_result = event.route_node_state.node_run_result + if run_result is None: + return + process_data = run_result.process_data + outputs = run_result.outputs + with Session(bind=db.engine) as session, session.begin(): + draft_var_saver = DraftVariableSaver( + session=session, + app_id=self._get_app_id(), + node_id=event.node_id, + node_type=event.node_type, + # FIXME(QuantumGhost): rely on private state of queue_manager is not ideal. + invoke_from=self.queue_manager._invoke_from, + node_execution_id=event.id, + enclosing_node_id=event.in_loop_id or event.in_iteration_id or None, + ) + draft_var_saver.save(process_data=process_data, outputs=outputs) + + +def _remove_first_element_from_variable_string(key: str) -> str: + """ + Remove the first element from the prefix. + """ + prefix, remaining = key.split(".", maxsplit=1) + return remaining diff --git a/api/core/app/entities/app_invoke_entities.py b/api/core/app/entities/app_invoke_entities.py index c0d99693b..65ed26795 100644 --- a/api/core/app/entities/app_invoke_entities.py +++ b/api/core/app/entities/app_invoke_entities.py @@ -17,9 +17,24 @@ class InvokeFrom(Enum): Invoke From. """ + # SERVICE_API indicates that this invocation is from an API call to Dify app. + # + # Description of service api in Dify docs: + # https://docs.dify.ai/en/guides/application-publishing/developing-with-apis SERVICE_API = "service-api" + + # WEB_APP indicates that this invocation is from + # the web app of the workflow (or chatflow). + # + # Description of web app in Dify docs: + # https://docs.dify.ai/en/guides/application-publishing/launch-your-webapp-quickly/README WEB_APP = "web-app" + + # EXPLORE indicates that this invocation is from + # the workflow (or chatflow) explore page. EXPLORE = "explore" + # DEBUGGER indicates that this invocation is from + # the workflow (or chatflow) edit page. DEBUGGER = "debugger" @classmethod diff --git a/api/core/file/constants.py b/api/core/file/constants.py index ce1d238e9..0665ed7e0 100644 --- a/api/core/file/constants.py +++ b/api/core/file/constants.py @@ -1 +1,11 @@ +from typing import Any + +# TODO(QuantumGhost): Refactor variable type identification. Instead of directly +# comparing `dify_model_identity` with constants throughout the codebase, extract +# this logic into a dedicated function. This would encapsulate the implementation +# details of how different variable types are identified. FILE_MODEL_IDENTITY = "__dify__file__" + + +def maybe_file_object(o: Any) -> bool: + return isinstance(o, dict) and o.get("dify_model_identity") == FILE_MODEL_IDENTITY diff --git a/api/core/repositories/sqlalchemy_workflow_execution_repository.py b/api/core/repositories/sqlalchemy_workflow_execution_repository.py index e30538742..cdec92aee 100644 --- a/api/core/repositories/sqlalchemy_workflow_execution_repository.py +++ b/api/core/repositories/sqlalchemy_workflow_execution_repository.py @@ -16,6 +16,7 @@ from core.workflow.entities.workflow_execution import ( WorkflowType, ) from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository +from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter from models import ( Account, CreatorUserRole, @@ -152,7 +153,11 @@ class SQLAlchemyWorkflowExecutionRepository(WorkflowExecutionRepository): db_model.version = domain_model.workflow_version db_model.graph = json.dumps(domain_model.graph) if domain_model.graph else None db_model.inputs = json.dumps(domain_model.inputs) if domain_model.inputs else None - db_model.outputs = json.dumps(domain_model.outputs) if domain_model.outputs else None + db_model.outputs = ( + json.dumps(WorkflowRuntimeTypeConverter().to_json_encodable(domain_model.outputs)) + if domain_model.outputs + else None + ) db_model.status = domain_model.status db_model.error = domain_model.error_message if domain_model.error_message else None db_model.total_tokens = domain_model.total_tokens diff --git a/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py b/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py index 2f2744261..797cce935 100644 --- a/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py +++ b/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py @@ -19,6 +19,7 @@ from core.workflow.entities.workflow_node_execution import ( ) from core.workflow.nodes.enums import NodeType from core.workflow.repositories.workflow_node_execution_repository import OrderConfig, WorkflowNodeExecutionRepository +from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter from models import ( Account, CreatorUserRole, @@ -146,6 +147,7 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository) if not self._creator_user_role: raise ValueError("created_by_role is required in repository constructor") + json_converter = WorkflowRuntimeTypeConverter() db_model = WorkflowNodeExecutionModel() db_model.id = domain_model.id db_model.tenant_id = self._tenant_id @@ -160,9 +162,17 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository) db_model.node_id = domain_model.node_id db_model.node_type = domain_model.node_type db_model.title = domain_model.title - db_model.inputs = json.dumps(domain_model.inputs) if domain_model.inputs else None - db_model.process_data = json.dumps(domain_model.process_data) if domain_model.process_data else None - db_model.outputs = json.dumps(domain_model.outputs) if domain_model.outputs else None + db_model.inputs = ( + json.dumps(json_converter.to_json_encodable(domain_model.inputs)) if domain_model.inputs else None + ) + db_model.process_data = ( + json.dumps(json_converter.to_json_encodable(domain_model.process_data)) + if domain_model.process_data + else None + ) + db_model.outputs = ( + json.dumps(json_converter.to_json_encodable(domain_model.outputs)) if domain_model.outputs else None + ) db_model.status = domain_model.status db_model.error = domain_model.error db_model.elapsed_time = domain_model.elapsed_time diff --git a/api/core/variables/segments.py b/api/core/variables/segments.py index 64ba16c36..6cf09e037 100644 --- a/api/core/variables/segments.py +++ b/api/core/variables/segments.py @@ -75,6 +75,20 @@ class StringSegment(Segment): class FloatSegment(Segment): value_type: SegmentType = SegmentType.NUMBER value: float + # NOTE(QuantumGhost): seems that the equality for FloatSegment with `NaN` value has some problems. + # The following tests cannot pass. + # + # def test_float_segment_and_nan(): + # nan = float("nan") + # assert nan != nan + # + # f1 = FloatSegment(value=float("nan")) + # f2 = FloatSegment(value=float("nan")) + # assert f1 != f2 + # + # f3 = FloatSegment(value=nan) + # f4 = FloatSegment(value=nan) + # assert f3 != f4 class IntegerSegment(Segment): diff --git a/api/core/variables/types.py b/api/core/variables/types.py index 4387e9693..68d3d8288 100644 --- a/api/core/variables/types.py +++ b/api/core/variables/types.py @@ -18,3 +18,17 @@ class SegmentType(StrEnum): NONE = "none" GROUP = "group" + + def is_array_type(self): + return self in _ARRAY_TYPES + + +_ARRAY_TYPES = frozenset( + [ + SegmentType.ARRAY_ANY, + SegmentType.ARRAY_STRING, + SegmentType.ARRAY_NUMBER, + SegmentType.ARRAY_OBJECT, + SegmentType.ARRAY_FILE, + ] +) diff --git a/api/core/variables/utils.py b/api/core/variables/utils.py index e5d222af7..692db3502 100644 --- a/api/core/variables/utils.py +++ b/api/core/variables/utils.py @@ -1,8 +1,26 @@ +import json from collections.abc import Iterable, Sequence +from .segment_group import SegmentGroup +from .segments import ArrayFileSegment, FileSegment, Segment + def to_selector(node_id: str, name: str, paths: Iterable[str] = ()) -> Sequence[str]: selectors = [node_id, name] if paths: selectors.extend(paths) return selectors + + +class SegmentJSONEncoder(json.JSONEncoder): + def default(self, o): + if isinstance(o, ArrayFileSegment): + return [v.model_dump() for v in o.value] + elif isinstance(o, FileSegment): + return o.value.model_dump() + elif isinstance(o, SegmentGroup): + return [self.default(seg) for seg in o.value] + elif isinstance(o, Segment): + return o.value + else: + super().default(o) diff --git a/api/core/workflow/conversation_variable_updater.py b/api/core/workflow/conversation_variable_updater.py new file mode 100644 index 000000000..84e99bb58 --- /dev/null +++ b/api/core/workflow/conversation_variable_updater.py @@ -0,0 +1,39 @@ +import abc +from typing import Protocol + +from core.variables import Variable + + +class ConversationVariableUpdater(Protocol): + """ + ConversationVariableUpdater defines an abstraction for updating conversation variable values. + + It is intended for use by `v1.VariableAssignerNode` and `v2.VariableAssignerNode` when updating + conversation variables. + + Implementations may choose to batch updates. If batching is used, the `flush` method + should be implemented to persist buffered changes, and `update` + should handle buffering accordingly. + + Note: Since implementations may buffer updates, instances of ConversationVariableUpdater + are not thread-safe. Each VariableAssignerNode should create its own instance during execution. + """ + + @abc.abstractmethod + def update(self, conversation_id: str, variable: "Variable") -> None: + """ + Updates the value of the specified conversation variable in the underlying storage. + + :param conversation_id: The ID of the conversation to update. Typically references `ConversationVariable.id`. + :param variable: The `Variable` instance containing the updated value. + """ + pass + + @abc.abstractmethod + def flush(self): + """ + Flushes all pending updates to the underlying storage system. + + If the implementation does not buffer updates, this method can be a no-op. + """ + pass diff --git a/api/core/workflow/entities/variable_pool.py b/api/core/workflow/entities/variable_pool.py index af26864c0..80dda2632 100644 --- a/api/core/workflow/entities/variable_pool.py +++ b/api/core/workflow/entities/variable_pool.py @@ -7,12 +7,12 @@ from pydantic import BaseModel, Field from core.file import File, FileAttribute, file_manager from core.variables import Segment, SegmentGroup, Variable +from core.variables.consts import MIN_SELECTORS_LENGTH from core.variables.segments import FileSegment, NoneSegment +from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID +from core.workflow.enums import SystemVariableKey from factories import variable_factory -from ..constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID -from ..enums import SystemVariableKey - VariableValue = Union[str, int, float, dict, list, File] VARIABLE_PATTERN = re.compile(r"\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10})#\}\}") @@ -30,9 +30,11 @@ class VariablePool(BaseModel): # TODO: This user inputs is not used for pool. user_inputs: Mapping[str, Any] = Field( description="User inputs", + default_factory=dict, ) system_variables: Mapping[SystemVariableKey, Any] = Field( description="System variables", + default_factory=dict, ) environment_variables: Sequence[Variable] = Field( description="Environment variables.", @@ -43,28 +45,7 @@ class VariablePool(BaseModel): default_factory=list, ) - def __init__( - self, - *, - system_variables: Mapping[SystemVariableKey, Any] | None = None, - user_inputs: Mapping[str, Any] | None = None, - environment_variables: Sequence[Variable] | None = None, - conversation_variables: Sequence[Variable] | None = None, - **kwargs, - ): - environment_variables = environment_variables or [] - conversation_variables = conversation_variables or [] - user_inputs = user_inputs or {} - system_variables = system_variables or {} - - super().__init__( - system_variables=system_variables, - user_inputs=user_inputs, - environment_variables=environment_variables, - conversation_variables=conversation_variables, - **kwargs, - ) - + def model_post_init(self, context: Any, /) -> None: for key, value in self.system_variables.items(): self.add((SYSTEM_VARIABLE_NODE_ID, key.value), value) # Add environment variables to the variable pool @@ -91,12 +72,12 @@ class VariablePool(BaseModel): Returns: None """ - if len(selector) < 2: + if len(selector) < MIN_SELECTORS_LENGTH: raise ValueError("Invalid selector") if isinstance(value, Variable): variable = value - if isinstance(value, Segment): + elif isinstance(value, Segment): variable = variable_factory.segment_to_variable(segment=value, selector=selector) else: segment = variable_factory.build_segment(value) @@ -118,7 +99,7 @@ class VariablePool(BaseModel): Raises: ValueError: If the selector is invalid. """ - if len(selector) < 2: + if len(selector) < MIN_SELECTORS_LENGTH: return None hash_key = hash(tuple(selector[1:])) diff --git a/api/core/workflow/graph_engine/entities/event.py b/api/core/workflow/graph_engine/entities/event.py index 9a4939502..e57e9e4d6 100644 --- a/api/core/workflow/graph_engine/entities/event.py +++ b/api/core/workflow/graph_engine/entities/event.py @@ -66,6 +66,8 @@ class BaseNodeEvent(GraphEngineEvent): """iteration id if node is in iteration""" in_loop_id: Optional[str] = None """loop id if node is in loop""" + # The version of the node, or "1" if not specified. + node_version: str = "1" class NodeRunStartedEvent(BaseNodeEvent): diff --git a/api/core/workflow/graph_engine/graph_engine.py b/api/core/workflow/graph_engine/graph_engine.py index ee2164f22..2809afad0 100644 --- a/api/core/workflow/graph_engine/graph_engine.py +++ b/api/core/workflow/graph_engine/graph_engine.py @@ -314,6 +314,7 @@ class GraphEngine: parallel_start_node_id=parallel_start_node_id, parent_parallel_id=parent_parallel_id, parent_parallel_start_node_id=parent_parallel_start_node_id, + node_version=node_instance.version(), ) raise e @@ -627,6 +628,7 @@ class GraphEngine: parent_parallel_id=parent_parallel_id, parent_parallel_start_node_id=parent_parallel_start_node_id, agent_strategy=agent_strategy, + node_version=node_instance.version(), ) max_retries = node_instance.node_data.retry_config.max_retries @@ -677,6 +679,7 @@ class GraphEngine: error=run_result.error or "Unknown error", retry_index=retries, start_at=retry_start_at, + node_version=node_instance.version(), ) time.sleep(retry_interval) break @@ -712,6 +715,7 @@ class GraphEngine: parallel_start_node_id=parallel_start_node_id, parent_parallel_id=parent_parallel_id, parent_parallel_start_node_id=parent_parallel_start_node_id, + node_version=node_instance.version(), ) should_continue_retry = False else: @@ -726,6 +730,7 @@ class GraphEngine: parallel_start_node_id=parallel_start_node_id, parent_parallel_id=parent_parallel_id, parent_parallel_start_node_id=parent_parallel_start_node_id, + node_version=node_instance.version(), ) should_continue_retry = False elif run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED: @@ -786,6 +791,7 @@ class GraphEngine: parallel_start_node_id=parallel_start_node_id, parent_parallel_id=parent_parallel_id, parent_parallel_start_node_id=parent_parallel_start_node_id, + node_version=node_instance.version(), ) should_continue_retry = False @@ -803,6 +809,7 @@ class GraphEngine: parallel_start_node_id=parallel_start_node_id, parent_parallel_id=parent_parallel_id, parent_parallel_start_node_id=parent_parallel_start_node_id, + node_version=node_instance.version(), ) elif isinstance(event, RunRetrieverResourceEvent): yield NodeRunRetrieverResourceEvent( @@ -817,6 +824,7 @@ class GraphEngine: parallel_start_node_id=parallel_start_node_id, parent_parallel_id=parent_parallel_id, parent_parallel_start_node_id=parent_parallel_start_node_id, + node_version=node_instance.version(), ) except GenerateTaskStoppedError: # trigger node run failed event @@ -833,6 +841,7 @@ class GraphEngine: parallel_start_node_id=parallel_start_node_id, parent_parallel_id=parent_parallel_id, parent_parallel_start_node_id=parent_parallel_start_node_id, + node_version=node_instance.version(), ) return except Exception as e: diff --git a/api/core/workflow/nodes/answer/answer_node.py b/api/core/workflow/nodes/answer/answer_node.py index aa030870e..38c2bcbdf 100644 --- a/api/core/workflow/nodes/answer/answer_node.py +++ b/api/core/workflow/nodes/answer/answer_node.py @@ -18,7 +18,11 @@ from core.workflow.utils.variable_template_parser import VariableTemplateParser class AnswerNode(BaseNode[AnswerNodeData]): _node_data_cls = AnswerNodeData - _node_type: NodeType = NodeType.ANSWER + _node_type = NodeType.ANSWER + + @classmethod + def version(cls) -> str: + return "1" def _run(self) -> NodeRunResult: """ @@ -45,7 +49,10 @@ class AnswerNode(BaseNode[AnswerNodeData]): part = cast(TextGenerateRouteChunk, part) answer += part.text - return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, outputs={"answer": answer, "files": files}) + return NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + outputs={"answer": answer, "files": ArrayFileSegment(value=files)}, + ) @classmethod def _extract_variable_selector_to_variable_mapping( diff --git a/api/core/workflow/nodes/answer/answer_stream_processor.py b/api/core/workflow/nodes/answer/answer_stream_processor.py index ba6ba16e3..f3e4a62ad 100644 --- a/api/core/workflow/nodes/answer/answer_stream_processor.py +++ b/api/core/workflow/nodes/answer/answer_stream_processor.py @@ -109,6 +109,7 @@ class AnswerStreamProcessor(StreamProcessor): parallel_id=event.parallel_id, parallel_start_node_id=event.parallel_start_node_id, from_variable_selector=[answer_node_id, "answer"], + node_version=event.node_version, ) else: route_chunk = cast(VarGenerateRouteChunk, route_chunk) @@ -134,6 +135,7 @@ class AnswerStreamProcessor(StreamProcessor): route_node_state=event.route_node_state, parallel_id=event.parallel_id, parallel_start_node_id=event.parallel_start_node_id, + node_version=event.node_version, ) self.route_position[answer_node_id] += 1 diff --git a/api/core/workflow/nodes/base/node.py b/api/core/workflow/nodes/base/node.py index 7da0c1974..697340142 100644 --- a/api/core/workflow/nodes/base/node.py +++ b/api/core/workflow/nodes/base/node.py @@ -1,7 +1,7 @@ import logging from abc import abstractmethod from collections.abc import Generator, Mapping, Sequence -from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar, Union, cast +from typing import TYPE_CHECKING, Any, ClassVar, Generic, Optional, TypeVar, Union, cast from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus @@ -23,7 +23,7 @@ GenericNodeData = TypeVar("GenericNodeData", bound=BaseNodeData) class BaseNode(Generic[GenericNodeData]): _node_data_cls: type[GenericNodeData] - _node_type: NodeType + _node_type: ClassVar[NodeType] def __init__( self, @@ -90,8 +90,38 @@ class BaseNode(Generic[GenericNodeData]): graph_config: Mapping[str, Any], config: Mapping[str, Any], ) -> Mapping[str, Sequence[str]]: - """ - Extract variable selector to variable mapping + """Extracts references variable selectors from node configuration. + + The `config` parameter represents the configuration for a specific node type and corresponds + to the `data` field in the node definition object. + + The returned mapping has the following structure: + + {'1747829548239.#1747829667553.result#': ['1747829667553', 'result']} + + For loop and iteration nodes, the mapping may look like this: + + { + "1748332301644.input_selector": ["1748332363630", "result"], + "1748332325079.1748332325079.#sys.workflow_id#": ["sys", "workflow_id"], + } + + where `1748332301644` is the ID of the loop / iteration node, + and `1748332325079` is the ID of the node inside the loop or iteration node. + + Here, the key consists of two parts: the current node ID (provided as the `node_id` + parameter to `_extract_variable_selector_to_variable_mapping`) and the variable selector, + enclosed in `#` symbols. These two parts are separated by a dot (`.`). + + The value is a list of string representing the variable selector, where the first element is the node ID + of the referenced variable, and the second element is the variable name within that node. + + The meaning of the above response is: + + The node with ID `1747829548239` references the variable `result` from the node with + ID `1747829667553`. For example, if `1747829548239` is a LLM node, its prompt may contain a + reference to the `result` output variable of node `1747829667553`. + :param graph_config: graph config :param config: node config :return: @@ -101,9 +131,10 @@ class BaseNode(Generic[GenericNodeData]): raise ValueError("Node ID is required when extracting variable selector to variable mapping.") node_data = cls._node_data_cls(**config.get("data", {})) - return cls._extract_variable_selector_to_variable_mapping( + data = cls._extract_variable_selector_to_variable_mapping( graph_config=graph_config, node_id=node_id, node_data=cast(GenericNodeData, node_data) ) + return data @classmethod def _extract_variable_selector_to_variable_mapping( @@ -139,6 +170,16 @@ class BaseNode(Generic[GenericNodeData]): """ return self._node_type + @classmethod + @abstractmethod + def version(cls) -> str: + """`node_version` returns the version of current node type.""" + # NOTE(QuantumGhost): This should be in sync with `NODE_TYPE_CLASSES_MAPPING`. + # + # If you have introduced a new node type, please add it to `NODE_TYPE_CLASSES_MAPPING` + # in `api/core/workflow/nodes/__init__.py`. + raise NotImplementedError("subclasses of BaseNode must implement `version` method.") + @property def should_continue_on_error(self) -> bool: """judge if should continue on error diff --git a/api/core/workflow/nodes/code/code_node.py b/api/core/workflow/nodes/code/code_node.py index 61c08a7d7..22ed9e265 100644 --- a/api/core/workflow/nodes/code/code_node.py +++ b/api/core/workflow/nodes/code/code_node.py @@ -40,6 +40,10 @@ class CodeNode(BaseNode[CodeNodeData]): return code_provider.get_default_config() + @classmethod + def version(cls) -> str: + return "1" + def _run(self) -> NodeRunResult: # Get code language code_language = self.node_data.code_language @@ -126,6 +130,9 @@ class CodeNode(BaseNode[CodeNodeData]): prefix: str = "", depth: int = 1, ): + # TODO(QuantumGhost): Replace native Python lists with `Array*Segment` classes. + # Note that `_transform_result` may produce lists containing `None` values, + # which don't conform to the type requirements of `Array*Segment` classes. if depth > dify_config.CODE_MAX_DEPTH: raise DepthLimitError(f"Depth limit {dify_config.CODE_MAX_DEPTH} reached, object too deep.") diff --git a/api/core/workflow/nodes/document_extractor/node.py b/api/core/workflow/nodes/document_extractor/node.py index 429fed2d0..9f48b4886 100644 --- a/api/core/workflow/nodes/document_extractor/node.py +++ b/api/core/workflow/nodes/document_extractor/node.py @@ -24,7 +24,7 @@ from configs import dify_config from core.file import File, FileTransferMethod, file_manager from core.helper import ssrf_proxy from core.variables import ArrayFileSegment -from core.variables.segments import FileSegment +from core.variables.segments import ArrayStringSegment, FileSegment from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus from core.workflow.nodes.base import BaseNode @@ -45,6 +45,10 @@ class DocumentExtractorNode(BaseNode[DocumentExtractorNodeData]): _node_data_cls = DocumentExtractorNodeData _node_type = NodeType.DOCUMENT_EXTRACTOR + @classmethod + def version(cls) -> str: + return "1" + def _run(self): variable_selector = self.node_data.variable_selector variable = self.graph_runtime_state.variable_pool.get(variable_selector) @@ -67,7 +71,7 @@ class DocumentExtractorNode(BaseNode[DocumentExtractorNodeData]): status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=inputs, process_data=process_data, - outputs={"text": extracted_text_list}, + outputs={"text": ArrayStringSegment(value=extracted_text_list)}, ) elif isinstance(value, File): extracted_text = _extract_text_from_file(value) diff --git a/api/core/workflow/nodes/end/end_node.py b/api/core/workflow/nodes/end/end_node.py index 0e9756b24..17a0b3ade 100644 --- a/api/core/workflow/nodes/end/end_node.py +++ b/api/core/workflow/nodes/end/end_node.py @@ -9,6 +9,10 @@ class EndNode(BaseNode[EndNodeData]): _node_data_cls = EndNodeData _node_type = NodeType.END + @classmethod + def version(cls) -> str: + return "1" + def _run(self) -> NodeRunResult: """ Run node diff --git a/api/core/workflow/nodes/end/end_stream_processor.py b/api/core/workflow/nodes/end/end_stream_processor.py index 3ae5af713..a6fb2ffc1 100644 --- a/api/core/workflow/nodes/end/end_stream_processor.py +++ b/api/core/workflow/nodes/end/end_stream_processor.py @@ -139,6 +139,7 @@ class EndStreamProcessor(StreamProcessor): route_node_state=event.route_node_state, parallel_id=event.parallel_id, parallel_start_node_id=event.parallel_start_node_id, + node_version=event.node_version, ) self.route_position[end_node_id] += 1 diff --git a/api/core/workflow/nodes/http_request/node.py b/api/core/workflow/nodes/http_request/node.py index 6b1ac57c0..971e0f73e 100644 --- a/api/core/workflow/nodes/http_request/node.py +++ b/api/core/workflow/nodes/http_request/node.py @@ -6,6 +6,7 @@ from typing import Any, Optional from configs import dify_config from core.file import File, FileTransferMethod from core.tools.tool_file_manager import ToolFileManager +from core.variables.segments import ArrayFileSegment from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.variable_entities import VariableSelector from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus @@ -60,6 +61,10 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]): }, } + @classmethod + def version(cls) -> str: + return "1" + def _run(self) -> NodeRunResult: process_data = {} try: @@ -92,7 +97,7 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]): status=WorkflowNodeExecutionStatus.SUCCEEDED, outputs={ "status_code": response.status_code, - "body": response.text if not files else "", + "body": response.text if not files.value else "", "headers": response.headers, "files": files, }, @@ -166,7 +171,7 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]): return mapping - def extract_files(self, url: str, response: Response) -> list[File]: + def extract_files(self, url: str, response: Response) -> ArrayFileSegment: """ Extract files from response by checking both Content-Type header and URL """ @@ -178,7 +183,7 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]): content_disposition_type = None if not is_file: - return files + return ArrayFileSegment(value=[]) if parsed_content_disposition: content_disposition_filename = parsed_content_disposition.get_filename() @@ -211,4 +216,4 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]): ) files.append(file) - return files + return ArrayFileSegment(value=files) diff --git a/api/core/workflow/nodes/if_else/if_else_node.py b/api/core/workflow/nodes/if_else/if_else_node.py index 976922f75..22b748030 100644 --- a/api/core/workflow/nodes/if_else/if_else_node.py +++ b/api/core/workflow/nodes/if_else/if_else_node.py @@ -1,4 +1,5 @@ -from typing import Literal +from collections.abc import Mapping, Sequence +from typing import Any, Literal from typing_extensions import deprecated @@ -16,6 +17,10 @@ class IfElseNode(BaseNode[IfElseNodeData]): _node_data_cls = IfElseNodeData _node_type = NodeType.IF_ELSE + @classmethod + def version(cls) -> str: + return "1" + def _run(self) -> NodeRunResult: """ Run node @@ -87,6 +92,22 @@ class IfElseNode(BaseNode[IfElseNodeData]): return data + @classmethod + def _extract_variable_selector_to_variable_mapping( + cls, + *, + graph_config: Mapping[str, Any], + node_id: str, + node_data: IfElseNodeData, + ) -> Mapping[str, Sequence[str]]: + var_mapping: dict[str, list[str]] = {} + for case in node_data.cases or []: + for condition in case.conditions: + key = "{}.#{}#".format(node_id, ".".join(condition.variable_selector)) + var_mapping[key] = condition.variable_selector + + return var_mapping + @deprecated("This function is deprecated. You should use the new cases structure.") def _should_not_use_old_function( diff --git a/api/core/workflow/nodes/iteration/iteration_node.py b/api/core/workflow/nodes/iteration/iteration_node.py index 42b6795fb..151efc28e 100644 --- a/api/core/workflow/nodes/iteration/iteration_node.py +++ b/api/core/workflow/nodes/iteration/iteration_node.py @@ -11,6 +11,7 @@ from flask import Flask, current_app from configs import dify_config from core.variables import ArrayVariable, IntegerVariable, NoneVariable +from core.variables.segments import ArrayAnySegment, ArraySegment from core.workflow.entities.node_entities import ( NodeRunResult, ) @@ -37,6 +38,7 @@ from core.workflow.nodes.base import BaseNode from core.workflow.nodes.enums import NodeType from core.workflow.nodes.event import NodeEvent, RunCompletedEvent from core.workflow.nodes.iteration.entities import ErrorHandleMode, IterationNodeData +from factories.variable_factory import build_segment from libs.flask_utils import preserve_flask_contexts from .exc import ( @@ -72,6 +74,10 @@ class IterationNode(BaseNode[IterationNodeData]): }, } + @classmethod + def version(cls) -> str: + return "1" + def _run(self) -> Generator[NodeEvent | InNodeEvent, None, None]: """ Run the node. @@ -85,10 +91,17 @@ class IterationNode(BaseNode[IterationNodeData]): raise InvalidIteratorValueError(f"invalid iterator value: {variable}, please provide a list.") if isinstance(variable, NoneVariable) or len(variable.value) == 0: + # Try our best to preserve the type informat. + if isinstance(variable, ArraySegment): + output = variable.model_copy(update={"value": []}) + else: + output = ArrayAnySegment(value=[]) yield RunCompletedEvent( run_result=NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, - outputs={"output": []}, + # TODO(QuantumGhost): is it possible to compute the type of `output` + # from graph definition? + outputs={"output": output}, ) ) return @@ -231,6 +244,7 @@ class IterationNode(BaseNode[IterationNodeData]): # Flatten the list of lists if isinstance(outputs, list) and all(isinstance(output, list) for output in outputs): outputs = [item for sublist in outputs for item in sublist] + output_segment = build_segment(outputs) yield IterationRunSucceededEvent( iteration_id=self.id, @@ -247,7 +261,7 @@ class IterationNode(BaseNode[IterationNodeData]): yield RunCompletedEvent( run_result=NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, - outputs={"output": outputs}, + outputs={"output": output_segment}, metadata={ WorkflowNodeExecutionMetadataKey.ITERATION_DURATION_MAP: iter_run_map, WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens, diff --git a/api/core/workflow/nodes/iteration/iteration_start_node.py b/api/core/workflow/nodes/iteration/iteration_start_node.py index bee481ebd..9900aa225 100644 --- a/api/core/workflow/nodes/iteration/iteration_start_node.py +++ b/api/core/workflow/nodes/iteration/iteration_start_node.py @@ -13,6 +13,10 @@ class IterationStartNode(BaseNode[IterationStartNodeData]): _node_data_cls = IterationStartNodeData _node_type = NodeType.ITERATION_START + @classmethod + def version(cls) -> str: + return "1" + def _run(self) -> NodeRunResult: """ Run the node. diff --git a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py index 5cf5848d5..2995f0682 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py +++ b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py @@ -24,6 +24,7 @@ from core.rag.entities.metadata_entities import Condition, MetadataCondition from core.rag.retrieval.dataset_retrieval import DatasetRetrieval from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.variables import StringSegment +from core.variables.segments import ArrayObjectSegment from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus from core.workflow.nodes.enums import NodeType @@ -115,9 +116,12 @@ class KnowledgeRetrievalNode(LLMNode): # retrieve knowledge try: results = self._fetch_dataset_retriever(node_data=node_data, query=query) - outputs = {"result": results} + outputs = {"result": ArrayObjectSegment(value=results)} return NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, process_data=None, outputs=outputs + status=WorkflowNodeExecutionStatus.SUCCEEDED, + inputs=variables, + process_data=None, + outputs=outputs, # type: ignore ) except KnowledgeRetrievalNodeError as e: diff --git a/api/core/workflow/nodes/list_operator/node.py b/api/core/workflow/nodes/list_operator/node.py index e698d3f5d..3c9ba44cf 100644 --- a/api/core/workflow/nodes/list_operator/node.py +++ b/api/core/workflow/nodes/list_operator/node.py @@ -3,6 +3,7 @@ from typing import Any, Literal, Union from core.file import File from core.variables import ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment +from core.variables.segments import ArrayAnySegment, ArraySegment from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus from core.workflow.nodes.base import BaseNode @@ -16,6 +17,10 @@ class ListOperatorNode(BaseNode[ListOperatorNodeData]): _node_data_cls = ListOperatorNodeData _node_type = NodeType.LIST_OPERATOR + @classmethod + def version(cls) -> str: + return "1" + def _run(self): inputs: dict[str, list] = {} process_data: dict[str, list] = {} @@ -30,7 +35,11 @@ class ListOperatorNode(BaseNode[ListOperatorNodeData]): if not variable.value: inputs = {"variable": []} process_data = {"variable": []} - outputs = {"result": [], "first_record": None, "last_record": None} + if isinstance(variable, ArraySegment): + result = variable.model_copy(update={"value": []}) + else: + result = ArrayAnySegment(value=[]) + outputs = {"result": result, "first_record": None, "last_record": None} return NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=inputs, @@ -71,7 +80,7 @@ class ListOperatorNode(BaseNode[ListOperatorNodeData]): variable = self._apply_slice(variable) outputs = { - "result": variable.value, + "result": variable, "first_record": variable.value[0] if variable.value else None, "last_record": variable.value[-1] if variable.value else None, } diff --git a/api/core/workflow/nodes/llm/file_saver.py b/api/core/workflow/nodes/llm/file_saver.py index c85baade0..a4b45ce65 100644 --- a/api/core/workflow/nodes/llm/file_saver.py +++ b/api/core/workflow/nodes/llm/file_saver.py @@ -119,9 +119,6 @@ class FileSaverImpl(LLMFileSaver): size=len(data), related_id=tool_file.id, url=url, - # TODO(QuantumGhost): how should I set the following key? - # What's the difference between `remote_url` and `url`? - # What's the purpose of `storage_key` and `dify_model_identity`? storage_key=tool_file.file_key, ) diff --git a/api/core/workflow/nodes/llm/node.py b/api/core/workflow/nodes/llm/node.py index d27124d62..124ae6d75 100644 --- a/api/core/workflow/nodes/llm/node.py +++ b/api/core/workflow/nodes/llm/node.py @@ -138,6 +138,10 @@ class LLMNode(BaseNode[LLMNodeData]): ) self._llm_file_saver = llm_file_saver + @classmethod + def version(cls) -> str: + return "1" + def _run(self) -> Generator[NodeEvent | InNodeEvent, None, None]: def process_structured_output(text: str) -> Optional[dict[str, Any]]: """Process structured output if enabled""" @@ -255,7 +259,7 @@ class LLMNode(BaseNode[LLMNodeData]): if structured_output: outputs["structured_output"] = structured_output if self._file_outputs is not None: - outputs["files"] = self._file_outputs + outputs["files"] = ArrayFileSegment(value=self._file_outputs) yield RunCompletedEvent( run_result=NodeRunResult( diff --git a/api/core/workflow/nodes/loop/loop_end_node.py b/api/core/workflow/nodes/loop/loop_end_node.py index 327b9e234..b144021ba 100644 --- a/api/core/workflow/nodes/loop/loop_end_node.py +++ b/api/core/workflow/nodes/loop/loop_end_node.py @@ -13,6 +13,10 @@ class LoopEndNode(BaseNode[LoopEndNodeData]): _node_data_cls = LoopEndNodeData _node_type = NodeType.LOOP_END + @classmethod + def version(cls) -> str: + return "1" + def _run(self) -> NodeRunResult: """ Run the node. diff --git a/api/core/workflow/nodes/loop/loop_node.py b/api/core/workflow/nodes/loop/loop_node.py index fafa20538..368d662a7 100644 --- a/api/core/workflow/nodes/loop/loop_node.py +++ b/api/core/workflow/nodes/loop/loop_node.py @@ -54,6 +54,10 @@ class LoopNode(BaseNode[LoopNodeData]): _node_data_cls = LoopNodeData _node_type = NodeType.LOOP + @classmethod + def version(cls) -> str: + return "1" + def _run(self) -> Generator[NodeEvent | InNodeEvent, None, None]: """Run the node.""" # Get inputs @@ -482,6 +486,13 @@ class LoopNode(BaseNode[LoopNodeData]): variable_mapping.update(sub_node_variable_mapping) + for loop_variable in node_data.loop_variables or []: + if loop_variable.value_type == "variable": + assert loop_variable.value is not None, "Loop variable value must be provided for variable type" + # add loop variable to variable mapping + selector = loop_variable.value + variable_mapping[f"{node_id}.{loop_variable.label}"] = selector + # remove variable out from loop variable_mapping = { key: value for key, value in variable_mapping.items() if value[0] not in loop_graph.node_ids diff --git a/api/core/workflow/nodes/loop/loop_start_node.py b/api/core/workflow/nodes/loop/loop_start_node.py index 5a15f3604..f5e38b751 100644 --- a/api/core/workflow/nodes/loop/loop_start_node.py +++ b/api/core/workflow/nodes/loop/loop_start_node.py @@ -13,6 +13,10 @@ class LoopStartNode(BaseNode[LoopStartNodeData]): _node_data_cls = LoopStartNodeData _node_type = NodeType.LOOP_START + @classmethod + def version(cls) -> str: + return "1" + def _run(self) -> NodeRunResult: """ Run the node. diff --git a/api/core/workflow/nodes/node_mapping.py b/api/core/workflow/nodes/node_mapping.py index 1f1be5954..67cc884f2 100644 --- a/api/core/workflow/nodes/node_mapping.py +++ b/api/core/workflow/nodes/node_mapping.py @@ -25,6 +25,11 @@ from core.workflow.nodes.variable_assigner.v2 import VariableAssignerNode as Var LATEST_VERSION = "latest" +# NOTE(QuantumGhost): This should be in sync with subclasses of BaseNode. +# Specifically, if you have introduced new node types, you should add them here. +# +# TODO(QuantumGhost): This could be automated with either metaclass or `__init_subclass__` +# hook. Try to avoid duplication of node information. NODE_TYPE_CLASSES_MAPPING: Mapping[NodeType, Mapping[str, type[BaseNode]]] = { NodeType.START: { LATEST_VERSION: StartNode, diff --git a/api/core/workflow/nodes/parameter_extractor/entities.py b/api/core/workflow/nodes/parameter_extractor/entities.py index 369eb13b0..916778d16 100644 --- a/api/core/workflow/nodes/parameter_extractor/entities.py +++ b/api/core/workflow/nodes/parameter_extractor/entities.py @@ -7,6 +7,10 @@ from core.workflow.nodes.base import BaseNodeData from core.workflow.nodes.llm import ModelConfig, VisionConfig +class _ParameterConfigError(Exception): + pass + + class ParameterConfig(BaseModel): """ Parameter Config. @@ -27,6 +31,19 @@ class ParameterConfig(BaseModel): raise ValueError("Invalid parameter name, __reason and __is_success are reserved") return str(value) + def is_array_type(self) -> bool: + return self.type in ("array[string]", "array[number]", "array[object]") + + def element_type(self) -> Literal["string", "number", "object"]: + if self.type == "array[number]": + return "number" + elif self.type == "array[string]": + return "string" + elif self.type == "array[object]": + return "object" + else: + raise _ParameterConfigError(f"{self.type} is not array type.") + class ParameterExtractorNodeData(BaseNodeData): """ diff --git a/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py b/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py index 255278476..8d6c2d0a5 100644 --- a/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py +++ b/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py @@ -25,6 +25,7 @@ from core.prompt.advanced_prompt_transform import AdvancedPromptTransform from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate from core.prompt.simple_prompt_transform import ModelMode from core.prompt.utils.prompt_message_util import PromptMessageUtil +from core.variables.types import SegmentType from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus @@ -32,6 +33,7 @@ from core.workflow.nodes.base.node import BaseNode from core.workflow.nodes.enums import NodeType from core.workflow.nodes.llm import ModelConfig, llm_utils from core.workflow.utils import variable_template_parser +from factories.variable_factory import build_segment_with_type from .entities import ParameterExtractorNodeData from .exc import ( @@ -109,6 +111,10 @@ class ParameterExtractorNode(BaseNode): } } + @classmethod + def version(cls) -> str: + return "1" + def _run(self): """ Run the node. @@ -584,28 +590,30 @@ class ParameterExtractorNode(BaseNode): elif parameter.type in {"string", "select"}: if isinstance(result[parameter.name], str): transformed_result[parameter.name] = result[parameter.name] - elif parameter.type.startswith("array"): + elif parameter.is_array_type(): if isinstance(result[parameter.name], list): - nested_type = parameter.type[6:-1] - transformed_result[parameter.name] = [] + nested_type = parameter.element_type() + assert nested_type is not None + segment_value = build_segment_with_type(segment_type=SegmentType(parameter.type), value=[]) + transformed_result[parameter.name] = segment_value for item in result[parameter.name]: if nested_type == "number": if isinstance(item, int | float): - transformed_result[parameter.name].append(item) + segment_value.value.append(item) elif isinstance(item, str): try: if "." in item: - transformed_result[parameter.name].append(float(item)) + segment_value.value.append(float(item)) else: - transformed_result[parameter.name].append(int(item)) + segment_value.value.append(int(item)) except ValueError: pass elif nested_type == "string": if isinstance(item, str): - transformed_result[parameter.name].append(item) + segment_value.value.append(item) elif nested_type == "object": if isinstance(item, dict): - transformed_result[parameter.name].append(item) + segment_value.value.append(item) if parameter.name not in transformed_result: if parameter.type == "number": @@ -615,7 +623,9 @@ class ParameterExtractorNode(BaseNode): elif parameter.type in {"string", "select"}: transformed_result[parameter.name] = "" elif parameter.type.startswith("array"): - transformed_result[parameter.name] = [] + transformed_result[parameter.name] = build_segment_with_type( + segment_type=SegmentType(parameter.type), value=[] + ) return transformed_result diff --git a/api/core/workflow/nodes/start/start_node.py b/api/core/workflow/nodes/start/start_node.py index 8839aec9d..5ee9bc331 100644 --- a/api/core/workflow/nodes/start/start_node.py +++ b/api/core/workflow/nodes/start/start_node.py @@ -10,6 +10,10 @@ class StartNode(BaseNode[StartNodeData]): _node_data_cls = StartNodeData _node_type = NodeType.START + @classmethod + def version(cls) -> str: + return "1" + def _run(self) -> NodeRunResult: node_inputs = dict(self.graph_runtime_state.variable_pool.user_inputs) system_inputs = self.graph_runtime_state.variable_pool.system_variables @@ -18,5 +22,6 @@ class StartNode(BaseNode[StartNodeData]): # Set system variables as node outputs. for var in system_inputs: node_inputs[SYSTEM_VARIABLE_NODE_ID + "." + var] = system_inputs[var] + outputs = dict(node_inputs) - return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=node_inputs, outputs=node_inputs) + return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=node_inputs, outputs=outputs) diff --git a/api/core/workflow/nodes/template_transform/template_transform_node.py b/api/core/workflow/nodes/template_transform/template_transform_node.py index 476cf7eee..ba573074c 100644 --- a/api/core/workflow/nodes/template_transform/template_transform_node.py +++ b/api/core/workflow/nodes/template_transform/template_transform_node.py @@ -28,6 +28,10 @@ class TemplateTransformNode(BaseNode[TemplateTransformNodeData]): "config": {"variables": [{"variable": "arg1", "value_selector": []}], "template": "{{ arg1 }}"}, } + @classmethod + def version(cls) -> str: + return "1" + def _run(self) -> NodeRunResult: # Get variables variables = {} diff --git a/api/core/workflow/nodes/tool/tool_node.py b/api/core/workflow/nodes/tool/tool_node.py index aaecc7b98..aa15d6993 100644 --- a/api/core/workflow/nodes/tool/tool_node.py +++ b/api/core/workflow/nodes/tool/tool_node.py @@ -12,7 +12,7 @@ from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter from core.tools.errors import ToolInvokeError from core.tools.tool_engine import ToolEngine from core.tools.utils.message_transformer import ToolFileMessageTransformer -from core.variables.segments import ArrayAnySegment +from core.variables.segments import ArrayAnySegment, ArrayFileSegment from core.variables.variables import ArrayAnyVariable from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.variable_pool import VariablePool @@ -44,6 +44,10 @@ class ToolNode(BaseNode[ToolNodeData]): _node_data_cls = ToolNodeData _node_type = NodeType.TOOL + @classmethod + def version(cls) -> str: + return "1" + def _run(self) -> Generator: """ Run the tool node @@ -300,6 +304,7 @@ class ToolNode(BaseNode[ToolNodeData]): variables[variable_name] = variable_value elif message.type == ToolInvokeMessage.MessageType.FILE: assert message.meta is not None + assert isinstance(message.meta, File) files.append(message.meta["file"]) elif message.type == ToolInvokeMessage.MessageType.LOG: assert isinstance(message.message, ToolInvokeMessage.LogMessage) @@ -363,7 +368,7 @@ class ToolNode(BaseNode[ToolNodeData]): yield RunCompletedEvent( run_result=NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, - outputs={"text": text, "files": files, "json": json, **variables}, + outputs={"text": text, "files": ArrayFileSegment(value=files), "json": json, **variables}, metadata={ **agent_execution_metadata, WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info, diff --git a/api/core/workflow/nodes/variable_aggregator/variable_aggregator_node.py b/api/core/workflow/nodes/variable_aggregator/variable_aggregator_node.py index db3e25b01..96bb3e793 100644 --- a/api/core/workflow/nodes/variable_aggregator/variable_aggregator_node.py +++ b/api/core/workflow/nodes/variable_aggregator/variable_aggregator_node.py @@ -1,3 +1,6 @@ +from collections.abc import Mapping + +from core.variables.segments import Segment from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus from core.workflow.nodes.base import BaseNode @@ -9,16 +12,20 @@ class VariableAggregatorNode(BaseNode[VariableAssignerNodeData]): _node_data_cls = VariableAssignerNodeData _node_type = NodeType.VARIABLE_AGGREGATOR + @classmethod + def version(cls) -> str: + return "1" + def _run(self) -> NodeRunResult: # Get variables - outputs = {} + outputs: dict[str, Segment | Mapping[str, Segment]] = {} inputs = {} if not self.node_data.advanced_settings or not self.node_data.advanced_settings.group_enabled: for selector in self.node_data.variables: variable = self.graph_runtime_state.variable_pool.get(selector) if variable is not None: - outputs = {"output": variable.to_object()} + outputs = {"output": variable} inputs = {".".join(selector[1:]): variable.to_object()} break @@ -28,7 +35,7 @@ class VariableAggregatorNode(BaseNode[VariableAssignerNodeData]): variable = self.graph_runtime_state.variable_pool.get(selector) if variable is not None: - outputs[group.group_name] = {"output": variable.to_object()} + outputs[group.group_name] = {"output": variable} inputs[".".join(selector[1:])] = variable.to_object() break diff --git a/api/core/workflow/nodes/variable_assigner/common/helpers.py b/api/core/workflow/nodes/variable_assigner/common/helpers.py index 8031b57fa..0d2822233 100644 --- a/api/core/workflow/nodes/variable_assigner/common/helpers.py +++ b/api/core/workflow/nodes/variable_assigner/common/helpers.py @@ -1,19 +1,55 @@ -from sqlalchemy import select -from sqlalchemy.orm import Session +from collections.abc import Mapping, MutableMapping, Sequence +from typing import Any, TypeVar -from core.variables import Variable -from core.workflow.nodes.variable_assigner.common.exc import VariableOperatorNodeError -from extensions.ext_database import db -from models import ConversationVariable +from pydantic import BaseModel + +from core.variables import Segment +from core.variables.consts import MIN_SELECTORS_LENGTH +from core.variables.types import SegmentType + +# Use double underscore (`__`) prefix for internal variables +# to minimize risk of collision with user-defined variable names. +_UPDATED_VARIABLES_KEY = "__updated_variables" -def update_conversation_variable(conversation_id: str, variable: Variable): - stmt = select(ConversationVariable).where( - ConversationVariable.id == variable.id, ConversationVariable.conversation_id == conversation_id +class UpdatedVariable(BaseModel): + name: str + selector: Sequence[str] + value_type: SegmentType + new_value: Any + + +_T = TypeVar("_T", bound=MutableMapping[str, Any]) + + +def variable_to_processed_data(selector: Sequence[str], seg: Segment) -> UpdatedVariable: + if len(selector) < MIN_SELECTORS_LENGTH: + raise Exception("selector too short") + node_id, var_name = selector[:2] + return UpdatedVariable( + name=var_name, + selector=list(selector[:2]), + value_type=seg.value_type, + new_value=seg.value, ) - with Session(db.engine) as session: - row = session.scalar(stmt) - if not row: - raise VariableOperatorNodeError("conversation variable not found in the database") - row.data = variable.model_dump_json() - session.commit() + + +def set_updated_variables(m: _T, updates: Sequence[UpdatedVariable]) -> _T: + m[_UPDATED_VARIABLES_KEY] = updates + return m + + +def get_updated_variables(m: Mapping[str, Any]) -> Sequence[UpdatedVariable] | None: + updated_values = m.get(_UPDATED_VARIABLES_KEY, None) + if updated_values is None: + return None + result = [] + for items in updated_values: + if isinstance(items, UpdatedVariable): + result.append(items) + elif isinstance(items, dict): + items = UpdatedVariable.model_validate(items) + result.append(items) + else: + raise TypeError(f"Invalid updated variable: {items}, type={type(items)}") + return result diff --git a/api/core/workflow/nodes/variable_assigner/common/impl.py b/api/core/workflow/nodes/variable_assigner/common/impl.py new file mode 100644 index 000000000..8f7a44bb6 --- /dev/null +++ b/api/core/workflow/nodes/variable_assigner/common/impl.py @@ -0,0 +1,38 @@ +from sqlalchemy import Engine, select +from sqlalchemy.orm import Session + +from core.variables.variables import Variable +from models.engine import db +from models.workflow import ConversationVariable + +from .exc import VariableOperatorNodeError + + +class ConversationVariableUpdaterImpl: + _engine: Engine | None + + def __init__(self, engine: Engine | None = None) -> None: + self._engine = engine + + def _get_engine(self) -> Engine: + if self._engine: + return self._engine + return db.engine + + def update(self, conversation_id: str, variable: Variable): + stmt = select(ConversationVariable).where( + ConversationVariable.id == variable.id, ConversationVariable.conversation_id == conversation_id + ) + with Session(self._get_engine()) as session: + row = session.scalar(stmt) + if not row: + raise VariableOperatorNodeError("conversation variable not found in the database") + row.data = variable.model_dump_json() + session.commit() + + def flush(self): + pass + + +def conversation_variable_updater_factory() -> ConversationVariableUpdaterImpl: + return ConversationVariableUpdaterImpl() diff --git a/api/core/workflow/nodes/variable_assigner/v1/node.py b/api/core/workflow/nodes/variable_assigner/v1/node.py index 835e1d77b..be5083c9c 100644 --- a/api/core/workflow/nodes/variable_assigner/v1/node.py +++ b/api/core/workflow/nodes/variable_assigner/v1/node.py @@ -1,4 +1,9 @@ +from collections.abc import Callable, Mapping, Sequence +from typing import TYPE_CHECKING, Any, Optional, TypeAlias + from core.variables import SegmentType, Variable +from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID +from core.workflow.conversation_variable_updater import ConversationVariableUpdater from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus from core.workflow.nodes.base import BaseNode @@ -7,16 +12,71 @@ from core.workflow.nodes.variable_assigner.common import helpers as common_helpe from core.workflow.nodes.variable_assigner.common.exc import VariableOperatorNodeError from factories import variable_factory +from ..common.impl import conversation_variable_updater_factory from .node_data import VariableAssignerData, WriteMode +if TYPE_CHECKING: + from core.workflow.graph_engine import Graph, GraphInitParams, GraphRuntimeState + + +_CONV_VAR_UPDATER_FACTORY: TypeAlias = Callable[[], ConversationVariableUpdater] + class VariableAssignerNode(BaseNode[VariableAssignerData]): _node_data_cls = VariableAssignerData _node_type = NodeType.VARIABLE_ASSIGNER + _conv_var_updater_factory: _CONV_VAR_UPDATER_FACTORY + + def __init__( + self, + id: str, + config: Mapping[str, Any], + graph_init_params: "GraphInitParams", + graph: "Graph", + graph_runtime_state: "GraphRuntimeState", + previous_node_id: Optional[str] = None, + thread_pool_id: Optional[str] = None, + conv_var_updater_factory: _CONV_VAR_UPDATER_FACTORY = conversation_variable_updater_factory, + ) -> None: + super().__init__( + id=id, + config=config, + graph_init_params=graph_init_params, + graph=graph, + graph_runtime_state=graph_runtime_state, + previous_node_id=previous_node_id, + thread_pool_id=thread_pool_id, + ) + self._conv_var_updater_factory = conv_var_updater_factory + + @classmethod + def version(cls) -> str: + return "1" + + @classmethod + def _extract_variable_selector_to_variable_mapping( + cls, + *, + graph_config: Mapping[str, Any], + node_id: str, + node_data: VariableAssignerData, + ) -> Mapping[str, Sequence[str]]: + mapping = {} + assigned_variable_node_id = node_data.assigned_variable_selector[0] + if assigned_variable_node_id == CONVERSATION_VARIABLE_NODE_ID: + selector_key = ".".join(node_data.assigned_variable_selector) + key = f"{node_id}.#{selector_key}#" + mapping[key] = node_data.assigned_variable_selector + + selector_key = ".".join(node_data.input_variable_selector) + key = f"{node_id}.#{selector_key}#" + mapping[key] = node_data.input_variable_selector + return mapping def _run(self) -> NodeRunResult: + assigned_variable_selector = self.node_data.assigned_variable_selector # Should be String, Number, Object, ArrayString, ArrayNumber, ArrayObject - original_variable = self.graph_runtime_state.variable_pool.get(self.node_data.assigned_variable_selector) + original_variable = self.graph_runtime_state.variable_pool.get(assigned_variable_selector) if not isinstance(original_variable, Variable): raise VariableOperatorNodeError("assigned variable not found") @@ -44,20 +104,28 @@ class VariableAssignerNode(BaseNode[VariableAssignerData]): raise VariableOperatorNodeError(f"unsupported write mode: {self.node_data.write_mode}") # Over write the variable. - self.graph_runtime_state.variable_pool.add(self.node_data.assigned_variable_selector, updated_variable) + self.graph_runtime_state.variable_pool.add(assigned_variable_selector, updated_variable) # TODO: Move database operation to the pipeline. # Update conversation variable. conversation_id = self.graph_runtime_state.variable_pool.get(["sys", "conversation_id"]) if not conversation_id: raise VariableOperatorNodeError("conversation_id not found") - common_helpers.update_conversation_variable(conversation_id=conversation_id.text, variable=updated_variable) + conv_var_updater = self._conv_var_updater_factory() + conv_var_updater.update(conversation_id=conversation_id.text, variable=updated_variable) + conv_var_updater.flush() + updated_variables = [common_helpers.variable_to_processed_data(assigned_variable_selector, updated_variable)] return NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs={ "value": income_value.to_object(), }, + # NOTE(QuantumGhost): although only one variable is updated in `v1.VariableAssignerNode`, + # we still set `output_variables` as a list to ensure the schema of output is + # compatible with `v2.VariableAssignerNode`. + process_data=common_helpers.set_updated_variables({}, updated_variables), + outputs={}, ) diff --git a/api/core/workflow/nodes/variable_assigner/v2/entities.py b/api/core/workflow/nodes/variable_assigner/v2/entities.py index 01df33b6d..d93affcd1 100644 --- a/api/core/workflow/nodes/variable_assigner/v2/entities.py +++ b/api/core/workflow/nodes/variable_assigner/v2/entities.py @@ -12,6 +12,12 @@ class VariableOperationItem(BaseModel): variable_selector: Sequence[str] input_type: InputType operation: Operation + # NOTE(QuantumGhost): The `value` field serves multiple purposes depending on context: + # + # 1. For CONSTANT input_type: Contains the literal value to be used in the operation. + # 2. For VARIABLE input_type: Initially contains the selector of the source variable. + # 3. During the variable updating procedure: The `value` field is reassigned to hold + # the resolved actual value that will be applied to the target variable. value: Any | None = None diff --git a/api/core/workflow/nodes/variable_assigner/v2/exc.py b/api/core/workflow/nodes/variable_assigner/v2/exc.py index b67af6d73..fd6c304a9 100644 --- a/api/core/workflow/nodes/variable_assigner/v2/exc.py +++ b/api/core/workflow/nodes/variable_assigner/v2/exc.py @@ -29,3 +29,8 @@ class InvalidInputValueError(VariableOperatorNodeError): class ConversationIDNotFoundError(VariableOperatorNodeError): def __init__(self): super().__init__("conversation_id not found") + + +class InvalidDataError(VariableOperatorNodeError): + def __init__(self, message: str) -> None: + super().__init__(message) diff --git a/api/core/workflow/nodes/variable_assigner/v2/node.py b/api/core/workflow/nodes/variable_assigner/v2/node.py index 8759a55b3..9292da6f1 100644 --- a/api/core/workflow/nodes/variable_assigner/v2/node.py +++ b/api/core/workflow/nodes/variable_assigner/v2/node.py @@ -1,34 +1,84 @@ import json -from collections.abc import Sequence -from typing import Any, cast +from collections.abc import Callable, Mapping, MutableMapping, Sequence +from typing import Any, TypeAlias, cast from core.app.entities.app_invoke_entities import InvokeFrom from core.variables import SegmentType, Variable +from core.variables.consts import MIN_SELECTORS_LENGTH from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID +from core.workflow.conversation_variable_updater import ConversationVariableUpdater from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus from core.workflow.nodes.base import BaseNode from core.workflow.nodes.enums import NodeType from core.workflow.nodes.variable_assigner.common import helpers as common_helpers from core.workflow.nodes.variable_assigner.common.exc import VariableOperatorNodeError +from core.workflow.nodes.variable_assigner.common.impl import conversation_variable_updater_factory from . import helpers from .constants import EMPTY_VALUE_MAPPING -from .entities import VariableAssignerNodeData +from .entities import VariableAssignerNodeData, VariableOperationItem from .enums import InputType, Operation from .exc import ( ConversationIDNotFoundError, InputTypeNotSupportedError, + InvalidDataError, InvalidInputValueError, OperationNotSupportedError, VariableNotFoundError, ) +_CONV_VAR_UPDATER_FACTORY: TypeAlias = Callable[[], ConversationVariableUpdater] + + +def _target_mapping_from_item(mapping: MutableMapping[str, Sequence[str]], node_id: str, item: VariableOperationItem): + selector_node_id = item.variable_selector[0] + if selector_node_id != CONVERSATION_VARIABLE_NODE_ID: + return + selector_str = ".".join(item.variable_selector) + key = f"{node_id}.#{selector_str}#" + mapping[key] = item.variable_selector + + +def _source_mapping_from_item(mapping: MutableMapping[str, Sequence[str]], node_id: str, item: VariableOperationItem): + # Keep this in sync with the logic in _run methods... + if item.input_type != InputType.VARIABLE: + return + selector = item.value + if not isinstance(selector, list): + raise InvalidDataError(f"selector is not a list, {node_id=}, {item=}") + if len(selector) < MIN_SELECTORS_LENGTH: + raise InvalidDataError(f"selector too short, {node_id=}, {item=}") + selector_str = ".".join(selector) + key = f"{node_id}.#{selector_str}#" + mapping[key] = selector + class VariableAssignerNode(BaseNode[VariableAssignerNodeData]): _node_data_cls = VariableAssignerNodeData _node_type = NodeType.VARIABLE_ASSIGNER + def _conv_var_updater_factory(self) -> ConversationVariableUpdater: + return conversation_variable_updater_factory() + + @classmethod + def version(cls) -> str: + return "2" + + @classmethod + def _extract_variable_selector_to_variable_mapping( + cls, + *, + graph_config: Mapping[str, Any], + node_id: str, + node_data: VariableAssignerNodeData, + ) -> Mapping[str, Sequence[str]]: + var_mapping: dict[str, Sequence[str]] = {} + for item in node_data.items: + _target_mapping_from_item(var_mapping, node_id, item) + _source_mapping_from_item(var_mapping, node_id, item) + return var_mapping + def _run(self) -> NodeRunResult: inputs = self.node_data.model_dump() process_data: dict[str, Any] = {} @@ -114,6 +164,7 @@ class VariableAssignerNode(BaseNode[VariableAssignerNodeData]): # remove the duplicated items first. updated_variable_selectors = list(set(map(tuple, updated_variable_selectors))) + conv_var_updater = self._conv_var_updater_factory() # Update variables for selector in updated_variable_selectors: variable = self.graph_runtime_state.variable_pool.get(selector) @@ -128,15 +179,23 @@ class VariableAssignerNode(BaseNode[VariableAssignerNodeData]): raise ConversationIDNotFoundError else: conversation_id = conversation_id.value - common_helpers.update_conversation_variable( + conv_var_updater.update( conversation_id=cast(str, conversation_id), variable=variable, ) + conv_var_updater.flush() + updated_variables = [ + common_helpers.variable_to_processed_data(selector, seg) + for selector in updated_variable_selectors + if (seg := self.graph_runtime_state.variable_pool.get(selector)) is not None + ] + process_data = common_helpers.set_updated_variables(process_data, updated_variables) return NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=inputs, process_data=process_data, + outputs={}, ) def _handle_item( diff --git a/api/core/workflow/variable_loader.py b/api/core/workflow/variable_loader.py new file mode 100644 index 000000000..4842ee00a --- /dev/null +++ b/api/core/workflow/variable_loader.py @@ -0,0 +1,79 @@ +import abc +from collections.abc import Mapping, Sequence +from typing import Any, Protocol + +from core.variables import Variable +from core.workflow.entities.variable_pool import VariablePool + + +class VariableLoader(Protocol): + """Interface for loading variables based on selectors. + + A `VariableLoader` is responsible for retrieving additional variables required during the execution + of a single node, which are not provided as user inputs. + + NOTE(QuantumGhost): Typically, all variables loaded by a `VariableLoader` should belong to the same + application and share the same `app_id`. However, this interface does not enforce that constraint, + and the `app_id` parameter is intentionally omitted from `load_variables` to achieve separation of + concern and allow for flexible implementations. + + Implementations of `VariableLoader` should almost always have an `app_id` parameter in + their constructor. + + TODO(QuantumGhost): this is a temporally workaround. If we can move the creation of node instance into + `WorkflowService.single_step_run`, we may get rid of this interface. + """ + + @abc.abstractmethod + def load_variables(self, selectors: list[list[str]]) -> list[Variable]: + """Load variables based on the provided selectors. If the selectors are empty, + this method should return an empty list. + + The order of the returned variables is not guaranteed. If the caller wants to ensure + a specific order, they should sort the returned list themselves. + + :param: selectors: a list of string list, each inner list should have at least two elements: + - the first element is the node ID, + - the second element is the variable name. + :return: a list of Variable objects that match the provided selectors. + """ + pass + + +class _DummyVariableLoader(VariableLoader): + """A dummy implementation of VariableLoader that does not load any variables. + Serves as a placeholder when no variable loading is needed. + """ + + def load_variables(self, selectors: list[list[str]]) -> list[Variable]: + return [] + + +DUMMY_VARIABLE_LOADER = _DummyVariableLoader() + + +def load_into_variable_pool( + variable_loader: VariableLoader, + variable_pool: VariablePool, + variable_mapping: Mapping[str, Sequence[str]], + user_inputs: Mapping[str, Any], +): + # Loading missing variable from draft var here, and set it into + # variable_pool. + variables_to_load: list[list[str]] = [] + for key, selector in variable_mapping.items(): + # NOTE(QuantumGhost): this logic needs to be in sync with + # `WorkflowEntry.mapping_user_inputs_to_variable_pool`. + node_variable_list = key.split(".") + if len(node_variable_list) < 1: + raise ValueError(f"Invalid variable key: {key}. It should have at least one element.") + if key in user_inputs: + continue + node_variable_key = ".".join(node_variable_list[1:]) + if node_variable_key in user_inputs: + continue + if variable_pool.get(selector) is None: + variables_to_load.append(list(selector)) + loaded = variable_loader.load_variables(variables_to_load) + for var in loaded: + variable_pool.add(var.selector, var) diff --git a/api/core/workflow/workflow_cycle_manager.py b/api/core/workflow/workflow_cycle_manager.py index b88f9edd0..6ee562fc8 100644 --- a/api/core/workflow/workflow_cycle_manager.py +++ b/api/core/workflow/workflow_cycle_manager.py @@ -92,7 +92,7 @@ class WorkflowCycleManager: ) -> WorkflowExecution: workflow_execution = self._get_workflow_execution_or_raise_error(workflow_run_id) - outputs = WorkflowEntry.handle_special_values(outputs) + # outputs = WorkflowEntry.handle_special_values(outputs) workflow_execution.status = WorkflowExecutionStatus.SUCCEEDED workflow_execution.outputs = outputs or {} @@ -125,7 +125,7 @@ class WorkflowCycleManager: trace_manager: Optional[TraceQueueManager] = None, ) -> WorkflowExecution: execution = self._get_workflow_execution_or_raise_error(workflow_run_id) - outputs = WorkflowEntry.handle_special_values(dict(outputs) if outputs else None) + # outputs = WorkflowEntry.handle_special_values(dict(outputs) if outputs else None) execution.status = WorkflowExecutionStatus.PARTIAL_SUCCEEDED execution.outputs = outputs or {} @@ -242,9 +242,9 @@ class WorkflowCycleManager: raise ValueError(f"Domain node execution not found: {event.node_execution_id}") # Process data - inputs = WorkflowEntry.handle_special_values(event.inputs) - process_data = WorkflowEntry.handle_special_values(event.process_data) - outputs = WorkflowEntry.handle_special_values(event.outputs) + inputs = event.inputs + process_data = event.process_data + outputs = event.outputs # Convert metadata keys to strings execution_metadata_dict = {} @@ -289,7 +289,7 @@ class WorkflowCycleManager: # Process data inputs = WorkflowEntry.handle_special_values(event.inputs) process_data = WorkflowEntry.handle_special_values(event.process_data) - outputs = WorkflowEntry.handle_special_values(event.outputs) + outputs = event.outputs # Convert metadata keys to strings execution_metadata_dict = {} @@ -326,7 +326,7 @@ class WorkflowCycleManager: finished_at = datetime.now(UTC).replace(tzinfo=None) elapsed_time = (finished_at - created_at).total_seconds() inputs = WorkflowEntry.handle_special_values(event.inputs) - outputs = WorkflowEntry.handle_special_values(event.outputs) + outputs = event.outputs # Convert metadata keys to strings origin_metadata = { diff --git a/api/core/workflow/workflow_entry.py b/api/core/workflow/workflow_entry.py index 7648947fc..182c54fa7 100644 --- a/api/core/workflow/workflow_entry.py +++ b/api/core/workflow/workflow_entry.py @@ -21,6 +21,7 @@ from core.workflow.nodes import NodeType from core.workflow.nodes.base import BaseNode from core.workflow.nodes.event import NodeEvent from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING +from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader, load_into_variable_pool from factories import file_factory from models.enums import UserFrom from models.workflow import ( @@ -119,7 +120,9 @@ class WorkflowEntry: workflow: Workflow, node_id: str, user_id: str, - user_inputs: dict, + user_inputs: Mapping[str, Any], + variable_pool: VariablePool, + variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER, ) -> tuple[BaseNode, Generator[NodeEvent | InNodeEvent, None, None]]: """ Single step run workflow node @@ -129,29 +132,14 @@ class WorkflowEntry: :param user_inputs: user inputs :return: """ - # fetch node info from workflow graph - workflow_graph = workflow.graph_dict - if not workflow_graph: - raise ValueError("workflow graph not found") - - nodes = workflow_graph.get("nodes") - if not nodes: - raise ValueError("nodes not found in workflow graph") - - # fetch node config from node id - try: - node_config = next(filter(lambda node: node["id"] == node_id, nodes)) - except StopIteration: - raise ValueError("node id not found in workflow graph") + node_config = workflow.get_node_config_by_id(node_id) + node_config_data = node_config.get("data", {}) # Get node class - node_type = NodeType(node_config.get("data", {}).get("type")) - node_version = node_config.get("data", {}).get("version", "1") + node_type = NodeType(node_config_data.get("type")) + node_version = node_config_data.get("version", "1") node_cls = NODE_TYPE_CLASSES_MAPPING[node_type][node_version] - # init variable pool - variable_pool = VariablePool(environment_variables=workflow.environment_variables) - # init graph graph = Graph.init(graph_config=workflow.graph_dict) @@ -182,16 +170,33 @@ class WorkflowEntry: except NotImplementedError: variable_mapping = {} + # Loading missing variable from draft var here, and set it into + # variable_pool. + load_into_variable_pool( + variable_loader=variable_loader, + variable_pool=variable_pool, + variable_mapping=variable_mapping, + user_inputs=user_inputs, + ) + cls.mapping_user_inputs_to_variable_pool( variable_mapping=variable_mapping, user_inputs=user_inputs, variable_pool=variable_pool, tenant_id=workflow.tenant_id, ) + try: # run node generator = node_instance.run() except Exception as e: + logger.exception( + "error while running node_instance, workflow_id=%s, node_id=%s, type=%s, version=%s", + workflow.id, + node_instance.id, + node_instance.node_type, + node_instance.version(), + ) raise WorkflowNodeRunFailedError(node_instance=node_instance, error=str(e)) return node_instance, generator @@ -294,10 +299,20 @@ class WorkflowEntry: return node_instance, generator except Exception as e: + logger.exception( + "error while running node_instance, workflow_id=%s, node_id=%s, type=%s, version=%s", + node_instance.id, + node_instance.node_type, + node_instance.version(), + ) raise WorkflowNodeRunFailedError(node_instance=node_instance, error=str(e)) @staticmethod def handle_special_values(value: Optional[Mapping[str, Any]]) -> Mapping[str, Any] | None: + # NOTE(QuantumGhost): Avoid using this function in new code. + # Keep values structured as long as possible and only convert to dict + # immediately before serialization (e.g., JSON serialization) to maintain + # data integrity and type information. result = WorkflowEntry._handle_special_values(value) return result if isinstance(result, Mapping) or result is None else dict(result) @@ -324,10 +339,17 @@ class WorkflowEntry: cls, *, variable_mapping: Mapping[str, Sequence[str]], - user_inputs: dict, + user_inputs: Mapping[str, Any], variable_pool: VariablePool, tenant_id: str, ) -> None: + # NOTE(QuantumGhost): This logic should remain synchronized with + # the implementation of `load_into_variable_pool`, specifically the logic about + # variable existence checking. + + # WARNING(QuantumGhost): The semantics of this method are not clearly defined, + # and multiple parts of the codebase depend on its current behavior. + # Modify with caution. for node_variable, variable_selector in variable_mapping.items(): # fetch node id and variable key from node_variable node_variable_list = node_variable.split(".") diff --git a/api/core/workflow/workflow_type_encoder.py b/api/core/workflow/workflow_type_encoder.py new file mode 100644 index 000000000..0123fdac1 --- /dev/null +++ b/api/core/workflow/workflow_type_encoder.py @@ -0,0 +1,49 @@ +import json +from collections.abc import Mapping +from typing import Any + +from pydantic import BaseModel + +from core.file.models import File +from core.variables import Segment + + +class WorkflowRuntimeTypeEncoder(json.JSONEncoder): + def default(self, o: Any): + if isinstance(o, Segment): + return o.value + elif isinstance(o, File): + return o.to_dict() + elif isinstance(o, BaseModel): + return o.model_dump(mode="json") + else: + return super().default(o) + + +class WorkflowRuntimeTypeConverter: + def to_json_encodable(self, value: Mapping[str, Any] | None) -> Mapping[str, Any] | None: + result = self._to_json_encodable_recursive(value) + return result if isinstance(result, Mapping) or result is None else dict(result) + + def _to_json_encodable_recursive(self, value: Any) -> Any: + if value is None: + return value + if isinstance(value, (bool, int, str, float)): + return value + if isinstance(value, Segment): + return self._to_json_encodable_recursive(value.value) + if isinstance(value, File): + return value.to_dict() + if isinstance(value, BaseModel): + return value.model_dump(mode="json") + if isinstance(value, dict): + res = {} + for k, v in value.items(): + res[k] = self._to_json_encodable_recursive(v) + return res + if isinstance(value, list): + res_list = [] + for item in value: + res_list.append(self._to_json_encodable_recursive(item)) + return res_list + return value diff --git a/api/factories/file_factory.py b/api/factories/file_factory.py index 52f119936..e0beef40c 100644 --- a/api/factories/file_factory.py +++ b/api/factories/file_factory.py @@ -5,6 +5,7 @@ from typing import Any, cast import httpx from sqlalchemy import select +from sqlalchemy.orm import Session from constants import AUDIO_EXTENSIONS, DOCUMENT_EXTENSIONS, IMAGE_EXTENSIONS, VIDEO_EXTENSIONS from core.file import File, FileBelongsTo, FileTransferMethod, FileType, FileUploadConfig, helpers @@ -91,6 +92,8 @@ def build_from_mappings( tenant_id: str, strict_type_validation: bool = False, ) -> Sequence[File]: + # TODO(QuantumGhost): Performance concern - each mapping triggers a separate database query. + # Implement batch processing to reduce database load when handling multiple files. files = [ build_from_mapping( mapping=mapping, @@ -377,3 +380,75 @@ def _get_file_type_by_mimetype(mime_type: str) -> FileType | None: def get_file_type_by_mime_type(mime_type: str) -> FileType: return _get_file_type_by_mimetype(mime_type) or FileType.CUSTOM + + +class StorageKeyLoader: + """FileKeyLoader load the storage key from database for a list of files. + This loader is batched, the + """ + + def __init__(self, session: Session, tenant_id: str) -> None: + self._session = session + self._tenant_id = tenant_id + + def _load_upload_files(self, upload_file_ids: Sequence[uuid.UUID]) -> Mapping[uuid.UUID, UploadFile]: + stmt = select(UploadFile).where( + UploadFile.id.in_(upload_file_ids), + UploadFile.tenant_id == self._tenant_id, + ) + + return {uuid.UUID(i.id): i for i in self._session.scalars(stmt)} + + def _load_tool_files(self, tool_file_ids: Sequence[uuid.UUID]) -> Mapping[uuid.UUID, ToolFile]: + stmt = select(ToolFile).where( + ToolFile.id.in_(tool_file_ids), + ToolFile.tenant_id == self._tenant_id, + ) + return {uuid.UUID(i.id): i for i in self._session.scalars(stmt)} + + def load_storage_keys(self, files: Sequence[File]): + """Loads storage keys for a sequence of files by retrieving the corresponding + `UploadFile` or `ToolFile` records from the database based on their transfer method. + + This method doesn't modify the input sequence structure but updates the `_storage_key` + property of each file object by extracting the relevant key from its database record. + + Performance note: This is a batched operation where database query count remains constant + regardless of input size. However, for optimal performance, input sequences should contain + fewer than 1000 files. For larger collections, split into smaller batches and process each + batch separately. + """ + + upload_file_ids: list[uuid.UUID] = [] + tool_file_ids: list[uuid.UUID] = [] + for file in files: + related_model_id = file.related_id + if file.related_id is None: + raise ValueError("file id should not be None.") + if file.tenant_id != self._tenant_id: + err_msg = ( + f"invalid file, expected tenant_id={self._tenant_id}, " + f"got tenant_id={file.tenant_id}, file_id={file.id}, related_model_id={related_model_id}" + ) + raise ValueError(err_msg) + model_id = uuid.UUID(related_model_id) + + if file.transfer_method in (FileTransferMethod.LOCAL_FILE, FileTransferMethod.REMOTE_URL): + upload_file_ids.append(model_id) + elif file.transfer_method == FileTransferMethod.TOOL_FILE: + tool_file_ids.append(model_id) + + tool_files = self._load_tool_files(tool_file_ids) + upload_files = self._load_upload_files(upload_file_ids) + for file in files: + model_id = uuid.UUID(file.related_id) + if file.transfer_method in (FileTransferMethod.LOCAL_FILE, FileTransferMethod.REMOTE_URL): + upload_file_row = upload_files.get(model_id) + if upload_file_row is None: + raise ValueError(...) + file._storage_key = upload_file_row.key + elif file.transfer_method == FileTransferMethod.TOOL_FILE: + tool_file_row = tool_files.get(model_id) + if tool_file_row is None: + raise ValueError(...) + file._storage_key = tool_file_row.file_key diff --git a/api/factories/variable_factory.py b/api/factories/variable_factory.py index a41ef4ae4..250ee4695 100644 --- a/api/factories/variable_factory.py +++ b/api/factories/variable_factory.py @@ -43,6 +43,10 @@ class UnsupportedSegmentTypeError(Exception): pass +class TypeMismatchError(Exception): + pass + + # Define the constant SEGMENT_TO_VARIABLE_MAP = { StringSegment: StringVariable, @@ -110,6 +114,10 @@ def _build_variable_from_mapping(*, mapping: Mapping[str, Any], selector: Sequen return cast(Variable, result) +def infer_segment_type_from_value(value: Any, /) -> SegmentType: + return build_segment(value).value_type + + def build_segment(value: Any, /) -> Segment: if value is None: return NoneSegment() @@ -140,10 +148,80 @@ def build_segment(value: Any, /) -> Segment: case SegmentType.NONE: return ArrayAnySegment(value=value) case _: + # This should be unreachable. raise ValueError(f"not supported value {value}") raise ValueError(f"not supported value {value}") +def build_segment_with_type(segment_type: SegmentType, value: Any) -> Segment: + """ + Build a segment with explicit type checking. + + This function creates a segment from a value while enforcing type compatibility + with the specified segment_type. It provides stricter type validation compared + to the standard build_segment function. + + Args: + segment_type: The expected SegmentType for the resulting segment + value: The value to be converted into a segment + + Returns: + Segment: A segment instance of the appropriate type + + Raises: + TypeMismatchError: If the value type doesn't match the expected segment_type + + Special Cases: + - For empty list [] values, if segment_type is array[*], returns the corresponding array type + - Type validation is performed before segment creation + + Examples: + >>> build_segment_with_type(SegmentType.STRING, "hello") + StringSegment(value="hello") + + >>> build_segment_with_type(SegmentType.ARRAY_STRING, []) + ArrayStringSegment(value=[]) + + >>> build_segment_with_type(SegmentType.STRING, 123) + # Raises TypeMismatchError + """ + # Handle None values + if value is None: + if segment_type == SegmentType.NONE: + return NoneSegment() + else: + raise TypeMismatchError(f"Expected {segment_type}, but got None") + + # Handle empty list special case for array types + if isinstance(value, list) and len(value) == 0: + if segment_type == SegmentType.ARRAY_ANY: + return ArrayAnySegment(value=value) + elif segment_type == SegmentType.ARRAY_STRING: + return ArrayStringSegment(value=value) + elif segment_type == SegmentType.ARRAY_NUMBER: + return ArrayNumberSegment(value=value) + elif segment_type == SegmentType.ARRAY_OBJECT: + return ArrayObjectSegment(value=value) + elif segment_type == SegmentType.ARRAY_FILE: + return ArrayFileSegment(value=value) + else: + raise TypeMismatchError(f"Expected {segment_type}, but got empty list") + + # Build segment using existing logic to infer actual type + inferred_segment = build_segment(value) + inferred_type = inferred_segment.value_type + + # Type compatibility checking + if inferred_type == segment_type: + return inferred_segment + + # Type mismatch - raise error with descriptive message + raise TypeMismatchError( + f"Type mismatch: expected {segment_type}, but value '{value}' " + f"(type: {type(value).__name__}) corresponds to {inferred_type}" + ) + + def segment_to_variable( *, segment: Segment, diff --git a/api/libs/datetime_utils.py b/api/libs/datetime_utils.py new file mode 100644 index 000000000..e576a3462 --- /dev/null +++ b/api/libs/datetime_utils.py @@ -0,0 +1,22 @@ +import abc +import datetime +from typing import Protocol + + +class _NowFunction(Protocol): + @abc.abstractmethod + def __call__(self, tz: datetime.timezone | None) -> datetime.datetime: + pass + + +# _now_func is a callable with the _NowFunction signature. +# Its sole purpose is to abstract time retrieval, enabling +# developers to mock this behavior in tests and time-dependent scenarios. +_now_func: _NowFunction = datetime.datetime.now + + +def naive_utc_now() -> datetime.datetime: + """Return a naive datetime object (without timezone information) + representing current UTC time. + """ + return _now_func(datetime.UTC).replace(tzinfo=None) diff --git a/api/libs/jsonutil.py b/api/libs/jsonutil.py new file mode 100644 index 000000000..fa2967103 --- /dev/null +++ b/api/libs/jsonutil.py @@ -0,0 +1,11 @@ +import json + +from pydantic import BaseModel + + +class PydanticModelEncoder(json.JSONEncoder): + def default(self, o): + if isinstance(o, BaseModel): + return o.model_dump() + else: + super().default(o) diff --git a/api/models/_workflow_exc.py b/api/models/_workflow_exc.py new file mode 100644 index 000000000..f6271bda4 --- /dev/null +++ b/api/models/_workflow_exc.py @@ -0,0 +1,20 @@ +"""All these exceptions are not meant to be caught by callers.""" + + +class WorkflowDataError(Exception): + """Base class for all workflow data related exceptions. + + This should be used to indicate issues with workflow data integrity, such as + no `graph` configuration, missing `nodes` field in `graph` configuration, or + similar issues. + """ + + pass + + +class NodeNotFoundError(WorkflowDataError): + """Raised when a node with the specified ID is not found in the workflow.""" + + def __init__(self, node_id: str): + super().__init__(f"Node with ID '{node_id}' not found in the workflow.") + self.node_id = node_id diff --git a/api/models/model.py b/api/models/model.py index fa83baa9c..75202d936 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -611,6 +611,14 @@ class InstalledApp(Base): return tenant +class ConversationSource(StrEnum): + """This enumeration is designed for use with `Conversation.from_source`.""" + + # NOTE(QuantumGhost): The enumeration members may not cover all possible cases. + API = "api" + CONSOLE = "console" + + class Conversation(Base): __tablename__ = "conversations" __table_args__ = ( @@ -632,7 +640,14 @@ class Conversation(Base): system_instruction = db.Column(db.Text) system_instruction_tokens = db.Column(db.Integer, nullable=False, server_default=db.text("0")) status = db.Column(db.String(255), nullable=False) + + # The `invoke_from` records how the conversation is created. + # + # Its value corresponds to the members of `InvokeFrom`. + # (api/core/app/entities/app_invoke_entities.py) invoke_from = db.Column(db.String(255), nullable=True) + + # ref: ConversationSource. from_source = db.Column(db.String(255), nullable=False) from_end_user_id = db.Column(StringUUID) from_account_id = db.Column(StringUUID) diff --git a/api/models/workflow.py b/api/models/workflow.py index 1733dec0f..7f01135af 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -7,10 +7,16 @@ from typing import TYPE_CHECKING, Any, Optional, Union from uuid import uuid4 from flask_login import current_user +from sqlalchemy import orm +from core.file.constants import maybe_file_object +from core.file.models import File from core.variables import utils as variable_utils from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID -from factories.variable_factory import build_segment +from core.workflow.nodes.enums import NodeType +from factories.variable_factory import TypeMismatchError, build_segment_with_type + +from ._workflow_exc import NodeNotFoundError, WorkflowDataError if TYPE_CHECKING: from models.model import AppMode @@ -72,6 +78,10 @@ class WorkflowType(Enum): return cls.WORKFLOW if app_mode == AppMode.WORKFLOW else cls.CHAT +class _InvalidGraphDefinitionError(Exception): + pass + + class Workflow(Base): """ Workflow, for `Workflow App` and `Chat App workflow mode`. @@ -136,6 +146,8 @@ class Workflow(Base): "conversation_variables", db.Text, nullable=False, server_default="{}" ) + VERSION_DRAFT = "draft" + @classmethod def new( cls, @@ -179,8 +191,72 @@ class Workflow(Base): @property def graph_dict(self) -> Mapping[str, Any]: + # TODO(QuantumGhost): Consider caching `graph_dict` to avoid repeated JSON decoding. + # + # Using `functools.cached_property` could help, but some code in the codebase may + # modify the returned dict, which can cause issues elsewhere. + # + # For example, changing this property to a cached property led to errors like the + # following when single stepping an `Iteration` node: + # + # Root node id 1748401971780start not found in the graph + # + # There is currently no standard way to make a dict deeply immutable in Python, + # and tracking modifications to the returned dict is difficult. For now, we leave + # the code as-is to avoid these issues. + # + # Currently, the following functions / methods would mutate the returned dict: + # + # - `_get_graph_and_variable_pool_of_single_iteration`. + # - `_get_graph_and_variable_pool_of_single_loop`. return json.loads(self.graph) if self.graph else {} + def get_node_config_by_id(self, node_id: str) -> Mapping[str, Any]: + """Extract a node configuration from the workflow graph by node ID. + A node configuration is a dictionary containing the node's properties, including + the node's id, title, and its data as a dict. + """ + workflow_graph = self.graph_dict + + if not workflow_graph: + raise WorkflowDataError(f"workflow graph not found, workflow_id={self.id}") + + nodes = workflow_graph.get("nodes") + if not nodes: + raise WorkflowDataError("nodes not found in workflow graph") + + try: + node_config = next(filter(lambda node: node["id"] == node_id, nodes)) + except StopIteration: + raise NodeNotFoundError(node_id) + assert isinstance(node_config, dict) + return node_config + + @staticmethod + def get_node_type_from_node_config(node_config: Mapping[str, Any]) -> NodeType: + """Extract type of a node from the node configuration returned by `get_node_config_by_id`.""" + node_config_data = node_config.get("data", {}) + # Get node class + node_type = NodeType(node_config_data.get("type")) + return node_type + + @staticmethod + def get_enclosing_node_type_and_id(node_config: Mapping[str, Any]) -> tuple[NodeType, str] | None: + in_loop = node_config.get("isInLoop", False) + in_iteration = node_config.get("isInIteration", False) + if in_loop: + loop_id = node_config.get("loop_id") + if loop_id is None: + raise _InvalidGraphDefinitionError("invalid graph") + return NodeType.LOOP, loop_id + elif in_iteration: + iteration_id = node_config.get("iteration_id") + if iteration_id is None: + raise _InvalidGraphDefinitionError("invalid graph") + return NodeType.ITERATION, iteration_id + else: + return None + @property def features(self) -> str: """ @@ -376,6 +452,10 @@ class Workflow(Base): ensure_ascii=False, ) + @staticmethod + def version_from_datetime(d: datetime) -> str: + return str(d) + class WorkflowRun(Base): """ @@ -835,8 +915,18 @@ def _naive_utc_datetime(): class WorkflowDraftVariable(Base): + """`WorkflowDraftVariable` record variables and outputs generated during + debugging worfklow or chatflow. + + IMPORTANT: This model maintains multiple invariant rules that must be preserved. + Do not instantiate this class directly with the constructor. + + Instead, use the factory methods (`new_conversation_variable`, `new_sys_variable`, + `new_node_variable`) defined below to ensure all invariants are properly maintained. + """ + @staticmethod - def unique_columns() -> list[str]: + def unique_app_id_node_id_name() -> list[str]: return [ "app_id", "node_id", @@ -844,7 +934,9 @@ class WorkflowDraftVariable(Base): ] __tablename__ = "workflow_draft_variables" - __table_args__ = (UniqueConstraint(*unique_columns()),) + __table_args__ = (UniqueConstraint(*unique_app_id_node_id_name()),) + # Required for instance variable annotation. + __allow_unmapped__ = True # id is the unique identifier of a draft variable. id: Mapped[str] = mapped_column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()")) @@ -925,6 +1017,36 @@ class WorkflowDraftVariable(Base): default=None, ) + # Cache for deserialized value + # + # NOTE(QuantumGhost): This field serves two purposes: + # + # 1. Caches deserialized values to reduce repeated parsing costs + # 2. Allows modification of the deserialized value after retrieval, + # particularly important for `File`` variables which require database + # lookups to obtain storage_key and other metadata + # + # Use double underscore prefix for better encapsulation, + # making this attribute harder to access from outside the class. + __value: Segment | None + + def __init__(self, *args, **kwargs): + """ + The constructor of `WorkflowDraftVariable` is not intended for + direct use outside this file. Its solo purpose is setup private state + used by the model instance. + + Please use the factory methods + (`new_conversation_variable`, `new_sys_variable`, `new_node_variable`) + defined below to create instances of this class. + """ + super().__init__(*args, **kwargs) + self.__value = None + + @orm.reconstructor + def _init_on_load(self): + self.__value = None + def get_selector(self) -> list[str]: selector = json.loads(self.selector) if not isinstance(selector, list): @@ -939,15 +1061,92 @@ class WorkflowDraftVariable(Base): def _set_selector(self, value: list[str]): self.selector = json.dumps(value) - def get_value(self) -> Segment | None: - return build_segment(json.loads(self.value)) + def _loads_value(self) -> Segment: + value = json.loads(self.value) + return self.build_segment_with_type(self.value_type, value) + + @staticmethod + def rebuild_file_types(value: Any) -> Any: + # NOTE(QuantumGhost): Temporary workaround for structured data handling. + # By this point, `output` has been converted to dict by + # `WorkflowEntry.handle_special_values`, so we need to + # reconstruct File objects from their serialized form + # to maintain proper variable saving behavior. + # + # Ideally, we should work with structured data objects directly + # rather than their serialized forms. + # However, multiple components in the codebase depend on + # `WorkflowEntry.handle_special_values`, making a comprehensive migration challenging. + if isinstance(value, dict): + if not maybe_file_object(value): + return value + return File.model_validate(value) + elif isinstance(value, list) and value: + first = value[0] + if not maybe_file_object(first): + return value + return [File.model_validate(i) for i in value] + else: + return value + + @classmethod + def build_segment_with_type(cls, segment_type: SegmentType, value: Any) -> Segment: + # Extends `variable_factory.build_segment_with_type` functionality by + # reconstructing `FileSegment`` or `ArrayFileSegment`` objects from + # their serialized dictionary or list representations, respectively. + if segment_type == SegmentType.FILE: + if isinstance(value, File): + return build_segment_with_type(segment_type, value) + elif isinstance(value, dict): + file = cls.rebuild_file_types(value) + return build_segment_with_type(segment_type, file) + else: + raise TypeMismatchError(f"expected dict or File for FileSegment, got {type(value)}") + if segment_type == SegmentType.ARRAY_FILE: + if not isinstance(value, list): + raise TypeMismatchError(f"expected list for ArrayFileSegment, got {type(value)}") + file_list = cls.rebuild_file_types(value) + return build_segment_with_type(segment_type=segment_type, value=file_list) + + return build_segment_with_type(segment_type=segment_type, value=value) + + def get_value(self) -> Segment: + """Decode the serialized value into its corresponding `Segment` object. + + This method caches the result, so repeated calls will return the same + object instance without re-parsing the serialized data. + + If you need to modify the returned `Segment`, use `value.model_copy()` + to create a copy first to avoid affecting the cached instance. + + For more information about the caching mechanism, see the documentation + of the `__value` field. + + Returns: + Segment: The deserialized value as a Segment object. + """ + + if self.__value is not None: + return self.__value + value = self._loads_value() + self.__value = value + return value def set_name(self, name: str): self.name = name self._set_selector([self.node_id, name]) def set_value(self, value: Segment): - self.value = json.dumps(value.value) + """Updates the `value` and corresponding `value_type` fields in the database model. + + This method also stores the provided Segment object in the deserialized cache + without creating a copy, allowing for efficient value access. + + Args: + value: The Segment object to store as the variable's value. + """ + self.__value = value + self.value = json.dumps(value, cls=variable_utils.SegmentJSONEncoder) self.value_type = value.value_type def get_node_id(self) -> str | None: @@ -973,6 +1172,7 @@ class WorkflowDraftVariable(Base): node_id: str, name: str, value: Segment, + node_execution_id: str | None, description: str = "", ) -> "WorkflowDraftVariable": variable = WorkflowDraftVariable() @@ -984,6 +1184,7 @@ class WorkflowDraftVariable(Base): variable.name = name variable.set_value(value) variable._set_selector(list(variable_utils.to_selector(node_id, name))) + variable.node_execution_id = node_execution_id return variable @classmethod @@ -993,13 +1194,17 @@ class WorkflowDraftVariable(Base): app_id: str, name: str, value: Segment, + description: str = "", ) -> "WorkflowDraftVariable": variable = cls._new( app_id=app_id, node_id=CONVERSATION_VARIABLE_NODE_ID, name=name, value=value, + description=description, + node_execution_id=None, ) + variable.editable = True return variable @classmethod @@ -1009,9 +1214,16 @@ class WorkflowDraftVariable(Base): app_id: str, name: str, value: Segment, + node_execution_id: str, editable: bool = False, ) -> "WorkflowDraftVariable": - variable = cls._new(app_id=app_id, node_id=SYSTEM_VARIABLE_NODE_ID, name=name, value=value) + variable = cls._new( + app_id=app_id, + node_id=SYSTEM_VARIABLE_NODE_ID, + name=name, + node_execution_id=node_execution_id, + value=value, + ) variable.editable = editable return variable @@ -1023,11 +1235,19 @@ class WorkflowDraftVariable(Base): node_id: str, name: str, value: Segment, + node_execution_id: str, visible: bool = True, + editable: bool = True, ) -> "WorkflowDraftVariable": - variable = cls._new(app_id=app_id, node_id=node_id, name=name, value=value) + variable = cls._new( + app_id=app_id, + node_id=node_id, + name=name, + node_execution_id=node_execution_id, + value=value, + ) variable.visible = visible - variable.editable = True + variable.editable = editable return variable @property diff --git a/api/pyproject.toml b/api/pyproject.toml index 38cc9ae75..fed0128b9 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -149,6 +149,7 @@ dev = [ "types-ujson>=5.10.0", "boto3-stubs>=1.38.20", "types-jmespath>=1.0.2.20240106", + "hypothesis>=6.131.15", "types_pyOpenSSL>=24.1.0", "types_cffi>=1.17.0", "types_setuptools>=80.9.0", diff --git a/api/services/app_dsl_service.py b/api/services/app_dsl_service.py index 1b026acfd..20257fa34 100644 --- a/api/services/app_dsl_service.py +++ b/api/services/app_dsl_service.py @@ -32,6 +32,7 @@ from models import Account, App, AppMode from models.model import AppModelConfig from models.workflow import Workflow from services.plugin.dependencies_analysis import DependenciesAnalysisService +from services.workflow_draft_variable_service import WorkflowDraftVariableService from services.workflow_service import WorkflowService logger = logging.getLogger(__name__) @@ -292,6 +293,8 @@ class AppDslService: dependencies=check_dependencies_pending_data, ) + draft_var_srv = WorkflowDraftVariableService(session=self._session) + draft_var_srv.delete_workflow_variables(app_id=app.id) return Import( id=import_id, status=status, diff --git a/api/services/errors/app.py b/api/services/errors/app.py index 87e9e9247..5d348c61b 100644 --- a/api/services/errors/app.py +++ b/api/services/errors/app.py @@ -4,3 +4,7 @@ class MoreLikeThisDisabledError(Exception): class WorkflowHashNotEqualError(Exception): pass + + +class IsDraftWorkflowError(Exception): + pass diff --git a/api/services/workflow_draft_variable_service.py b/api/services/workflow_draft_variable_service.py new file mode 100644 index 000000000..cd30440b4 --- /dev/null +++ b/api/services/workflow_draft_variable_service.py @@ -0,0 +1,721 @@ +import dataclasses +import datetime +import logging +from collections.abc import Mapping, Sequence +from enum import StrEnum +from typing import Any, ClassVar + +from sqlalchemy import Engine, orm, select +from sqlalchemy.dialects.postgresql import insert +from sqlalchemy.orm import Session +from sqlalchemy.sql.expression import and_, or_ + +from core.app.entities.app_invoke_entities import InvokeFrom +from core.file.models import File +from core.variables import Segment, StringSegment, Variable +from core.variables.consts import MIN_SELECTORS_LENGTH +from core.variables.segments import ArrayFileSegment, FileSegment +from core.variables.types import SegmentType +from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID +from core.workflow.enums import SystemVariableKey +from core.workflow.nodes import NodeType +from core.workflow.nodes.variable_assigner.common.helpers import get_updated_variables +from core.workflow.variable_loader import VariableLoader +from factories.file_factory import StorageKeyLoader +from factories.variable_factory import build_segment, segment_to_variable +from models import App, Conversation +from models.enums import DraftVariableType +from models.workflow import Workflow, WorkflowDraftVariable, WorkflowNodeExecutionModel, is_system_variable_editable + +_logger = logging.getLogger(__name__) + + +@dataclasses.dataclass(frozen=True) +class WorkflowDraftVariableList: + variables: list[WorkflowDraftVariable] + total: int | None = None + + +class WorkflowDraftVariableError(Exception): + pass + + +class VariableResetError(WorkflowDraftVariableError): + pass + + +class UpdateNotSupportedError(WorkflowDraftVariableError): + pass + + +class DraftVarLoader(VariableLoader): + # This implements the VariableLoader interface for loading draft variables. + # + # ref: core.workflow.variable_loader.VariableLoader + + # Database engine used for loading variables. + _engine: Engine + # Application ID for which variables are being loaded. + _app_id: str + _tenant_id: str + _fallback_variables: Sequence[Variable] + + def __init__( + self, + engine: Engine, + app_id: str, + tenant_id: str, + fallback_variables: Sequence[Variable] | None = None, + ) -> None: + self._engine = engine + self._app_id = app_id + self._tenant_id = tenant_id + self._fallback_variables = fallback_variables or [] + + def _selector_to_tuple(self, selector: Sequence[str]) -> tuple[str, str]: + return (selector[0], selector[1]) + + def load_variables(self, selectors: list[list[str]]) -> list[Variable]: + if not selectors: + return [] + + # Map each selector (as a tuple via `_selector_to_tuple`) to its corresponding Variable instance. + variable_by_selector: dict[tuple[str, str], Variable] = {} + + with Session(bind=self._engine, expire_on_commit=False) as session: + srv = WorkflowDraftVariableService(session) + draft_vars = srv.get_draft_variables_by_selectors(self._app_id, selectors) + + for draft_var in draft_vars: + segment = draft_var.get_value() + variable = segment_to_variable( + segment=segment, + selector=draft_var.get_selector(), + id=draft_var.id, + name=draft_var.name, + description=draft_var.description, + ) + selector_tuple = self._selector_to_tuple(variable.selector) + variable_by_selector[selector_tuple] = variable + + # Important: + files: list[File] = [] + for draft_var in draft_vars: + value = draft_var.get_value() + if isinstance(value, FileSegment): + files.append(value.value) + elif isinstance(value, ArrayFileSegment): + files.extend(value.value) + with Session(bind=self._engine) as session: + storage_key_loader = StorageKeyLoader(session, tenant_id=self._tenant_id) + storage_key_loader.load_storage_keys(files) + + return list(variable_by_selector.values()) + + +class WorkflowDraftVariableService: + _session: Session + + def __init__(self, session: Session) -> None: + self._session = session + + def get_variable(self, variable_id: str) -> WorkflowDraftVariable | None: + return self._session.query(WorkflowDraftVariable).filter(WorkflowDraftVariable.id == variable_id).first() + + def get_draft_variables_by_selectors( + self, + app_id: str, + selectors: Sequence[list[str]], + ) -> list[WorkflowDraftVariable]: + ors = [] + for selector in selectors: + node_id, name = selector + ors.append(and_(WorkflowDraftVariable.node_id == node_id, WorkflowDraftVariable.name == name)) + + # NOTE(QuantumGhost): Although the number of `or` expressions may be large, as long as + # each expression includes conditions on both `node_id` and `name` (which are covered by the unique index), + # PostgreSQL can efficiently retrieve the results using a bitmap index scan. + # + # Alternatively, a `SELECT` statement could be constructed for each selector and + # combined using `UNION` to fetch all rows. + # Benchmarking indicates that both approaches yield comparable performance. + variables = ( + self._session.query(WorkflowDraftVariable).where(WorkflowDraftVariable.app_id == app_id, or_(*ors)).all() + ) + return variables + + def list_variables_without_values(self, app_id: str, page: int, limit: int) -> WorkflowDraftVariableList: + criteria = WorkflowDraftVariable.app_id == app_id + total = None + query = self._session.query(WorkflowDraftVariable).filter(criteria) + if page == 1: + total = query.count() + variables = ( + # Do not load the `value` field. + query.options(orm.defer(WorkflowDraftVariable.value)) + .order_by(WorkflowDraftVariable.id.desc()) + .limit(limit) + .offset((page - 1) * limit) + .all() + ) + + return WorkflowDraftVariableList(variables=variables, total=total) + + def _list_node_variables(self, app_id: str, node_id: str) -> WorkflowDraftVariableList: + criteria = ( + WorkflowDraftVariable.app_id == app_id, + WorkflowDraftVariable.node_id == node_id, + ) + query = self._session.query(WorkflowDraftVariable).filter(*criteria) + variables = query.order_by(WorkflowDraftVariable.id.desc()).all() + return WorkflowDraftVariableList(variables=variables) + + def list_node_variables(self, app_id: str, node_id: str) -> WorkflowDraftVariableList: + return self._list_node_variables(app_id, node_id) + + def list_conversation_variables(self, app_id: str) -> WorkflowDraftVariableList: + return self._list_node_variables(app_id, CONVERSATION_VARIABLE_NODE_ID) + + def list_system_variables(self, app_id: str) -> WorkflowDraftVariableList: + return self._list_node_variables(app_id, SYSTEM_VARIABLE_NODE_ID) + + def get_conversation_variable(self, app_id: str, name: str) -> WorkflowDraftVariable | None: + return self._get_variable(app_id=app_id, node_id=CONVERSATION_VARIABLE_NODE_ID, name=name) + + def get_system_variable(self, app_id: str, name: str) -> WorkflowDraftVariable | None: + return self._get_variable(app_id=app_id, node_id=SYSTEM_VARIABLE_NODE_ID, name=name) + + def get_node_variable(self, app_id: str, node_id: str, name: str) -> WorkflowDraftVariable | None: + return self._get_variable(app_id, node_id, name) + + def _get_variable(self, app_id: str, node_id: str, name: str) -> WorkflowDraftVariable | None: + variable = ( + self._session.query(WorkflowDraftVariable) + .where( + WorkflowDraftVariable.app_id == app_id, + WorkflowDraftVariable.node_id == node_id, + WorkflowDraftVariable.name == name, + ) + .first() + ) + return variable + + def update_variable( + self, + variable: WorkflowDraftVariable, + name: str | None = None, + value: Segment | None = None, + ) -> WorkflowDraftVariable: + if not variable.editable: + raise UpdateNotSupportedError(f"variable not support updating, id={variable.id}") + if name is not None: + variable.set_name(name) + if value is not None: + variable.set_value(value) + variable.last_edited_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + self._session.flush() + return variable + + def _reset_conv_var(self, workflow: Workflow, variable: WorkflowDraftVariable) -> WorkflowDraftVariable | None: + conv_var_by_name = {i.name: i for i in workflow.conversation_variables} + conv_var = conv_var_by_name.get(variable.name) + + if conv_var is None: + self._session.delete(instance=variable) + self._session.flush() + _logger.warning( + "Conversation variable not found for draft variable, id=%s, name=%s", variable.id, variable.name + ) + return None + + variable.set_value(conv_var) + variable.last_edited_at = None + self._session.add(variable) + self._session.flush() + return variable + + def _reset_node_var(self, workflow: Workflow, variable: WorkflowDraftVariable) -> WorkflowDraftVariable | None: + # If a variable does not allow updating, it makes no sence to resetting it. + if not variable.editable: + return variable + # No execution record for this variable, delete the variable instead. + if variable.node_execution_id is None: + self._session.delete(instance=variable) + self._session.flush() + _logger.warning("draft variable has no node_execution_id, id=%s, name=%s", variable.id, variable.name) + return None + + query = select(WorkflowNodeExecutionModel).where(WorkflowNodeExecutionModel.id == variable.node_execution_id) + node_exec = self._session.scalars(query).first() + if node_exec is None: + _logger.warning( + "Node exectution not found for draft variable, id=%s, name=%s, node_execution_id=%s", + variable.id, + variable.name, + variable.node_execution_id, + ) + self._session.delete(instance=variable) + self._session.flush() + return None + + # Get node type for proper value extraction + node_config = workflow.get_node_config_by_id(variable.node_id) + node_type = workflow.get_node_type_from_node_config(node_config) + + outputs_dict = node_exec.outputs_dict or {} + + # Note: Based on the implementation in `_build_from_variable_assigner_mapping`, + # VariableAssignerNode (both v1 and v2) can only create conversation draft variables. + # For consistency, we should simply return when processing VARIABLE_ASSIGNER nodes. + # + # This implementation must remain synchronized with the `_build_from_variable_assigner_mapping` + # and `save` methods. + if node_type == NodeType.VARIABLE_ASSIGNER: + return variable + + if variable.name not in outputs_dict: + # If variable not found in execution data, delete the variable + self._session.delete(instance=variable) + self._session.flush() + return None + value = outputs_dict[variable.name] + value_seg = WorkflowDraftVariable.build_segment_with_type(variable.value_type, value) + # Extract variable value using unified logic + variable.set_value(value_seg) + variable.last_edited_at = None # Reset to indicate this is a reset operation + self._session.flush() + return variable + + def reset_variable(self, workflow: Workflow, variable: WorkflowDraftVariable) -> WorkflowDraftVariable | None: + variable_type = variable.get_variable_type() + if variable_type == DraftVariableType.CONVERSATION: + return self._reset_conv_var(workflow, variable) + elif variable_type == DraftVariableType.NODE: + return self._reset_node_var(workflow, variable) + else: + raise VariableResetError(f"cannot reset system variable, variable_id={variable.id}") + + def delete_variable(self, variable: WorkflowDraftVariable): + self._session.delete(variable) + + def delete_workflow_variables(self, app_id: str): + ( + self._session.query(WorkflowDraftVariable) + .filter(WorkflowDraftVariable.app_id == app_id) + .delete(synchronize_session=False) + ) + + def delete_node_variables(self, app_id: str, node_id: str): + return self._delete_node_variables(app_id, node_id) + + def _delete_node_variables(self, app_id: str, node_id: str): + self._session.query(WorkflowDraftVariable).where( + WorkflowDraftVariable.app_id == app_id, + WorkflowDraftVariable.node_id == node_id, + ).delete() + + def _get_conversation_id_from_draft_variable(self, app_id: str) -> str | None: + draft_var = self._get_variable( + app_id=app_id, + node_id=SYSTEM_VARIABLE_NODE_ID, + name=str(SystemVariableKey.CONVERSATION_ID), + ) + if draft_var is None: + return None + segment = draft_var.get_value() + if not isinstance(segment, StringSegment): + _logger.warning( + "sys.conversation_id variable is not a string: app_id=%s, id=%s", + app_id, + draft_var.id, + ) + return None + return segment.value + + def get_or_create_conversation( + self, + account_id: str, + app: App, + workflow: Workflow, + ) -> str: + """ + get_or_create_conversation creates and returns the ID of a conversation for debugging. + + If a conversation already exists, as determined by the following criteria, its ID is returned: + - The system variable `sys.conversation_id` exists in the draft variable table, and + - A corresponding conversation record is found in the database. + + If no such conversation exists, a new conversation is created and its ID is returned. + """ + conv_id = self._get_conversation_id_from_draft_variable(workflow.app_id) + + if conv_id is not None: + conversation = ( + self._session.query(Conversation) + .filter( + Conversation.id == conv_id, + Conversation.app_id == workflow.app_id, + ) + .first() + ) + # Only return the conversation ID if it exists and is valid (has a correspond conversation record in DB). + if conversation is not None: + return conv_id + conversation = Conversation( + app_id=workflow.app_id, + app_model_config_id=app.app_model_config_id, + model_provider=None, + model_id="", + override_model_configs=None, + mode=app.mode, + name="Draft Debugging Conversation", + inputs={}, + introduction="", + system_instruction="", + system_instruction_tokens=0, + status="normal", + invoke_from=InvokeFrom.DEBUGGER.value, + from_source="console", + from_end_user_id=None, + from_account_id=account_id, + ) + + self._session.add(conversation) + self._session.flush() + return conversation.id + + def prefill_conversation_variable_default_values(self, workflow: Workflow): + """""" + draft_conv_vars: list[WorkflowDraftVariable] = [] + for conv_var in workflow.conversation_variables: + draft_var = WorkflowDraftVariable.new_conversation_variable( + app_id=workflow.app_id, + name=conv_var.name, + value=conv_var, + description=conv_var.description, + ) + draft_conv_vars.append(draft_var) + _batch_upsert_draft_varaible( + self._session, + draft_conv_vars, + policy=_UpsertPolicy.IGNORE, + ) + + +class _UpsertPolicy(StrEnum): + IGNORE = "ignore" + OVERWRITE = "overwrite" + + +def _batch_upsert_draft_varaible( + session: Session, + draft_vars: Sequence[WorkflowDraftVariable], + policy: _UpsertPolicy = _UpsertPolicy.OVERWRITE, +) -> None: + if not draft_vars: + return None + # Although we could use SQLAlchemy ORM operations here, we choose not to for several reasons: + # + # 1. The variable saving process involves writing multiple rows to the + # `workflow_draft_variables` table. Batch insertion significantly improves performance. + # 2. Using the ORM would require either: + # + # a. Checking for the existence of each variable before insertion, + # resulting in 2n SQL statements for n variables and potential concurrency issues. + # b. Attempting insertion first, then updating if a unique index violation occurs, + # which still results in n to 2n SQL statements. + # + # Both approaches are inefficient and suboptimal. + # 3. We do not need to retrieve the results of the SQL execution or populate ORM + # model instances with the returned values. + # 4. Batch insertion with `ON CONFLICT DO UPDATE` allows us to insert or update all + # variables in a single SQL statement, avoiding the issues above. + # + # For these reasons, we use the SQLAlchemy query builder and rely on dialect-specific + # insert operations instead of the ORM layer. + stmt = insert(WorkflowDraftVariable).values([_model_to_insertion_dict(v) for v in draft_vars]) + if policy == _UpsertPolicy.OVERWRITE: + stmt = stmt.on_conflict_do_update( + index_elements=WorkflowDraftVariable.unique_app_id_node_id_name(), + set_={ + "updated_at": stmt.excluded.updated_at, + "last_edited_at": stmt.excluded.last_edited_at, + "description": stmt.excluded.description, + "value_type": stmt.excluded.value_type, + "value": stmt.excluded.value, + "visible": stmt.excluded.visible, + "editable": stmt.excluded.editable, + "node_execution_id": stmt.excluded.node_execution_id, + }, + ) + elif _UpsertPolicy.IGNORE: + stmt = stmt.on_conflict_do_nothing(index_elements=WorkflowDraftVariable.unique_app_id_node_id_name()) + else: + raise Exception("Invalid value for update policy.") + session.execute(stmt) + + +def _model_to_insertion_dict(model: WorkflowDraftVariable) -> dict[str, Any]: + d: dict[str, Any] = { + "app_id": model.app_id, + "last_edited_at": None, + "node_id": model.node_id, + "name": model.name, + "selector": model.selector, + "value_type": model.value_type, + "value": model.value, + "node_execution_id": model.node_execution_id, + } + if model.visible is not None: + d["visible"] = model.visible + if model.editable is not None: + d["editable"] = model.editable + if model.created_at is not None: + d["created_at"] = model.created_at + if model.updated_at is not None: + d["updated_at"] = model.updated_at + if model.description is not None: + d["description"] = model.description + return d + + +def _build_segment_for_serialized_values(v: Any) -> Segment: + """ + Reconstructs Segment objects from serialized values, with special handling + for FileSegment and ArrayFileSegment types. + + This function should only be used when: + 1. No explicit type information is available + 2. The input value is in serialized form (dict or list) + + It detects potential file objects in the serialized data and properly rebuilds the + appropriate segment type. + """ + return build_segment(WorkflowDraftVariable.rebuild_file_types(v)) + + +class DraftVariableSaver: + # _DUMMY_OUTPUT_IDENTITY is a placeholder output for workflow nodes. + # Its sole possible value is `None`. + # + # This is used to signal the execution of a workflow node when it has no other outputs. + _DUMMY_OUTPUT_IDENTITY: ClassVar[str] = "__dummy__" + _DUMMY_OUTPUT_VALUE: ClassVar[None] = None + + # _EXCLUDE_VARIABLE_NAMES_MAPPING maps node types and versions to variable names that + # should be excluded when saving draft variables. This prevents certain internal or + # technical variables from being exposed in the draft environment, particularly those + # that aren't meant to be directly edited or viewed by users. + _EXCLUDE_VARIABLE_NAMES_MAPPING: dict[NodeType, frozenset[str]] = { + NodeType.LLM: frozenset(["finish_reason"]), + NodeType.LOOP: frozenset(["loop_round"]), + } + + # Database session used for persisting draft variables. + _session: Session + + # The application ID associated with the draft variables. + # This should match the `Workflow.app_id` of the workflow to which the current node belongs. + _app_id: str + + # The ID of the node for which DraftVariableSaver is saving output variables. + _node_id: str + + # The type of the current node (see NodeType). + _node_type: NodeType + + # Indicates how the workflow execution was triggered (see InvokeFrom). + _invoke_from: InvokeFrom + + # + _node_execution_id: str + + # _enclosing_node_id identifies the container node that the current node belongs to. + # For example, if the current node is an LLM node inside an Iteration node + # or Loop node, then `_enclosing_node_id` refers to the ID of + # the containing Iteration or Loop node. + # + # If the current node is not nested within another node, `_enclosing_node_id` is + # `None`. + _enclosing_node_id: str | None + + def __init__( + self, + session: Session, + app_id: str, + node_id: str, + node_type: NodeType, + invoke_from: InvokeFrom, + node_execution_id: str, + enclosing_node_id: str | None = None, + ): + self._session = session + self._app_id = app_id + self._node_id = node_id + self._node_type = node_type + self._invoke_from = invoke_from + self._node_execution_id = node_execution_id + self._enclosing_node_id = enclosing_node_id + + def _create_dummy_output_variable(self): + return WorkflowDraftVariable.new_node_variable( + app_id=self._app_id, + node_id=self._node_id, + name=self._DUMMY_OUTPUT_IDENTITY, + node_execution_id=self._node_execution_id, + value=build_segment(self._DUMMY_OUTPUT_VALUE), + visible=False, + editable=False, + ) + + def _should_save_output_variables_for_draft(self) -> bool: + # Only save output variables for debugging execution of workflow. + if self._invoke_from != InvokeFrom.DEBUGGER: + return False + if self._enclosing_node_id is not None and self._node_type != NodeType.VARIABLE_ASSIGNER: + # Currently we do not save output variables for nodes inside loop or iteration. + return False + return True + + def _build_from_variable_assigner_mapping(self, process_data: Mapping[str, Any]) -> list[WorkflowDraftVariable]: + draft_vars: list[WorkflowDraftVariable] = [] + updated_variables = get_updated_variables(process_data) or [] + + for item in updated_variables: + selector = item.selector + if len(selector) < MIN_SELECTORS_LENGTH: + raise Exception("selector too short") + # NOTE(QuantumGhost): only the following two kinds of variable could be updated by + # VariableAssigner: ConversationVariable and iteration variable. + # We only save conversation variable here. + if selector[0] != CONVERSATION_VARIABLE_NODE_ID: + continue + segment = WorkflowDraftVariable.build_segment_with_type(segment_type=item.value_type, value=item.new_value) + draft_vars.append( + WorkflowDraftVariable.new_conversation_variable( + app_id=self._app_id, + name=item.name, + value=segment, + ) + ) + # Add a dummy output variable to indicate that this node is executed. + draft_vars.append(self._create_dummy_output_variable()) + return draft_vars + + def _build_variables_from_start_mapping(self, output: Mapping[str, Any]) -> list[WorkflowDraftVariable]: + draft_vars = [] + has_non_sys_variables = False + for name, value in output.items(): + value_seg = _build_segment_for_serialized_values(value) + node_id, name = self._normalize_variable_for_start_node(name) + # If node_id is not `sys`, it means that the variable is a user-defined input field + # in `Start` node. + if node_id != SYSTEM_VARIABLE_NODE_ID: + draft_vars.append( + WorkflowDraftVariable.new_node_variable( + app_id=self._app_id, + node_id=self._node_id, + name=name, + node_execution_id=self._node_execution_id, + value=value_seg, + visible=True, + editable=True, + ) + ) + has_non_sys_variables = True + else: + if name == SystemVariableKey.FILES: + # Here we know the type of variable must be `array[file]`, we + # just build files from the value. + files = [File.model_validate(v) for v in value] + if files: + value_seg = WorkflowDraftVariable.build_segment_with_type(SegmentType.ARRAY_FILE, files) + else: + value_seg = ArrayFileSegment(value=[]) + + draft_vars.append( + WorkflowDraftVariable.new_sys_variable( + app_id=self._app_id, + name=name, + node_execution_id=self._node_execution_id, + value=value_seg, + editable=self._should_variable_be_editable(node_id, name), + ) + ) + if not has_non_sys_variables: + draft_vars.append(self._create_dummy_output_variable()) + return draft_vars + + def _normalize_variable_for_start_node(self, name: str) -> tuple[str, str]: + if not name.startswith(f"{SYSTEM_VARIABLE_NODE_ID}."): + return self._node_id, name + _, name_ = name.split(".", maxsplit=1) + return SYSTEM_VARIABLE_NODE_ID, name_ + + def _build_variables_from_mapping(self, output: Mapping[str, Any]) -> list[WorkflowDraftVariable]: + draft_vars = [] + for name, value in output.items(): + if not self._should_variable_be_saved(name): + _logger.debug( + "Skip saving variable as it has been excluded by its node_type, name=%s, node_type=%s", + name, + self._node_type, + ) + continue + if isinstance(value, Segment): + value_seg = value + else: + value_seg = _build_segment_for_serialized_values(value) + draft_vars.append( + WorkflowDraftVariable.new_node_variable( + app_id=self._app_id, + node_id=self._node_id, + name=name, + node_execution_id=self._node_execution_id, + value=value_seg, + visible=self._should_variable_be_visible(self._node_id, self._node_type, name), + ) + ) + return draft_vars + + def save( + self, + process_data: Mapping[str, Any] | None = None, + outputs: Mapping[str, Any] | None = None, + ): + draft_vars: list[WorkflowDraftVariable] = [] + if outputs is None: + outputs = {} + if process_data is None: + process_data = {} + if not self._should_save_output_variables_for_draft(): + return + if self._node_type == NodeType.VARIABLE_ASSIGNER: + draft_vars = self._build_from_variable_assigner_mapping(process_data=process_data) + elif self._node_type == NodeType.START: + draft_vars = self._build_variables_from_start_mapping(outputs) + else: + draft_vars = self._build_variables_from_mapping(outputs) + _batch_upsert_draft_varaible(self._session, draft_vars) + + @staticmethod + def _should_variable_be_editable(node_id: str, name: str) -> bool: + if node_id in (CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID): + return False + if node_id == SYSTEM_VARIABLE_NODE_ID and not is_system_variable_editable(name): + return False + return True + + @staticmethod + def _should_variable_be_visible(node_id: str, node_type: NodeType, name: str) -> bool: + if node_type in NodeType.IF_ELSE: + return False + if node_id == SYSTEM_VARIABLE_NODE_ID and not is_system_variable_editable(name): + return False + return True + + def _should_variable_be_saved(self, name: str) -> bool: + exclude_var_names = self._EXCLUDE_VARIABLE_NAMES_MAPPING.get(self._node_type) + if exclude_var_names is None: + return True + return name not in exclude_var_names diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index bc213ccce..0fd94ac86 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -1,6 +1,7 @@ import json import time -from collections.abc import Callable, Generator, Sequence +import uuid +from collections.abc import Callable, Generator, Mapping, Sequence from datetime import UTC, datetime from typing import Any, Optional from uuid import uuid4 @@ -8,12 +9,17 @@ from uuid import uuid4 from sqlalchemy import select from sqlalchemy.orm import Session +from core.app.app_config.entities import VariableEntityType from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager +from core.app.entities.app_invoke_entities import InvokeFrom +from core.file import File from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository from core.variables import Variable from core.workflow.entities.node_entities import NodeRunResult +from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.workflow_node_execution import WorkflowNodeExecution, WorkflowNodeExecutionStatus +from core.workflow.enums import SystemVariableKey from core.workflow.errors import WorkflowNodeRunFailedError from core.workflow.graph_engine.entities.event import InNodeEvent from core.workflow.nodes import NodeType @@ -22,9 +28,11 @@ from core.workflow.nodes.enums import ErrorStrategy from core.workflow.nodes.event import RunCompletedEvent from core.workflow.nodes.event.types import NodeEvent from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING +from core.workflow.nodes.start.entities import StartNodeData from core.workflow.workflow_entry import WorkflowEntry from events.app_event import app_draft_workflow_was_synced, app_published_workflow_was_updated from extensions.ext_database import db +from factories.file_factory import build_from_mapping, build_from_mappings from models.account import Account from models.model import App, AppMode from models.tools import WorkflowToolProvider @@ -34,10 +42,15 @@ from models.workflow import ( WorkflowNodeExecutionTriggeredFrom, WorkflowType, ) -from services.errors.app import WorkflowHashNotEqualError +from services.errors.app import IsDraftWorkflowError, WorkflowHashNotEqualError from services.workflow.workflow_converter import WorkflowConverter from .errors.workflow_service import DraftWorkflowDeletionError, WorkflowInUseError +from .workflow_draft_variable_service import ( + DraftVariableSaver, + DraftVarLoader, + WorkflowDraftVariableService, +) class WorkflowService: @@ -45,6 +58,33 @@ class WorkflowService: Workflow Service """ + def get_node_last_run(self, app_model: App, workflow: Workflow, node_id: str) -> WorkflowNodeExecutionModel | None: + # TODO(QuantumGhost): This query is not fully covered by index. + criteria = ( + WorkflowNodeExecutionModel.tenant_id == app_model.tenant_id, + WorkflowNodeExecutionModel.app_id == app_model.id, + WorkflowNodeExecutionModel.workflow_id == workflow.id, + WorkflowNodeExecutionModel.node_id == node_id, + ) + node_exec = ( + db.session.query(WorkflowNodeExecutionModel) + .filter(*criteria) + .order_by(WorkflowNodeExecutionModel.created_at.desc()) + .first() + ) + return node_exec + + def is_workflow_exist(self, app_model: App) -> bool: + return ( + db.session.query(Workflow) + .filter( + Workflow.tenant_id == app_model.tenant_id, + Workflow.app_id == app_model.id, + Workflow.version == Workflow.VERSION_DRAFT, + ) + .count() + ) > 0 + def get_draft_workflow(self, app_model: App) -> Optional[Workflow]: """ Get draft workflow @@ -61,6 +101,23 @@ class WorkflowService: # return draft workflow return workflow + def get_published_workflow_by_id(self, app_model: App, workflow_id: str) -> Optional[Workflow]: + # fetch published workflow by workflow_id + workflow = ( + db.session.query(Workflow) + .filter( + Workflow.tenant_id == app_model.tenant_id, + Workflow.app_id == app_model.id, + Workflow.id == workflow_id, + ) + .first() + ) + if not workflow: + return None + if workflow.version == Workflow.VERSION_DRAFT: + raise IsDraftWorkflowError(f"Workflow is draft version, id={workflow_id}") + return workflow + def get_published_workflow(self, app_model: App) -> Optional[Workflow]: """ Get published workflow @@ -199,7 +256,7 @@ class WorkflowService: tenant_id=app_model.tenant_id, app_id=app_model.id, type=draft_workflow.type, - version=str(datetime.now(UTC).replace(tzinfo=None)), + version=Workflow.version_from_datetime(datetime.now(UTC).replace(tzinfo=None)), graph=draft_workflow.graph, features=draft_workflow.features, created_by=account.id, @@ -253,26 +310,85 @@ class WorkflowService: return default_config def run_draft_workflow_node( - self, app_model: App, node_id: str, user_inputs: dict, account: Account + self, + app_model: App, + draft_workflow: Workflow, + node_id: str, + user_inputs: Mapping[str, Any], + account: Account, + query: str = "", + files: Sequence[File] | None = None, ) -> WorkflowNodeExecutionModel: """ Run draft workflow node """ - # fetch draft workflow by app_model - draft_workflow = self.get_draft_workflow(app_model=app_model) - if not draft_workflow: - raise ValueError("Workflow not initialized") + files = files or [] + + with Session(bind=db.engine, expire_on_commit=False) as session, session.begin(): + draft_var_srv = WorkflowDraftVariableService(session) + draft_var_srv.prefill_conversation_variable_default_values(draft_workflow) + + node_config = draft_workflow.get_node_config_by_id(node_id) + node_type = Workflow.get_node_type_from_node_config(node_config) + node_data = node_config.get("data", {}) + if node_type == NodeType.START: + with Session(bind=db.engine) as session, session.begin(): + draft_var_srv = WorkflowDraftVariableService(session) + conversation_id = draft_var_srv.get_or_create_conversation( + account_id=account.id, + app=app_model, + workflow=draft_workflow, + ) + start_data = StartNodeData.model_validate(node_data) + user_inputs = _rebuild_file_for_user_inputs_in_start_node( + tenant_id=draft_workflow.tenant_id, start_node_data=start_data, user_inputs=user_inputs + ) + # init variable pool + variable_pool = _setup_variable_pool( + query=query, + files=files or [], + user_id=account.id, + user_inputs=user_inputs, + workflow=draft_workflow, + # NOTE(QuantumGhost): We rely on `DraftVarLoader` to load conversation variables. + conversation_variables=[], + node_type=node_type, + conversation_id=conversation_id, + ) + + else: + variable_pool = VariablePool( + system_variables={}, + user_inputs=user_inputs, + environment_variables=draft_workflow.environment_variables, + conversation_variables=[], + ) + + variable_loader = DraftVarLoader( + engine=db.engine, + app_id=app_model.id, + tenant_id=app_model.tenant_id, + ) + + eclosing_node_type_and_id = draft_workflow.get_enclosing_node_type_and_id(node_config) + if eclosing_node_type_and_id: + _, enclosing_node_id = eclosing_node_type_and_id + else: + enclosing_node_id = None + + run = WorkflowEntry.single_step_run( + workflow=draft_workflow, + node_id=node_id, + user_inputs=user_inputs, + user_id=account.id, + variable_pool=variable_pool, + variable_loader=variable_loader, + ) # run draft workflow node start_at = time.perf_counter() - node_execution = self._handle_node_run_result( - invoke_node_fn=lambda: WorkflowEntry.single_step_run( - workflow=draft_workflow, - node_id=node_id, - user_inputs=user_inputs, - user_id=account.id, - ), + invoke_node_fn=lambda: run, start_at=start_at, node_id=node_id, ) @@ -292,6 +408,18 @@ class WorkflowService: # Convert node_execution to WorkflowNodeExecution after save workflow_node_execution = repository.to_db_model(node_execution) + with Session(bind=db.engine) as session, session.begin(): + draft_var_saver = DraftVariableSaver( + session=session, + app_id=app_model.id, + node_id=workflow_node_execution.node_id, + node_type=NodeType(workflow_node_execution.node_type), + invoke_from=InvokeFrom.DEBUGGER, + enclosing_node_id=enclosing_node_id, + node_execution_id=node_execution.id, + ) + draft_var_saver.save(process_data=node_execution.process_data, outputs=node_execution.outputs) + session.commit() return workflow_node_execution def run_free_workflow_node( @@ -332,7 +460,7 @@ class WorkflowService: node_run_result = event.run_result # sign output files - node_run_result.outputs = WorkflowEntry.handle_special_values(node_run_result.outputs) + # node_run_result.outputs = WorkflowEntry.handle_special_values(node_run_result.outputs) break if not node_run_result: @@ -394,7 +522,7 @@ class WorkflowService: if node_run_result.process_data else None ) - outputs = WorkflowEntry.handle_special_values(node_run_result.outputs) if node_run_result.outputs else None + outputs = node_run_result.outputs node_execution.inputs = inputs node_execution.process_data = process_data @@ -531,3 +659,83 @@ class WorkflowService: session.delete(workflow) return True + + +def _setup_variable_pool( + query: str, + files: Sequence[File], + user_id: str, + user_inputs: Mapping[str, Any], + workflow: Workflow, + node_type: NodeType, + conversation_id: str, + conversation_variables: list[Variable], +): + # Only inject system variables for START node type. + if node_type == NodeType.START: + # Create a variable pool. + system_inputs: dict[SystemVariableKey, Any] = { + # From inputs: + SystemVariableKey.FILES: files, + SystemVariableKey.USER_ID: user_id, + # From workflow model + SystemVariableKey.APP_ID: workflow.app_id, + SystemVariableKey.WORKFLOW_ID: workflow.id, + # Randomly generated. + SystemVariableKey.WORKFLOW_EXECUTION_ID: str(uuid.uuid4()), + } + + # Only add chatflow-specific variables for non-workflow types + if workflow.type != WorkflowType.WORKFLOW.value: + system_inputs.update( + { + SystemVariableKey.QUERY: query, + SystemVariableKey.CONVERSATION_ID: conversation_id, + SystemVariableKey.DIALOGUE_COUNT: 0, + } + ) + else: + system_inputs = {} + + # init variable pool + variable_pool = VariablePool( + system_variables=system_inputs, + user_inputs=user_inputs, + environment_variables=workflow.environment_variables, + conversation_variables=conversation_variables, + ) + + return variable_pool + + +def _rebuild_file_for_user_inputs_in_start_node( + tenant_id: str, start_node_data: StartNodeData, user_inputs: Mapping[str, Any] +) -> Mapping[str, Any]: + inputs_copy = dict(user_inputs) + + for variable in start_node_data.variables: + if variable.type not in (VariableEntityType.FILE, VariableEntityType.FILE_LIST): + continue + if variable.variable not in user_inputs: + continue + value = user_inputs[variable.variable] + file = _rebuild_single_file(tenant_id=tenant_id, value=value, variable_entity_type=variable.type) + inputs_copy[variable.variable] = file + return inputs_copy + + +def _rebuild_single_file(tenant_id: str, value: Any, variable_entity_type: VariableEntityType) -> File | Sequence[File]: + if variable_entity_type == VariableEntityType.FILE: + if not isinstance(value, dict): + raise ValueError(f"expected dict for file object, got {type(value)}") + return build_from_mapping(mapping=value, tenant_id=tenant_id) + elif variable_entity_type == VariableEntityType.FILE_LIST: + if not isinstance(value, list): + raise ValueError(f"expected list for file list object, got {type(value)}") + if len(value) == 0: + return [] + if not isinstance(value[0], dict): + raise ValueError(f"expected dict for first element in the file list, got {type(value)}") + return build_from_mappings(mappings=value, tenant_id=tenant_id) + else: + raise Exception("unreachable") diff --git a/api/tests/integration_tests/.env.example b/api/tests/integration_tests/.env.example index 9e40a8494..4046096c2 100644 --- a/api/tests/integration_tests/.env.example +++ b/api/tests/integration_tests/.env.example @@ -1,107 +1,217 @@ -# OpenAI API Key -OPENAI_API_KEY= +FLASK_APP=app.py +FLASK_DEBUG=0 +SECRET_KEY='uhySf6a3aZuvRNfAlcr47paOw9TRYBY6j8ZHXpVw1yx5RP27Yj3w2uvI' -# Azure OpenAI API Base Endpoint & API Key -AZURE_OPENAI_API_BASE= -AZURE_OPENAI_API_KEY= +CONSOLE_API_URL=http://127.0.0.1:5001 +CONSOLE_WEB_URL=http://127.0.0.1:3000 -# Anthropic API Key -ANTHROPIC_API_KEY= +# Service API base URL +SERVICE_API_URL=http://127.0.0.1:5001 -# Replicate API Key -REPLICATE_API_KEY= +# Web APP base URL +APP_WEB_URL=http://127.0.0.1:3000 -# Hugging Face API Key -HUGGINGFACE_API_KEY= -HUGGINGFACE_TEXT_GEN_ENDPOINT_URL= -HUGGINGFACE_TEXT2TEXT_GEN_ENDPOINT_URL= -HUGGINGFACE_EMBEDDINGS_ENDPOINT_URL= +# Files URL +FILES_URL=http://127.0.0.1:5001 -# Minimax Credentials -MINIMAX_API_KEY= -MINIMAX_GROUP_ID= +# The time in seconds after the signature is rejected +FILES_ACCESS_TIMEOUT=300 -# Spark Credentials -SPARK_APP_ID= -SPARK_API_KEY= -SPARK_API_SECRET= +# Access token expiration time in minutes +ACCESS_TOKEN_EXPIRE_MINUTES=60 -# Tongyi Credentials -TONGYI_DASHSCOPE_API_KEY= +# Refresh token expiration time in days +REFRESH_TOKEN_EXPIRE_DAYS=30 -# Wenxin Credentials -WENXIN_API_KEY= -WENXIN_SECRET_KEY= +# celery configuration +CELERY_BROKER_URL=redis://:difyai123456@localhost:6379/1 -# ZhipuAI Credentials -ZHIPUAI_API_KEY= +# redis configuration +REDIS_HOST=localhost +REDIS_PORT=6379 +REDIS_USERNAME= +REDIS_PASSWORD=difyai123456 +REDIS_USE_SSL=false +REDIS_DB=0 -# Baichuan Credentials -BAICHUAN_API_KEY= -BAICHUAN_SECRET_KEY= +# PostgreSQL database configuration +DB_USERNAME=postgres +DB_PASSWORD=difyai123456 +DB_HOST=localhost +DB_PORT=5432 +DB_DATABASE=dify -# ChatGLM Credentials -CHATGLM_API_BASE= +# Storage configuration +# use for store upload files, private keys... +# storage type: opendal, s3, aliyun-oss, azure-blob, baidu-obs, google-storage, huawei-obs, oci-storage, tencent-cos, volcengine-tos, supabase +STORAGE_TYPE=opendal -# Xinference Credentials -XINFERENCE_SERVER_URL= -XINFERENCE_GENERATION_MODEL_UID= -XINFERENCE_CHAT_MODEL_UID= -XINFERENCE_EMBEDDINGS_MODEL_UID= -XINFERENCE_RERANK_MODEL_UID= +# Apache OpenDAL storage configuration, refer to https://github.com/apache/opendal +OPENDAL_SCHEME=fs +OPENDAL_FS_ROOT=storage -# OpenLLM Credentials -OPENLLM_SERVER_URL= +# CORS configuration +WEB_API_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,* +CONSOLE_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,* -# LocalAI Credentials -LOCALAI_SERVER_URL= +# Vector database configuration +# support: weaviate, qdrant, milvus, myscale, relyt, pgvecto_rs, pgvector, pgvector, chroma, opensearch, tidb_vector, couchbase, vikingdb, upstash, lindorm, oceanbase +VECTOR_STORE=weaviate +# Weaviate configuration +WEAVIATE_ENDPOINT=http://localhost:8080 +WEAVIATE_API_KEY=WVF5YThaHlkYwhGUSmCRgsX3tD5ngdN8pkih +WEAVIATE_GRPC_ENABLED=false +WEAVIATE_BATCH_SIZE=100 -# Cohere Credentials -COHERE_API_KEY= -# Jina Credentials -JINA_API_KEY= +# Upload configuration +UPLOAD_FILE_SIZE_LIMIT=15 +UPLOAD_FILE_BATCH_LIMIT=5 +UPLOAD_IMAGE_FILE_SIZE_LIMIT=10 +UPLOAD_VIDEO_FILE_SIZE_LIMIT=100 +UPLOAD_AUDIO_FILE_SIZE_LIMIT=50 -# Ollama Credentials -OLLAMA_BASE_URL= +# Model configuration +MULTIMODAL_SEND_FORMAT=base64 +PROMPT_GENERATION_MAX_TOKENS=4096 +CODE_GENERATION_MAX_TOKENS=1024 -# Together API Key -TOGETHER_API_KEY= +# Mail configuration, support: resend, smtp +MAIL_TYPE= +MAIL_DEFAULT_SEND_FROM=no-reply +RESEND_API_KEY= +RESEND_API_URL=https://api.resend.com +# smtp configuration +SMTP_SERVER=smtp.example.com +SMTP_PORT=465 +SMTP_USERNAME=123 +SMTP_PASSWORD=abc +SMTP_USE_TLS=true +SMTP_OPPORTUNISTIC_TLS=false -# Mock Switch -MOCK_SWITCH=false +# Sentry configuration +SENTRY_DSN= + +# DEBUG +DEBUG=false +SQLALCHEMY_ECHO=false + +# Notion import configuration, support public and internal +NOTION_INTEGRATION_TYPE=public +NOTION_CLIENT_SECRET=you-client-secret +NOTION_CLIENT_ID=you-client-id +NOTION_INTERNAL_SECRET=you-internal-secret + +ETL_TYPE=dify +UNSTRUCTURED_API_URL= +UNSTRUCTURED_API_KEY= +SCARF_NO_ANALYTICS=false + +#ssrf +SSRF_PROXY_HTTP_URL= +SSRF_PROXY_HTTPS_URL= +SSRF_DEFAULT_MAX_RETRIES=3 +SSRF_DEFAULT_TIME_OUT=5 +SSRF_DEFAULT_CONNECT_TIME_OUT=5 +SSRF_DEFAULT_READ_TIME_OUT=5 +SSRF_DEFAULT_WRITE_TIME_OUT=5 + +BATCH_UPLOAD_LIMIT=10 +KEYWORD_DATA_SOURCE_TYPE=database + +# Workflow file upload limit +WORKFLOW_FILE_UPLOAD_LIMIT=10 # CODE EXECUTION CONFIGURATION -CODE_EXECUTION_ENDPOINT= -CODE_EXECUTION_API_KEY= +CODE_EXECUTION_ENDPOINT=http://127.0.0.1:8194 +CODE_EXECUTION_API_KEY=dify-sandbox +CODE_MAX_NUMBER=9223372036854775807 +CODE_MIN_NUMBER=-9223372036854775808 +CODE_MAX_STRING_LENGTH=80000 +TEMPLATE_TRANSFORM_MAX_LENGTH=80000 +CODE_MAX_STRING_ARRAY_LENGTH=30 +CODE_MAX_OBJECT_ARRAY_LENGTH=30 +CODE_MAX_NUMBER_ARRAY_LENGTH=1000 -# Volcengine MaaS Credentials -VOLC_API_KEY= -VOLC_SECRET_KEY= -VOLC_MODEL_ENDPOINT_ID= -VOLC_EMBEDDING_ENDPOINT_ID= +# API Tool configuration +API_TOOL_DEFAULT_CONNECT_TIMEOUT=10 +API_TOOL_DEFAULT_READ_TIMEOUT=60 -# 360 AI Credentials -ZHINAO_API_KEY= +# HTTP Node configuration +HTTP_REQUEST_MAX_CONNECT_TIMEOUT=300 +HTTP_REQUEST_MAX_READ_TIMEOUT=600 +HTTP_REQUEST_MAX_WRITE_TIMEOUT=600 +HTTP_REQUEST_NODE_MAX_BINARY_SIZE=10485760 +HTTP_REQUEST_NODE_MAX_TEXT_SIZE=1048576 + +# Respect X-* headers to redirect clients +RESPECT_XFORWARD_HEADERS_ENABLED=false + +# Log file path +LOG_FILE= +# Log file max size, the unit is MB +LOG_FILE_MAX_SIZE=20 +# Log file max backup count +LOG_FILE_BACKUP_COUNT=5 +# Log dateformat +LOG_DATEFORMAT=%Y-%m-%d %H:%M:%S +# Log Timezone +LOG_TZ=UTC +# Log format +LOG_FORMAT=%(asctime)s,%(msecs)d %(levelname)-2s [%(filename)s:%(lineno)d] %(req_id)s %(message)s + +# Indexing configuration +INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH=4000 + +# Workflow runtime configuration +WORKFLOW_MAX_EXECUTION_STEPS=500 +WORKFLOW_MAX_EXECUTION_TIME=1200 +WORKFLOW_CALL_MAX_DEPTH=5 +WORKFLOW_PARALLEL_DEPTH_LIMIT=3 +MAX_VARIABLE_SIZE=204800 + +# App configuration +APP_MAX_EXECUTION_TIME=1200 +APP_MAX_ACTIVE_REQUESTS=0 + +# Celery beat configuration +CELERY_BEAT_SCHEDULER_TIME=1 + +# Position configuration +POSITION_TOOL_PINS= +POSITION_TOOL_INCLUDES= +POSITION_TOOL_EXCLUDES= + +POSITION_PROVIDER_PINS= +POSITION_PROVIDER_INCLUDES= +POSITION_PROVIDER_EXCLUDES= # Plugin configuration -PLUGIN_DAEMON_KEY= -PLUGIN_DAEMON_URL= +PLUGIN_DAEMON_KEY=lYkiYYT6owG+71oLerGzA7GXCgOT++6ovaezWAjpCjf+Sjc3ZtU+qUEi +PLUGIN_DAEMON_URL=http://127.0.0.1:5002 +PLUGIN_REMOTE_INSTALL_PORT=5003 +PLUGIN_REMOTE_INSTALL_HOST=localhost +PLUGIN_MAX_PACKAGE_SIZE=15728640 +INNER_API_KEY_FOR_PLUGIN=QaHbTe77CtuXmsfyhR7+vRjI/+XbV1AaFy691iy+kGDv2Jvy0/eAh8Y1 # Marketplace configuration -MARKETPLACE_API_URL= -# VESSL AI Credentials -VESSL_AI_MODEL_NAME= -VESSL_AI_API_KEY= -VESSL_AI_ENDPOINT_URL= +MARKETPLACE_ENABLED=true +MARKETPLACE_API_URL=https://marketplace.dify.ai -# GPUStack Credentials -GPUSTACK_SERVER_URL= -GPUSTACK_API_KEY= +# Endpoint configuration +ENDPOINT_URL_TEMPLATE=http://localhost:5002/e/{hook_id} -# Gitee AI Credentials -GITEE_AI_API_KEY= +# Reset password token expiry minutes +RESET_PASSWORD_TOKEN_EXPIRY_MINUTES=5 -# xAI Credentials -XAI_API_KEY= -XAI_API_BASE= +CREATE_TIDB_SERVICE_JOB_ENABLED=false + +# Maximum number of submitted thread count in a ThreadPool for parallel node execution +MAX_SUBMIT_COUNT=100 +# Lockout duration in seconds +LOGIN_LOCKOUT_DURATION=86400 + +HTTP_PROXY='http://127.0.0.1:1092' +HTTPS_PROXY='http://127.0.0.1:1092' +NO_PROXY='localhost,127.0.0.1' +LOG_LEVEL=INFO diff --git a/api/tests/integration_tests/conftest.py b/api/tests/integration_tests/conftest.py index 6e3ab4b74..d9f90f992 100644 --- a/api/tests/integration_tests/conftest.py +++ b/api/tests/integration_tests/conftest.py @@ -1,19 +1,91 @@ -import os +import pathlib +import random +import secrets +from collections.abc import Generator -# Getting the absolute path of the current file's directory -ABS_PATH = os.path.dirname(os.path.abspath(__file__)) +import pytest +from flask import Flask +from flask.testing import FlaskClient +from sqlalchemy.orm import Session -# Getting the absolute path of the project's root directory -PROJECT_DIR = os.path.abspath(os.path.join(ABS_PATH, os.pardir, os.pardir)) +from app_factory import create_app +from models import Account, DifySetup, Tenant, TenantAccountJoin, db +from services.account_service import AccountService, RegisterService # Loading the .env file if it exists def _load_env() -> None: - dotenv_path = os.path.join(PROJECT_DIR, "tests", "integration_tests", ".env") - if os.path.exists(dotenv_path): + current_file_path = pathlib.Path(__file__).absolute() + # Items later in the list have higher precedence. + files_to_load = [".env", "vdb.env"] + + env_file_paths = [current_file_path.parent / i for i in files_to_load] + for path in env_file_paths: + if not path.exists(): + continue + from dotenv import load_dotenv - load_dotenv(dotenv_path) + # Set `override=True` to ensure values from `vdb.env` take priority over values from `.env` + load_dotenv(str(path), override=True) _load_env() + +_CACHED_APP = create_app() + + +@pytest.fixture +def flask_app() -> Flask: + return _CACHED_APP + + +@pytest.fixture(scope="session") +def setup_account(request) -> Generator[Account, None, None]: + """`dify_setup` completes the setup process for the Dify application. + + It creates `Account` and `Tenant`, and inserts a `DifySetup` record into the database. + + Most tests in the `controllers` package may require dify has been successfully setup. + """ + with _CACHED_APP.test_request_context(): + rand_suffix = random.randint(int(1e6), int(1e7)) # noqa + name = f"test-user-{rand_suffix}" + email = f"{name}@example.com" + RegisterService.setup( + email=email, + name=name, + password=secrets.token_hex(16), + ip_address="localhost", + ) + + with _CACHED_APP.test_request_context(): + with Session(bind=db.engine, expire_on_commit=False) as session: + account = session.query(Account).filter_by(email=email).one() + + yield account + + with _CACHED_APP.test_request_context(): + db.session.query(DifySetup).delete() + db.session.query(TenantAccountJoin).delete() + db.session.query(Account).delete() + db.session.query(Tenant).delete() + db.session.commit() + + +@pytest.fixture +def flask_req_ctx(): + with _CACHED_APP.test_request_context(): + yield + + +@pytest.fixture +def auth_header(setup_account) -> dict[str, str]: + token = AccountService.get_account_jwt_token(setup_account) + return {"Authorization": f"Bearer {token}"} + + +@pytest.fixture +def test_client() -> Generator[FlaskClient, None, None]: + with _CACHED_APP.test_client() as client: + yield client diff --git a/api/tests/integration_tests/controllers/app_fixture.py b/api/tests/integration_tests/controllers/app_fixture.py deleted file mode 100644 index 32e8c11d1..000000000 --- a/api/tests/integration_tests/controllers/app_fixture.py +++ /dev/null @@ -1,25 +0,0 @@ -import pytest - -from app_factory import create_app -from configs import dify_config - -mock_user = type( - "MockUser", - (object,), - { - "is_authenticated": True, - "id": "123", - "is_editor": True, - "is_dataset_editor": True, - "status": "active", - "get_id": "123", - "current_tenant_id": "9d2074fc-6f86-45a9-b09d-6ecc63b9056b", - }, -) - - -@pytest.fixture -def app(): - app = create_app() - dify_config.LOGIN_DISABLED = True - return app diff --git a/api/tests/integration_tests/controllers/console/__init__.py b/api/tests/integration_tests/controllers/console/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/api/tests/integration_tests/controllers/console/app/__init__.py b/api/tests/integration_tests/controllers/console/app/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/api/tests/integration_tests/controllers/console/app/test_workflow_draft_variable.py b/api/tests/integration_tests/controllers/console/app/test_workflow_draft_variable.py new file mode 100644 index 000000000..038f37af5 --- /dev/null +++ b/api/tests/integration_tests/controllers/console/app/test_workflow_draft_variable.py @@ -0,0 +1,47 @@ +import uuid +from unittest import mock + +from controllers.console.app import workflow_draft_variable as draft_variable_api +from controllers.console.app import wraps +from factories.variable_factory import build_segment +from models import App, AppMode +from models.workflow import WorkflowDraftVariable +from services.workflow_draft_variable_service import WorkflowDraftVariableList, WorkflowDraftVariableService + + +def _get_mock_srv_class() -> type[WorkflowDraftVariableService]: + return mock.create_autospec(WorkflowDraftVariableService) + + +class TestWorkflowDraftNodeVariableListApi: + def test_get(self, test_client, auth_header, monkeypatch): + srv_class = _get_mock_srv_class() + mock_app_model: App = App() + mock_app_model.id = str(uuid.uuid4()) + test_node_id = "test_node_id" + mock_app_model.mode = AppMode.ADVANCED_CHAT + mock_load_app_model = mock.Mock(return_value=mock_app_model) + + monkeypatch.setattr(draft_variable_api, "WorkflowDraftVariableService", srv_class) + monkeypatch.setattr(wraps, "_load_app_model", mock_load_app_model) + + var1 = WorkflowDraftVariable.new_node_variable( + app_id="test_app_1", + node_id="test_node_1", + name="str_var", + value=build_segment("str_value"), + node_execution_id=str(uuid.uuid4()), + ) + srv_instance = mock.create_autospec(WorkflowDraftVariableService, instance=True) + srv_class.return_value = srv_instance + srv_instance.list_node_variables.return_value = WorkflowDraftVariableList(variables=[var1]) + + response = test_client.get( + f"/console/api/apps/{mock_app_model.id}/workflows/draft/nodes/{test_node_id}/variables", + headers=auth_header, + ) + assert response.status_code == 200 + response_dict = response.json + assert isinstance(response_dict, dict) + assert "items" in response_dict + assert len(response_dict["items"]) == 1 diff --git a/api/tests/integration_tests/controllers/test_controllers.py b/api/tests/integration_tests/controllers/test_controllers.py deleted file mode 100644 index 276ad3a7e..000000000 --- a/api/tests/integration_tests/controllers/test_controllers.py +++ /dev/null @@ -1,9 +0,0 @@ -from unittest.mock import patch - -from app_fixture import mock_user # type: ignore - - -def test_post_requires_login(app): - with app.test_client() as client, patch("flask_login.utils._get_user", mock_user): - response = client.get("/console/api/data-source/integrates") - assert response.status_code == 200 diff --git a/api/tests/integration_tests/factories/__init__.py b/api/tests/integration_tests/factories/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/api/tests/integration_tests/factories/test_storage_key_loader.py b/api/tests/integration_tests/factories/test_storage_key_loader.py new file mode 100644 index 000000000..fecb3f6d9 --- /dev/null +++ b/api/tests/integration_tests/factories/test_storage_key_loader.py @@ -0,0 +1,371 @@ +import unittest +from datetime import UTC, datetime +from typing import Optional +from unittest.mock import patch +from uuid import uuid4 + +import pytest +from sqlalchemy.orm import Session + +from core.file import File, FileTransferMethod, FileType +from extensions.ext_database import db +from factories.file_factory import StorageKeyLoader +from models import ToolFile, UploadFile +from models.enums import CreatorUserRole + + +@pytest.mark.usefixtures("flask_req_ctx") +class TestStorageKeyLoader(unittest.TestCase): + """ + Integration tests for StorageKeyLoader class. + + Tests the batched loading of storage keys from the database for files + with different transfer methods: LOCAL_FILE, REMOTE_URL, and TOOL_FILE. + """ + + def setUp(self): + """Set up test data before each test method.""" + self.session = db.session() + self.tenant_id = str(uuid4()) + self.user_id = str(uuid4()) + self.conversation_id = str(uuid4()) + + # Create test data that will be cleaned up after each test + self.test_upload_files = [] + self.test_tool_files = [] + + # Create StorageKeyLoader instance + self.loader = StorageKeyLoader(self.session, self.tenant_id) + + def tearDown(self): + """Clean up test data after each test method.""" + self.session.rollback() + + def _create_upload_file( + self, file_id: Optional[str] = None, storage_key: Optional[str] = None, tenant_id: Optional[str] = None + ) -> UploadFile: + """Helper method to create an UploadFile record for testing.""" + if file_id is None: + file_id = str(uuid4()) + if storage_key is None: + storage_key = f"test_storage_key_{uuid4()}" + if tenant_id is None: + tenant_id = self.tenant_id + + upload_file = UploadFile( + tenant_id=tenant_id, + storage_type="local", + key=storage_key, + name="test_file.txt", + size=1024, + extension=".txt", + mime_type="text/plain", + created_by_role=CreatorUserRole.ACCOUNT, + created_by=self.user_id, + created_at=datetime.now(UTC), + used=False, + ) + upload_file.id = file_id + + self.session.add(upload_file) + self.session.flush() + self.test_upload_files.append(upload_file) + + return upload_file + + def _create_tool_file( + self, file_id: Optional[str] = None, file_key: Optional[str] = None, tenant_id: Optional[str] = None + ) -> ToolFile: + """Helper method to create a ToolFile record for testing.""" + if file_id is None: + file_id = str(uuid4()) + if file_key is None: + file_key = f"test_file_key_{uuid4()}" + if tenant_id is None: + tenant_id = self.tenant_id + + tool_file = ToolFile() + tool_file.id = file_id + tool_file.user_id = self.user_id + tool_file.tenant_id = tenant_id + tool_file.conversation_id = self.conversation_id + tool_file.file_key = file_key + tool_file.mimetype = "text/plain" + tool_file.original_url = "http://example.com/file.txt" + tool_file.name = "test_tool_file.txt" + tool_file.size = 2048 + + self.session.add(tool_file) + self.session.flush() + self.test_tool_files.append(tool_file) + + return tool_file + + def _create_file( + self, related_id: str, transfer_method: FileTransferMethod, tenant_id: Optional[str] = None + ) -> File: + """Helper method to create a File object for testing.""" + if tenant_id is None: + tenant_id = self.tenant_id + + # Set related_id for LOCAL_FILE and TOOL_FILE transfer methods + file_related_id = None + remote_url = None + + if transfer_method in (FileTransferMethod.LOCAL_FILE, FileTransferMethod.TOOL_FILE): + file_related_id = related_id + elif transfer_method == FileTransferMethod.REMOTE_URL: + remote_url = "https://example.com/test_file.txt" + file_related_id = related_id + + return File( + id=str(uuid4()), # Generate new UUID for File.id + tenant_id=tenant_id, + type=FileType.DOCUMENT, + transfer_method=transfer_method, + related_id=file_related_id, + remote_url=remote_url, + filename="test_file.txt", + extension=".txt", + mime_type="text/plain", + size=1024, + storage_key="initial_key", + ) + + def test_load_storage_keys_local_file(self): + """Test loading storage keys for LOCAL_FILE transfer method.""" + # Create test data + upload_file = self._create_upload_file() + file = self._create_file(related_id=upload_file.id, transfer_method=FileTransferMethod.LOCAL_FILE) + + # Load storage keys + self.loader.load_storage_keys([file]) + + # Verify storage key was loaded correctly + assert file._storage_key == upload_file.key + + def test_load_storage_keys_remote_url(self): + """Test loading storage keys for REMOTE_URL transfer method.""" + # Create test data + upload_file = self._create_upload_file() + file = self._create_file(related_id=upload_file.id, transfer_method=FileTransferMethod.REMOTE_URL) + + # Load storage keys + self.loader.load_storage_keys([file]) + + # Verify storage key was loaded correctly + assert file._storage_key == upload_file.key + + def test_load_storage_keys_tool_file(self): + """Test loading storage keys for TOOL_FILE transfer method.""" + # Create test data + tool_file = self._create_tool_file() + file = self._create_file(related_id=tool_file.id, transfer_method=FileTransferMethod.TOOL_FILE) + + # Load storage keys + self.loader.load_storage_keys([file]) + + # Verify storage key was loaded correctly + assert file._storage_key == tool_file.file_key + + def test_load_storage_keys_mixed_methods(self): + """Test batch loading with mixed transfer methods.""" + # Create test data for different transfer methods + upload_file1 = self._create_upload_file() + upload_file2 = self._create_upload_file() + tool_file = self._create_tool_file() + + file1 = self._create_file(related_id=upload_file1.id, transfer_method=FileTransferMethod.LOCAL_FILE) + file2 = self._create_file(related_id=upload_file2.id, transfer_method=FileTransferMethod.REMOTE_URL) + file3 = self._create_file(related_id=tool_file.id, transfer_method=FileTransferMethod.TOOL_FILE) + + files = [file1, file2, file3] + + # Load storage keys + self.loader.load_storage_keys(files) + + # Verify all storage keys were loaded correctly + assert file1._storage_key == upload_file1.key + assert file2._storage_key == upload_file2.key + assert file3._storage_key == tool_file.file_key + + def test_load_storage_keys_empty_list(self): + """Test with empty file list.""" + # Should not raise any exceptions + self.loader.load_storage_keys([]) + + def test_load_storage_keys_tenant_mismatch(self): + """Test tenant_id validation.""" + # Create file with different tenant_id + upload_file = self._create_upload_file() + file = self._create_file( + related_id=upload_file.id, transfer_method=FileTransferMethod.LOCAL_FILE, tenant_id=str(uuid4()) + ) + + # Should raise ValueError for tenant mismatch + with pytest.raises(ValueError) as context: + self.loader.load_storage_keys([file]) + + assert "invalid file, expected tenant_id" in str(context.value) + + def test_load_storage_keys_missing_file_id(self): + """Test with None file.related_id.""" + # Create a file with valid parameters first, then manually set related_id to None + file = self._create_file(related_id=str(uuid4()), transfer_method=FileTransferMethod.LOCAL_FILE) + file.related_id = None + + # Should raise ValueError for None file related_id + with pytest.raises(ValueError) as context: + self.loader.load_storage_keys([file]) + + assert str(context.value) == "file id should not be None." + + def test_load_storage_keys_nonexistent_upload_file_records(self): + """Test with missing UploadFile database records.""" + # Create file with non-existent upload file id + non_existent_id = str(uuid4()) + file = self._create_file(related_id=non_existent_id, transfer_method=FileTransferMethod.LOCAL_FILE) + + # Should raise ValueError for missing record + with pytest.raises(ValueError): + self.loader.load_storage_keys([file]) + + def test_load_storage_keys_nonexistent_tool_file_records(self): + """Test with missing ToolFile database records.""" + # Create file with non-existent tool file id + non_existent_id = str(uuid4()) + file = self._create_file(related_id=non_existent_id, transfer_method=FileTransferMethod.TOOL_FILE) + + # Should raise ValueError for missing record + with pytest.raises(ValueError): + self.loader.load_storage_keys([file]) + + def test_load_storage_keys_invalid_uuid(self): + """Test with invalid UUID format.""" + # Create a file with valid parameters first, then manually set invalid related_id + file = self._create_file(related_id=str(uuid4()), transfer_method=FileTransferMethod.LOCAL_FILE) + file.related_id = "invalid-uuid-format" + + # Should raise ValueError for invalid UUID + with pytest.raises(ValueError): + self.loader.load_storage_keys([file]) + + def test_load_storage_keys_batch_efficiency(self): + """Test batched operations use efficient queries.""" + # Create multiple files of different types + upload_files = [self._create_upload_file() for _ in range(3)] + tool_files = [self._create_tool_file() for _ in range(2)] + + files = [] + files.extend( + [self._create_file(related_id=uf.id, transfer_method=FileTransferMethod.LOCAL_FILE) for uf in upload_files] + ) + files.extend( + [self._create_file(related_id=tf.id, transfer_method=FileTransferMethod.TOOL_FILE) for tf in tool_files] + ) + + # Mock the session to count queries + with patch.object(self.session, "scalars", wraps=self.session.scalars) as mock_scalars: + self.loader.load_storage_keys(files) + + # Should make exactly 2 queries (one for upload_files, one for tool_files) + assert mock_scalars.call_count == 2 + + # Verify all storage keys were loaded correctly + for i, file in enumerate(files[:3]): + assert file._storage_key == upload_files[i].key + for i, file in enumerate(files[3:]): + assert file._storage_key == tool_files[i].file_key + + def test_load_storage_keys_tenant_isolation(self): + """Test that tenant isolation works correctly.""" + # Create files for different tenants + other_tenant_id = str(uuid4()) + + # Create upload file for current tenant + upload_file_current = self._create_upload_file() + file_current = self._create_file( + related_id=upload_file_current.id, transfer_method=FileTransferMethod.LOCAL_FILE + ) + + # Create upload file for other tenant (but don't add to cleanup list) + upload_file_other = UploadFile( + tenant_id=other_tenant_id, + storage_type="local", + key="other_tenant_key", + name="other_file.txt", + size=1024, + extension=".txt", + mime_type="text/plain", + created_by_role=CreatorUserRole.ACCOUNT, + created_by=self.user_id, + created_at=datetime.now(UTC), + used=False, + ) + upload_file_other.id = str(uuid4()) + self.session.add(upload_file_other) + self.session.flush() + + # Create file for other tenant but try to load with current tenant's loader + file_other = self._create_file( + related_id=upload_file_other.id, transfer_method=FileTransferMethod.LOCAL_FILE, tenant_id=other_tenant_id + ) + + # Should raise ValueError due to tenant mismatch + with pytest.raises(ValueError) as context: + self.loader.load_storage_keys([file_other]) + + assert "invalid file, expected tenant_id" in str(context.value) + + # Current tenant's file should still work + self.loader.load_storage_keys([file_current]) + assert file_current._storage_key == upload_file_current.key + + def test_load_storage_keys_mixed_tenant_batch(self): + """Test batch with mixed tenant files (should fail on first mismatch).""" + # Create files for current tenant + upload_file_current = self._create_upload_file() + file_current = self._create_file( + related_id=upload_file_current.id, transfer_method=FileTransferMethod.LOCAL_FILE + ) + + # Create file for different tenant + other_tenant_id = str(uuid4()) + file_other = self._create_file( + related_id=str(uuid4()), transfer_method=FileTransferMethod.LOCAL_FILE, tenant_id=other_tenant_id + ) + + # Should raise ValueError on tenant mismatch + with pytest.raises(ValueError) as context: + self.loader.load_storage_keys([file_current, file_other]) + + assert "invalid file, expected tenant_id" in str(context.value) + + def test_load_storage_keys_duplicate_file_ids(self): + """Test handling of duplicate file IDs in the batch.""" + # Create upload file + upload_file = self._create_upload_file() + + # Create two File objects with same related_id + file1 = self._create_file(related_id=upload_file.id, transfer_method=FileTransferMethod.LOCAL_FILE) + file2 = self._create_file(related_id=upload_file.id, transfer_method=FileTransferMethod.LOCAL_FILE) + + # Should handle duplicates gracefully + self.loader.load_storage_keys([file1, file2]) + + # Both files should have the same storage key + assert file1._storage_key == upload_file.key + assert file2._storage_key == upload_file.key + + def test_load_storage_keys_session_isolation(self): + """Test that the loader uses the provided session correctly.""" + # Create test data + upload_file = self._create_upload_file() + file = self._create_file(related_id=upload_file.id, transfer_method=FileTransferMethod.LOCAL_FILE) + + # Create loader with different session (same underlying connection) + + with Session(bind=db.engine) as other_session: + other_loader = StorageKeyLoader(other_session, self.tenant_id) + with pytest.raises(ValueError): + other_loader.load_storage_keys([file]) diff --git a/api/tests/integration_tests/services/__init__.py b/api/tests/integration_tests/services/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/api/tests/integration_tests/services/test_workflow_draft_variable_service.py b/api/tests/integration_tests/services/test_workflow_draft_variable_service.py new file mode 100644 index 000000000..30cd2e60c --- /dev/null +++ b/api/tests/integration_tests/services/test_workflow_draft_variable_service.py @@ -0,0 +1,501 @@ +import json +import unittest +import uuid + +import pytest +from sqlalchemy.orm import Session + +from core.variables.variables import StringVariable +from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID +from core.workflow.nodes import NodeType +from factories.variable_factory import build_segment +from libs import datetime_utils +from models import db +from models.workflow import Workflow, WorkflowDraftVariable, WorkflowNodeExecutionModel +from services.workflow_draft_variable_service import DraftVarLoader, VariableResetError, WorkflowDraftVariableService + + +@pytest.mark.usefixtures("flask_req_ctx") +class TestWorkflowDraftVariableService(unittest.TestCase): + _test_app_id: str + _session: Session + _node1_id = "test_node_1" + _node2_id = "test_node_2" + _node_exec_id = str(uuid.uuid4()) + + def setUp(self): + self._test_app_id = str(uuid.uuid4()) + self._session: Session = db.session() + sys_var = WorkflowDraftVariable.new_sys_variable( + app_id=self._test_app_id, + name="sys_var", + value=build_segment("sys_value"), + node_execution_id=self._node_exec_id, + ) + conv_var = WorkflowDraftVariable.new_conversation_variable( + app_id=self._test_app_id, + name="conv_var", + value=build_segment("conv_value"), + ) + node2_vars = [ + WorkflowDraftVariable.new_node_variable( + app_id=self._test_app_id, + node_id=self._node2_id, + name="int_var", + value=build_segment(1), + visible=False, + node_execution_id=self._node_exec_id, + ), + WorkflowDraftVariable.new_node_variable( + app_id=self._test_app_id, + node_id=self._node2_id, + name="str_var", + value=build_segment("str_value"), + visible=True, + node_execution_id=self._node_exec_id, + ), + ] + node1_var = WorkflowDraftVariable.new_node_variable( + app_id=self._test_app_id, + node_id=self._node1_id, + name="str_var", + value=build_segment("str_value"), + visible=True, + node_execution_id=self._node_exec_id, + ) + _variables = list(node2_vars) + _variables.extend( + [ + node1_var, + sys_var, + conv_var, + ] + ) + + db.session.add_all(_variables) + db.session.flush() + self._variable_ids = [v.id for v in _variables] + self._node1_str_var_id = node1_var.id + self._sys_var_id = sys_var.id + self._conv_var_id = conv_var.id + self._node2_var_ids = [v.id for v in node2_vars] + + def _get_test_srv(self) -> WorkflowDraftVariableService: + return WorkflowDraftVariableService(session=self._session) + + def tearDown(self): + self._session.rollback() + + def test_list_variables(self): + srv = self._get_test_srv() + var_list = srv.list_variables_without_values(self._test_app_id, page=1, limit=2) + assert var_list.total == 5 + assert len(var_list.variables) == 2 + page1_var_ids = {v.id for v in var_list.variables} + assert page1_var_ids.issubset(self._variable_ids) + + var_list_2 = srv.list_variables_without_values(self._test_app_id, page=2, limit=2) + assert var_list_2.total is None + assert len(var_list_2.variables) == 2 + page2_var_ids = {v.id for v in var_list_2.variables} + assert page2_var_ids.isdisjoint(page1_var_ids) + assert page2_var_ids.issubset(self._variable_ids) + + def test_get_node_variable(self): + srv = self._get_test_srv() + node_var = srv.get_node_variable(self._test_app_id, self._node1_id, "str_var") + assert node_var is not None + assert node_var.id == self._node1_str_var_id + assert node_var.name == "str_var" + assert node_var.get_value() == build_segment("str_value") + + def test_get_system_variable(self): + srv = self._get_test_srv() + sys_var = srv.get_system_variable(self._test_app_id, "sys_var") + assert sys_var is not None + assert sys_var.id == self._sys_var_id + assert sys_var.name == "sys_var" + assert sys_var.get_value() == build_segment("sys_value") + + def test_get_conversation_variable(self): + srv = self._get_test_srv() + conv_var = srv.get_conversation_variable(self._test_app_id, "conv_var") + assert conv_var is not None + assert conv_var.id == self._conv_var_id + assert conv_var.name == "conv_var" + assert conv_var.get_value() == build_segment("conv_value") + + def test_delete_node_variables(self): + srv = self._get_test_srv() + srv.delete_node_variables(self._test_app_id, self._node2_id) + node2_var_count = ( + self._session.query(WorkflowDraftVariable) + .where( + WorkflowDraftVariable.app_id == self._test_app_id, + WorkflowDraftVariable.node_id == self._node2_id, + ) + .count() + ) + assert node2_var_count == 0 + + def test_delete_variable(self): + srv = self._get_test_srv() + node_1_var = ( + self._session.query(WorkflowDraftVariable).where(WorkflowDraftVariable.id == self._node1_str_var_id).one() + ) + srv.delete_variable(node_1_var) + exists = bool( + self._session.query(WorkflowDraftVariable).where(WorkflowDraftVariable.id == self._node1_str_var_id).first() + ) + assert exists is False + + def test__list_node_variables(self): + srv = self._get_test_srv() + node_vars = srv._list_node_variables(self._test_app_id, self._node2_id) + assert len(node_vars.variables) == 2 + assert {v.id for v in node_vars.variables} == set(self._node2_var_ids) + + def test_get_draft_variables_by_selectors(self): + srv = self._get_test_srv() + selectors = [ + [self._node1_id, "str_var"], + [self._node2_id, "str_var"], + [self._node2_id, "int_var"], + ] + variables = srv.get_draft_variables_by_selectors(self._test_app_id, selectors) + assert len(variables) == 3 + assert {v.id for v in variables} == {self._node1_str_var_id} | set(self._node2_var_ids) + + +@pytest.mark.usefixtures("flask_req_ctx") +class TestDraftVariableLoader(unittest.TestCase): + _test_app_id: str + _test_tenant_id: str + + _node1_id = "test_loader_node_1" + _node_exec_id = str(uuid.uuid4()) + + def setUp(self): + self._test_app_id = str(uuid.uuid4()) + self._test_tenant_id = str(uuid.uuid4()) + sys_var = WorkflowDraftVariable.new_sys_variable( + app_id=self._test_app_id, + name="sys_var", + value=build_segment("sys_value"), + node_execution_id=self._node_exec_id, + ) + conv_var = WorkflowDraftVariable.new_conversation_variable( + app_id=self._test_app_id, + name="conv_var", + value=build_segment("conv_value"), + ) + node_var = WorkflowDraftVariable.new_node_variable( + app_id=self._test_app_id, + node_id=self._node1_id, + name="str_var", + value=build_segment("str_value"), + visible=True, + node_execution_id=self._node_exec_id, + ) + _variables = [ + node_var, + sys_var, + conv_var, + ] + + with Session(bind=db.engine, expire_on_commit=False) as session: + session.add_all(_variables) + session.flush() + session.commit() + self._variable_ids = [v.id for v in _variables] + self._node_var_id = node_var.id + self._sys_var_id = sys_var.id + self._conv_var_id = conv_var.id + + def tearDown(self): + with Session(bind=db.engine, expire_on_commit=False) as session: + session.query(WorkflowDraftVariable).filter(WorkflowDraftVariable.app_id == self._test_app_id).delete( + synchronize_session=False + ) + session.commit() + + def test_variable_loader_with_empty_selector(self): + var_loader = DraftVarLoader(engine=db.engine, app_id=self._test_app_id, tenant_id=self._test_tenant_id) + variables = var_loader.load_variables([]) + assert len(variables) == 0 + + def test_variable_loader_with_non_empty_selector(self): + var_loader = DraftVarLoader(engine=db.engine, app_id=self._test_app_id, tenant_id=self._test_tenant_id) + variables = var_loader.load_variables( + [ + [SYSTEM_VARIABLE_NODE_ID, "sys_var"], + [CONVERSATION_VARIABLE_NODE_ID, "conv_var"], + [self._node1_id, "str_var"], + ] + ) + assert len(variables) == 3 + conv_var = next(v for v in variables if v.selector[0] == CONVERSATION_VARIABLE_NODE_ID) + assert conv_var.id == self._conv_var_id + sys_var = next(v for v in variables if v.selector[0] == SYSTEM_VARIABLE_NODE_ID) + assert sys_var.id == self._sys_var_id + node1_var = next(v for v in variables if v.selector[0] == self._node1_id) + assert node1_var.id == self._node_var_id + + +@pytest.mark.usefixtures("flask_req_ctx") +class TestWorkflowDraftVariableServiceResetVariable(unittest.TestCase): + """Integration tests for reset_variable functionality using real database""" + + _test_app_id: str + _test_tenant_id: str + _test_workflow_id: str + _session: Session + _node_id = "test_reset_node" + _node_exec_id: str + _workflow_node_exec_id: str + + def setUp(self): + self._test_app_id = str(uuid.uuid4()) + self._test_tenant_id = str(uuid.uuid4()) + self._test_workflow_id = str(uuid.uuid4()) + self._node_exec_id = str(uuid.uuid4()) + self._workflow_node_exec_id = str(uuid.uuid4()) + self._session: Session = db.session() + + # Create a workflow node execution record with outputs + # Note: The WorkflowNodeExecutionModel.id should match the node_execution_id in WorkflowDraftVariable + self._workflow_node_execution = WorkflowNodeExecutionModel( + id=self._node_exec_id, # This should match the node_execution_id in the variable + tenant_id=self._test_tenant_id, + app_id=self._test_app_id, + workflow_id=self._test_workflow_id, + triggered_from="workflow-run", + workflow_run_id=str(uuid.uuid4()), + index=1, + node_execution_id=self._node_exec_id, + node_id=self._node_id, + node_type=NodeType.LLM.value, + title="Test Node", + inputs='{"input": "test input"}', + process_data='{"test_var": "process_value", "other_var": "other_process"}', + outputs='{"test_var": "output_value", "other_var": "other_output"}', + status="succeeded", + elapsed_time=1.5, + created_by_role="account", + created_by=str(uuid.uuid4()), + ) + + # Create conversation variables for the workflow + self._conv_variables = [ + StringVariable( + id=str(uuid.uuid4()), + name="conv_var_1", + description="Test conversation variable 1", + value="default_value_1", + ), + StringVariable( + id=str(uuid.uuid4()), + name="conv_var_2", + description="Test conversation variable 2", + value="default_value_2", + ), + ] + + # Create test variables + self._node_var_with_exec = WorkflowDraftVariable.new_node_variable( + app_id=self._test_app_id, + node_id=self._node_id, + name="test_var", + value=build_segment("old_value"), + node_execution_id=self._node_exec_id, + ) + self._node_var_with_exec.last_edited_at = datetime_utils.naive_utc_now() + + self._node_var_without_exec = WorkflowDraftVariable.new_node_variable( + app_id=self._test_app_id, + node_id=self._node_id, + name="no_exec_var", + value=build_segment("some_value"), + node_execution_id="temp_exec_id", + ) + # Manually set node_execution_id to None after creation + self._node_var_without_exec.node_execution_id = None + + self._node_var_missing_exec = WorkflowDraftVariable.new_node_variable( + app_id=self._test_app_id, + node_id=self._node_id, + name="missing_exec_var", + value=build_segment("some_value"), + node_execution_id=str(uuid.uuid4()), # Use a valid UUID that doesn't exist in database + ) + + self._conv_var = WorkflowDraftVariable.new_conversation_variable( + app_id=self._test_app_id, + name="conv_var_1", + value=build_segment("old_conv_value"), + ) + self._conv_var.last_edited_at = datetime_utils.naive_utc_now() + + # Add all to database + db.session.add_all( + [ + self._workflow_node_execution, + self._node_var_with_exec, + self._node_var_without_exec, + self._node_var_missing_exec, + self._conv_var, + ] + ) + db.session.flush() + + # Store IDs for assertions + self._node_var_with_exec_id = self._node_var_with_exec.id + self._node_var_without_exec_id = self._node_var_without_exec.id + self._node_var_missing_exec_id = self._node_var_missing_exec.id + self._conv_var_id = self._conv_var.id + + def _get_test_srv(self) -> WorkflowDraftVariableService: + return WorkflowDraftVariableService(session=self._session) + + def _create_mock_workflow(self) -> Workflow: + """Create a real workflow with conversation variables and graph""" + conversation_vars = self._conv_variables + + # Create a simple graph with the test node + graph = { + "nodes": [{"id": "test_reset_node", "type": "llm", "title": "Test Node", "data": {"type": "llm"}}], + "edges": [], + } + + workflow = Workflow.new( + tenant_id=str(uuid.uuid4()), + app_id=self._test_app_id, + type="workflow", + version="1.0", + graph=json.dumps(graph), + features="{}", + created_by=str(uuid.uuid4()), + environment_variables=[], + conversation_variables=conversation_vars, + ) + return workflow + + def tearDown(self): + self._session.rollback() + + def test_reset_node_variable_with_valid_execution_record(self): + """Test resetting a node variable with valid execution record - should restore from execution""" + srv = self._get_test_srv() + mock_workflow = self._create_mock_workflow() + + # Get the variable before reset + variable = srv.get_variable(self._node_var_with_exec_id) + assert variable is not None + assert variable.get_value().value == "old_value" + assert variable.last_edited_at is not None + + # Reset the variable + result = srv.reset_variable(mock_workflow, variable) + + # Should return the updated variable + assert result is not None + assert result.id == self._node_var_with_exec_id + assert result.node_execution_id == self._workflow_node_execution.id + assert result.last_edited_at is None # Should be reset to None + + # The returned variable should have the updated value from execution record + assert result.get_value().value == "output_value" + + # Verify the variable was updated in database + updated_variable = srv.get_variable(self._node_var_with_exec_id) + assert updated_variable is not None + # The value should be updated from the execution record's outputs + assert updated_variable.get_value().value == "output_value" + assert updated_variable.last_edited_at is None + assert updated_variable.node_execution_id == self._workflow_node_execution.id + + def test_reset_node_variable_with_no_execution_id(self): + """Test resetting a node variable with no execution ID - should delete variable""" + srv = self._get_test_srv() + mock_workflow = self._create_mock_workflow() + + # Get the variable before reset + variable = srv.get_variable(self._node_var_without_exec_id) + assert variable is not None + + # Reset the variable + result = srv.reset_variable(mock_workflow, variable) + + # Should return None (variable deleted) + assert result is None + + # Verify the variable was deleted + deleted_variable = srv.get_variable(self._node_var_without_exec_id) + assert deleted_variable is None + + def test_reset_node_variable_with_missing_execution_record(self): + """Test resetting a node variable when execution record doesn't exist""" + srv = self._get_test_srv() + mock_workflow = self._create_mock_workflow() + + # Get the variable before reset + variable = srv.get_variable(self._node_var_missing_exec_id) + assert variable is not None + + # Reset the variable + result = srv.reset_variable(mock_workflow, variable) + + # Should return None (variable deleted) + assert result is None + + # Verify the variable was deleted + deleted_variable = srv.get_variable(self._node_var_missing_exec_id) + assert deleted_variable is None + + def test_reset_conversation_variable(self): + """Test resetting a conversation variable""" + srv = self._get_test_srv() + mock_workflow = self._create_mock_workflow() + + # Get the variable before reset + variable = srv.get_variable(self._conv_var_id) + assert variable is not None + assert variable.get_value().value == "old_conv_value" + assert variable.last_edited_at is not None + + # Reset the variable + result = srv.reset_variable(mock_workflow, variable) + + # Should return the updated variable + assert result is not None + assert result.id == self._conv_var_id + assert result.last_edited_at is None # Should be reset to None + + # Verify the variable was updated with default value from workflow + updated_variable = srv.get_variable(self._conv_var_id) + assert updated_variable is not None + # The value should be updated from the workflow's conversation variable default + assert updated_variable.get_value().value == "default_value_1" + assert updated_variable.last_edited_at is None + + def test_reset_system_variable_raises_error(self): + """Test that resetting a system variable raises an error""" + srv = self._get_test_srv() + mock_workflow = self._create_mock_workflow() + + # Create a system variable + sys_var = WorkflowDraftVariable.new_sys_variable( + app_id=self._test_app_id, + name="sys_var", + value=build_segment("sys_value"), + node_execution_id=self._node_exec_id, + ) + db.session.add(sys_var) + db.session.flush() + + # Attempt to reset the system variable + with pytest.raises(VariableResetError) as exc_info: + srv.reset_variable(mock_workflow, sys_var) + + assert "cannot reset system variable" in str(exc_info.value) + assert sys_var.id in str(exc_info.value) diff --git a/api/tests/integration_tests/workflow/nodes/test_llm.py b/api/tests/integration_tests/workflow/nodes/test_llm.py index 6aa48b1cb..a3b2fdc37 100644 --- a/api/tests/integration_tests/workflow/nodes/test_llm.py +++ b/api/tests/integration_tests/workflow/nodes/test_llm.py @@ -8,8 +8,6 @@ from unittest.mock import MagicMock, patch import pytest -from app_factory import create_app -from configs import dify_config from core.app.entities.app_invoke_entities import InvokeFrom from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage from core.model_runtime.entities.message_entities import AssistantPromptMessage @@ -30,21 +28,6 @@ from tests.integration_tests.model_runtime.__mock.plugin_daemon import setup_mod from tests.integration_tests.workflow.nodes.__mock.code_executor import setup_code_executor_mock -@pytest.fixture(scope="session") -def app(): - # Set up storage configuration - os.environ["STORAGE_TYPE"] = "opendal" - os.environ["OPENDAL_SCHEME"] = "fs" - os.environ["OPENDAL_FS_ROOT"] = "storage" - - # Ensure storage directory exists - os.makedirs("storage", exist_ok=True) - - app = create_app() - dify_config.LOGIN_DISABLED = True - return app - - def init_llm_node(config: dict) -> LLMNode: graph_config = { "edges": [ @@ -102,197 +85,195 @@ def init_llm_node(config: dict) -> LLMNode: return node -def test_execute_llm(app): - with app.app_context(): - node = init_llm_node( - config={ - "id": "llm", - "data": { - "title": "123", - "type": "llm", - "model": { - "provider": "langgenius/openai/openai", - "name": "gpt-3.5-turbo", - "mode": "chat", - "completion_params": {}, - }, - "prompt_template": [ - { - "role": "system", - "text": "you are a helpful assistant.\ntoday's weather is {{#abc.output#}}.", - }, - {"role": "user", "text": "{{#sys.query#}}"}, - ], - "memory": None, - "context": {"enabled": False}, - "vision": {"enabled": False}, +def test_execute_llm(flask_req_ctx): + node = init_llm_node( + config={ + "id": "llm", + "data": { + "title": "123", + "type": "llm", + "model": { + "provider": "langgenius/openai/openai", + "name": "gpt-3.5-turbo", + "mode": "chat", + "completion_params": {}, }, + "prompt_template": [ + { + "role": "system", + "text": "you are a helpful assistant.\ntoday's weather is {{#abc.output#}}.", + }, + {"role": "user", "text": "{{#sys.query#}}"}, + ], + "memory": None, + "context": {"enabled": False}, + "vision": {"enabled": False}, }, - ) + }, + ) - credentials = {"openai_api_key": os.environ.get("OPENAI_API_KEY")} + credentials = {"openai_api_key": os.environ.get("OPENAI_API_KEY")} - # Create a proper LLM result with real entities - mock_usage = LLMUsage( - prompt_tokens=30, - prompt_unit_price=Decimal("0.001"), - prompt_price_unit=Decimal("1000"), - prompt_price=Decimal("0.00003"), - completion_tokens=20, - completion_unit_price=Decimal("0.002"), - completion_price_unit=Decimal("1000"), - completion_price=Decimal("0.00004"), - total_tokens=50, - total_price=Decimal("0.00007"), - currency="USD", - latency=0.5, - ) + # Create a proper LLM result with real entities + mock_usage = LLMUsage( + prompt_tokens=30, + prompt_unit_price=Decimal("0.001"), + prompt_price_unit=Decimal("1000"), + prompt_price=Decimal("0.00003"), + completion_tokens=20, + completion_unit_price=Decimal("0.002"), + completion_price_unit=Decimal("1000"), + completion_price=Decimal("0.00004"), + total_tokens=50, + total_price=Decimal("0.00007"), + currency="USD", + latency=0.5, + ) - mock_message = AssistantPromptMessage(content="This is a test response from the mocked LLM.") + mock_message = AssistantPromptMessage(content="This is a test response from the mocked LLM.") - mock_llm_result = LLMResult( - model="gpt-3.5-turbo", - prompt_messages=[], - message=mock_message, - usage=mock_usage, - ) + mock_llm_result = LLMResult( + model="gpt-3.5-turbo", + prompt_messages=[], + message=mock_message, + usage=mock_usage, + ) - # Create a simple mock model instance that doesn't call real providers - mock_model_instance = MagicMock() - mock_model_instance.invoke_llm.return_value = mock_llm_result + # Create a simple mock model instance that doesn't call real providers + mock_model_instance = MagicMock() + mock_model_instance.invoke_llm.return_value = mock_llm_result - # Create a simple mock model config with required attributes - mock_model_config = MagicMock() - mock_model_config.mode = "chat" - mock_model_config.provider = "langgenius/openai/openai" - mock_model_config.model = "gpt-3.5-turbo" - mock_model_config.provider_model_bundle.configuration.tenant_id = "9d2074fc-6f86-45a9-b09d-6ecc63b9056b" + # Create a simple mock model config with required attributes + mock_model_config = MagicMock() + mock_model_config.mode = "chat" + mock_model_config.provider = "langgenius/openai/openai" + mock_model_config.model = "gpt-3.5-turbo" + mock_model_config.provider_model_bundle.configuration.tenant_id = "9d2074fc-6f86-45a9-b09d-6ecc63b9056b" - # Mock the _fetch_model_config method - def mock_fetch_model_config_func(_node_data_model): - return mock_model_instance, mock_model_config + # Mock the _fetch_model_config method + def mock_fetch_model_config_func(_node_data_model): + return mock_model_instance, mock_model_config - # Also mock ModelManager.get_model_instance to avoid database calls - def mock_get_model_instance(_self, **kwargs): - return mock_model_instance + # Also mock ModelManager.get_model_instance to avoid database calls + def mock_get_model_instance(_self, **kwargs): + return mock_model_instance - with ( - patch.object(node, "_fetch_model_config", mock_fetch_model_config_func), - patch("core.model_manager.ModelManager.get_model_instance", mock_get_model_instance), - ): - # execute node - result = node._run() - assert isinstance(result, Generator) + with ( + patch.object(node, "_fetch_model_config", mock_fetch_model_config_func), + patch("core.model_manager.ModelManager.get_model_instance", mock_get_model_instance), + ): + # execute node + result = node._run() + assert isinstance(result, Generator) - for item in result: - if isinstance(item, RunCompletedEvent): - assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED - assert item.run_result.process_data is not None - assert item.run_result.outputs is not None - assert item.run_result.outputs.get("text") is not None - assert item.run_result.outputs.get("usage", {})["total_tokens"] > 0 + for item in result: + if isinstance(item, RunCompletedEvent): + assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED + assert item.run_result.process_data is not None + assert item.run_result.outputs is not None + assert item.run_result.outputs.get("text") is not None + assert item.run_result.outputs.get("usage", {})["total_tokens"] > 0 @pytest.mark.parametrize("setup_code_executor_mock", [["none"]], indirect=True) -def test_execute_llm_with_jinja2(app, setup_code_executor_mock): +def test_execute_llm_with_jinja2(flask_req_ctx, setup_code_executor_mock): """ Test execute LLM node with jinja2 """ - with app.app_context(): - node = init_llm_node( - config={ - "id": "llm", - "data": { - "title": "123", - "type": "llm", - "model": {"provider": "openai", "name": "gpt-3.5-turbo", "mode": "chat", "completion_params": {}}, - "prompt_config": { - "jinja2_variables": [ - {"variable": "sys_query", "value_selector": ["sys", "query"]}, - {"variable": "output", "value_selector": ["abc", "output"]}, - ] - }, - "prompt_template": [ - { - "role": "system", - "text": "you are a helpful assistant.\ntoday's weather is {{#abc.output#}}", - "jinja2_text": "you are a helpful assistant.\ntoday's weather is {{output}}.", - "edition_type": "jinja2", - }, - { - "role": "user", - "text": "{{#sys.query#}}", - "jinja2_text": "{{sys_query}}", - "edition_type": "basic", - }, - ], - "memory": None, - "context": {"enabled": False}, - "vision": {"enabled": False}, + node = init_llm_node( + config={ + "id": "llm", + "data": { + "title": "123", + "type": "llm", + "model": {"provider": "openai", "name": "gpt-3.5-turbo", "mode": "chat", "completion_params": {}}, + "prompt_config": { + "jinja2_variables": [ + {"variable": "sys_query", "value_selector": ["sys", "query"]}, + {"variable": "output", "value_selector": ["abc", "output"]}, + ] }, + "prompt_template": [ + { + "role": "system", + "text": "you are a helpful assistant.\ntoday's weather is {{#abc.output#}}", + "jinja2_text": "you are a helpful assistant.\ntoday's weather is {{output}}.", + "edition_type": "jinja2", + }, + { + "role": "user", + "text": "{{#sys.query#}}", + "jinja2_text": "{{sys_query}}", + "edition_type": "basic", + }, + ], + "memory": None, + "context": {"enabled": False}, + "vision": {"enabled": False}, }, - ) + }, + ) - # Mock db.session.close() - db.session.close = MagicMock() + # Mock db.session.close() + db.session.close = MagicMock() - # Create a proper LLM result with real entities - mock_usage = LLMUsage( - prompt_tokens=30, - prompt_unit_price=Decimal("0.001"), - prompt_price_unit=Decimal("1000"), - prompt_price=Decimal("0.00003"), - completion_tokens=20, - completion_unit_price=Decimal("0.002"), - completion_price_unit=Decimal("1000"), - completion_price=Decimal("0.00004"), - total_tokens=50, - total_price=Decimal("0.00007"), - currency="USD", - latency=0.5, - ) + # Create a proper LLM result with real entities + mock_usage = LLMUsage( + prompt_tokens=30, + prompt_unit_price=Decimal("0.001"), + prompt_price_unit=Decimal("1000"), + prompt_price=Decimal("0.00003"), + completion_tokens=20, + completion_unit_price=Decimal("0.002"), + completion_price_unit=Decimal("1000"), + completion_price=Decimal("0.00004"), + total_tokens=50, + total_price=Decimal("0.00007"), + currency="USD", + latency=0.5, + ) - mock_message = AssistantPromptMessage(content="Test response: sunny weather and what's the weather today?") + mock_message = AssistantPromptMessage(content="Test response: sunny weather and what's the weather today?") - mock_llm_result = LLMResult( - model="gpt-3.5-turbo", - prompt_messages=[], - message=mock_message, - usage=mock_usage, - ) + mock_llm_result = LLMResult( + model="gpt-3.5-turbo", + prompt_messages=[], + message=mock_message, + usage=mock_usage, + ) - # Create a simple mock model instance that doesn't call real providers - mock_model_instance = MagicMock() - mock_model_instance.invoke_llm.return_value = mock_llm_result + # Create a simple mock model instance that doesn't call real providers + mock_model_instance = MagicMock() + mock_model_instance.invoke_llm.return_value = mock_llm_result - # Create a simple mock model config with required attributes - mock_model_config = MagicMock() - mock_model_config.mode = "chat" - mock_model_config.provider = "openai" - mock_model_config.model = "gpt-3.5-turbo" - mock_model_config.provider_model_bundle.configuration.tenant_id = "9d2074fc-6f86-45a9-b09d-6ecc63b9056b" + # Create a simple mock model config with required attributes + mock_model_config = MagicMock() + mock_model_config.mode = "chat" + mock_model_config.provider = "openai" + mock_model_config.model = "gpt-3.5-turbo" + mock_model_config.provider_model_bundle.configuration.tenant_id = "9d2074fc-6f86-45a9-b09d-6ecc63b9056b" - # Mock the _fetch_model_config method - def mock_fetch_model_config_func(_node_data_model): - return mock_model_instance, mock_model_config + # Mock the _fetch_model_config method + def mock_fetch_model_config_func(_node_data_model): + return mock_model_instance, mock_model_config - # Also mock ModelManager.get_model_instance to avoid database calls - def mock_get_model_instance(_self, **kwargs): - return mock_model_instance + # Also mock ModelManager.get_model_instance to avoid database calls + def mock_get_model_instance(_self, **kwargs): + return mock_model_instance - with ( - patch.object(node, "_fetch_model_config", mock_fetch_model_config_func), - patch("core.model_manager.ModelManager.get_model_instance", mock_get_model_instance), - ): - # execute node - result = node._run() + with ( + patch.object(node, "_fetch_model_config", mock_fetch_model_config_func), + patch("core.model_manager.ModelManager.get_model_instance", mock_get_model_instance), + ): + # execute node + result = node._run() - for item in result: - if isinstance(item, RunCompletedEvent): - assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED - assert item.run_result.process_data is not None - assert "sunny" in json.dumps(item.run_result.process_data) - assert "what's the weather today?" in json.dumps(item.run_result.process_data) + for item in result: + if isinstance(item, RunCompletedEvent): + assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED + assert item.run_result.process_data is not None + assert "sunny" in json.dumps(item.run_result.process_data) + assert "what's the weather today?" in json.dumps(item.run_result.process_data) def test_extract_json(): diff --git a/api/tests/unit_tests/controllers/console/app/workflow_draft_variables_test.py b/api/tests/unit_tests/controllers/console/app/workflow_draft_variables_test.py new file mode 100644 index 000000000..f26be6702 --- /dev/null +++ b/api/tests/unit_tests/controllers/console/app/workflow_draft_variables_test.py @@ -0,0 +1,302 @@ +import datetime +import uuid +from collections import OrderedDict +from typing import Any, NamedTuple + +from flask_restful import marshal + +from controllers.console.app.workflow_draft_variable import ( + _WORKFLOW_DRAFT_VARIABLE_FIELDS, + _WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS, + _WORKFLOW_DRAFT_VARIABLE_LIST_WITHOUT_VALUE_FIELDS, + _WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS, +) +from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID +from factories.variable_factory import build_segment +from models.workflow import WorkflowDraftVariable +from services.workflow_draft_variable_service import WorkflowDraftVariableList + +_TEST_APP_ID = "test_app_id" +_TEST_NODE_EXEC_ID = str(uuid.uuid4()) + + +class TestWorkflowDraftVariableFields: + def test_conversation_variable(self): + conv_var = WorkflowDraftVariable.new_conversation_variable( + app_id=_TEST_APP_ID, name="conv_var", value=build_segment(1) + ) + + conv_var.id = str(uuid.uuid4()) + conv_var.visible = True + + expected_without_value: OrderedDict[str, Any] = OrderedDict( + { + "id": str(conv_var.id), + "type": conv_var.get_variable_type().value, + "name": "conv_var", + "description": "", + "selector": [CONVERSATION_VARIABLE_NODE_ID, "conv_var"], + "value_type": "number", + "edited": False, + "visible": True, + } + ) + + assert marshal(conv_var, _WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS) == expected_without_value + expected_with_value = expected_without_value.copy() + expected_with_value["value"] = 1 + assert marshal(conv_var, _WORKFLOW_DRAFT_VARIABLE_FIELDS) == expected_with_value + + def test_create_sys_variable(self): + sys_var = WorkflowDraftVariable.new_sys_variable( + app_id=_TEST_APP_ID, + name="sys_var", + value=build_segment("a"), + editable=True, + node_execution_id=_TEST_NODE_EXEC_ID, + ) + + sys_var.id = str(uuid.uuid4()) + sys_var.last_edited_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + sys_var.visible = True + + expected_without_value = OrderedDict( + { + "id": str(sys_var.id), + "type": sys_var.get_variable_type().value, + "name": "sys_var", + "description": "", + "selector": [SYSTEM_VARIABLE_NODE_ID, "sys_var"], + "value_type": "string", + "edited": True, + "visible": True, + } + ) + assert marshal(sys_var, _WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS) == expected_without_value + expected_with_value = expected_without_value.copy() + expected_with_value["value"] = "a" + assert marshal(sys_var, _WORKFLOW_DRAFT_VARIABLE_FIELDS) == expected_with_value + + def test_node_variable(self): + node_var = WorkflowDraftVariable.new_node_variable( + app_id=_TEST_APP_ID, + node_id="test_node", + name="node_var", + value=build_segment([1, "a"]), + visible=False, + node_execution_id=_TEST_NODE_EXEC_ID, + ) + + node_var.id = str(uuid.uuid4()) + node_var.last_edited_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + + expected_without_value: OrderedDict[str, Any] = OrderedDict( + { + "id": str(node_var.id), + "type": node_var.get_variable_type().value, + "name": "node_var", + "description": "", + "selector": ["test_node", "node_var"], + "value_type": "array[any]", + "edited": True, + "visible": False, + } + ) + + assert marshal(node_var, _WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS) == expected_without_value + expected_with_value = expected_without_value.copy() + expected_with_value["value"] = [1, "a"] + assert marshal(node_var, _WORKFLOW_DRAFT_VARIABLE_FIELDS) == expected_with_value + + +class TestWorkflowDraftVariableList: + def test_workflow_draft_variable_list(self): + class TestCase(NamedTuple): + name: str + var_list: WorkflowDraftVariableList + expected: dict + + node_var = WorkflowDraftVariable.new_node_variable( + app_id=_TEST_APP_ID, + node_id="test_node", + name="test_var", + value=build_segment("a"), + visible=True, + node_execution_id=_TEST_NODE_EXEC_ID, + ) + node_var.id = str(uuid.uuid4()) + node_var_dict = OrderedDict( + { + "id": str(node_var.id), + "type": node_var.get_variable_type().value, + "name": "test_var", + "description": "", + "selector": ["test_node", "test_var"], + "value_type": "string", + "edited": False, + "visible": True, + } + ) + + cases = [ + TestCase( + name="empty variable list", + var_list=WorkflowDraftVariableList(variables=[]), + expected=OrderedDict( + { + "items": [], + "total": None, + } + ), + ), + TestCase( + name="empty variable list with total", + var_list=WorkflowDraftVariableList(variables=[], total=10), + expected=OrderedDict( + { + "items": [], + "total": 10, + } + ), + ), + TestCase( + name="non-empty variable list", + var_list=WorkflowDraftVariableList(variables=[node_var], total=None), + expected=OrderedDict( + { + "items": [node_var_dict], + "total": None, + } + ), + ), + TestCase( + name="non-empty variable list with total", + var_list=WorkflowDraftVariableList(variables=[node_var], total=10), + expected=OrderedDict( + { + "items": [node_var_dict], + "total": 10, + } + ), + ), + ] + + for idx, case in enumerate(cases, 1): + assert marshal(case.var_list, _WORKFLOW_DRAFT_VARIABLE_LIST_WITHOUT_VALUE_FIELDS) == case.expected, ( + f"Test case {idx} failed, {case.name=}" + ) + + +def test_workflow_node_variables_fields(): + conv_var = WorkflowDraftVariable.new_conversation_variable( + app_id=_TEST_APP_ID, name="conv_var", value=build_segment(1) + ) + resp = marshal(WorkflowDraftVariableList(variables=[conv_var]), _WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS) + assert isinstance(resp, dict) + assert len(resp["items"]) == 1 + item_dict = resp["items"][0] + assert item_dict["name"] == "conv_var" + assert item_dict["value"] == 1 + + +def test_workflow_file_variable_with_signed_url(): + """Test that File type variables include signed URLs in API responses.""" + from core.file.enums import FileTransferMethod, FileType + from core.file.models import File + + # Create a File object with LOCAL_FILE transfer method (which generates signed URLs) + test_file = File( + id="test_file_id", + tenant_id="test_tenant_id", + type=FileType.IMAGE, + transfer_method=FileTransferMethod.LOCAL_FILE, + related_id="test_upload_file_id", + filename="test.jpg", + extension=".jpg", + mime_type="image/jpeg", + size=12345, + ) + + # Create a WorkflowDraftVariable with the File + file_var = WorkflowDraftVariable.new_node_variable( + app_id=_TEST_APP_ID, + node_id="test_node", + name="file_var", + value=build_segment(test_file), + node_execution_id=_TEST_NODE_EXEC_ID, + ) + + # Marshal the variable using the API fields + resp = marshal(WorkflowDraftVariableList(variables=[file_var]), _WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS) + + # Verify the response structure + assert isinstance(resp, dict) + assert len(resp["items"]) == 1 + item_dict = resp["items"][0] + assert item_dict["name"] == "file_var" + + # Verify the value is a dict (File.to_dict() result) and contains expected fields + value = item_dict["value"] + assert isinstance(value, dict) + + # Verify the File fields are preserved + assert value["id"] == test_file.id + assert value["filename"] == test_file.filename + assert value["type"] == test_file.type.value + assert value["transfer_method"] == test_file.transfer_method.value + assert value["size"] == test_file.size + + # Verify the URL is present (it should be a signed URL for LOCAL_FILE transfer method) + remote_url = value["remote_url"] + assert remote_url is not None + + assert isinstance(remote_url, str) + # For LOCAL_FILE, the URL should contain signature parameters + assert "timestamp=" in remote_url + assert "nonce=" in remote_url + assert "sign=" in remote_url + + +def test_workflow_file_variable_remote_url(): + """Test that File type variables with REMOTE_URL transfer method return the remote URL.""" + from core.file.enums import FileTransferMethod, FileType + from core.file.models import File + + # Create a File object with REMOTE_URL transfer method + test_file = File( + id="test_file_id", + tenant_id="test_tenant_id", + type=FileType.IMAGE, + transfer_method=FileTransferMethod.REMOTE_URL, + remote_url="https://example.com/test.jpg", + filename="test.jpg", + extension=".jpg", + mime_type="image/jpeg", + size=12345, + ) + + # Create a WorkflowDraftVariable with the File + file_var = WorkflowDraftVariable.new_node_variable( + app_id=_TEST_APP_ID, + node_id="test_node", + name="file_var", + value=build_segment(test_file), + node_execution_id=_TEST_NODE_EXEC_ID, + ) + + # Marshal the variable using the API fields + resp = marshal(WorkflowDraftVariableList(variables=[file_var]), _WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS) + + # Verify the response structure + assert isinstance(resp, dict) + assert len(resp["items"]) == 1 + item_dict = resp["items"][0] + assert item_dict["name"] == "file_var" + + # Verify the value is a dict (File.to_dict() result) and contains expected fields + value = item_dict["value"] + assert isinstance(value, dict) + remote_url = value["remote_url"] + + # For REMOTE_URL, the URL should be the original remote URL + assert remote_url == test_file.remote_url diff --git a/api/tests/unit_tests/core/app/segments/test_factory.py b/api/tests/unit_tests/core/app/segments/test_factory.py deleted file mode 100644 index e6e289c12..000000000 --- a/api/tests/unit_tests/core/app/segments/test_factory.py +++ /dev/null @@ -1,165 +0,0 @@ -from uuid import uuid4 - -import pytest - -from core.variables import ( - ArrayNumberVariable, - ArrayObjectVariable, - ArrayStringVariable, - FloatVariable, - IntegerVariable, - ObjectSegment, - SecretVariable, - StringVariable, -) -from core.variables.exc import VariableError -from core.variables.segments import ArrayAnySegment -from factories import variable_factory - - -def test_string_variable(): - test_data = {"value_type": "string", "name": "test_text", "value": "Hello, World!"} - result = variable_factory.build_conversation_variable_from_mapping(test_data) - assert isinstance(result, StringVariable) - - -def test_integer_variable(): - test_data = {"value_type": "number", "name": "test_int", "value": 42} - result = variable_factory.build_conversation_variable_from_mapping(test_data) - assert isinstance(result, IntegerVariable) - - -def test_float_variable(): - test_data = {"value_type": "number", "name": "test_float", "value": 3.14} - result = variable_factory.build_conversation_variable_from_mapping(test_data) - assert isinstance(result, FloatVariable) - - -def test_secret_variable(): - test_data = {"value_type": "secret", "name": "test_secret", "value": "secret_value"} - result = variable_factory.build_conversation_variable_from_mapping(test_data) - assert isinstance(result, SecretVariable) - - -def test_invalid_value_type(): - test_data = {"value_type": "unknown", "name": "test_invalid", "value": "value"} - with pytest.raises(VariableError): - variable_factory.build_conversation_variable_from_mapping(test_data) - - -def test_build_a_blank_string(): - result = variable_factory.build_conversation_variable_from_mapping( - { - "value_type": "string", - "name": "blank", - "value": "", - } - ) - assert isinstance(result, StringVariable) - assert result.value == "" - - -def test_build_a_object_variable_with_none_value(): - var = variable_factory.build_segment( - { - "key1": None, - } - ) - assert isinstance(var, ObjectSegment) - assert var.value["key1"] is None - - -def test_object_variable(): - mapping = { - "id": str(uuid4()), - "value_type": "object", - "name": "test_object", - "description": "Description of the variable.", - "value": { - "key1": "text", - "key2": 2, - }, - } - variable = variable_factory.build_conversation_variable_from_mapping(mapping) - assert isinstance(variable, ObjectSegment) - assert isinstance(variable.value["key1"], str) - assert isinstance(variable.value["key2"], int) - - -def test_array_string_variable(): - mapping = { - "id": str(uuid4()), - "value_type": "array[string]", - "name": "test_array", - "description": "Description of the variable.", - "value": [ - "text", - "text", - ], - } - variable = variable_factory.build_conversation_variable_from_mapping(mapping) - assert isinstance(variable, ArrayStringVariable) - assert isinstance(variable.value[0], str) - assert isinstance(variable.value[1], str) - - -def test_array_number_variable(): - mapping = { - "id": str(uuid4()), - "value_type": "array[number]", - "name": "test_array", - "description": "Description of the variable.", - "value": [ - 1, - 2.0, - ], - } - variable = variable_factory.build_conversation_variable_from_mapping(mapping) - assert isinstance(variable, ArrayNumberVariable) - assert isinstance(variable.value[0], int) - assert isinstance(variable.value[1], float) - - -def test_array_object_variable(): - mapping = { - "id": str(uuid4()), - "value_type": "array[object]", - "name": "test_array", - "description": "Description of the variable.", - "value": [ - { - "key1": "text", - "key2": 1, - }, - { - "key1": "text", - "key2": 1, - }, - ], - } - variable = variable_factory.build_conversation_variable_from_mapping(mapping) - assert isinstance(variable, ArrayObjectVariable) - assert isinstance(variable.value[0], dict) - assert isinstance(variable.value[1], dict) - assert isinstance(variable.value[0]["key1"], str) - assert isinstance(variable.value[0]["key2"], int) - assert isinstance(variable.value[1]["key1"], str) - assert isinstance(variable.value[1]["key2"], int) - - -def test_variable_cannot_large_than_200_kb(): - with pytest.raises(VariableError): - variable_factory.build_conversation_variable_from_mapping( - { - "id": str(uuid4()), - "value_type": "string", - "name": "test_text", - "value": "a" * 1024 * 201, - } - ) - - -def test_array_none_variable(): - var = variable_factory.build_segment([None, None, None, None]) - assert isinstance(var, ArrayAnySegment) - assert var.value == [None, None, None, None] diff --git a/api/tests/unit_tests/core/file/test_models.py b/api/tests/unit_tests/core/file/test_models.py new file mode 100644 index 000000000..3ada2087c --- /dev/null +++ b/api/tests/unit_tests/core/file/test_models.py @@ -0,0 +1,25 @@ +from core.file import File, FileTransferMethod, FileType + + +def test_file(): + file = File( + id="test-file", + tenant_id="test-tenant-id", + type=FileType.IMAGE, + transfer_method=FileTransferMethod.TOOL_FILE, + related_id="test-related-id", + filename="image.png", + extension=".png", + mime_type="image/png", + size=67, + storage_key="test-storage-key", + url="https://example.com/image.png", + ) + assert file.tenant_id == "test-tenant-id" + assert file.type == FileType.IMAGE + assert file.transfer_method == FileTransferMethod.TOOL_FILE + assert file.related_id == "test-related-id" + assert file.filename == "image.png" + assert file.extension == ".png" + assert file.mime_type == "image/png" + assert file.size == 67 diff --git a/api/tests/unit_tests/core/app/segments/test_segment.py b/api/tests/unit_tests/core/variables/test_segment.py similarity index 100% rename from api/tests/unit_tests/core/app/segments/test_segment.py rename to api/tests/unit_tests/core/variables/test_segment.py diff --git a/api/tests/unit_tests/core/app/segments/test_variables.py b/api/tests/unit_tests/core/variables/test_variables.py similarity index 100% rename from api/tests/unit_tests/core/app/segments/test_variables.py rename to api/tests/unit_tests/core/variables/test_variables.py diff --git a/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration.py b/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration.py index 6d854c950..362072a3d 100644 --- a/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration.py +++ b/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration.py @@ -3,6 +3,7 @@ import uuid from unittest.mock import patch from core.app.entities.app_invoke_entities import InvokeFrom +from core.variables.segments import ArrayAnySegment, ArrayStringSegment from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus @@ -197,7 +198,7 @@ def test_run(): count += 1 if isinstance(item, RunCompletedEvent): assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED - assert item.run_result.outputs == {"output": ["dify 123", "dify 123"]} + assert item.run_result.outputs == {"output": ArrayStringSegment(value=["dify 123", "dify 123"])} assert count == 20 @@ -413,7 +414,7 @@ def test_run_parallel(): count += 1 if isinstance(item, RunCompletedEvent): assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED - assert item.run_result.outputs == {"output": ["dify 123", "dify 123"]} + assert item.run_result.outputs == {"output": ArrayStringSegment(value=["dify 123", "dify 123"])} assert count == 32 @@ -654,7 +655,7 @@ def test_iteration_run_in_parallel_mode(): parallel_arr.append(item) if isinstance(item, RunCompletedEvent): assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED - assert item.run_result.outputs == {"output": ["dify 123", "dify 123"]} + assert item.run_result.outputs == {"output": ArrayStringSegment(value=["dify 123", "dify 123"])} assert count == 32 for item in sequential_result: @@ -662,7 +663,7 @@ def test_iteration_run_in_parallel_mode(): count += 1 if isinstance(item, RunCompletedEvent): assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED - assert item.run_result.outputs == {"output": ["dify 123", "dify 123"]} + assert item.run_result.outputs == {"output": ArrayStringSegment(value=["dify 123", "dify 123"])} assert count == 64 @@ -846,7 +847,7 @@ def test_iteration_run_error_handle(): count += 1 if isinstance(item, RunCompletedEvent): assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED - assert item.run_result.outputs == {"output": [None, None]} + assert item.run_result.outputs == {"output": ArrayAnySegment(value=[None, None])} assert count == 14 # execute remove abnormal output @@ -857,5 +858,5 @@ def test_iteration_run_error_handle(): count += 1 if isinstance(item, RunCompletedEvent): assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED - assert item.run_result.outputs == {"output": []} + assert item.run_result.outputs == {"output": ArrayAnySegment(value=[])} assert count == 14 diff --git a/api/tests/unit_tests/core/workflow/nodes/test_document_extractor_node.py b/api/tests/unit_tests/core/workflow/nodes/test_document_extractor_node.py index 4cb1aa93f..76bb640d1 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_document_extractor_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_document_extractor_node.py @@ -7,6 +7,7 @@ from docx.oxml.text.paragraph import CT_P from core.file import File, FileTransferMethod from core.variables import ArrayFileSegment +from core.variables.segments import ArrayStringSegment from core.variables.variables import StringVariable from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus @@ -69,7 +70,13 @@ def test_run_invalid_variable_type(document_extractor_node, mock_graph_runtime_s @pytest.mark.parametrize( ("mime_type", "file_content", "expected_text", "transfer_method", "extension"), [ - ("text/plain", b"Hello, world!", ["Hello, world!"], FileTransferMethod.LOCAL_FILE, ".txt"), + ( + "text/plain", + b"Hello, world!", + ["Hello, world!"], + FileTransferMethod.LOCAL_FILE, + ".txt", + ), ( "application/pdf", b"%PDF-1.5\n%Test PDF content", @@ -84,7 +91,13 @@ def test_run_invalid_variable_type(document_extractor_node, mock_graph_runtime_s FileTransferMethod.REMOTE_URL, "", ), - ("text/plain", b"Remote content", ["Remote content"], FileTransferMethod.REMOTE_URL, None), + ( + "text/plain", + b"Remote content", + ["Remote content"], + FileTransferMethod.REMOTE_URL, + None, + ), ], ) def test_run_extract_text( @@ -131,7 +144,7 @@ def test_run_extract_text( assert isinstance(result, NodeRunResult) assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED, result.error assert result.outputs is not None - assert result.outputs["text"] == expected_text + assert result.outputs["text"] == ArrayStringSegment(value=expected_text) if transfer_method == FileTransferMethod.REMOTE_URL: mock_ssrf_proxy_get.assert_called_once_with("https://example.com/file.txt") diff --git a/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py b/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py index 77d42e269..7d3a1d6a2 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py @@ -115,7 +115,7 @@ def test_filter_files_by_type(list_operator_node): }, ] assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED - for expected_file, result_file in zip(expected_files, result.outputs["result"]): + for expected_file, result_file in zip(expected_files, result.outputs["result"].value): assert expected_file["filename"] == result_file.filename assert expected_file["type"] == result_file.type assert expected_file["tenant_id"] == result_file.tenant_id diff --git a/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v1/test_variable_assigner_v1.py b/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v1/test_variable_assigner_v1.py index 9793da129..deb3e29b8 100644 --- a/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v1/test_variable_assigner_v1.py +++ b/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v1/test_variable_assigner_v1.py @@ -5,6 +5,7 @@ from uuid import uuid4 from core.app.entities.app_invoke_entities import InvokeFrom from core.variables import ArrayStringVariable, StringVariable +from core.workflow.conversation_variable_updater import ConversationVariableUpdater from core.workflow.entities.variable_pool import VariablePool from core.workflow.enums import SystemVariableKey from core.workflow.graph_engine.entities.graph import Graph @@ -63,10 +64,11 @@ def test_overwrite_string_variable(): name="test_string_variable", value="the second value", ) + conversation_id = str(uuid.uuid4()) # construct variable pool variable_pool = VariablePool( - system_variables={SystemVariableKey.CONVERSATION_ID: "conversation_id"}, + system_variables={SystemVariableKey.CONVERSATION_ID: conversation_id}, user_inputs={}, environment_variables=[], conversation_variables=[conversation_variable], @@ -77,6 +79,9 @@ def test_overwrite_string_variable(): input_variable, ) + mock_conv_var_updater = mock.Mock(spec=ConversationVariableUpdater) + mock_conv_var_updater_factory = mock.Mock(return_value=mock_conv_var_updater) + node = VariableAssignerNode( id=str(uuid.uuid4()), graph_init_params=init_params, @@ -91,11 +96,20 @@ def test_overwrite_string_variable(): "input_variable_selector": [DEFAULT_NODE_ID, input_variable.name], }, }, + conv_var_updater_factory=mock_conv_var_updater_factory, ) - with mock.patch("core.workflow.nodes.variable_assigner.common.helpers.update_conversation_variable") as mock_run: - list(node.run()) - mock_run.assert_called_once() + list(node.run()) + expected_var = StringVariable( + id=conversation_variable.id, + name=conversation_variable.name, + description=conversation_variable.description, + selector=conversation_variable.selector, + value_type=conversation_variable.value_type, + value=input_variable.value, + ) + mock_conv_var_updater.update.assert_called_once_with(conversation_id=conversation_id, variable=expected_var) + mock_conv_var_updater.flush.assert_called_once() got = variable_pool.get(["conversation", conversation_variable.name]) assert got is not None @@ -148,9 +162,10 @@ def test_append_variable_to_array(): name="test_string_variable", value="the second value", ) + conversation_id = str(uuid.uuid4()) variable_pool = VariablePool( - system_variables={SystemVariableKey.CONVERSATION_ID: "conversation_id"}, + system_variables={SystemVariableKey.CONVERSATION_ID: conversation_id}, user_inputs={}, environment_variables=[], conversation_variables=[conversation_variable], @@ -160,6 +175,9 @@ def test_append_variable_to_array(): input_variable, ) + mock_conv_var_updater = mock.Mock(spec=ConversationVariableUpdater) + mock_conv_var_updater_factory = mock.Mock(return_value=mock_conv_var_updater) + node = VariableAssignerNode( id=str(uuid.uuid4()), graph_init_params=init_params, @@ -174,11 +192,22 @@ def test_append_variable_to_array(): "input_variable_selector": [DEFAULT_NODE_ID, input_variable.name], }, }, + conv_var_updater_factory=mock_conv_var_updater_factory, ) - with mock.patch("core.workflow.nodes.variable_assigner.common.helpers.update_conversation_variable") as mock_run: - list(node.run()) - mock_run.assert_called_once() + list(node.run()) + expected_value = list(conversation_variable.value) + expected_value.append(input_variable.value) + expected_var = ArrayStringVariable( + id=conversation_variable.id, + name=conversation_variable.name, + description=conversation_variable.description, + selector=conversation_variable.selector, + value_type=conversation_variable.value_type, + value=expected_value, + ) + mock_conv_var_updater.update.assert_called_once_with(conversation_id=conversation_id, variable=expected_var) + mock_conv_var_updater.flush.assert_called_once() got = variable_pool.get(["conversation", conversation_variable.name]) assert got is not None @@ -225,13 +254,17 @@ def test_clear_array(): value=["the first value"], ) + conversation_id = str(uuid.uuid4()) variable_pool = VariablePool( - system_variables={SystemVariableKey.CONVERSATION_ID: "conversation_id"}, + system_variables={SystemVariableKey.CONVERSATION_ID: conversation_id}, user_inputs={}, environment_variables=[], conversation_variables=[conversation_variable], ) + mock_conv_var_updater = mock.Mock(spec=ConversationVariableUpdater) + mock_conv_var_updater_factory = mock.Mock(return_value=mock_conv_var_updater) + node = VariableAssignerNode( id=str(uuid.uuid4()), graph_init_params=init_params, @@ -246,11 +279,20 @@ def test_clear_array(): "input_variable_selector": [], }, }, + conv_var_updater_factory=mock_conv_var_updater_factory, ) - with mock.patch("core.workflow.nodes.variable_assigner.common.helpers.update_conversation_variable") as mock_run: - list(node.run()) - mock_run.assert_called_once() + list(node.run()) + expected_var = ArrayStringVariable( + id=conversation_variable.id, + name=conversation_variable.name, + description=conversation_variable.description, + selector=conversation_variable.selector, + value_type=conversation_variable.value_type, + value=[], + ) + mock_conv_var_updater.update.assert_called_once_with(conversation_id=conversation_id, variable=expected_var) + mock_conv_var_updater.flush.assert_called_once() got = variable_pool.get(["conversation", conversation_variable.name]) assert got is not None diff --git a/api/tests/unit_tests/core/workflow/test_variable_pool.py b/api/tests/unit_tests/core/workflow/test_variable_pool.py index efbcdc760..bb8d34fad 100644 --- a/api/tests/unit_tests/core/workflow/test_variable_pool.py +++ b/api/tests/unit_tests/core/workflow/test_variable_pool.py @@ -1,8 +1,12 @@ import pytest +from pydantic import ValidationError from core.file import File, FileTransferMethod, FileType from core.variables import FileSegment, StringSegment +from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID from core.workflow.entities.variable_pool import VariablePool +from core.workflow.enums import SystemVariableKey +from factories.variable_factory import build_segment, segment_to_variable @pytest.fixture @@ -44,3 +48,38 @@ def test_use_long_selector(pool): result = pool.get(("node_1", "part_1", "part_2")) assert result is not None assert result.value == "test_value" + + +class TestVariablePool: + def test_constructor(self): + pool = VariablePool() + pool = VariablePool( + variable_dictionary={}, + user_inputs={}, + system_variables={}, + environment_variables=[], + conversation_variables=[], + ) + + pool = VariablePool( + user_inputs={"key": "value"}, + system_variables={SystemVariableKey.WORKFLOW_ID: "test_workflow_id"}, + environment_variables=[ + segment_to_variable( + segment=build_segment(1), + selector=[ENVIRONMENT_VARIABLE_NODE_ID, "env_var_1"], + name="env_var_1", + ) + ], + conversation_variables=[ + segment_to_variable( + segment=build_segment("1"), + selector=[CONVERSATION_VARIABLE_NODE_ID, "conv_var_1"], + name="conv_var_1", + ) + ], + ) + + def test_constructor_with_invalid_system_variable_key(self): + with pytest.raises(ValidationError): + VariablePool(system_variables={"invalid_key": "value"}) # type: ignore diff --git a/api/tests/unit_tests/core/workflow/utils/test_variable_template_parser.py b/api/tests/unit_tests/core/workflow/utils/test_variable_template_parser.py index 2f90afcf8..28ef05edd 100644 --- a/api/tests/unit_tests/core/workflow/utils/test_variable_template_parser.py +++ b/api/tests/unit_tests/core/workflow/utils/test_variable_template_parser.py @@ -1,22 +1,10 @@ -from core.variables import SecretVariable +import dataclasses + from core.workflow.entities.variable_entities import VariableSelector -from core.workflow.entities.variable_pool import VariablePool -from core.workflow.enums import SystemVariableKey from core.workflow.utils import variable_template_parser def test_extract_selectors_from_template(): - variable_pool = VariablePool( - system_variables={ - SystemVariableKey("user_id"): "fake-user-id", - }, - user_inputs={}, - environment_variables=[ - SecretVariable(name="secret_key", value="fake-secret-key"), - ], - conversation_variables=[], - ) - variable_pool.add(("node_id", "custom_query"), "fake-user-query") template = ( "Hello, {{#sys.user_id#}}! Your query is {{#node_id.custom_query#}}. And your key is {{#env.secret_key#}}." ) @@ -26,3 +14,35 @@ def test_extract_selectors_from_template(): VariableSelector(variable="#node_id.custom_query#", value_selector=["node_id", "custom_query"]), VariableSelector(variable="#env.secret_key#", value_selector=["env", "secret_key"]), ] + + +def test_invalid_references(): + @dataclasses.dataclass + class TestCase: + name: str + template: str + + cases = [ + TestCase( + name="lack of closing brace", + template="Hello, {{#sys.user_id#", + ), + TestCase( + name="lack of opening brace", + template="Hello, #sys.user_id#}}", + ), + TestCase( + name="lack selector name", + template="Hello, {{#sys#}}", + ), + TestCase( + name="empty node name part", + template="Hello, {{#.user_id#}}", + ), + ] + for idx, c in enumerate(cases, 1): + fail_msg = f"Test case {c.name} failed, index={idx}" + selectors = variable_template_parser.extract_selectors_from_template(c.template) + assert selectors == [], fail_msg + parser = variable_template_parser.VariableTemplateParser(c.template) + assert parser.extract_variable_selectors() == [], fail_msg diff --git a/api/tests/unit_tests/factories/test_variable_factory.py b/api/tests/unit_tests/factories/test_variable_factory.py new file mode 100644 index 000000000..481fbdc91 --- /dev/null +++ b/api/tests/unit_tests/factories/test_variable_factory.py @@ -0,0 +1,865 @@ +import math +from dataclasses import dataclass +from typing import Any +from uuid import uuid4 + +import pytest +from hypothesis import given +from hypothesis import strategies as st + +from core.file import File, FileTransferMethod, FileType +from core.variables import ( + ArrayNumberVariable, + ArrayObjectVariable, + ArrayStringVariable, + FloatVariable, + IntegerVariable, + ObjectSegment, + SecretVariable, + SegmentType, + StringVariable, +) +from core.variables.exc import VariableError +from core.variables.segments import ( + ArrayAnySegment, + ArrayFileSegment, + ArrayNumberSegment, + ArrayObjectSegment, + ArrayStringSegment, + FileSegment, + FloatSegment, + IntegerSegment, + NoneSegment, + ObjectSegment, + StringSegment, +) +from core.variables.types import SegmentType +from factories import variable_factory +from factories.variable_factory import TypeMismatchError, build_segment_with_type + + +def test_string_variable(): + test_data = {"value_type": "string", "name": "test_text", "value": "Hello, World!"} + result = variable_factory.build_conversation_variable_from_mapping(test_data) + assert isinstance(result, StringVariable) + + +def test_integer_variable(): + test_data = {"value_type": "number", "name": "test_int", "value": 42} + result = variable_factory.build_conversation_variable_from_mapping(test_data) + assert isinstance(result, IntegerVariable) + + +def test_float_variable(): + test_data = {"value_type": "number", "name": "test_float", "value": 3.14} + result = variable_factory.build_conversation_variable_from_mapping(test_data) + assert isinstance(result, FloatVariable) + + +def test_secret_variable(): + test_data = {"value_type": "secret", "name": "test_secret", "value": "secret_value"} + result = variable_factory.build_conversation_variable_from_mapping(test_data) + assert isinstance(result, SecretVariable) + + +def test_invalid_value_type(): + test_data = {"value_type": "unknown", "name": "test_invalid", "value": "value"} + with pytest.raises(VariableError): + variable_factory.build_conversation_variable_from_mapping(test_data) + + +def test_build_a_blank_string(): + result = variable_factory.build_conversation_variable_from_mapping( + { + "value_type": "string", + "name": "blank", + "value": "", + } + ) + assert isinstance(result, StringVariable) + assert result.value == "" + + +def test_build_a_object_variable_with_none_value(): + var = variable_factory.build_segment( + { + "key1": None, + } + ) + assert isinstance(var, ObjectSegment) + assert var.value["key1"] is None + + +def test_object_variable(): + mapping = { + "id": str(uuid4()), + "value_type": "object", + "name": "test_object", + "description": "Description of the variable.", + "value": { + "key1": "text", + "key2": 2, + }, + } + variable = variable_factory.build_conversation_variable_from_mapping(mapping) + assert isinstance(variable, ObjectSegment) + assert isinstance(variable.value["key1"], str) + assert isinstance(variable.value["key2"], int) + + +def test_array_string_variable(): + mapping = { + "id": str(uuid4()), + "value_type": "array[string]", + "name": "test_array", + "description": "Description of the variable.", + "value": [ + "text", + "text", + ], + } + variable = variable_factory.build_conversation_variable_from_mapping(mapping) + assert isinstance(variable, ArrayStringVariable) + assert isinstance(variable.value[0], str) + assert isinstance(variable.value[1], str) + + +def test_array_number_variable(): + mapping = { + "id": str(uuid4()), + "value_type": "array[number]", + "name": "test_array", + "description": "Description of the variable.", + "value": [ + 1, + 2.0, + ], + } + variable = variable_factory.build_conversation_variable_from_mapping(mapping) + assert isinstance(variable, ArrayNumberVariable) + assert isinstance(variable.value[0], int) + assert isinstance(variable.value[1], float) + + +def test_array_object_variable(): + mapping = { + "id": str(uuid4()), + "value_type": "array[object]", + "name": "test_array", + "description": "Description of the variable.", + "value": [ + { + "key1": "text", + "key2": 1, + }, + { + "key1": "text", + "key2": 1, + }, + ], + } + variable = variable_factory.build_conversation_variable_from_mapping(mapping) + assert isinstance(variable, ArrayObjectVariable) + assert isinstance(variable.value[0], dict) + assert isinstance(variable.value[1], dict) + assert isinstance(variable.value[0]["key1"], str) + assert isinstance(variable.value[0]["key2"], int) + assert isinstance(variable.value[1]["key1"], str) + assert isinstance(variable.value[1]["key2"], int) + + +def test_variable_cannot_large_than_200_kb(): + with pytest.raises(VariableError): + variable_factory.build_conversation_variable_from_mapping( + { + "id": str(uuid4()), + "value_type": "string", + "name": "test_text", + "value": "a" * 1024 * 201, + } + ) + + +def test_array_none_variable(): + var = variable_factory.build_segment([None, None, None, None]) + assert isinstance(var, ArrayAnySegment) + assert var.value == [None, None, None, None] + + +def test_build_segment_none_type(): + """Test building NoneSegment from None value.""" + segment = variable_factory.build_segment(None) + assert isinstance(segment, NoneSegment) + assert segment.value is None + assert segment.value_type == SegmentType.NONE + + +def test_build_segment_none_type_properties(): + """Test NoneSegment properties and methods.""" + segment = variable_factory.build_segment(None) + assert segment.text == "" + assert segment.log == "" + assert segment.markdown == "" + assert segment.to_object() is None + + +def test_build_segment_array_file_single_file(): + """Test building ArrayFileSegment from list with single file.""" + file = File( + id="test_file_id", + tenant_id="test_tenant_id", + type=FileType.IMAGE, + transfer_method=FileTransferMethod.REMOTE_URL, + remote_url="https://test.example.com/test-file.png", + filename="test-file", + extension=".png", + mime_type="image/png", + size=1000, + ) + segment = variable_factory.build_segment([file]) + assert isinstance(segment, ArrayFileSegment) + assert len(segment.value) == 1 + assert segment.value[0] == file + assert segment.value_type == SegmentType.ARRAY_FILE + + +def test_build_segment_array_file_multiple_files(): + """Test building ArrayFileSegment from list with multiple files.""" + file1 = File( + id="test_file_id_1", + tenant_id="test_tenant_id", + type=FileType.IMAGE, + transfer_method=FileTransferMethod.REMOTE_URL, + remote_url="https://test.example.com/test-file1.png", + filename="test-file1", + extension=".png", + mime_type="image/png", + size=1000, + ) + file2 = File( + id="test_file_id_2", + tenant_id="test_tenant_id", + type=FileType.DOCUMENT, + transfer_method=FileTransferMethod.LOCAL_FILE, + related_id="test_relation_id", + filename="test-file2", + extension=".txt", + mime_type="text/plain", + size=500, + ) + segment = variable_factory.build_segment([file1, file2]) + assert isinstance(segment, ArrayFileSegment) + assert len(segment.value) == 2 + assert segment.value[0] == file1 + assert segment.value[1] == file2 + assert segment.value_type == SegmentType.ARRAY_FILE + + +def test_build_segment_array_file_empty_list(): + """Test building ArrayFileSegment from empty list should create ArrayAnySegment.""" + segment = variable_factory.build_segment([]) + assert isinstance(segment, ArrayAnySegment) + assert segment.value == [] + assert segment.value_type == SegmentType.ARRAY_ANY + + +def test_build_segment_array_any_mixed_types(): + """Test building ArrayAnySegment from list with mixed types.""" + mixed_values = ["string", 42, 3.14, {"key": "value"}, None] + segment = variable_factory.build_segment(mixed_values) + assert isinstance(segment, ArrayAnySegment) + assert segment.value == mixed_values + assert segment.value_type == SegmentType.ARRAY_ANY + + +def test_build_segment_array_any_with_nested_arrays(): + """Test building ArrayAnySegment from list containing arrays.""" + nested_values = [["nested", "array"], [1, 2, 3], "string"] + segment = variable_factory.build_segment(nested_values) + assert isinstance(segment, ArrayAnySegment) + assert segment.value == nested_values + assert segment.value_type == SegmentType.ARRAY_ANY + + +def test_build_segment_array_any_mixed_with_files(): + """Test building ArrayAnySegment from list with files and other types.""" + file = File( + id="test_file_id", + tenant_id="test_tenant_id", + type=FileType.IMAGE, + transfer_method=FileTransferMethod.REMOTE_URL, + remote_url="https://test.example.com/test-file.png", + filename="test-file", + extension=".png", + mime_type="image/png", + size=1000, + ) + mixed_values = [file, "string", 42] + segment = variable_factory.build_segment(mixed_values) + assert isinstance(segment, ArrayAnySegment) + assert segment.value == mixed_values + assert segment.value_type == SegmentType.ARRAY_ANY + + +def test_build_segment_array_any_all_none_values(): + """Test building ArrayAnySegment from list with all None values.""" + none_values = [None, None, None] + segment = variable_factory.build_segment(none_values) + assert isinstance(segment, ArrayAnySegment) + assert segment.value == none_values + assert segment.value_type == SegmentType.ARRAY_ANY + + +def test_build_segment_array_file_properties(): + """Test ArrayFileSegment properties and methods.""" + file1 = File( + id="test_file_id_1", + tenant_id="test_tenant_id", + type=FileType.IMAGE, + transfer_method=FileTransferMethod.REMOTE_URL, + remote_url="https://test.example.com/test-file1.png", + filename="test-file1", + extension=".png", + mime_type="image/png", + size=1000, + ) + file2 = File( + id="test_file_id_2", + tenant_id="test_tenant_id", + type=FileType.DOCUMENT, + transfer_method=FileTransferMethod.REMOTE_URL, + remote_url="https://test.example.com/test-file2.txt", + filename="test-file2", + extension=".txt", + mime_type="text/plain", + size=500, + ) + segment = variable_factory.build_segment([file1, file2]) + + # Test properties + assert segment.text == "" # ArrayFileSegment text property returns empty string + assert segment.log == "" # ArrayFileSegment log property returns empty string + assert segment.markdown == f"{file1.markdown}\n{file2.markdown}" + assert segment.to_object() == [file1, file2] + + +def test_build_segment_array_any_properties(): + """Test ArrayAnySegment properties and methods.""" + mixed_values = ["string", 42, None] + segment = variable_factory.build_segment(mixed_values) + + # Test properties + assert segment.text == str(mixed_values) + assert segment.log == str(mixed_values) + assert segment.markdown == "string\n42\nNone" + assert segment.to_object() == mixed_values + + +def test_build_segment_edge_cases(): + """Test edge cases for build_segment function.""" + # Test with complex nested structures + complex_structure = [{"nested": {"deep": [1, 2, 3]}}, [{"inner": "value"}], "mixed"] + segment = variable_factory.build_segment(complex_structure) + assert isinstance(segment, ArrayAnySegment) + assert segment.value == complex_structure + + # Test with single None in list + single_none = [None] + segment = variable_factory.build_segment(single_none) + assert isinstance(segment, ArrayAnySegment) + assert segment.value == single_none + + +def test_build_segment_file_array_with_different_file_types(): + """Test ArrayFileSegment with different file types.""" + image_file = File( + id="image_id", + tenant_id="test_tenant_id", + type=FileType.IMAGE, + transfer_method=FileTransferMethod.REMOTE_URL, + remote_url="https://test.example.com/image.png", + filename="image", + extension=".png", + mime_type="image/png", + size=1000, + ) + + video_file = File( + id="video_id", + tenant_id="test_tenant_id", + type=FileType.VIDEO, + transfer_method=FileTransferMethod.LOCAL_FILE, + related_id="video_relation_id", + filename="video", + extension=".mp4", + mime_type="video/mp4", + size=5000, + ) + + audio_file = File( + id="audio_id", + tenant_id="test_tenant_id", + type=FileType.AUDIO, + transfer_method=FileTransferMethod.LOCAL_FILE, + related_id="audio_relation_id", + filename="audio", + extension=".mp3", + mime_type="audio/mpeg", + size=3000, + ) + + segment = variable_factory.build_segment([image_file, video_file, audio_file]) + assert isinstance(segment, ArrayFileSegment) + assert len(segment.value) == 3 + assert segment.value[0].type == FileType.IMAGE + assert segment.value[1].type == FileType.VIDEO + assert segment.value[2].type == FileType.AUDIO + + +@st.composite +def _generate_file(draw) -> File: + file_id = draw(st.text(min_size=1, max_size=10)) + tenant_id = draw(st.text(min_size=1, max_size=10)) + file_type, mime_type, extension = draw( + st.sampled_from( + [ + (FileType.IMAGE, "image/png", ".png"), + (FileType.VIDEO, "video/mp4", ".mp4"), + (FileType.DOCUMENT, "text/plain", ".txt"), + (FileType.AUDIO, "audio/mpeg", ".mp3"), + ] + ) + ) + filename = "test-file" + size = draw(st.integers(min_value=0, max_value=1024 * 1024)) + + transfer_method = draw(st.sampled_from(list(FileTransferMethod))) + if transfer_method == FileTransferMethod.REMOTE_URL: + url = "https://test.example.com/test-file" + file = File( + id="test_file_id", + tenant_id="test_tenant_id", + type=file_type, + transfer_method=transfer_method, + remote_url=url, + related_id=None, + filename=filename, + extension=extension, + mime_type=mime_type, + size=size, + ) + else: + relation_id = draw(st.uuids(version=4)) + + file = File( + id="test_file_id", + tenant_id="test_tenant_id", + type=file_type, + transfer_method=transfer_method, + related_id=str(relation_id), + filename=filename, + extension=extension, + mime_type=mime_type, + size=size, + ) + return file + + +def _scalar_value() -> st.SearchStrategy[int | float | str | File | None]: + return st.one_of( + st.none(), + st.integers(), + st.floats(), + st.text(), + _generate_file(), + ) + + +@given(_scalar_value()) +def test_build_segment_and_extract_values_for_scalar_types(value): + seg = variable_factory.build_segment(value) + # nan == nan yields false, so we need to use `math.isnan` to check `seg.value` here. + if isinstance(value, float) and math.isnan(value): + assert math.isnan(seg.value) + else: + assert seg.value == value + + +@given(st.lists(_scalar_value())) +def test_build_segment_and_extract_values_for_array_types(values): + seg = variable_factory.build_segment(values) + assert seg.value == values + + +def test_build_segment_type_for_scalar(): + @dataclass(frozen=True) + class TestCase: + value: int | float | str | File + expected_type: SegmentType + + file = File( + id="test_file_id", + tenant_id="test_tenant_id", + type=FileType.IMAGE, + transfer_method=FileTransferMethod.REMOTE_URL, + remote_url="https://test.example.com/test-file.png", + filename="test-file", + extension=".png", + mime_type="image/png", + size=1000, + ) + cases = [ + TestCase(0, SegmentType.NUMBER), + TestCase(0.0, SegmentType.NUMBER), + TestCase("", SegmentType.STRING), + TestCase(file, SegmentType.FILE), + ] + + for idx, c in enumerate(cases, 1): + segment = variable_factory.build_segment(c.value) + assert segment.value_type == c.expected_type, f"test case {idx} failed." + + +class TestBuildSegmentWithType: + """Test cases for build_segment_with_type function.""" + + def test_string_type(self): + """Test building a string segment with correct type.""" + result = build_segment_with_type(SegmentType.STRING, "hello") + assert isinstance(result, StringSegment) + assert result.value == "hello" + assert result.value_type == SegmentType.STRING + + def test_number_type_integer(self): + """Test building a number segment with integer value.""" + result = build_segment_with_type(SegmentType.NUMBER, 42) + assert isinstance(result, IntegerSegment) + assert result.value == 42 + assert result.value_type == SegmentType.NUMBER + + def test_number_type_float(self): + """Test building a number segment with float value.""" + result = build_segment_with_type(SegmentType.NUMBER, 3.14) + assert isinstance(result, FloatSegment) + assert result.value == 3.14 + assert result.value_type == SegmentType.NUMBER + + def test_object_type(self): + """Test building an object segment with correct type.""" + test_obj = {"key": "value", "nested": {"inner": 123}} + result = build_segment_with_type(SegmentType.OBJECT, test_obj) + assert isinstance(result, ObjectSegment) + assert result.value == test_obj + assert result.value_type == SegmentType.OBJECT + + def test_file_type(self): + """Test building a file segment with correct type.""" + test_file = File( + id="test_file_id", + tenant_id="test_tenant_id", + type=FileType.IMAGE, + transfer_method=FileTransferMethod.REMOTE_URL, + remote_url="https://test.example.com/test-file.png", + filename="test-file", + extension=".png", + mime_type="image/png", + size=1000, + storage_key="test_storage_key", + ) + result = build_segment_with_type(SegmentType.FILE, test_file) + assert isinstance(result, FileSegment) + assert result.value == test_file + assert result.value_type == SegmentType.FILE + + def test_none_type(self): + """Test building a none segment with None value.""" + result = build_segment_with_type(SegmentType.NONE, None) + assert isinstance(result, NoneSegment) + assert result.value is None + assert result.value_type == SegmentType.NONE + + def test_empty_array_string(self): + """Test building an empty array[string] segment.""" + result = build_segment_with_type(SegmentType.ARRAY_STRING, []) + assert isinstance(result, ArrayStringSegment) + assert result.value == [] + assert result.value_type == SegmentType.ARRAY_STRING + + def test_empty_array_number(self): + """Test building an empty array[number] segment.""" + result = build_segment_with_type(SegmentType.ARRAY_NUMBER, []) + assert isinstance(result, ArrayNumberSegment) + assert result.value == [] + assert result.value_type == SegmentType.ARRAY_NUMBER + + def test_empty_array_object(self): + """Test building an empty array[object] segment.""" + result = build_segment_with_type(SegmentType.ARRAY_OBJECT, []) + assert isinstance(result, ArrayObjectSegment) + assert result.value == [] + assert result.value_type == SegmentType.ARRAY_OBJECT + + def test_empty_array_file(self): + """Test building an empty array[file] segment.""" + result = build_segment_with_type(SegmentType.ARRAY_FILE, []) + assert isinstance(result, ArrayFileSegment) + assert result.value == [] + assert result.value_type == SegmentType.ARRAY_FILE + + def test_empty_array_any(self): + """Test building an empty array[any] segment.""" + result = build_segment_with_type(SegmentType.ARRAY_ANY, []) + assert isinstance(result, ArrayAnySegment) + assert result.value == [] + assert result.value_type == SegmentType.ARRAY_ANY + + def test_array_with_values(self): + """Test building array segments with actual values.""" + # Array of strings + result = build_segment_with_type(SegmentType.ARRAY_STRING, ["hello", "world"]) + assert isinstance(result, ArrayStringSegment) + assert result.value == ["hello", "world"] + assert result.value_type == SegmentType.ARRAY_STRING + + # Array of numbers + result = build_segment_with_type(SegmentType.ARRAY_NUMBER, [1, 2, 3.14]) + assert isinstance(result, ArrayNumberSegment) + assert result.value == [1, 2, 3.14] + assert result.value_type == SegmentType.ARRAY_NUMBER + + # Array of objects + result = build_segment_with_type(SegmentType.ARRAY_OBJECT, [{"a": 1}, {"b": 2}]) + assert isinstance(result, ArrayObjectSegment) + assert result.value == [{"a": 1}, {"b": 2}] + assert result.value_type == SegmentType.ARRAY_OBJECT + + def test_type_mismatch_string_to_number(self): + """Test type mismatch when expecting number but getting string.""" + with pytest.raises(TypeMismatchError) as exc_info: + build_segment_with_type(SegmentType.NUMBER, "not_a_number") + + assert "Type mismatch" in str(exc_info.value) + assert "expected number" in str(exc_info.value) + assert "str" in str(exc_info.value) + + def test_type_mismatch_number_to_string(self): + """Test type mismatch when expecting string but getting number.""" + with pytest.raises(TypeMismatchError) as exc_info: + build_segment_with_type(SegmentType.STRING, 123) + + assert "Type mismatch" in str(exc_info.value) + assert "expected string" in str(exc_info.value) + assert "int" in str(exc_info.value) + + def test_type_mismatch_none_to_string(self): + """Test type mismatch when expecting string but getting None.""" + with pytest.raises(TypeMismatchError) as exc_info: + build_segment_with_type(SegmentType.STRING, None) + + assert "Expected string, but got None" in str(exc_info.value) + + def test_type_mismatch_empty_list_to_non_array(self): + """Test type mismatch when expecting non-array type but getting empty list.""" + with pytest.raises(TypeMismatchError) as exc_info: + build_segment_with_type(SegmentType.STRING, []) + + assert "Expected string, but got empty list" in str(exc_info.value) + + def test_type_mismatch_object_to_array(self): + """Test type mismatch when expecting array but getting object.""" + with pytest.raises(TypeMismatchError) as exc_info: + build_segment_with_type(SegmentType.ARRAY_STRING, {"key": "value"}) + + assert "Type mismatch" in str(exc_info.value) + assert "expected array[string]" in str(exc_info.value) + + def test_compatible_number_types(self): + """Test that int and float are both compatible with NUMBER type.""" + # Integer should work + result_int = build_segment_with_type(SegmentType.NUMBER, 42) + assert isinstance(result_int, IntegerSegment) + assert result_int.value_type == SegmentType.NUMBER + + # Float should work + result_float = build_segment_with_type(SegmentType.NUMBER, 3.14) + assert isinstance(result_float, FloatSegment) + assert result_float.value_type == SegmentType.NUMBER + + @pytest.mark.parametrize( + ("segment_type", "value", "expected_class"), + [ + (SegmentType.STRING, "test", StringSegment), + (SegmentType.NUMBER, 42, IntegerSegment), + (SegmentType.NUMBER, 3.14, FloatSegment), + (SegmentType.OBJECT, {}, ObjectSegment), + (SegmentType.NONE, None, NoneSegment), + (SegmentType.ARRAY_STRING, [], ArrayStringSegment), + (SegmentType.ARRAY_NUMBER, [], ArrayNumberSegment), + (SegmentType.ARRAY_OBJECT, [], ArrayObjectSegment), + (SegmentType.ARRAY_ANY, [], ArrayAnySegment), + ], + ) + def test_parametrized_valid_types(self, segment_type, value, expected_class): + """Parametrized test for valid type combinations.""" + result = build_segment_with_type(segment_type, value) + assert isinstance(result, expected_class) + assert result.value == value + assert result.value_type == segment_type + + @pytest.mark.parametrize( + ("segment_type", "value"), + [ + (SegmentType.STRING, 123), + (SegmentType.NUMBER, "not_a_number"), + (SegmentType.OBJECT, "not_an_object"), + (SegmentType.ARRAY_STRING, "not_an_array"), + (SegmentType.STRING, None), + (SegmentType.NUMBER, None), + ], + ) + def test_parametrized_type_mismatches(self, segment_type, value): + """Parametrized test for type mismatches that should raise TypeMismatchError.""" + with pytest.raises(TypeMismatchError): + build_segment_with_type(segment_type, value) + + +# Test cases for ValueError scenarios in build_segment function +class TestBuildSegmentValueErrors: + """Test cases for ValueError scenarios in the build_segment function.""" + + @dataclass(frozen=True) + class ValueErrorTestCase: + """Test case data for ValueError scenarios.""" + + name: str + description: str + test_value: Any + + def _get_test_cases(self): + """Get all test cases for ValueError scenarios.""" + + # Define inline classes for complex test cases + class CustomType: + pass + + def unsupported_function(): + return "test" + + def gen(): + yield 1 + yield 2 + + return [ + self.ValueErrorTestCase( + name="unsupported_custom_type", + description="custom class that doesn't match any supported type", + test_value=CustomType(), + ), + self.ValueErrorTestCase( + name="unsupported_set_type", + description="set (unsupported collection type)", + test_value={1, 2, 3}, + ), + self.ValueErrorTestCase( + name="unsupported_tuple_type", description="tuple (unsupported type)", test_value=(1, 2, 3) + ), + self.ValueErrorTestCase( + name="unsupported_bytes_type", + description="bytes (unsupported type)", + test_value=b"hello world", + ), + self.ValueErrorTestCase( + name="unsupported_function_type", + description="function (unsupported type)", + test_value=unsupported_function, + ), + self.ValueErrorTestCase( + name="unsupported_module_type", description="module (unsupported type)", test_value=math + ), + self.ValueErrorTestCase( + name="array_with_unsupported_element_types", + description="array with unsupported element types", + test_value=[CustomType()], + ), + self.ValueErrorTestCase( + name="mixed_array_with_unsupported_types", + description="array with mix of supported and unsupported types", + test_value=["valid_string", 42, CustomType()], + ), + self.ValueErrorTestCase( + name="nested_unsupported_types", + description="nested structures containing unsupported types", + test_value=[{"valid": "data"}, CustomType()], + ), + self.ValueErrorTestCase( + name="complex_number_type", + description="complex number (unsupported type)", + test_value=3 + 4j, + ), + self.ValueErrorTestCase( + name="range_type", description="range object (unsupported type)", test_value=range(10) + ), + self.ValueErrorTestCase( + name="generator_type", + description="generator (unsupported type)", + test_value=gen(), + ), + self.ValueErrorTestCase( + name="exception_message_contains_value", + description="set to verify error message contains the actual unsupported value", + test_value={1, 2, 3}, + ), + self.ValueErrorTestCase( + name="array_with_mixed_unsupported_segment_types", + description="array processing with unsupported segment types in match", + test_value=[CustomType()], + ), + self.ValueErrorTestCase( + name="frozenset_type", + description="frozenset (unsupported type)", + test_value=frozenset([1, 2, 3]), + ), + self.ValueErrorTestCase( + name="memoryview_type", + description="memoryview (unsupported type)", + test_value=memoryview(b"hello"), + ), + self.ValueErrorTestCase( + name="slice_type", description="slice object (unsupported type)", test_value=slice(1, 10, 2) + ), + self.ValueErrorTestCase(name="type_object", description="type object (unsupported type)", test_value=type), + self.ValueErrorTestCase( + name="generic_object", description="generic object (unsupported type)", test_value=object() + ), + ] + + def test_build_segment_unsupported_types(self): + """Table-driven test for all ValueError scenarios in build_segment function.""" + test_cases = self._get_test_cases() + + for index, test_case in enumerate(test_cases, 1): + # Use test value directly + test_value = test_case.test_value + + with pytest.raises(ValueError) as exc_info: # noqa: PT012 + segment = variable_factory.build_segment(test_value) + pytest.fail(f"Test case {index} ({test_case.name}) should raise ValueError but not, result={segment}") + + error_message = str(exc_info.value) + assert "not supported value" in error_message, ( + f"Test case {index} ({test_case.name}): Expected 'not supported value' in error message, " + f"but got: {error_message}" + ) + + def test_build_segment_boolean_type_note(self): + """Note: Boolean values are actually handled as integers in Python, so they don't raise ValueError.""" + # Boolean values in Python are subclasses of int, so they get processed as integers + # True becomes IntegerSegment(value=1) and False becomes IntegerSegment(value=0) + true_segment = variable_factory.build_segment(True) + false_segment = variable_factory.build_segment(False) + + # Verify they are processed as integers, not as errors + assert true_segment.value == 1, "Test case 1 (boolean_true): Expected True to be processed as integer 1" + assert false_segment.value == 0, "Test case 2 (boolean_false): Expected False to be processed as integer 0" + assert true_segment.value_type == SegmentType.NUMBER + assert false_segment.value_type == SegmentType.NUMBER diff --git a/api/tests/unit_tests/libs/test_datetime_utils.py b/api/tests/unit_tests/libs/test_datetime_utils.py new file mode 100644 index 000000000..e7781a582 --- /dev/null +++ b/api/tests/unit_tests/libs/test_datetime_utils.py @@ -0,0 +1,20 @@ +import datetime + +from libs.datetime_utils import naive_utc_now + + +def test_naive_utc_now(monkeypatch): + tz_aware_utc_now = datetime.datetime.now(tz=datetime.UTC) + + def _now_func(tz: datetime.timezone | None) -> datetime.datetime: + return tz_aware_utc_now.astimezone(tz) + + monkeypatch.setattr("libs.datetime_utils._now_func", _now_func) + + naive_datetime = naive_utc_now() + + assert naive_datetime.tzinfo is None + assert naive_datetime.date() == tz_aware_utc_now.date() + naive_time = naive_datetime.time() + utc_time = tz_aware_utc_now.time() + assert naive_time == utc_time diff --git a/api/tests/unit_tests/models/__init__.py b/api/tests/unit_tests/models/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/api/tests/unit_tests/models/test_workflow.py b/api/tests/unit_tests/models/test_workflow.py index b79e95c7e..69163d48b 100644 --- a/api/tests/unit_tests/models/test_workflow.py +++ b/api/tests/unit_tests/models/test_workflow.py @@ -1,10 +1,15 @@ +import dataclasses import json from unittest import mock from uuid import uuid4 from constants import HIDDEN_VALUE +from core.file.enums import FileTransferMethod, FileType +from core.file.models import File from core.variables import FloatVariable, IntegerVariable, SecretVariable, StringVariable -from models.workflow import Workflow, WorkflowNodeExecutionModel +from core.variables.segments import IntegerSegment, Segment +from factories.variable_factory import build_segment +from models.workflow import Workflow, WorkflowDraftVariable, WorkflowNodeExecutionModel, is_system_variable_editable def test_environment_variables(): @@ -163,3 +168,147 @@ class TestWorkflowNodeExecution: original = {"a": 1, "b": ["2"]} node_exec.execution_metadata = json.dumps(original) assert node_exec.execution_metadata_dict == original + + +class TestIsSystemVariableEditable: + def test_is_system_variable(self): + cases = [ + ("query", True), + ("files", True), + ("dialogue_count", False), + ("conversation_id", False), + ("user_id", False), + ("app_id", False), + ("workflow_id", False), + ("workflow_run_id", False), + ] + for name, editable in cases: + assert editable == is_system_variable_editable(name) + + assert is_system_variable_editable("invalid_or_new_system_variable") == False + + +class TestWorkflowDraftVariableGetValue: + def test_get_value_by_case(self): + @dataclasses.dataclass + class TestCase: + name: str + value: Segment + + tenant_id = "test_tenant_id" + + test_file = File( + tenant_id=tenant_id, + type=FileType.IMAGE, + transfer_method=FileTransferMethod.REMOTE_URL, + remote_url="https://example.com/example.jpg", + filename="example.jpg", + extension=".jpg", + mime_type="image/jpeg", + size=100, + ) + cases: list[TestCase] = [ + TestCase( + name="number/int", + value=build_segment(1), + ), + TestCase( + name="number/float", + value=build_segment(1.0), + ), + TestCase( + name="string", + value=build_segment("a"), + ), + TestCase( + name="object", + value=build_segment({}), + ), + TestCase( + name="file", + value=build_segment(test_file), + ), + TestCase( + name="array[any]", + value=build_segment([1, "a"]), + ), + TestCase( + name="array[string]", + value=build_segment(["a", "b"]), + ), + TestCase( + name="array[number]/int", + value=build_segment([1, 2]), + ), + TestCase( + name="array[number]/float", + value=build_segment([1.0, 2.0]), + ), + TestCase( + name="array[number]/mixed", + value=build_segment([1, 2.0]), + ), + TestCase( + name="array[object]", + value=build_segment([{}, {"a": 1}]), + ), + TestCase( + name="none", + value=build_segment(None), + ), + ] + + for idx, c in enumerate(cases, 1): + fail_msg = f"test case {c.name} failed, index={idx}" + draft_var = WorkflowDraftVariable() + draft_var.set_value(c.value) + assert c.value == draft_var.get_value(), fail_msg + + def test_file_variable_preserves_all_fields(self): + """Test that File type variables preserve all fields during encoding/decoding.""" + tenant_id = "test_tenant_id" + + # Create a File with specific field values + test_file = File( + id="test_file_id", + tenant_id=tenant_id, + type=FileType.IMAGE, + transfer_method=FileTransferMethod.REMOTE_URL, + remote_url="https://example.com/test.jpg", + filename="test.jpg", + extension=".jpg", + mime_type="image/jpeg", + size=12345, # Specific size to test preservation + storage_key="test_storage_key", + ) + + # Create a FileSegment and WorkflowDraftVariable + file_segment = build_segment(test_file) + draft_var = WorkflowDraftVariable() + draft_var.set_value(file_segment) + + # Retrieve the value and verify all fields are preserved + retrieved_segment = draft_var.get_value() + retrieved_file = retrieved_segment.value + + # Verify all important fields are preserved + assert retrieved_file.id == test_file.id + assert retrieved_file.tenant_id == test_file.tenant_id + assert retrieved_file.type == test_file.type + assert retrieved_file.transfer_method == test_file.transfer_method + assert retrieved_file.remote_url == test_file.remote_url + assert retrieved_file.filename == test_file.filename + assert retrieved_file.extension == test_file.extension + assert retrieved_file.mime_type == test_file.mime_type + assert retrieved_file.size == test_file.size # This was the main issue being fixed + # Note: storage_key is not serialized in model_dump() so it won't be preserved + + # Verify the segments have the same type and the important fields match + assert file_segment.value_type == retrieved_segment.value_type + + def test_get_and_set_value(self): + draft_var = WorkflowDraftVariable() + int_var = IntegerSegment(value=1) + draft_var.set_value(int_var) + value = draft_var.get_value() + assert value == int_var diff --git a/api/tests/unit_tests/services/workflow/test_workflow_draft_variable_service.py b/api/tests/unit_tests/services/workflow/test_workflow_draft_variable_service.py new file mode 100644 index 000000000..8ae69c8d6 --- /dev/null +++ b/api/tests/unit_tests/services/workflow/test_workflow_draft_variable_service.py @@ -0,0 +1,222 @@ +import dataclasses +import secrets +from unittest import mock +from unittest.mock import Mock, patch + +import pytest +from sqlalchemy.orm import Session + +from core.app.entities.app_invoke_entities import InvokeFrom +from core.variables.types import SegmentType +from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID +from core.workflow.nodes import NodeType +from models.enums import DraftVariableType +from models.workflow import Workflow, WorkflowDraftVariable, WorkflowNodeExecutionModel +from services.workflow_draft_variable_service import ( + DraftVariableSaver, + VariableResetError, + WorkflowDraftVariableService, +) + + +class TestDraftVariableSaver: + def _get_test_app_id(self): + suffix = secrets.token_hex(6) + return f"test_app_id_{suffix}" + + def test__should_variable_be_visible(self): + mock_session = mock.MagicMock(spec=Session) + test_app_id = self._get_test_app_id() + saver = DraftVariableSaver( + session=mock_session, + app_id=test_app_id, + node_id="test_node_id", + node_type=NodeType.START, + invoke_from=InvokeFrom.DEBUGGER, + node_execution_id="test_execution_id", + ) + assert saver._should_variable_be_visible("123_456", NodeType.IF_ELSE, "output") == False + assert saver._should_variable_be_visible("123", NodeType.START, "output") == True + + def test__normalize_variable_for_start_node(self): + @dataclasses.dataclass(frozen=True) + class TestCase: + name: str + input_node_id: str + input_name: str + expected_node_id: str + expected_name: str + + _NODE_ID = "1747228642872" + cases = [ + TestCase( + name="name with `sys.` prefix should return the system node_id", + input_node_id=_NODE_ID, + input_name="sys.workflow_id", + expected_node_id=SYSTEM_VARIABLE_NODE_ID, + expected_name="workflow_id", + ), + TestCase( + name="name without `sys.` prefix should return the original input node_id", + input_node_id=_NODE_ID, + input_name="start_input", + expected_node_id=_NODE_ID, + expected_name="start_input", + ), + TestCase( + name="dummy_variable should return the original input node_id", + input_node_id=_NODE_ID, + input_name="__dummy__", + expected_node_id=_NODE_ID, + expected_name="__dummy__", + ), + ] + + mock_session = mock.MagicMock(spec=Session) + test_app_id = self._get_test_app_id() + saver = DraftVariableSaver( + session=mock_session, + app_id=test_app_id, + node_id=_NODE_ID, + node_type=NodeType.START, + invoke_from=InvokeFrom.DEBUGGER, + node_execution_id="test_execution_id", + ) + for idx, c in enumerate(cases, 1): + fail_msg = f"Test case {c.name} failed, index={idx}" + node_id, name = saver._normalize_variable_for_start_node(c.input_name) + assert node_id == c.expected_node_id, fail_msg + assert name == c.expected_name, fail_msg + + +class TestWorkflowDraftVariableService: + def _get_test_app_id(self): + suffix = secrets.token_hex(6) + return f"test_app_id_{suffix}" + + def test_reset_conversation_variable(self): + """Test resetting a conversation variable""" + mock_session = Mock(spec=Session) + service = WorkflowDraftVariableService(mock_session) + mock_workflow = Mock(spec=Workflow) + mock_workflow.app_id = self._get_test_app_id() + + # Create mock variable + mock_variable = Mock(spec=WorkflowDraftVariable) + mock_variable.get_variable_type.return_value = DraftVariableType.CONVERSATION + mock_variable.id = "var-id" + mock_variable.name = "test_var" + + # Mock the _reset_conv_var method + expected_result = Mock(spec=WorkflowDraftVariable) + with patch.object(service, "_reset_conv_var", return_value=expected_result) as mock_reset_conv: + result = service.reset_variable(mock_workflow, mock_variable) + + mock_reset_conv.assert_called_once_with(mock_workflow, mock_variable) + assert result == expected_result + + def test_reset_node_variable_with_no_execution_id(self): + """Test resetting a node variable with no execution ID - should delete variable""" + mock_session = Mock(spec=Session) + service = WorkflowDraftVariableService(mock_session) + mock_workflow = Mock(spec=Workflow) + mock_workflow.app_id = self._get_test_app_id() + + # Create mock variable with no execution ID + mock_variable = Mock(spec=WorkflowDraftVariable) + mock_variable.get_variable_type.return_value = DraftVariableType.NODE + mock_variable.node_execution_id = None + mock_variable.id = "var-id" + mock_variable.name = "test_var" + + result = service._reset_node_var(mock_workflow, mock_variable) + + # Should delete the variable and return None + mock_session.delete.assert_called_once_with(instance=mock_variable) + mock_session.flush.assert_called_once() + assert result is None + + def test_reset_node_variable_with_missing_execution_record(self): + """Test resetting a node variable when execution record doesn't exist""" + mock_session = Mock(spec=Session) + service = WorkflowDraftVariableService(mock_session) + mock_workflow = Mock(spec=Workflow) + mock_workflow.app_id = self._get_test_app_id() + + # Create mock variable with execution ID + mock_variable = Mock(spec=WorkflowDraftVariable) + mock_variable.get_variable_type.return_value = DraftVariableType.NODE + mock_variable.node_execution_id = "exec-id" + mock_variable.id = "var-id" + mock_variable.name = "test_var" + + # Mock session.scalars to return None (no execution record found) + mock_scalars = Mock() + mock_scalars.first.return_value = None + mock_session.scalars.return_value = mock_scalars + + result = service._reset_node_var(mock_workflow, mock_variable) + + # Should delete the variable and return None + mock_session.delete.assert_called_once_with(instance=mock_variable) + mock_session.flush.assert_called_once() + assert result is None + + def test_reset_node_variable_with_valid_execution_record(self): + """Test resetting a node variable with valid execution record - should restore from execution""" + mock_session = Mock(spec=Session) + service = WorkflowDraftVariableService(mock_session) + mock_workflow = Mock(spec=Workflow) + mock_workflow.app_id = self._get_test_app_id() + + # Create mock variable with execution ID + mock_variable = Mock(spec=WorkflowDraftVariable) + mock_variable.get_variable_type.return_value = DraftVariableType.NODE + mock_variable.node_execution_id = "exec-id" + mock_variable.id = "var-id" + mock_variable.name = "test_var" + mock_variable.node_id = "node-id" + mock_variable.value_type = SegmentType.STRING + + # Create mock execution record + mock_execution = Mock(spec=WorkflowNodeExecutionModel) + mock_execution.process_data_dict = {"test_var": "process_value"} + mock_execution.outputs_dict = {"test_var": "output_value"} + + # Mock session.scalars to return the execution record + mock_scalars = Mock() + mock_scalars.first.return_value = mock_execution + mock_session.scalars.return_value = mock_scalars + + # Mock workflow methods + mock_node_config = {"type": "test_node"} + mock_workflow.get_node_config_by_id.return_value = mock_node_config + mock_workflow.get_node_type_from_node_config.return_value = NodeType.LLM + + result = service._reset_node_var(mock_workflow, mock_variable) + + # Verify variable.set_value was called with the correct value + mock_variable.set_value.assert_called_once() + # Verify last_edited_at was reset + assert mock_variable.last_edited_at is None + # Verify session.flush was called + mock_session.flush.assert_called() + + # Should return the updated variable + assert result == mock_variable + + def test_reset_system_variable_raises_error(self): + """Test that resetting a system variable raises an error""" + mock_session = Mock(spec=Session) + service = WorkflowDraftVariableService(mock_session) + mock_workflow = Mock(spec=Workflow) + mock_workflow.app_id = self._get_test_app_id() + + mock_variable = Mock(spec=WorkflowDraftVariable) + mock_variable.get_variable_type.return_value = DraftVariableType.SYS # Not a valid enum value for this test + mock_variable.id = "var-id" + + with pytest.raises(VariableResetError) as exc_info: + service.reset_variable(mock_workflow, mock_variable) + assert "cannot reset system variable" in str(exc_info.value) + assert "variable_id=var-id" in str(exc_info.value) diff --git a/api/uv.lock b/api/uv.lock index a03929510..66bfdcef3 100644 --- a/api/uv.lock +++ b/api/uv.lock @@ -1284,6 +1284,7 @@ dev = [ { name = "coverage" }, { name = "dotenv-linter" }, { name = "faker" }, + { name = "hypothesis" }, { name = "lxml-stubs" }, { name = "mypy" }, { name = "pandas-stubs" }, @@ -1461,6 +1462,7 @@ dev = [ { name = "coverage", specifier = "~=7.2.4" }, { name = "dotenv-linter", specifier = "~=0.5.0" }, { name = "faker", specifier = "~=32.1.0" }, + { name = "hypothesis", specifier = ">=6.131.15" }, { name = "lxml-stubs", specifier = "~=0.5.1" }, { name = "mypy", specifier = "~=1.16.0" }, { name = "pandas-stubs", specifier = "~=2.2.3" }, @@ -2556,6 +2558,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/48/30/47d0bf6072f7252e6521f3447ccfa40b421b6824517f82854703d0f5a98b/hyperframe-6.1.0-py3-none-any.whl", hash = "sha256:b03380493a519fce58ea5af42e4a42317bf9bd425596f7a0835ffce80f1a42e5", size = 13007 }, ] +[[package]] +name = "hypothesis" +version = "6.131.15" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "attrs" }, + { name = "sortedcontainers" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/f1/6f/1e291f80627f3e043b19a86f9f6b172b910e3575577917d3122a6558410d/hypothesis-6.131.15.tar.gz", hash = "sha256:11849998ae5eecc8c586c6c98e47677fcc02d97475065f62768cfffbcc15ef7a", size = 436596, upload_time = "2025-05-07T23:04:25.127Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/b6/c7/78597bcec48e1585ea9029deb2bf2341516e90dd615a3db498413d68a4cc/hypothesis-6.131.15-py3-none-any.whl", hash = "sha256:e02e67e9f3cfd4cd4a67ccc03bf7431beccc1a084c5e90029799ddd36ce006d7", size = 501128, upload_time = "2025-05-07T23:04:22.045Z" }, +] + [[package]] name = "idna" version = "3.10" @@ -5241,6 +5256,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/37/c3/6eeb6034408dac0fa653d126c9204ade96b819c936e136c5e8a6897eee9c/socksio-1.0.0-py3-none-any.whl", hash = "sha256:95dc1f15f9b34e8d7b16f06d74b8ccf48f609af32ab33c608d08761c5dcbb1f3", size = 12763 }, ] +[[package]] +name = "sortedcontainers" +version = "2.4.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/e8/c4/ba2f8066cceb6f23394729afe52f3bf7adec04bf9ed2c820b39e19299111/sortedcontainers-2.4.0.tar.gz", hash = "sha256:25caa5a06cc30b6b83d11423433f65d1f9d76c4c6a0c90e3379eaa43b9bfdb88", size = 30594, upload_time = "2021-05-16T22:03:42.897Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/32/46/9cb0e58b2deb7f82b84065f37f3bffeb12413f947f9388e4cac22c4621ce/sortedcontainers-2.4.0-py2.py3-none-any.whl", hash = "sha256:a163dcaede0f1c021485e957a39245190e74249897e2ae4b2aa38595db237ee0", size = 29575, upload_time = "2021-05-16T22:03:41.177Z" }, +] + [[package]] name = "soupsieve" version = "2.7" diff --git a/web/app/components/base/markdown-blocks/code-block.tsx b/web/app/components/base/markdown-blocks/code-block.tsx index 87dbd834d..c88cfde9e 100644 --- a/web/app/components/base/markdown-blocks/code-block.tsx +++ b/web/app/components/base/markdown-blocks/code-block.tsx @@ -63,7 +63,7 @@ const getCorrectCapitalizationLanguageName = (language: string) => { // or use the non-minified dev environment for full errors and additional helpful warnings. // Define ECharts event parameter types -interface EChartsEventParams { +type EChartsEventParams = { type: string; seriesIndex?: number; dataIndex?: number;