From 935e72d4495530d7ee8725718c19655591582cb1 Mon Sep 17 00:00:00 2001 From: KVOJJJin Date: Tue, 13 Aug 2024 14:44:10 +0800 Subject: [PATCH] Feat: conversation variable & variable assigner node (#7222) Signed-off-by: -LAN- Co-authored-by: Joel Co-authored-by: -LAN- --- api/configs/app_config.py | 8 +- api/controllers/console/__init__.py | 1 + .../console/app/conversation_variables.py | 61 +++ api/controllers/console/app/workflow.py | 7 +- api/core/app/app_config/entities.py | 8 +- .../features/file_upload/manager.py | 2 +- .../app/apps/advanced_chat/app_generator.py | 6 +- api/core/app/apps/advanced_chat/app_runner.py | 190 +++++---- api/core/app/apps/workflow/app_runner.py | 78 ++-- api/core/app/segments/__init__.py | 2 + api/core/app/segments/exc.py | 2 + api/core/app/segments/factory.py | 53 +-- api/core/app/segments/segments.py | 16 +- api/core/file/file_obj.py | 12 +- api/core/file/message_file_parser.py | 3 +- api/core/helper/encrypter.py | 2 +- api/core/workflow/entities/node_entities.py | 2 + api/core/workflow/entities/variable_pool.py | 8 +- api/core/workflow/nodes/base_node.py | 20 +- .../nodes/variable_assigner/__init__.py | 109 +++++ api/core/workflow/workflow_engine_manager.py | 22 +- api/fields/conversation_variable_fields.py | 21 + api/fields/workflow_fields.py | 6 +- ...3fcf12ba_support_conversation_variables.py | 51 +++ api/models/__init__.py | 58 +-- api/models/account.py | 3 +- api/models/api_based_extension.py | 3 +- api/models/dataset.py | 7 +- api/models/model.py | 2 +- api/models/provider.py | 3 +- api/models/source.py | 3 +- api/models/tool.py | 3 +- api/models/tools.py | 5 +- api/models/types.py | 26 ++ api/models/web.py | 5 +- api/models/workflow.py | 64 ++- api/services/app_dsl_service.py | 3 + api/services/workflow/workflow_converter.py | 2 +- api/services/workflow_service.py | 12 +- api/tasks/remove_app_and_related_data_task.py | 14 +- .../core/app/segments/test_factory.py | 195 +++------ .../prompt/test_advanced_prompt_transform.py | 4 +- .../workflow/nodes/test_variable_assigner.py | 150 +++++++ .../models/test_conversation_variable.py | 25 ++ web/app/components/base/badge.tsx | 5 +- .../assets/vender/line/others/bubble-x.svg | 8 + .../vender/line/others/long-arrow-left.svg | 3 + .../vender/line/others/long-arrow-right.svg | 3 + .../icons/assets/vender/workflow/assigner.svg | 9 + .../icons/src/vender/line/others/BubbleX.json | 57 +++ .../icons/src/vender/line/others/BubbleX.tsx | 16 + .../src/vender/line/others/LongArrowLeft.json | 27 ++ .../src/vender/line/others/LongArrowLeft.tsx | 16 + .../vender/line/others/LongArrowRight.json | 27 ++ .../src/vender/line/others/LongArrowRight.tsx | 16 + .../icons/src/vender/line/others/index.ts | 3 + .../icons/src/vender/workflow/Assigner.json | 68 +++ .../icons/src/vender/workflow/Assigner.tsx | 16 + .../base/icons/src/vender/workflow/index.ts | 1 + web/app/components/base/input/index.tsx | 6 +- .../components/base/input/style.module.css | 7 - .../components/base/prompt-editor/index.tsx | 2 +- .../workflow-variable-block/component.tsx | 18 +- web/app/components/workflow/block-icon.tsx | 3 + .../workflow/block-selector/constants.tsx | 5 + web/app/components/workflow/constants.ts | 16 + .../workflow/header/chat-variable-button.tsx | 24 ++ .../components/workflow/header/env-button.tsx | 10 +- web/app/components/workflow/header/index.tsx | 10 +- .../workflow/hooks/use-nodes-sync-draft.ts | 2 + .../hooks/use-workflow-interactions.ts | 3 + .../workflow/hooks/use-workflow-start-run.tsx | 2 + .../workflow/hooks/use-workflow-variables.ts | 9 +- .../components/workflow/hooks/use-workflow.ts | 3 + .../add-variable-popup-with-position.tsx | 1 + .../components/before-run-form/form-item.tsx | 35 +- .../components/editor/code-editor/index.tsx | 4 +- .../components/input-support-select-var.tsx | 6 +- .../nodes/_base/components/option-card.tsx | 2 +- .../readonly-input-with-select-var.tsx | 12 +- .../nodes/_base/components/selector.tsx | 9 +- .../nodes/_base/components/variable-tag.tsx | 10 +- .../components/variable/constant-field.tsx | 6 +- .../nodes/_base/components/variable/utils.ts | 51 ++- .../variable/var-reference-picker.tsx | 116 ++++-- .../variable/var-reference-vars.tsx | 23 +- .../nodes/_base/hooks/use-node-help-link.ts | 2 + .../nodes/_base/hooks/use-one-step-run.ts | 11 +- .../workflow/nodes/assigner/default.ts | 46 +++ .../workflow/nodes/assigner/node.tsx | 47 +++ .../workflow/nodes/assigner/panel.tsx | 87 ++++ .../workflow/nodes/assigner/types.ts | 13 + .../workflow/nodes/assigner/use-config.ts | 144 +++++++ .../workflow/nodes/assigner/utils.ts | 5 + .../components/workflow/nodes/constants.ts | 4 + .../components/workflow/nodes/end/node.tsx | 13 +- .../key-value/key-value-edit/index.tsx | 9 + .../key-value/key-value-edit/input-item.tsx | 4 + .../key-value/key-value-edit/item.tsx | 35 +- .../if-else/components/condition-value.tsx | 13 +- .../components/node-group-item.tsx | 5 +- .../components/node-variable-item.tsx | 14 +- .../workflow/nodes/variable-assigner/hooks.ts | 1 + .../components/array-value-list.tsx | 72 ++++ .../components/object-value-item.tsx | 135 ++++++ .../components/object-value-list.tsx | 36 ++ .../components/variable-item.tsx | 49 +++ .../components/variable-modal-trigger.tsx | 69 ++++ .../components/variable-modal.tsx | 388 ++++++++++++++++++ .../components/variable-type-select.tsx | 66 +++ .../panel/chat-variable-panel/index.tsx | 202 +++++++++ .../panel/chat-variable-panel/type.ts | 8 + .../panel/debug-and-preview/chat-wrapper.tsx | 69 ++-- .../conversation-variable-modal.tsx | 155 +++++++ .../panel/debug-and-preview/index.tsx | 68 ++- .../panel/debug-and-preview/user-input.tsx | 61 +-- .../workflow/panel/env-panel/env-item.tsx | 53 +++ .../workflow/panel/env-panel/index.tsx | 43 +- .../panel/env-panel/variable-modal.tsx | 10 +- .../panel/env-panel/variable-trigger.tsx | 1 - web/app/components/workflow/panel/index.tsx | 7 + web/app/components/workflow/store.ts | 27 ++ web/app/components/workflow/types.ts | 13 + web/app/styles/globals.css | 6 + web/i18n/en-US/workflow.ts | 41 ++ web/i18n/zh-Hans/workflow.ts | 41 ++ web/service/workflow.ts | 7 +- web/types/workflow.ts | 10 + 128 files changed, 3354 insertions(+), 683 deletions(-) create mode 100644 api/controllers/console/app/conversation_variables.py create mode 100644 api/core/app/segments/exc.py create mode 100644 api/core/workflow/nodes/variable_assigner/__init__.py create mode 100644 api/fields/conversation_variable_fields.py create mode 100644 api/migrations/versions/2024_08_13_0633-63a83fcf12ba_support_conversation_variables.py create mode 100644 api/models/types.py create mode 100644 api/tests/unit_tests/core/workflow/nodes/test_variable_assigner.py create mode 100644 api/tests/unit_tests/models/test_conversation_variable.py create mode 100644 web/app/components/base/icons/assets/vender/line/others/bubble-x.svg create mode 100644 web/app/components/base/icons/assets/vender/line/others/long-arrow-left.svg create mode 100644 web/app/components/base/icons/assets/vender/line/others/long-arrow-right.svg create mode 100644 web/app/components/base/icons/assets/vender/workflow/assigner.svg create mode 100644 web/app/components/base/icons/src/vender/line/others/BubbleX.json create mode 100644 web/app/components/base/icons/src/vender/line/others/BubbleX.tsx create mode 100644 web/app/components/base/icons/src/vender/line/others/LongArrowLeft.json create mode 100644 web/app/components/base/icons/src/vender/line/others/LongArrowLeft.tsx create mode 100644 web/app/components/base/icons/src/vender/line/others/LongArrowRight.json create mode 100644 web/app/components/base/icons/src/vender/line/others/LongArrowRight.tsx create mode 100644 web/app/components/base/icons/src/vender/workflow/Assigner.json create mode 100644 web/app/components/base/icons/src/vender/workflow/Assigner.tsx delete mode 100644 web/app/components/base/input/style.module.css create mode 100644 web/app/components/workflow/header/chat-variable-button.tsx create mode 100644 web/app/components/workflow/nodes/assigner/default.ts create mode 100644 web/app/components/workflow/nodes/assigner/node.tsx create mode 100644 web/app/components/workflow/nodes/assigner/panel.tsx create mode 100644 web/app/components/workflow/nodes/assigner/types.ts create mode 100644 web/app/components/workflow/nodes/assigner/use-config.ts create mode 100644 web/app/components/workflow/nodes/assigner/utils.ts create mode 100644 web/app/components/workflow/panel/chat-variable-panel/components/array-value-list.tsx create mode 100644 web/app/components/workflow/panel/chat-variable-panel/components/object-value-item.tsx create mode 100644 web/app/components/workflow/panel/chat-variable-panel/components/object-value-list.tsx create mode 100644 web/app/components/workflow/panel/chat-variable-panel/components/variable-item.tsx create mode 100644 web/app/components/workflow/panel/chat-variable-panel/components/variable-modal-trigger.tsx create mode 100644 web/app/components/workflow/panel/chat-variable-panel/components/variable-modal.tsx create mode 100644 web/app/components/workflow/panel/chat-variable-panel/components/variable-type-select.tsx create mode 100644 web/app/components/workflow/panel/chat-variable-panel/index.tsx create mode 100644 web/app/components/workflow/panel/chat-variable-panel/type.ts create mode 100644 web/app/components/workflow/panel/debug-and-preview/conversation-variable-modal.tsx create mode 100644 web/app/components/workflow/panel/env-panel/env-item.tsx diff --git a/api/configs/app_config.py b/api/configs/app_config.py index a5a4fc788..b277760ed 100644 --- a/api/configs/app_config.py +++ b/api/configs/app_config.py @@ -12,19 +12,14 @@ from configs.packaging import PackagingInfo class DifyConfig( # Packaging info PackagingInfo, - # Deployment configs DeploymentConfig, - # Feature configs FeatureConfig, - # Middleware configs MiddlewareConfig, - # Extra service configs ExtraServiceConfig, - # Enterprise feature configs # **Before using, please contact business@dify.ai by email to inquire about licensing matters.** EnterpriseFeatureConfig, @@ -36,7 +31,6 @@ class DifyConfig( env_file='.env', env_file_encoding='utf-8', frozen=True, - # ignore extra attributes extra='ignore', ) @@ -67,3 +61,5 @@ class DifyConfig( SSRF_PROXY_HTTPS_URL: str | None = None MODERATION_BUFFER_SIZE: int = Field(default=300, description='The buffer size for moderation.') + + MAX_VARIABLE_SIZE: int = Field(default=5 * 1024, description='The maximum size of a variable. default is 5KB.') diff --git a/api/controllers/console/__init__.py b/api/controllers/console/__init__.py index bef40bea7..b2b9d8d49 100644 --- a/api/controllers/console/__init__.py +++ b/api/controllers/console/__init__.py @@ -17,6 +17,7 @@ from .app import ( audio, completion, conversation, + conversation_variables, generator, message, model_config, diff --git a/api/controllers/console/app/conversation_variables.py b/api/controllers/console/app/conversation_variables.py new file mode 100644 index 000000000..aa0722ea3 --- /dev/null +++ b/api/controllers/console/app/conversation_variables.py @@ -0,0 +1,61 @@ +from flask_restful import Resource, marshal_with, reqparse +from sqlalchemy import select +from sqlalchemy.orm import Session + +from controllers.console import api +from controllers.console.app.wraps import get_app_model +from controllers.console.setup import setup_required +from controllers.console.wraps import account_initialization_required +from extensions.ext_database import db +from fields.conversation_variable_fields import paginated_conversation_variable_fields +from libs.login import login_required +from models import ConversationVariable +from models.model import AppMode + + +class ConversationVariablesApi(Resource): + @setup_required + @login_required + @account_initialization_required + @get_app_model(mode=AppMode.ADVANCED_CHAT) + @marshal_with(paginated_conversation_variable_fields) + def get(self, app_model): + parser = reqparse.RequestParser() + parser.add_argument('conversation_id', type=str, location='args') + args = parser.parse_args() + + stmt = ( + select(ConversationVariable) + .where(ConversationVariable.app_id == app_model.id) + .order_by(ConversationVariable.created_at) + ) + if args['conversation_id']: + stmt = stmt.where(ConversationVariable.conversation_id == args['conversation_id']) + else: + raise ValueError('conversation_id is required') + + # NOTE: This is a temporary solution to avoid performance issues. + page = 1 + page_size = 100 + stmt = stmt.limit(page_size).offset((page - 1) * page_size) + + with Session(db.engine) as session: + rows = session.scalars(stmt).all() + + return { + 'page': page, + 'limit': page_size, + 'total': len(rows), + 'has_more': False, + 'data': [ + { + 'created_at': row.created_at, + 'updated_at': row.updated_at, + **row.to_variable().model_dump(), + } + for row in rows + ], + } + + +api.add_resource(ConversationVariablesApi, '/apps//conversation-variables') diff --git a/api/controllers/console/app/workflow.py b/api/controllers/console/app/workflow.py index 686ef7b4b..6eb97b6c8 100644 --- a/api/controllers/console/app/workflow.py +++ b/api/controllers/console/app/workflow.py @@ -74,6 +74,7 @@ class DraftWorkflowApi(Resource): parser.add_argument('hash', type=str, required=False, location='json') # TODO: set this to required=True after frontend is updated parser.add_argument('environment_variables', type=list, required=False, location='json') + parser.add_argument('conversation_variables', type=list, required=False, location='json') args = parser.parse_args() elif 'text/plain' in content_type: try: @@ -88,7 +89,8 @@ class DraftWorkflowApi(Resource): 'graph': data.get('graph'), 'features': data.get('features'), 'hash': data.get('hash'), - 'environment_variables': data.get('environment_variables') + 'environment_variables': data.get('environment_variables'), + 'conversation_variables': data.get('conversation_variables'), } except json.JSONDecodeError: return {'message': 'Invalid JSON data'}, 400 @@ -100,6 +102,8 @@ class DraftWorkflowApi(Resource): try: environment_variables_list = args.get('environment_variables') or [] environment_variables = [factory.build_variable_from_mapping(obj) for obj in environment_variables_list] + conversation_variables_list = args.get('conversation_variables') or [] + conversation_variables = [factory.build_variable_from_mapping(obj) for obj in conversation_variables_list] workflow = workflow_service.sync_draft_workflow( app_model=app_model, graph=args['graph'], @@ -107,6 +111,7 @@ class DraftWorkflowApi(Resource): unique_hash=args.get('hash'), account=current_user, environment_variables=environment_variables, + conversation_variables=conversation_variables, ) except WorkflowHashNotEqualError: raise DraftWorkflowNotSync() diff --git a/api/core/app/app_config/entities.py b/api/core/app/app_config/entities.py index a490ddd67..05a42a898 100644 --- a/api/core/app/app_config/entities.py +++ b/api/core/app/app_config/entities.py @@ -3,8 +3,9 @@ from typing import Any, Optional from pydantic import BaseModel +from core.file.file_obj import FileExtraConfig from core.model_runtime.entities.message_entities import PromptMessageRole -from models.model import AppMode +from models import AppMode class ModelConfigEntity(BaseModel): @@ -200,11 +201,6 @@ class TracingConfigEntity(BaseModel): tracing_provider: str -class FileExtraConfig(BaseModel): - """ - File Upload Entity. - """ - image_config: Optional[dict[str, Any]] = None class AppAdditionalFeatures(BaseModel): diff --git a/api/core/app/app_config/features/file_upload/manager.py b/api/core/app/app_config/features/file_upload/manager.py index 86799fb1a..3da3c2edd 100644 --- a/api/core/app/app_config/features/file_upload/manager.py +++ b/api/core/app/app_config/features/file_upload/manager.py @@ -1,7 +1,7 @@ from collections.abc import Mapping from typing import Any, Optional -from core.app.app_config.entities import FileExtraConfig +from core.file.file_obj import FileExtraConfig class FileUploadConfigManager: diff --git a/api/core/app/apps/advanced_chat/app_generator.py b/api/core/app/apps/advanced_chat/app_generator.py index e854ea18b..0cde65999 100644 --- a/api/core/app/apps/advanced_chat/app_generator.py +++ b/api/core/app/apps/advanced_chat/app_generator.py @@ -113,7 +113,6 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): contexts.tenant_id.set(application_generate_entity.app_config.tenant_id) return self._generate( - app_model=app_model, workflow=workflow, user=user, invoke_from=invoke_from, @@ -180,7 +179,6 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): contexts.tenant_id.set(application_generate_entity.app_config.tenant_id) return self._generate( - app_model=app_model, workflow=workflow, user=user, invoke_from=InvokeFrom.DEBUGGER, @@ -189,12 +187,12 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): stream=stream ) - def _generate(self, app_model: App, + def _generate(self, *, workflow: Workflow, user: Union[Account, EndUser], invoke_from: InvokeFrom, application_generate_entity: AdvancedChatAppGenerateEntity, - conversation: Conversation = None, + conversation: Conversation | None = None, stream: bool = True) \ -> Union[dict, Generator[dict, None, None]]: is_first_conversation = False diff --git a/api/core/app/apps/advanced_chat/app_runner.py b/api/core/app/apps/advanced_chat/app_runner.py index 18db0ab22..47c53531f 100644 --- a/api/core/app/apps/advanced_chat/app_runner.py +++ b/api/core/app/apps/advanced_chat/app_runner.py @@ -4,6 +4,9 @@ import time from collections.abc import Mapping from typing import Any, Optional, cast +from sqlalchemy import select +from sqlalchemy.orm import Session + from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig from core.app.apps.advanced_chat.workflow_event_trigger_callback import WorkflowEventTriggerCallback from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom @@ -17,11 +20,12 @@ from core.app.entities.queue_entities import QueueAnnotationReplyEvent, QueueSto from core.moderation.base import ModerationException from core.workflow.callbacks.base_workflow_callback import WorkflowCallback from core.workflow.entities.node_entities import SystemVariable +from core.workflow.entities.variable_pool import VariablePool from core.workflow.nodes.base_node import UserFrom from core.workflow.workflow_engine_manager import WorkflowEngineManager from extensions.ext_database import db from models.model import App, Conversation, EndUser, Message -from models.workflow import Workflow +from models.workflow import ConversationVariable, Workflow logger = logging.getLogger(__name__) @@ -31,10 +35,13 @@ class AdvancedChatAppRunner(AppRunner): AdvancedChat Application Runner """ - def run(self, application_generate_entity: AdvancedChatAppGenerateEntity, - queue_manager: AppQueueManager, - conversation: Conversation, - message: Message) -> None: + def run( + self, + application_generate_entity: AdvancedChatAppGenerateEntity, + queue_manager: AppQueueManager, + conversation: Conversation, + message: Message, + ) -> None: """ Run application :param application_generate_entity: application generate entity @@ -48,11 +55,11 @@ class AdvancedChatAppRunner(AppRunner): app_record = db.session.query(App).filter(App.id == app_config.app_id).first() if not app_record: - raise ValueError("App not found") + raise ValueError('App not found') workflow = self.get_workflow(app_model=app_record, workflow_id=app_config.workflow_id) if not workflow: - raise ValueError("Workflow not initialized") + raise ValueError('Workflow not initialized') inputs = application_generate_entity.inputs query = application_generate_entity.query @@ -68,35 +75,66 @@ class AdvancedChatAppRunner(AppRunner): # moderation if self.handle_input_moderation( - queue_manager=queue_manager, - app_record=app_record, - app_generate_entity=application_generate_entity, - inputs=inputs, - query=query, - message_id=message.id + queue_manager=queue_manager, + app_record=app_record, + app_generate_entity=application_generate_entity, + inputs=inputs, + query=query, + message_id=message.id, ): return # annotation reply if self.handle_annotation_reply( - app_record=app_record, - message=message, - query=query, - queue_manager=queue_manager, - app_generate_entity=application_generate_entity + app_record=app_record, + message=message, + query=query, + queue_manager=queue_manager, + app_generate_entity=application_generate_entity, ): return db.session.close() - workflow_callbacks: list[WorkflowCallback] = [WorkflowEventTriggerCallback( - queue_manager=queue_manager, - workflow=workflow - )] + workflow_callbacks: list[WorkflowCallback] = [ + WorkflowEventTriggerCallback(queue_manager=queue_manager, workflow=workflow) + ] - if bool(os.environ.get("DEBUG", 'False').lower() == 'true'): + if bool(os.environ.get('DEBUG', 'False').lower() == 'true'): workflow_callbacks.append(WorkflowLoggingCallback()) + # Init conversation variables + stmt = select(ConversationVariable).where( + ConversationVariable.app_id == conversation.app_id, ConversationVariable.conversation_id == conversation.id + ) + with Session(db.engine) as session: + conversation_variables = session.scalars(stmt).all() + if not conversation_variables: + conversation_variables = [ + ConversationVariable.from_variable( + app_id=conversation.app_id, conversation_id=conversation.id, variable=variable + ) + for variable in workflow.conversation_variables + ] + session.add_all(conversation_variables) + session.commit() + # Convert database entities to variables + conversation_variables = [item.to_variable() for item in conversation_variables] + + # Create a variable pool. + system_inputs = { + SystemVariable.QUERY: query, + SystemVariable.FILES: files, + SystemVariable.CONVERSATION_ID: conversation.id, + SystemVariable.USER_ID: user_id, + } + variable_pool = VariablePool( + system_variables=system_inputs, + user_inputs=inputs, + environment_variables=workflow.environment_variables, + conversation_variables=conversation_variables, + ) + # RUN WORKFLOW workflow_engine_manager = WorkflowEngineManager() workflow_engine_manager.run_workflow( @@ -106,43 +144,30 @@ class AdvancedChatAppRunner(AppRunner): if application_generate_entity.invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else UserFrom.END_USER, invoke_from=application_generate_entity.invoke_from, - user_inputs=inputs, - system_inputs={ - SystemVariable.QUERY: query, - SystemVariable.FILES: files, - SystemVariable.CONVERSATION_ID: conversation.id, - SystemVariable.USER_ID: user_id - }, callbacks=workflow_callbacks, - call_depth=application_generate_entity.call_depth + call_depth=application_generate_entity.call_depth, + variable_pool=variable_pool, ) - def single_iteration_run(self, app_id: str, workflow_id: str, - queue_manager: AppQueueManager, - inputs: dict, node_id: str, user_id: str) -> None: + def single_iteration_run( + self, app_id: str, workflow_id: str, queue_manager: AppQueueManager, inputs: dict, node_id: str, user_id: str + ) -> None: """ Single iteration run """ app_record: App = db.session.query(App).filter(App.id == app_id).first() if not app_record: - raise ValueError("App not found") - + raise ValueError('App not found') + workflow = self.get_workflow(app_model=app_record, workflow_id=workflow_id) if not workflow: - raise ValueError("Workflow not initialized") - - workflow_callbacks = [WorkflowEventTriggerCallback( - queue_manager=queue_manager, - workflow=workflow - )] + raise ValueError('Workflow not initialized') + + workflow_callbacks = [WorkflowEventTriggerCallback(queue_manager=queue_manager, workflow=workflow)] workflow_engine_manager = WorkflowEngineManager() workflow_engine_manager.single_step_run_iteration_workflow_node( - workflow=workflow, - node_id=node_id, - user_id=user_id, - user_inputs=inputs, - callbacks=workflow_callbacks + workflow=workflow, node_id=node_id, user_id=user_id, user_inputs=inputs, callbacks=workflow_callbacks ) def get_workflow(self, app_model: App, workflow_id: str) -> Optional[Workflow]: @@ -150,22 +175,25 @@ class AdvancedChatAppRunner(AppRunner): Get workflow """ # fetch workflow by workflow_id - workflow = db.session.query(Workflow).filter( - Workflow.tenant_id == app_model.tenant_id, - Workflow.app_id == app_model.id, - Workflow.id == workflow_id - ).first() + workflow = ( + db.session.query(Workflow) + .filter( + Workflow.tenant_id == app_model.tenant_id, Workflow.app_id == app_model.id, Workflow.id == workflow_id + ) + .first() + ) # return workflow return workflow def handle_input_moderation( - self, queue_manager: AppQueueManager, - app_record: App, - app_generate_entity: AdvancedChatAppGenerateEntity, - inputs: Mapping[str, Any], - query: str, - message_id: str + self, + queue_manager: AppQueueManager, + app_record: App, + app_generate_entity: AdvancedChatAppGenerateEntity, + inputs: Mapping[str, Any], + query: str, + message_id: str, ) -> bool: """ Handle input moderation @@ -192,17 +220,20 @@ class AdvancedChatAppRunner(AppRunner): queue_manager=queue_manager, text=str(e), stream=app_generate_entity.stream, - stopped_by=QueueStopEvent.StopBy.INPUT_MODERATION + stopped_by=QueueStopEvent.StopBy.INPUT_MODERATION, ) return True return False - def handle_annotation_reply(self, app_record: App, - message: Message, - query: str, - queue_manager: AppQueueManager, - app_generate_entity: AdvancedChatAppGenerateEntity) -> bool: + def handle_annotation_reply( + self, + app_record: App, + message: Message, + query: str, + queue_manager: AppQueueManager, + app_generate_entity: AdvancedChatAppGenerateEntity, + ) -> bool: """ Handle annotation reply :param app_record: app record @@ -217,29 +248,27 @@ class AdvancedChatAppRunner(AppRunner): message=message, query=query, user_id=app_generate_entity.user_id, - invoke_from=app_generate_entity.invoke_from + invoke_from=app_generate_entity.invoke_from, ) if annotation_reply: queue_manager.publish( - QueueAnnotationReplyEvent(message_annotation_id=annotation_reply.id), - PublishFrom.APPLICATION_MANAGER + QueueAnnotationReplyEvent(message_annotation_id=annotation_reply.id), PublishFrom.APPLICATION_MANAGER ) self._stream_output( queue_manager=queue_manager, text=annotation_reply.content, stream=app_generate_entity.stream, - stopped_by=QueueStopEvent.StopBy.ANNOTATION_REPLY + stopped_by=QueueStopEvent.StopBy.ANNOTATION_REPLY, ) return True return False - def _stream_output(self, queue_manager: AppQueueManager, - text: str, - stream: bool, - stopped_by: QueueStopEvent.StopBy) -> None: + def _stream_output( + self, queue_manager: AppQueueManager, text: str, stream: bool, stopped_by: QueueStopEvent.StopBy + ) -> None: """ Direct output :param queue_manager: application queue manager @@ -250,21 +279,10 @@ class AdvancedChatAppRunner(AppRunner): if stream: index = 0 for token in text: - queue_manager.publish( - QueueTextChunkEvent( - text=token - ), PublishFrom.APPLICATION_MANAGER - ) + queue_manager.publish(QueueTextChunkEvent(text=token), PublishFrom.APPLICATION_MANAGER) index += 1 time.sleep(0.01) else: - queue_manager.publish( - QueueTextChunkEvent( - text=text - ), PublishFrom.APPLICATION_MANAGER - ) + queue_manager.publish(QueueTextChunkEvent(text=text), PublishFrom.APPLICATION_MANAGER) - queue_manager.publish( - QueueStopEvent(stopped_by=stopped_by), - PublishFrom.APPLICATION_MANAGER - ) + queue_manager.publish(QueueStopEvent(stopped_by=stopped_by), PublishFrom.APPLICATION_MANAGER) diff --git a/api/core/app/apps/workflow/app_runner.py b/api/core/app/apps/workflow/app_runner.py index 24f4a8321..17a99cf1c 100644 --- a/api/core/app/apps/workflow/app_runner.py +++ b/api/core/app/apps/workflow/app_runner.py @@ -12,6 +12,7 @@ from core.app.entities.app_invoke_entities import ( ) from core.workflow.callbacks.base_workflow_callback import WorkflowCallback from core.workflow.entities.node_entities import SystemVariable +from core.workflow.entities.variable_pool import VariablePool from core.workflow.nodes.base_node import UserFrom from core.workflow.workflow_engine_manager import WorkflowEngineManager from extensions.ext_database import db @@ -26,8 +27,7 @@ class WorkflowAppRunner: Workflow Application Runner """ - def run(self, application_generate_entity: WorkflowAppGenerateEntity, - queue_manager: AppQueueManager) -> None: + def run(self, application_generate_entity: WorkflowAppGenerateEntity, queue_manager: AppQueueManager) -> None: """ Run application :param application_generate_entity: application generate entity @@ -47,25 +47,36 @@ class WorkflowAppRunner: app_record = db.session.query(App).filter(App.id == app_config.app_id).first() if not app_record: - raise ValueError("App not found") + raise ValueError('App not found') workflow = self.get_workflow(app_model=app_record, workflow_id=app_config.workflow_id) if not workflow: - raise ValueError("Workflow not initialized") + raise ValueError('Workflow not initialized') inputs = application_generate_entity.inputs files = application_generate_entity.files db.session.close() - workflow_callbacks: list[WorkflowCallback] = [WorkflowEventTriggerCallback( - queue_manager=queue_manager, - workflow=workflow - )] + workflow_callbacks: list[WorkflowCallback] = [ + WorkflowEventTriggerCallback(queue_manager=queue_manager, workflow=workflow) + ] - if bool(os.environ.get("DEBUG", 'False').lower() == 'true'): + if bool(os.environ.get('DEBUG', 'False').lower() == 'true'): workflow_callbacks.append(WorkflowLoggingCallback()) + # Create a variable pool. + system_inputs = { + SystemVariable.FILES: files, + SystemVariable.USER_ID: user_id, + } + variable_pool = VariablePool( + system_variables=system_inputs, + user_inputs=inputs, + environment_variables=workflow.environment_variables, + conversation_variables=[], + ) + # RUN WORKFLOW workflow_engine_manager = WorkflowEngineManager() workflow_engine_manager.run_workflow( @@ -75,44 +86,33 @@ class WorkflowAppRunner: if application_generate_entity.invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else UserFrom.END_USER, invoke_from=application_generate_entity.invoke_from, - user_inputs=inputs, - system_inputs={ - SystemVariable.FILES: files, - SystemVariable.USER_ID: user_id - }, callbacks=workflow_callbacks, - call_depth=application_generate_entity.call_depth + call_depth=application_generate_entity.call_depth, + variable_pool=variable_pool, ) - def single_iteration_run(self, app_id: str, workflow_id: str, - queue_manager: AppQueueManager, - inputs: dict, node_id: str, user_id: str) -> None: + def single_iteration_run( + self, app_id: str, workflow_id: str, queue_manager: AppQueueManager, inputs: dict, node_id: str, user_id: str + ) -> None: """ Single iteration run """ - app_record: App = db.session.query(App).filter(App.id == app_id).first() + app_record = db.session.query(App).filter(App.id == app_id).first() if not app_record: - raise ValueError("App not found") - + raise ValueError('App not found') + if not app_record.workflow_id: - raise ValueError("Workflow not initialized") + raise ValueError('Workflow not initialized') workflow = self.get_workflow(app_model=app_record, workflow_id=workflow_id) if not workflow: - raise ValueError("Workflow not initialized") - - workflow_callbacks = [WorkflowEventTriggerCallback( - queue_manager=queue_manager, - workflow=workflow - )] + raise ValueError('Workflow not initialized') + + workflow_callbacks = [WorkflowEventTriggerCallback(queue_manager=queue_manager, workflow=workflow)] workflow_engine_manager = WorkflowEngineManager() workflow_engine_manager.single_step_run_iteration_workflow_node( - workflow=workflow, - node_id=node_id, - user_id=user_id, - user_inputs=inputs, - callbacks=workflow_callbacks + workflow=workflow, node_id=node_id, user_id=user_id, user_inputs=inputs, callbacks=workflow_callbacks ) def get_workflow(self, app_model: App, workflow_id: str) -> Optional[Workflow]: @@ -120,11 +120,13 @@ class WorkflowAppRunner: Get workflow """ # fetch workflow by workflow_id - workflow = db.session.query(Workflow).filter( - Workflow.tenant_id == app_model.tenant_id, - Workflow.app_id == app_model.id, - Workflow.id == workflow_id - ).first() + workflow = ( + db.session.query(Workflow) + .filter( + Workflow.tenant_id == app_model.tenant_id, Workflow.app_id == app_model.id, Workflow.id == workflow_id + ) + .first() + ) # return workflow return workflow diff --git a/api/core/app/segments/__init__.py b/api/core/app/segments/__init__.py index d5cd0a589..174e24126 100644 --- a/api/core/app/segments/__init__.py +++ b/api/core/app/segments/__init__.py @@ -1,6 +1,7 @@ from .segment_group import SegmentGroup from .segments import ( ArrayAnySegment, + ArraySegment, FileSegment, FloatSegment, IntegerSegment, @@ -50,4 +51,5 @@ __all__ = [ 'ArrayNumberVariable', 'ArrayObjectVariable', 'ArrayFileVariable', + 'ArraySegment', ] diff --git a/api/core/app/segments/exc.py b/api/core/app/segments/exc.py new file mode 100644 index 000000000..d15d6d500 --- /dev/null +++ b/api/core/app/segments/exc.py @@ -0,0 +1,2 @@ +class VariableError(Exception): + pass diff --git a/api/core/app/segments/factory.py b/api/core/app/segments/factory.py index 1196284b1..91ff1fdb3 100644 --- a/api/core/app/segments/factory.py +++ b/api/core/app/segments/factory.py @@ -1,8 +1,10 @@ from collections.abc import Mapping from typing import Any +from configs import dify_config from core.file.file_obj import FileVar +from .exc import VariableError from .segments import ( ArrayAnySegment, FileSegment, @@ -29,39 +31,43 @@ from .variables import ( ) -def build_variable_from_mapping(m: Mapping[str, Any], /) -> Variable: - if (value_type := m.get('value_type')) is None: - raise ValueError('missing value type') - if not m.get('name'): - raise ValueError('missing name') - if (value := m.get('value')) is None: - raise ValueError('missing value') +def build_variable_from_mapping(mapping: Mapping[str, Any], /) -> Variable: + if (value_type := mapping.get('value_type')) is None: + raise VariableError('missing value type') + if not mapping.get('name'): + raise VariableError('missing name') + if (value := mapping.get('value')) is None: + raise VariableError('missing value') match value_type: case SegmentType.STRING: - return StringVariable.model_validate(m) + result = StringVariable.model_validate(mapping) case SegmentType.SECRET: - return SecretVariable.model_validate(m) + result = SecretVariable.model_validate(mapping) case SegmentType.NUMBER if isinstance(value, int): - return IntegerVariable.model_validate(m) + result = IntegerVariable.model_validate(mapping) case SegmentType.NUMBER if isinstance(value, float): - return FloatVariable.model_validate(m) + result = FloatVariable.model_validate(mapping) case SegmentType.NUMBER if not isinstance(value, float | int): - raise ValueError(f'invalid number value {value}') + raise VariableError(f'invalid number value {value}') case SegmentType.FILE: - return FileVariable.model_validate(m) + result = FileVariable.model_validate(mapping) case SegmentType.OBJECT if isinstance(value, dict): - return ObjectVariable.model_validate( - {**m, 'value': {k: build_variable_from_mapping(v) for k, v in value.items()}} - ) + result = ObjectVariable.model_validate(mapping) case SegmentType.ARRAY_STRING if isinstance(value, list): - return ArrayStringVariable.model_validate({**m, 'value': [build_variable_from_mapping(v) for v in value]}) + result = ArrayStringVariable.model_validate(mapping) case SegmentType.ARRAY_NUMBER if isinstance(value, list): - return ArrayNumberVariable.model_validate({**m, 'value': [build_variable_from_mapping(v) for v in value]}) + result = ArrayNumberVariable.model_validate(mapping) case SegmentType.ARRAY_OBJECT if isinstance(value, list): - return ArrayObjectVariable.model_validate({**m, 'value': [build_variable_from_mapping(v) for v in value]}) + result = ArrayObjectVariable.model_validate(mapping) case SegmentType.ARRAY_FILE if isinstance(value, list): - return ArrayFileVariable.model_validate({**m, 'value': [build_variable_from_mapping(v) for v in value]}) - raise ValueError(f'not supported value type {value_type}') + mapping = dict(mapping) + mapping['value'] = [{'value': v} for v in value] + result = ArrayFileVariable.model_validate(mapping) + case _: + raise VariableError(f'not supported value type {value_type}') + if result.size > dify_config.MAX_VARIABLE_SIZE: + raise VariableError(f'variable size {result.size} exceeds limit {dify_config.MAX_VARIABLE_SIZE}') + return result def build_segment(value: Any, /) -> Segment: @@ -74,12 +80,9 @@ def build_segment(value: Any, /) -> Segment: if isinstance(value, float): return FloatSegment(value=value) if isinstance(value, dict): - # TODO: Limit the depth of the object return ObjectSegment(value=value) if isinstance(value, list): - # TODO: Limit the depth of the array - elements = [build_segment(v) for v in value] - return ArrayAnySegment(value=elements) + return ArrayAnySegment(value=value) if isinstance(value, FileVar): return FileSegment(value=value) raise ValueError(f'not supported value {value}') diff --git a/api/core/app/segments/segments.py b/api/core/app/segments/segments.py index 0001c5300..7653e1085 100644 --- a/api/core/app/segments/segments.py +++ b/api/core/app/segments/segments.py @@ -1,4 +1,5 @@ import json +import sys from collections.abc import Mapping, Sequence from typing import Any @@ -37,6 +38,10 @@ class Segment(BaseModel): def markdown(self) -> str: return str(self.value) + @property + def size(self) -> int: + return sys.getsizeof(self.value) + def to_object(self) -> Any: return self.value @@ -105,28 +110,25 @@ class ArraySegment(Segment): def markdown(self) -> str: return '\n'.join(['- ' + item.markdown for item in self.value]) - def to_object(self): - return [v.to_object() for v in self.value] - class ArrayAnySegment(ArraySegment): value_type: SegmentType = SegmentType.ARRAY_ANY - value: Sequence[Segment] + value: Sequence[Any] class ArrayStringSegment(ArraySegment): value_type: SegmentType = SegmentType.ARRAY_STRING - value: Sequence[StringSegment] + value: Sequence[str] class ArrayNumberSegment(ArraySegment): value_type: SegmentType = SegmentType.ARRAY_NUMBER - value: Sequence[FloatSegment | IntegerSegment] + value: Sequence[float | int] class ArrayObjectSegment(ArraySegment): value_type: SegmentType = SegmentType.ARRAY_OBJECT - value: Sequence[ObjectSegment] + value: Sequence[Mapping[str, Any]] class ArrayFileSegment(ArraySegment): diff --git a/api/core/file/file_obj.py b/api/core/file/file_obj.py index 268ef5df8..3959f4b4a 100644 --- a/api/core/file/file_obj.py +++ b/api/core/file/file_obj.py @@ -1,14 +1,19 @@ import enum -from typing import Optional +from typing import Any, Optional from pydantic import BaseModel -from core.app.app_config.entities import FileExtraConfig from core.file.tool_file_parser import ToolFileParser from core.file.upload_file_parser import UploadFileParser from core.model_runtime.entities.message_entities import ImagePromptMessageContent from extensions.ext_database import db -from models.model import UploadFile + + +class FileExtraConfig(BaseModel): + """ + File Upload Entity. + """ + image_config: Optional[dict[str, Any]] = None class FileType(enum.Enum): @@ -114,6 +119,7 @@ class FileVar(BaseModel): ) def _get_data(self, force_url: bool = False) -> Optional[str]: + from models.model import UploadFile if self.type == FileType.IMAGE: if self.transfer_method == FileTransferMethod.REMOTE_URL: return self.url diff --git a/api/core/file/message_file_parser.py b/api/core/file/message_file_parser.py index ec502b5e0..01b89907d 100644 --- a/api/core/file/message_file_parser.py +++ b/api/core/file/message_file_parser.py @@ -5,8 +5,7 @@ from urllib.parse import parse_qs, urlparse import requests -from core.app.app_config.entities import FileExtraConfig -from core.file.file_obj import FileBelongsTo, FileTransferMethod, FileType, FileVar +from core.file.file_obj import FileBelongsTo, FileExtraConfig, FileTransferMethod, FileType, FileVar from extensions.ext_database import db from models.account import Account from models.model import EndUser, MessageFile, UploadFile diff --git a/api/core/helper/encrypter.py b/api/core/helper/encrypter.py index bf87a842c..5e5deb86b 100644 --- a/api/core/helper/encrypter.py +++ b/api/core/helper/encrypter.py @@ -2,7 +2,6 @@ import base64 from extensions.ext_database import db from libs import rsa -from models.account import Tenant def obfuscated_token(token: str): @@ -14,6 +13,7 @@ def obfuscated_token(token: str): def encrypt_token(tenant_id: str, token: str): + from models.account import Tenant if not (tenant := db.session.query(Tenant).filter(Tenant.id == tenant_id).first()): raise ValueError(f'Tenant with id {tenant_id} not found') encrypted_token = rsa.encrypt(token, tenant.encrypt_public_key) diff --git a/api/core/workflow/entities/node_entities.py b/api/core/workflow/entities/node_entities.py index 996aae94c..0978b09b9 100644 --- a/api/core/workflow/entities/node_entities.py +++ b/api/core/workflow/entities/node_entities.py @@ -23,10 +23,12 @@ class NodeType(Enum): HTTP_REQUEST = 'http-request' TOOL = 'tool' VARIABLE_AGGREGATOR = 'variable-aggregator' + # TODO: merge this into VARIABLE_AGGREGATOR VARIABLE_ASSIGNER = 'variable-assigner' LOOP = 'loop' ITERATION = 'iteration' PARAMETER_EXTRACTOR = 'parameter-extractor' + CONVERSATION_VARIABLE_ASSIGNER = 'assigner' @classmethod def value_of(cls, value: str) -> 'NodeType': diff --git a/api/core/workflow/entities/variable_pool.py b/api/core/workflow/entities/variable_pool.py index a27b4261e..a96a26f79 100644 --- a/api/core/workflow/entities/variable_pool.py +++ b/api/core/workflow/entities/variable_pool.py @@ -13,6 +13,7 @@ VariableValue = Union[str, int, float, dict, list, FileVar] SYSTEM_VARIABLE_NODE_ID = 'sys' ENVIRONMENT_VARIABLE_NODE_ID = 'env' +CONVERSATION_VARIABLE_NODE_ID = 'conversation' class VariablePool: @@ -21,6 +22,7 @@ class VariablePool: system_variables: Mapping[SystemVariable, Any], user_inputs: Mapping[str, Any], environment_variables: Sequence[Variable], + conversation_variables: Sequence[Variable] | None = None, ) -> None: # system variables # for example: @@ -44,9 +46,13 @@ class VariablePool: self.add((SYSTEM_VARIABLE_NODE_ID, key.value), value) # Add environment variables to the variable pool - for var in environment_variables or []: + for var in environment_variables: self.add((ENVIRONMENT_VARIABLE_NODE_ID, var.name), var) + # Add conversation variables to the variable pool + for var in conversation_variables or []: + self.add((CONVERSATION_VARIABLE_NODE_ID, var.name), var) + def add(self, selector: Sequence[str], value: Any, /) -> None: """ Adds a variable to the variable pool. diff --git a/api/core/workflow/nodes/base_node.py b/api/core/workflow/nodes/base_node.py index d8c812e7e..3d9cf5277 100644 --- a/api/core/workflow/nodes/base_node.py +++ b/api/core/workflow/nodes/base_node.py @@ -8,6 +8,7 @@ from core.workflow.callbacks.base_workflow_callback import WorkflowCallback from core.workflow.entities.base_node_data_entities import BaseIterationState, BaseNodeData from core.workflow.entities.node_entities import NodeRunResult, NodeType from core.workflow.entities.variable_pool import VariablePool +from models import WorkflowNodeExecutionStatus class UserFrom(Enum): @@ -91,14 +92,19 @@ class BaseNode(ABC): :param variable_pool: variable pool :return: """ - result = self._run( - variable_pool=variable_pool - ) + try: + result = self._run( + variable_pool=variable_pool + ) + self.node_run_result = result + return result + except Exception as e: + return NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + error=str(e), + ) - self.node_run_result = result - return result - - def publish_text_chunk(self, text: str, value_selector: list[str] = None) -> None: + def publish_text_chunk(self, text: str, value_selector: list[str] | None = None) -> None: """ Publish text chunk :param text: chunk text diff --git a/api/core/workflow/nodes/variable_assigner/__init__.py b/api/core/workflow/nodes/variable_assigner/__init__.py new file mode 100644 index 000000000..552cc367f --- /dev/null +++ b/api/core/workflow/nodes/variable_assigner/__init__.py @@ -0,0 +1,109 @@ +from collections.abc import Sequence +from enum import Enum +from typing import Optional, cast + +from sqlalchemy import select +from sqlalchemy.orm import Session + +from core.app.segments import SegmentType, Variable, factory +from core.workflow.entities.base_node_data_entities import BaseNodeData +from core.workflow.entities.node_entities import NodeRunResult, NodeType +from core.workflow.entities.variable_pool import VariablePool +from core.workflow.nodes.base_node import BaseNode +from extensions.ext_database import db +from models import ConversationVariable, WorkflowNodeExecutionStatus + + +class VariableAssignerNodeError(Exception): + pass + + +class WriteMode(str, Enum): + OVER_WRITE = 'over-write' + APPEND = 'append' + CLEAR = 'clear' + + +class VariableAssignerData(BaseNodeData): + title: str = 'Variable Assigner' + desc: Optional[str] = 'Assign a value to a variable' + assigned_variable_selector: Sequence[str] + write_mode: WriteMode + input_variable_selector: Sequence[str] + + +class VariableAssignerNode(BaseNode): + _node_data_cls: type[BaseNodeData] = VariableAssignerData + _node_type: NodeType = NodeType.CONVERSATION_VARIABLE_ASSIGNER + + def _run(self, variable_pool: VariablePool) -> NodeRunResult: + data = cast(VariableAssignerData, self.node_data) + + # Should be String, Number, Object, ArrayString, ArrayNumber, ArrayObject + original_variable = variable_pool.get(data.assigned_variable_selector) + if not isinstance(original_variable, Variable): + raise VariableAssignerNodeError('assigned variable not found') + + match data.write_mode: + case WriteMode.OVER_WRITE: + income_value = variable_pool.get(data.input_variable_selector) + if not income_value: + raise VariableAssignerNodeError('input value not found') + updated_variable = original_variable.model_copy(update={'value': income_value.value}) + + case WriteMode.APPEND: + income_value = variable_pool.get(data.input_variable_selector) + if not income_value: + raise VariableAssignerNodeError('input value not found') + updated_value = original_variable.value + [income_value.value] + updated_variable = original_variable.model_copy(update={'value': updated_value}) + + case WriteMode.CLEAR: + income_value = get_zero_value(original_variable.value_type) + updated_variable = original_variable.model_copy(update={'value': income_value.to_object()}) + + case _: + raise VariableAssignerNodeError(f'unsupported write mode: {data.write_mode}') + + # Over write the variable. + variable_pool.add(data.assigned_variable_selector, updated_variable) + + # Update conversation variable. + # TODO: Find a better way to use the database. + conversation_id = variable_pool.get(['sys', 'conversation_id']) + if not conversation_id: + raise VariableAssignerNodeError('conversation_id not found') + update_conversation_variable(conversation_id=conversation_id.text, variable=updated_variable) + + return NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + inputs={ + 'value': income_value.to_object(), + }, + ) + + +def update_conversation_variable(conversation_id: str, variable: Variable): + stmt = select(ConversationVariable).where( + ConversationVariable.id == variable.id, ConversationVariable.conversation_id == conversation_id + ) + with Session(db.engine) as session: + row = session.scalar(stmt) + if not row: + raise VariableAssignerNodeError('conversation variable not found in the database') + row.data = variable.model_dump_json() + session.commit() + + +def get_zero_value(t: SegmentType): + match t: + case SegmentType.ARRAY_OBJECT | SegmentType.ARRAY_STRING | SegmentType.ARRAY_NUMBER: + return factory.build_segment([]) + case SegmentType.OBJECT: + return factory.build_segment({}) + case SegmentType.STRING: + return factory.build_segment('') + case SegmentType.NUMBER: + return factory.build_segment(0) + case _: + raise VariableAssignerNodeError(f'unsupported variable type: {t}') diff --git a/api/core/workflow/workflow_engine_manager.py b/api/core/workflow/workflow_engine_manager.py index bd2b3eafa..f299f84ef 100644 --- a/api/core/workflow/workflow_engine_manager.py +++ b/api/core/workflow/workflow_engine_manager.py @@ -4,12 +4,11 @@ from collections.abc import Mapping, Sequence from typing import Any, Optional, cast from configs import dify_config -from core.app.app_config.entities import FileExtraConfig from core.app.apps.base_app_queue_manager import GenerateTaskStoppedException from core.app.entities.app_invoke_entities import InvokeFrom -from core.file.file_obj import FileTransferMethod, FileType, FileVar +from core.file.file_obj import FileExtraConfig, FileTransferMethod, FileType, FileVar from core.workflow.callbacks.base_workflow_callback import WorkflowCallback -from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType, SystemVariable +from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType from core.workflow.entities.variable_pool import VariablePool, VariableValue from core.workflow.entities.workflow_entities import WorkflowNodeAndResult, WorkflowRunState from core.workflow.errors import WorkflowNodeRunFailedError @@ -30,6 +29,7 @@ from core.workflow.nodes.start.start_node import StartNode from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode from core.workflow.nodes.tool.tool_node import ToolNode from core.workflow.nodes.variable_aggregator.variable_aggregator_node import VariableAggregatorNode +from core.workflow.nodes.variable_assigner import VariableAssignerNode from extensions.ext_database import db from models.workflow import ( Workflow, @@ -51,7 +51,8 @@ node_classes: Mapping[NodeType, type[BaseNode]] = { NodeType.VARIABLE_AGGREGATOR: VariableAggregatorNode, NodeType.VARIABLE_ASSIGNER: VariableAggregatorNode, NodeType.ITERATION: IterationNode, - NodeType.PARAMETER_EXTRACTOR: ParameterExtractorNode + NodeType.PARAMETER_EXTRACTOR: ParameterExtractorNode, + NodeType.CONVERSATION_VARIABLE_ASSIGNER: VariableAssignerNode, } logger = logging.getLogger(__name__) @@ -94,10 +95,9 @@ class WorkflowEngineManager: user_id: str, user_from: UserFrom, invoke_from: InvokeFrom, - user_inputs: Mapping[str, Any], - system_inputs: Mapping[SystemVariable, Any], callbacks: Sequence[WorkflowCallback], - call_depth: int = 0 + call_depth: int = 0, + variable_pool: VariablePool, ) -> None: """ :param workflow: Workflow instance @@ -122,12 +122,6 @@ class WorkflowEngineManager: if not isinstance(graph.get('edges'), list): raise ValueError('edges in workflow graph must be a list') - # init variable pool - variable_pool = VariablePool( - system_variables=system_inputs, - user_inputs=user_inputs, - environment_variables=workflow.environment_variables, - ) workflow_call_max_depth = dify_config.WORKFLOW_CALL_MAX_DEPTH if call_depth > workflow_call_max_depth: @@ -403,6 +397,7 @@ class WorkflowEngineManager: system_variables={}, user_inputs={}, environment_variables=workflow.environment_variables, + conversation_variables=workflow.conversation_variables, ) if node_cls is None: @@ -468,6 +463,7 @@ class WorkflowEngineManager: system_variables={}, user_inputs={}, environment_variables=workflow.environment_variables, + conversation_variables=workflow.conversation_variables, ) # variable selector to variable mapping diff --git a/api/fields/conversation_variable_fields.py b/api/fields/conversation_variable_fields.py new file mode 100644 index 000000000..782a848c1 --- /dev/null +++ b/api/fields/conversation_variable_fields.py @@ -0,0 +1,21 @@ +from flask_restful import fields + +from libs.helper import TimestampField + +conversation_variable_fields = { + 'id': fields.String, + 'name': fields.String, + 'value_type': fields.String(attribute='value_type.value'), + 'value': fields.String, + 'description': fields.String, + 'created_at': TimestampField, + 'updated_at': TimestampField, +} + +paginated_conversation_variable_fields = { + 'page': fields.Integer, + 'limit': fields.Integer, + 'total': fields.Integer, + 'has_more': fields.Boolean, + 'data': fields.List(fields.Nested(conversation_variable_fields), attribute='data'), +} diff --git a/api/fields/workflow_fields.py b/api/fields/workflow_fields.py index ff33a97ff..c1dd0e184 100644 --- a/api/fields/workflow_fields.py +++ b/api/fields/workflow_fields.py @@ -32,11 +32,12 @@ class EnvironmentVariableField(fields.Raw): return value -environment_variable_fields = { +conversation_variable_fields = { 'id': fields.String, 'name': fields.String, - 'value': fields.Raw, 'value_type': fields.String(attribute='value_type.value'), + 'value': fields.Raw, + 'description': fields.String, } workflow_fields = { @@ -50,4 +51,5 @@ workflow_fields = { 'updated_at': TimestampField, 'tool_published': fields.Boolean, 'environment_variables': fields.List(EnvironmentVariableField()), + 'conversation_variables': fields.List(fields.Nested(conversation_variable_fields)), } diff --git a/api/migrations/versions/2024_08_13_0633-63a83fcf12ba_support_conversation_variables.py b/api/migrations/versions/2024_08_13_0633-63a83fcf12ba_support_conversation_variables.py new file mode 100644 index 000000000..16e1efd4e --- /dev/null +++ b/api/migrations/versions/2024_08_13_0633-63a83fcf12ba_support_conversation_variables.py @@ -0,0 +1,51 @@ +"""support conversation variables + +Revision ID: 63a83fcf12ba +Revises: 1787fbae959a +Create Date: 2024-08-13 06:33:07.950379 + +""" +import sqlalchemy as sa +from alembic import op + +import models as models + +# revision identifiers, used by Alembic. +revision = '63a83fcf12ba' +down_revision = '1787fbae959a' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('workflow__conversation_variables', + sa.Column('id', models.types.StringUUID(), nullable=False), + sa.Column('conversation_id', models.types.StringUUID(), nullable=False), + sa.Column('app_id', models.types.StringUUID(), nullable=False), + sa.Column('data', sa.Text(), nullable=False), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False), + sa.PrimaryKeyConstraint('id', 'conversation_id', name=op.f('workflow__conversation_variables_pkey')) + ) + with op.batch_alter_table('workflow__conversation_variables', schema=None) as batch_op: + batch_op.create_index(batch_op.f('workflow__conversation_variables_app_id_idx'), ['app_id'], unique=False) + batch_op.create_index(batch_op.f('workflow__conversation_variables_created_at_idx'), ['created_at'], unique=False) + + with op.batch_alter_table('workflows', schema=None) as batch_op: + batch_op.add_column(sa.Column('conversation_variables', sa.Text(), server_default='{}', nullable=False)) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('workflows', schema=None) as batch_op: + batch_op.drop_column('conversation_variables') + + with op.batch_alter_table('workflow__conversation_variables', schema=None) as batch_op: + batch_op.drop_index(batch_op.f('workflow__conversation_variables_created_at_idx')) + batch_op.drop_index(batch_op.f('workflow__conversation_variables_app_id_idx')) + + op.drop_table('workflow__conversation_variables') + # ### end Alembic commands ### diff --git a/api/models/__init__.py b/api/models/__init__.py index 3b832cd22..f83135684 100644 --- a/api/models/__init__.py +++ b/api/models/__init__.py @@ -1,15 +1,19 @@ from enum import Enum -from sqlalchemy import CHAR, TypeDecorator -from sqlalchemy.dialects.postgresql import UUID +from .model import AppMode +from .types import StringUUID +from .workflow import ConversationVariable, WorkflowNodeExecutionStatus + +__all__ = ['ConversationVariable', 'StringUUID', 'AppMode', 'WorkflowNodeExecutionStatus'] class CreatedByRole(Enum): """ Enum class for createdByRole """ - ACCOUNT = "account" - END_USER = "end_user" + + ACCOUNT = 'account' + END_USER = 'end_user' @classmethod def value_of(cls, value: str) -> 'CreatedByRole': @@ -23,49 +27,3 @@ class CreatedByRole(Enum): if role.value == value: return role raise ValueError(f'invalid createdByRole value {value}') - - -class CreatedFrom(Enum): - """ - Enum class for createdFrom - """ - SERVICE_API = "service-api" - WEB_APP = "web-app" - EXPLORE = "explore" - - @classmethod - def value_of(cls, value: str) -> 'CreatedFrom': - """ - Get value of given mode. - - :param value: mode value - :return: mode - """ - for role in cls: - if role.value == value: - return role - raise ValueError(f'invalid createdFrom value {value}') - - -class StringUUID(TypeDecorator): - impl = CHAR - cache_ok = True - - def process_bind_param(self, value, dialect): - if value is None: - return value - elif dialect.name == 'postgresql': - return str(value) - else: - return value.hex - - def load_dialect_impl(self, dialect): - if dialect.name == 'postgresql': - return dialect.type_descriptor(UUID()) - else: - return dialect.type_descriptor(CHAR(36)) - - def process_result_value(self, value, dialect): - if value is None: - return value - return str(value) diff --git a/api/models/account.py b/api/models/account.py index d36b2b9fd..67d940b7b 100644 --- a/api/models/account.py +++ b/api/models/account.py @@ -4,7 +4,8 @@ import json from flask_login import UserMixin from extensions.ext_database import db -from models import StringUUID + +from .types import StringUUID class AccountStatus(str, enum.Enum): diff --git a/api/models/api_based_extension.py b/api/models/api_based_extension.py index d1f9cd78a..7f6932362 100644 --- a/api/models/api_based_extension.py +++ b/api/models/api_based_extension.py @@ -1,7 +1,8 @@ import enum from extensions.ext_database import db -from models import StringUUID + +from .types import StringUUID class APIBasedExtensionPoint(enum.Enum): diff --git a/api/models/dataset.py b/api/models/dataset.py index 40f9f4cf8..0d48177eb 100644 --- a/api/models/dataset.py +++ b/api/models/dataset.py @@ -16,9 +16,10 @@ from configs import dify_config from core.rag.retrieval.retrival_methods import RetrievalMethod from extensions.ext_database import db from extensions.ext_storage import storage -from models import StringUUID -from models.account import Account -from models.model import App, Tag, TagBinding, UploadFile + +from .account import Account +from .model import App, Tag, TagBinding, UploadFile +from .types import StringUUID class Dataset(db.Model): diff --git a/api/models/model.py b/api/models/model.py index 5afaf60b2..9909b10dc 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -14,8 +14,8 @@ from core.file.upload_file_parser import UploadFileParser from extensions.ext_database import db from libs.helper import generate_string -from . import StringUUID from .account import Account, Tenant +from .types import StringUUID class DifySetup(db.Model): diff --git a/api/models/provider.py b/api/models/provider.py index 4c14c33f0..5d92ee6eb 100644 --- a/api/models/provider.py +++ b/api/models/provider.py @@ -1,7 +1,8 @@ from enum import Enum from extensions.ext_database import db -from models import StringUUID + +from .types import StringUUID class ProviderType(Enum): diff --git a/api/models/source.py b/api/models/source.py index 265e68f01..adc00028b 100644 --- a/api/models/source.py +++ b/api/models/source.py @@ -3,7 +3,8 @@ import json from sqlalchemy.dialects.postgresql import JSONB from extensions.ext_database import db -from models import StringUUID + +from .types import StringUUID class DataSourceOauthBinding(db.Model): diff --git a/api/models/tool.py b/api/models/tool.py index f322944f5..79a70c6b1 100644 --- a/api/models/tool.py +++ b/api/models/tool.py @@ -2,7 +2,8 @@ import json from enum import Enum from extensions.ext_database import db -from models import StringUUID + +from .types import StringUUID class ToolProviderName(Enum): diff --git a/api/models/tools.py b/api/models/tools.py index 695ec26fb..069dc5bad 100644 --- a/api/models/tools.py +++ b/api/models/tools.py @@ -6,8 +6,9 @@ from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_bundle import ApiToolBundle from core.tools.entities.tool_entities import ApiProviderSchemaType, WorkflowToolParameterConfiguration from extensions.ext_database import db -from models import StringUUID -from models.model import Account, App, Tenant + +from .model import Account, App, Tenant +from .types import StringUUID class BuiltinToolProvider(db.Model): diff --git a/api/models/types.py b/api/models/types.py new file mode 100644 index 000000000..1614ec201 --- /dev/null +++ b/api/models/types.py @@ -0,0 +1,26 @@ +from sqlalchemy import CHAR, TypeDecorator +from sqlalchemy.dialects.postgresql import UUID + + +class StringUUID(TypeDecorator): + impl = CHAR + cache_ok = True + + def process_bind_param(self, value, dialect): + if value is None: + return value + elif dialect.name == 'postgresql': + return str(value) + else: + return value.hex + + def load_dialect_impl(self, dialect): + if dialect.name == 'postgresql': + return dialect.type_descriptor(UUID()) + else: + return dialect.type_descriptor(CHAR(36)) + + def process_result_value(self, value, dialect): + if value is None: + return value + return str(value) \ No newline at end of file diff --git a/api/models/web.py b/api/models/web.py index 6fd27206a..0e901d5f8 100644 --- a/api/models/web.py +++ b/api/models/web.py @@ -1,7 +1,8 @@ from extensions.ext_database import db -from models import StringUUID -from models.model import Message + +from .model import Message +from .types import StringUUID class SavedMessage(db.Model): diff --git a/api/models/workflow.py b/api/models/workflow.py index df2269cd0..759e07c71 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -3,18 +3,18 @@ from collections.abc import Mapping, Sequence from enum import Enum from typing import Any, Optional, Union +from sqlalchemy import func +from sqlalchemy.orm import Mapped + import contexts from constants import HIDDEN_VALUE -from core.app.segments import ( - SecretVariable, - Variable, - factory, -) +from core.app.segments import SecretVariable, Variable, factory from core.helper import encrypter from extensions.ext_database import db from libs import helper -from models import StringUUID -from models.account import Account + +from .account import Account +from .types import StringUUID class CreatedByRole(Enum): @@ -122,6 +122,7 @@ class Workflow(db.Model): updated_by = db.Column(StringUUID) updated_at = db.Column(db.DateTime) _environment_variables = db.Column('environment_variables', db.Text, nullable=False, server_default='{}') + _conversation_variables = db.Column('conversation_variables', db.Text, nullable=False, server_default='{}') @property def created_by_account(self): @@ -249,9 +250,27 @@ class Workflow(db.Model): 'graph': self.graph_dict, 'features': self.features_dict, 'environment_variables': [var.model_dump(mode='json') for var in environment_variables], + 'conversation_variables': [var.model_dump(mode='json') for var in self.conversation_variables], } return result + @property + def conversation_variables(self) -> Sequence[Variable]: + # TODO: find some way to init `self._conversation_variables` when instance created. + if self._conversation_variables is None: + self._conversation_variables = '{}' + + variables_dict: dict[str, Any] = json.loads(self._conversation_variables) + results = [factory.build_variable_from_mapping(v) for v in variables_dict.values()] + return results + + @conversation_variables.setter + def conversation_variables(self, value: Sequence[Variable]) -> None: + self._conversation_variables = json.dumps( + {var.name: var.model_dump() for var in value}, + ensure_ascii=False, + ) + class WorkflowRunTriggeredFrom(Enum): """ @@ -702,3 +721,34 @@ class WorkflowAppLog(db.Model): created_by_role = CreatedByRole.value_of(self.created_by_role) return db.session.get(EndUser, self.created_by) \ if created_by_role == CreatedByRole.END_USER else None + + +class ConversationVariable(db.Model): + __tablename__ = 'workflow__conversation_variables' + + id: Mapped[str] = db.Column(StringUUID, primary_key=True) + conversation_id: Mapped[str] = db.Column(StringUUID, nullable=False, primary_key=True) + app_id: Mapped[str] = db.Column(StringUUID, nullable=False, index=True) + data = db.Column(db.Text, nullable=False) + created_at = db.Column(db.DateTime, nullable=False, index=True, server_default=db.text('CURRENT_TIMESTAMP(0)')) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()) + + def __init__(self, *, id: str, app_id: str, conversation_id: str, data: str) -> None: + self.id = id + self.app_id = app_id + self.conversation_id = conversation_id + self.data = data + + @classmethod + def from_variable(cls, *, app_id: str, conversation_id: str, variable: Variable) -> 'ConversationVariable': + obj = cls( + id=variable.id, + app_id=app_id, + conversation_id=conversation_id, + data=variable.model_dump_json(), + ) + return obj + + def to_variable(self) -> Variable: + mapping = json.loads(self.data) + return factory.build_variable_from_mapping(mapping) diff --git a/api/services/app_dsl_service.py b/api/services/app_dsl_service.py index 9d78037c3..e16e5c715 100644 --- a/api/services/app_dsl_service.py +++ b/api/services/app_dsl_service.py @@ -238,6 +238,8 @@ class AppDslService: # init draft workflow environment_variables_list = workflow_data.get('environment_variables') or [] environment_variables = [factory.build_variable_from_mapping(obj) for obj in environment_variables_list] + conversation_variables_list = workflow_data.get('conversation_variables') or [] + conversation_variables = [factory.build_variable_from_mapping(obj) for obj in conversation_variables_list] workflow_service = WorkflowService() draft_workflow = workflow_service.sync_draft_workflow( app_model=app, @@ -246,6 +248,7 @@ class AppDslService: unique_hash=None, account=account, environment_variables=environment_variables, + conversation_variables=conversation_variables, ) workflow_service.publish_workflow( app_model=app, diff --git a/api/services/workflow/workflow_converter.py b/api/services/workflow/workflow_converter.py index 06b129be6..f99360829 100644 --- a/api/services/workflow/workflow_converter.py +++ b/api/services/workflow/workflow_converter.py @@ -6,7 +6,6 @@ from core.app.app_config.entities import ( DatasetRetrieveConfigEntity, EasyUIBasedAppConfig, ExternalDataVariableEntity, - FileExtraConfig, ModelConfigEntity, PromptTemplateEntity, VariableEntity, @@ -14,6 +13,7 @@ from core.app.app_config.entities import ( from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfigManager from core.app.apps.chat.app_config_manager import ChatAppConfigManager from core.app.apps.completion.app_config_manager import CompletionAppConfigManager +from core.file.file_obj import FileExtraConfig from core.helper import encrypter from core.model_runtime.entities.llm_entities import LLMMode from core.model_runtime.utils.encoders import jsonable_encoder diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index 38de538a1..2defb4cd6 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -72,6 +72,7 @@ class WorkflowService: unique_hash: Optional[str], account: Account, environment_variables: Sequence[Variable], + conversation_variables: Sequence[Variable], ) -> Workflow: """ Sync draft workflow @@ -99,7 +100,8 @@ class WorkflowService: graph=json.dumps(graph), features=json.dumps(features), created_by=account.id, - environment_variables=environment_variables + environment_variables=environment_variables, + conversation_variables=conversation_variables, ) db.session.add(workflow) # update draft workflow if found @@ -109,6 +111,7 @@ class WorkflowService: workflow.updated_by = account.id workflow.updated_at = datetime.now(timezone.utc).replace(tzinfo=None) workflow.environment_variables = environment_variables + workflow.conversation_variables = conversation_variables # commit db session changes db.session.commit() @@ -145,7 +148,8 @@ class WorkflowService: graph=draft_workflow.graph, features=draft_workflow.features, created_by=account.id, - environment_variables=draft_workflow.environment_variables + environment_variables=draft_workflow.environment_variables, + conversation_variables=draft_workflow.conversation_variables, ) # commit db session changes @@ -336,8 +340,8 @@ class WorkflowService: ) if not workflow_nodes: return elapsed_time - + for node in workflow_nodes: elapsed_time += node.elapsed_time - return elapsed_time \ No newline at end of file + return elapsed_time diff --git a/api/tasks/remove_app_and_related_data_task.py b/api/tasks/remove_app_and_related_data_task.py index 378756e68..4efe7ee38 100644 --- a/api/tasks/remove_app_and_related_data_task.py +++ b/api/tasks/remove_app_and_related_data_task.py @@ -1,8 +1,10 @@ import logging import time +from collections.abc import Callable import click from celery import shared_task +from sqlalchemy import delete from sqlalchemy.exc import SQLAlchemyError from extensions.ext_database import db @@ -28,7 +30,7 @@ from models.model import ( ) from models.tools import WorkflowToolProvider from models.web import PinnedConversation, SavedMessage -from models.workflow import Workflow, WorkflowAppLog, WorkflowNodeExecution, WorkflowRun +from models.workflow import ConversationVariable, Workflow, WorkflowAppLog, WorkflowNodeExecution, WorkflowRun @shared_task(queue='app_deletion', bind=True, max_retries=3) @@ -54,6 +56,7 @@ def remove_app_and_related_data_task(self, tenant_id: str, app_id: str): _delete_app_tag_bindings(tenant_id, app_id) _delete_end_users(tenant_id, app_id) _delete_trace_app_configs(tenant_id, app_id) + _delete_conversation_variables(app_id=app_id) end_at = time.perf_counter() logging.info(click.style(f'App and related data deleted: {app_id} latency: {end_at - start_at}', fg='green')) @@ -225,6 +228,13 @@ def _delete_app_conversations(tenant_id: str, app_id: str): "conversation" ) +def _delete_conversation_variables(*, app_id: str): + stmt = delete(ConversationVariable).where(ConversationVariable.app_id == app_id) + with db.engine.connect() as conn: + conn.execute(stmt) + conn.commit() + logging.info(click.style(f"Deleted conversation variables for app {app_id}", fg='green')) + def _delete_app_messages(tenant_id: str, app_id: str): def del_message(message_id: str): @@ -299,7 +309,7 @@ def _delete_trace_app_configs(tenant_id: str, app_id: str): ) -def _delete_records(query_sql: str, params: dict, delete_func: callable, name: str) -> None: +def _delete_records(query_sql: str, params: dict, delete_func: Callable, name: str) -> None: while True: with db.engine.begin() as conn: rs = conn.execute(db.text(query_sql), params) diff --git a/api/tests/unit_tests/core/app/segments/test_factory.py b/api/tests/unit_tests/core/app/segments/test_factory.py index a88dd939b..a8429b9c1 100644 --- a/api/tests/unit_tests/core/app/segments/test_factory.py +++ b/api/tests/unit_tests/core/app/segments/test_factory.py @@ -7,15 +7,16 @@ from core.app.segments import ( ArrayNumberVariable, ArrayObjectVariable, ArrayStringVariable, + FileSegment, FileVariable, FloatVariable, IntegerVariable, - NoneSegment, ObjectSegment, SecretVariable, StringVariable, factory, ) +from core.app.segments.exc import VariableError def test_string_variable(): @@ -44,7 +45,7 @@ def test_secret_variable(): def test_invalid_value_type(): test_data = {'value_type': 'unknown', 'name': 'test_invalid', 'value': 'value'} - with pytest.raises(ValueError): + with pytest.raises(VariableError): factory.build_variable_from_mapping(test_data) @@ -77,26 +78,14 @@ def test_object_variable(): 'name': 'test_object', 'description': 'Description of the variable.', 'value': { - 'key1': { - 'id': str(uuid4()), - 'value_type': 'string', - 'name': 'text', - 'value': 'text', - 'description': 'Description of the variable.', - }, - 'key2': { - 'id': str(uuid4()), - 'value_type': 'number', - 'name': 'number', - 'value': 1, - 'description': 'Description of the variable.', - }, + 'key1': 'text', + 'key2': 2, }, } variable = factory.build_variable_from_mapping(mapping) assert isinstance(variable, ObjectSegment) - assert isinstance(variable.value['key1'], StringVariable) - assert isinstance(variable.value['key2'], IntegerVariable) + assert isinstance(variable.value['key1'], str) + assert isinstance(variable.value['key2'], int) def test_array_string_variable(): @@ -106,26 +95,14 @@ def test_array_string_variable(): 'name': 'test_array', 'description': 'Description of the variable.', 'value': [ - { - 'id': str(uuid4()), - 'value_type': 'string', - 'name': 'text', - 'value': 'text', - 'description': 'Description of the variable.', - }, - { - 'id': str(uuid4()), - 'value_type': 'string', - 'name': 'text', - 'value': 'text', - 'description': 'Description of the variable.', - }, + 'text', + 'text', ], } variable = factory.build_variable_from_mapping(mapping) assert isinstance(variable, ArrayStringVariable) - assert isinstance(variable.value[0], StringVariable) - assert isinstance(variable.value[1], StringVariable) + assert isinstance(variable.value[0], str) + assert isinstance(variable.value[1], str) def test_array_number_variable(): @@ -135,26 +112,14 @@ def test_array_number_variable(): 'name': 'test_array', 'description': 'Description of the variable.', 'value': [ - { - 'id': str(uuid4()), - 'value_type': 'number', - 'name': 'number', - 'value': 1, - 'description': 'Description of the variable.', - }, - { - 'id': str(uuid4()), - 'value_type': 'number', - 'name': 'number', - 'value': 2.0, - 'description': 'Description of the variable.', - }, + 1, + 2.0, ], } variable = factory.build_variable_from_mapping(mapping) assert isinstance(variable, ArrayNumberVariable) - assert isinstance(variable.value[0], IntegerVariable) - assert isinstance(variable.value[1], FloatVariable) + assert isinstance(variable.value[0], int) + assert isinstance(variable.value[1], float) def test_array_object_variable(): @@ -165,59 +130,23 @@ def test_array_object_variable(): 'description': 'Description of the variable.', 'value': [ { - 'id': str(uuid4()), - 'value_type': 'object', - 'name': 'object', - 'description': 'Description of the variable.', - 'value': { - 'key1': { - 'id': str(uuid4()), - 'value_type': 'string', - 'name': 'text', - 'value': 'text', - 'description': 'Description of the variable.', - }, - 'key2': { - 'id': str(uuid4()), - 'value_type': 'number', - 'name': 'number', - 'value': 1, - 'description': 'Description of the variable.', - }, - }, + 'key1': 'text', + 'key2': 1, }, { - 'id': str(uuid4()), - 'value_type': 'object', - 'name': 'object', - 'description': 'Description of the variable.', - 'value': { - 'key1': { - 'id': str(uuid4()), - 'value_type': 'string', - 'name': 'text', - 'value': 'text', - 'description': 'Description of the variable.', - }, - 'key2': { - 'id': str(uuid4()), - 'value_type': 'number', - 'name': 'number', - 'value': 1, - 'description': 'Description of the variable.', - }, - }, + 'key1': 'text', + 'key2': 1, }, ], } variable = factory.build_variable_from_mapping(mapping) assert isinstance(variable, ArrayObjectVariable) - assert isinstance(variable.value[0], ObjectSegment) - assert isinstance(variable.value[1], ObjectSegment) - assert isinstance(variable.value[0].value['key1'], StringVariable) - assert isinstance(variable.value[0].value['key2'], IntegerVariable) - assert isinstance(variable.value[1].value['key1'], StringVariable) - assert isinstance(variable.value[1].value['key2'], IntegerVariable) + assert isinstance(variable.value[0], dict) + assert isinstance(variable.value[1], dict) + assert isinstance(variable.value[0]['key1'], str) + assert isinstance(variable.value[0]['key2'], int) + assert isinstance(variable.value[1]['key1'], str) + assert isinstance(variable.value[1]['key2'], int) def test_file_variable(): @@ -257,51 +186,53 @@ def test_array_file_variable(): 'value': [ { 'id': str(uuid4()), - 'name': 'file', - 'value_type': 'file', - 'value': { - 'id': str(uuid4()), - 'tenant_id': 'tenant_id', - 'type': 'image', - 'transfer_method': 'local_file', - 'url': 'url', - 'related_id': 'related_id', - 'extra_config': { - 'image_config': { - 'width': 100, - 'height': 100, - }, + 'tenant_id': 'tenant_id', + 'type': 'image', + 'transfer_method': 'local_file', + 'url': 'url', + 'related_id': 'related_id', + 'extra_config': { + 'image_config': { + 'width': 100, + 'height': 100, }, - 'filename': 'filename', - 'extension': 'extension', - 'mime_type': 'mime_type', }, + 'filename': 'filename', + 'extension': 'extension', + 'mime_type': 'mime_type', }, { 'id': str(uuid4()), - 'name': 'file', - 'value_type': 'file', - 'value': { - 'id': str(uuid4()), - 'tenant_id': 'tenant_id', - 'type': 'image', - 'transfer_method': 'local_file', - 'url': 'url', - 'related_id': 'related_id', - 'extra_config': { - 'image_config': { - 'width': 100, - 'height': 100, - }, + 'tenant_id': 'tenant_id', + 'type': 'image', + 'transfer_method': 'local_file', + 'url': 'url', + 'related_id': 'related_id', + 'extra_config': { + 'image_config': { + 'width': 100, + 'height': 100, }, - 'filename': 'filename', - 'extension': 'extension', - 'mime_type': 'mime_type', }, + 'filename': 'filename', + 'extension': 'extension', + 'mime_type': 'mime_type', }, ], } variable = factory.build_variable_from_mapping(mapping) assert isinstance(variable, ArrayFileVariable) - assert isinstance(variable.value[0], FileVariable) - assert isinstance(variable.value[1], FileVariable) + assert isinstance(variable.value[0], FileSegment) + assert isinstance(variable.value[1], FileSegment) + + +def test_variable_cannot_large_than_5_kb(): + with pytest.raises(VariableError): + factory.build_variable_from_mapping( + { + 'id': str(uuid4()), + 'value_type': 'string', + 'name': 'test_text', + 'value': 'a' * 1024 * 6, + } + ) diff --git a/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py b/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py index fd284488b..d24cd4aae 100644 --- a/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py +++ b/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py @@ -2,8 +2,8 @@ from unittest.mock import MagicMock import pytest -from core.app.app_config.entities import FileExtraConfig, ModelConfigEntity -from core.file.file_obj import FileTransferMethod, FileType, FileVar +from core.app.app_config.entities import ModelConfigEntity +from core.file.file_obj import FileExtraConfig, FileTransferMethod, FileType, FileVar from core.memory.token_buffer_memory import TokenBufferMemory from core.model_runtime.entities.message_entities import AssistantPromptMessage, PromptMessageRole, UserPromptMessage from core.prompt.advanced_prompt_transform import AdvancedPromptTransform diff --git a/api/tests/unit_tests/core/workflow/nodes/test_variable_assigner.py b/api/tests/unit_tests/core/workflow/nodes/test_variable_assigner.py new file mode 100644 index 000000000..8706ba05c --- /dev/null +++ b/api/tests/unit_tests/core/workflow/nodes/test_variable_assigner.py @@ -0,0 +1,150 @@ +from unittest import mock +from uuid import uuid4 + +from core.app.entities.app_invoke_entities import InvokeFrom +from core.app.segments import ArrayStringVariable, StringVariable +from core.workflow.entities.node_entities import SystemVariable +from core.workflow.entities.variable_pool import VariablePool +from core.workflow.nodes.base_node import UserFrom +from core.workflow.nodes.variable_assigner import VariableAssignerNode, WriteMode + +DEFAULT_NODE_ID = 'node_id' + + +def test_overwrite_string_variable(): + conversation_variable = StringVariable( + id=str(uuid4()), + name='test_conversation_variable', + value='the first value', + ) + + input_variable = StringVariable( + id=str(uuid4()), + name='test_string_variable', + value='the second value', + ) + + node = VariableAssignerNode( + tenant_id='tenant_id', + app_id='app_id', + workflow_id='workflow_id', + user_id='user_id', + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + config={ + 'id': 'node_id', + 'data': { + 'assigned_variable_selector': ['conversation', conversation_variable.name], + 'write_mode': WriteMode.OVER_WRITE.value, + 'input_variable_selector': [DEFAULT_NODE_ID, input_variable.name], + }, + }, + ) + + variable_pool = VariablePool( + system_variables={SystemVariable.CONVERSATION_ID: 'conversation_id'}, + user_inputs={}, + environment_variables=[], + conversation_variables=[conversation_variable], + ) + variable_pool.add( + [DEFAULT_NODE_ID, input_variable.name], + input_variable, + ) + + with mock.patch('core.workflow.nodes.variable_assigner.update_conversation_variable') as mock_run: + node.run(variable_pool) + mock_run.assert_called_once() + + got = variable_pool.get(['conversation', conversation_variable.name]) + assert got is not None + assert got.value == 'the second value' + assert got.to_object() == 'the second value' + + +def test_append_variable_to_array(): + conversation_variable = ArrayStringVariable( + id=str(uuid4()), + name='test_conversation_variable', + value=['the first value'], + ) + + input_variable = StringVariable( + id=str(uuid4()), + name='test_string_variable', + value='the second value', + ) + + node = VariableAssignerNode( + tenant_id='tenant_id', + app_id='app_id', + workflow_id='workflow_id', + user_id='user_id', + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + config={ + 'id': 'node_id', + 'data': { + 'assigned_variable_selector': ['conversation', conversation_variable.name], + 'write_mode': WriteMode.APPEND.value, + 'input_variable_selector': [DEFAULT_NODE_ID, input_variable.name], + }, + }, + ) + + variable_pool = VariablePool( + system_variables={SystemVariable.CONVERSATION_ID: 'conversation_id'}, + user_inputs={}, + environment_variables=[], + conversation_variables=[conversation_variable], + ) + variable_pool.add( + [DEFAULT_NODE_ID, input_variable.name], + input_variable, + ) + + with mock.patch('core.workflow.nodes.variable_assigner.update_conversation_variable') as mock_run: + node.run(variable_pool) + mock_run.assert_called_once() + + got = variable_pool.get(['conversation', conversation_variable.name]) + assert got is not None + assert got.to_object() == ['the first value', 'the second value'] + + +def test_clear_array(): + conversation_variable = ArrayStringVariable( + id=str(uuid4()), + name='test_conversation_variable', + value=['the first value'], + ) + + node = VariableAssignerNode( + tenant_id='tenant_id', + app_id='app_id', + workflow_id='workflow_id', + user_id='user_id', + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + config={ + 'id': 'node_id', + 'data': { + 'assigned_variable_selector': ['conversation', conversation_variable.name], + 'write_mode': WriteMode.CLEAR.value, + 'input_variable_selector': [], + }, + }, + ) + + variable_pool = VariablePool( + system_variables={SystemVariable.CONVERSATION_ID: 'conversation_id'}, + user_inputs={}, + environment_variables=[], + conversation_variables=[conversation_variable], + ) + + node.run(variable_pool) + + got = variable_pool.get(['conversation', conversation_variable.name]) + assert got is not None + assert got.to_object() == [] diff --git a/api/tests/unit_tests/models/test_conversation_variable.py b/api/tests/unit_tests/models/test_conversation_variable.py new file mode 100644 index 000000000..9e16010d7 --- /dev/null +++ b/api/tests/unit_tests/models/test_conversation_variable.py @@ -0,0 +1,25 @@ +from uuid import uuid4 + +from core.app.segments import SegmentType, factory +from models import ConversationVariable + + +def test_from_variable_and_to_variable(): + variable = factory.build_variable_from_mapping( + { + 'id': str(uuid4()), + 'name': 'name', + 'value_type': SegmentType.OBJECT, + 'value': { + 'key': { + 'key': 'value', + } + }, + } + ) + + conversation_variable = ConversationVariable.from_variable( + app_id='app_id', conversation_id='conversation_id', variable=variable + ) + + assert conversation_variable.to_variable() == variable diff --git a/web/app/components/base/badge.tsx b/web/app/components/base/badge.tsx index 3e5414fa2..c3300a1e6 100644 --- a/web/app/components/base/badge.tsx +++ b/web/app/components/base/badge.tsx @@ -4,16 +4,19 @@ import cn from '@/utils/classnames' type BadgeProps = { className?: string text: string + uppercase?: boolean } const Badge = ({ className, text, + uppercase = true, }: BadgeProps) => { return (
diff --git a/web/app/components/base/icons/assets/vender/line/others/bubble-x.svg b/web/app/components/base/icons/assets/vender/line/others/bubble-x.svg new file mode 100644 index 000000000..6e4df5b9b --- /dev/null +++ b/web/app/components/base/icons/assets/vender/line/others/bubble-x.svg @@ -0,0 +1,8 @@ + + + + + + + + diff --git a/web/app/components/base/icons/assets/vender/line/others/long-arrow-left.svg b/web/app/components/base/icons/assets/vender/line/others/long-arrow-left.svg new file mode 100644 index 000000000..7320664db --- /dev/null +++ b/web/app/components/base/icons/assets/vender/line/others/long-arrow-left.svg @@ -0,0 +1,3 @@ + + + diff --git a/web/app/components/base/icons/assets/vender/line/others/long-arrow-right.svg b/web/app/components/base/icons/assets/vender/line/others/long-arrow-right.svg new file mode 100644 index 000000000..733785a27 --- /dev/null +++ b/web/app/components/base/icons/assets/vender/line/others/long-arrow-right.svg @@ -0,0 +1,3 @@ + + + diff --git a/web/app/components/base/icons/assets/vender/workflow/assigner.svg b/web/app/components/base/icons/assets/vender/workflow/assigner.svg new file mode 100644 index 000000000..b37fbce52 --- /dev/null +++ b/web/app/components/base/icons/assets/vender/workflow/assigner.svg @@ -0,0 +1,9 @@ + + + + + + + + + diff --git a/web/app/components/base/icons/src/vender/line/others/BubbleX.json b/web/app/components/base/icons/src/vender/line/others/BubbleX.json new file mode 100644 index 000000000..0cb5702c1 --- /dev/null +++ b/web/app/components/base/icons/src/vender/line/others/BubbleX.json @@ -0,0 +1,57 @@ +{ + "icon": { + "type": "element", + "isRootNode": true, + "name": "svg", + "attributes": { + "width": "16", + "height": "16", + "viewBox": "0 0 16 16", + "fill": "none", + "xmlns": "http://www.w3.org/2000/svg" + }, + "children": [ + { + "type": "element", + "name": "g", + "attributes": { + "id": "Icon L" + }, + "children": [ + { + "type": "element", + "name": "g", + "attributes": { + "id": "Vector" + }, + "children": [ + { + "type": "element", + "name": "path", + "attributes": { + "fill-rule": "evenodd", + "clip-rule": "evenodd", + "d": "M3.33463 3.33333C2.96643 3.33333 2.66796 3.63181 2.66796 4V10.6667C2.66796 11.0349 2.96643 11.3333 3.33463 11.3333H4.66796C5.03615 11.3333 5.33463 11.6318 5.33463 12V12.8225L7.65833 11.4283C7.76194 11.3662 7.8805 11.3333 8.00132 11.3333H12.0013C12.3695 11.3333 12.668 11.0349 12.668 10.6667C12.668 10.2985 12.9665 10 13.3347 10C13.7028 10 14.0013 10.2985 14.0013 10.6667C14.0013 11.7713 13.1058 12.6667 12.0013 12.6667H8.18598L5.01095 14.5717C4.805 14.6952 4.5485 14.6985 4.33949 14.5801C4.13049 14.4618 4.00129 14.2402 4.00129 14V12.6667H3.33463C2.23006 12.6667 1.33463 11.7713 1.33463 10.6667V4C1.33463 2.89543 2.23006 2 3.33463 2H6.66798C7.03617 2 7.33464 2.29848 7.33464 2.66667C7.33464 3.03486 7.03617 3.33333 6.66798 3.33333H3.33463Z", + "fill": "currentColor" + }, + "children": [] + }, + { + "type": "element", + "name": "path", + "attributes": { + "fill-rule": "evenodd", + "clip-rule": "evenodd", + "d": "M8.74113 2.66667C8.74113 2.29848 9.03961 2 9.4078 2H10.331C10.9721 2 11.5177 2.43571 11.6859 3.04075L11.933 3.93004L12.8986 2.77189C13.3045 2.28508 13.9018 2 14.536 2H14.5954C14.9636 2 15.2621 2.29848 15.2621 2.66667C15.2621 3.03486 14.9636 3.33333 14.5954 3.33333H14.536C14.3048 3.33333 14.08 3.43702 13.9227 3.6257L12.367 5.49165L12.8609 7.2689C12.8746 7.31803 12.9105 7.33333 12.9312 7.33333H13.8543C14.2225 7.33333 14.521 7.63181 14.521 8C14.521 8.36819 14.2225 8.66667 13.8543 8.66667H12.9312C12.29 8.66667 11.7444 8.23095 11.5763 7.62591L11.3291 6.73654L10.3634 7.89478C9.95758 8.38159 9.36022 8.66667 8.72604 8.66667H8.66666C8.29847 8.66667 7.99999 8.36819 7.99999 8C7.99999 7.63181 8.29847 7.33333 8.66666 7.33333H8.72604C8.95723 7.33333 9.18204 7.22965 9.33935 7.04096L10.8951 5.17493L10.4012 3.39777C10.3876 3.34863 10.3516 3.33333 10.331 3.33333H9.4078C9.03961 3.33333 8.74113 3.03486 8.74113 2.66667Z", + "fill": "currentColor" + }, + "children": [] + } + ] + } + ] + } + ] + }, + "name": "BubbleX" +} \ No newline at end of file diff --git a/web/app/components/base/icons/src/vender/line/others/BubbleX.tsx b/web/app/components/base/icons/src/vender/line/others/BubbleX.tsx new file mode 100644 index 000000000..7d78bd33c --- /dev/null +++ b/web/app/components/base/icons/src/vender/line/others/BubbleX.tsx @@ -0,0 +1,16 @@ +// GENERATE BY script +// DON NOT EDIT IT MANUALLY + +import * as React from 'react' +import data from './BubbleX.json' +import IconBase from '@/app/components/base/icons/IconBase' +import type { IconBaseProps, IconData } from '@/app/components/base/icons/IconBase' + +const Icon = React.forwardRef, Omit>(( + props, + ref, +) => ) + +Icon.displayName = 'BubbleX' + +export default Icon diff --git a/web/app/components/base/icons/src/vender/line/others/LongArrowLeft.json b/web/app/components/base/icons/src/vender/line/others/LongArrowLeft.json new file mode 100644 index 000000000..d2646b109 --- /dev/null +++ b/web/app/components/base/icons/src/vender/line/others/LongArrowLeft.json @@ -0,0 +1,27 @@ +{ + "icon": { + "type": "element", + "isRootNode": true, + "name": "svg", + "attributes": { + "width": "21", + "height": "8", + "viewBox": "0 0 21 8", + "fill": "none", + "xmlns": "http://www.w3.org/2000/svg" + }, + "children": [ + { + "type": "element", + "name": "path", + "attributes": { + "d": "M0.646446 3.64645C0.451185 3.84171 0.451185 4.15829 0.646446 4.35355L3.82843 7.53553C4.02369 7.7308 4.34027 7.7308 4.53553 7.53553C4.7308 7.34027 4.7308 7.02369 4.53553 6.82843L1.70711 4L4.53553 1.17157C4.7308 0.976311 4.7308 0.659728 4.53553 0.464466C4.34027 0.269204 4.02369 0.269204 3.82843 0.464466L0.646446 3.64645ZM21 3.5L1 3.5V4.5L21 4.5V3.5Z", + "fill": "currentColor", + "fill-opacity": "0.3" + }, + "children": [] + } + ] + }, + "name": "LongArrowLeft" +} \ No newline at end of file diff --git a/web/app/components/base/icons/src/vender/line/others/LongArrowLeft.tsx b/web/app/components/base/icons/src/vender/line/others/LongArrowLeft.tsx new file mode 100644 index 000000000..930ced536 --- /dev/null +++ b/web/app/components/base/icons/src/vender/line/others/LongArrowLeft.tsx @@ -0,0 +1,16 @@ +// GENERATE BY script +// DON NOT EDIT IT MANUALLY + +import * as React from 'react' +import data from './LongArrowLeft.json' +import IconBase from '@/app/components/base/icons/IconBase' +import type { IconBaseProps, IconData } from '@/app/components/base/icons/IconBase' + +const Icon = React.forwardRef, Omit>(( + props, + ref, +) => ) + +Icon.displayName = 'LongArrowLeft' + +export default Icon diff --git a/web/app/components/base/icons/src/vender/line/others/LongArrowRight.json b/web/app/components/base/icons/src/vender/line/others/LongArrowRight.json new file mode 100644 index 000000000..7582b8156 --- /dev/null +++ b/web/app/components/base/icons/src/vender/line/others/LongArrowRight.json @@ -0,0 +1,27 @@ +{ + "icon": { + "type": "element", + "isRootNode": true, + "name": "svg", + "attributes": { + "width": "26", + "height": "8", + "viewBox": "0 0 26 8", + "fill": "none", + "xmlns": "http://www.w3.org/2000/svg" + }, + "children": [ + { + "type": "element", + "name": "path", + "attributes": { + "d": "M25.3536 4.35355C25.5488 4.15829 25.5488 3.84171 25.3536 3.64644L22.1716 0.464465C21.9763 0.269202 21.6597 0.269202 21.4645 0.464465C21.2692 0.659727 21.2692 0.976309 21.4645 1.17157L24.2929 4L21.4645 6.82843C21.2692 7.02369 21.2692 7.34027 21.4645 7.53553C21.6597 7.73079 21.9763 7.73079 22.1716 7.53553L25.3536 4.35355ZM3.59058e-08 4.5L25 4.5L25 3.5L-3.59058e-08 3.5L3.59058e-08 4.5Z", + "fill": "currentColor", + "fill-opacity": "0.3" + }, + "children": [] + } + ] + }, + "name": "LongArrowRight" +} \ No newline at end of file diff --git a/web/app/components/base/icons/src/vender/line/others/LongArrowRight.tsx b/web/app/components/base/icons/src/vender/line/others/LongArrowRight.tsx new file mode 100644 index 000000000..3c9084cad --- /dev/null +++ b/web/app/components/base/icons/src/vender/line/others/LongArrowRight.tsx @@ -0,0 +1,16 @@ +// GENERATE BY script +// DON NOT EDIT IT MANUALLY + +import * as React from 'react' +import data from './LongArrowRight.json' +import IconBase from '@/app/components/base/icons/IconBase' +import type { IconBaseProps, IconData } from '@/app/components/base/icons/IconBase' + +const Icon = React.forwardRef, Omit>(( + props, + ref, +) => ) + +Icon.displayName = 'LongArrowRight' + +export default Icon diff --git a/web/app/components/base/icons/src/vender/line/others/index.ts b/web/app/components/base/icons/src/vender/line/others/index.ts index 282a39499..d54f31e4a 100644 --- a/web/app/components/base/icons/src/vender/line/others/index.ts +++ b/web/app/components/base/icons/src/vender/line/others/index.ts @@ -1,8 +1,11 @@ export { default as Apps02 } from './Apps02' +export { default as BubbleX } from './BubbleX' export { default as Colors } from './Colors' export { default as DragHandle } from './DragHandle' export { default as Env } from './Env' export { default as Exchange02 } from './Exchange02' export { default as FileCode } from './FileCode' export { default as Icon3Dots } from './Icon3Dots' +export { default as LongArrowLeft } from './LongArrowLeft' +export { default as LongArrowRight } from './LongArrowRight' export { default as Tools } from './Tools' diff --git a/web/app/components/base/icons/src/vender/workflow/Assigner.json b/web/app/components/base/icons/src/vender/workflow/Assigner.json new file mode 100644 index 000000000..7106e5ad4 --- /dev/null +++ b/web/app/components/base/icons/src/vender/workflow/Assigner.json @@ -0,0 +1,68 @@ +{ + "icon": { + "type": "element", + "isRootNode": true, + "name": "svg", + "attributes": { + "width": "16", + "height": "16", + "viewBox": "0 0 16 16", + "fill": "none", + "xmlns": "http://www.w3.org/2000/svg" + }, + "children": [ + { + "type": "element", + "name": "g", + "attributes": { + "id": "variable assigner" + }, + "children": [ + { + "type": "element", + "name": "g", + "attributes": { + "id": "Vector" + }, + "children": [ + { + "type": "element", + "name": "path", + "attributes": { + "fill-rule": "evenodd", + "clip-rule": "evenodd", + "d": "M1.71438 4.42875C1.71438 3.22516 2.68954 2.25 3.89313 2.25C4.30734 2.25 4.64313 2.58579 4.64313 3C4.64313 3.41421 4.30734 3.75 3.89313 3.75C3.51796 3.75 3.21438 4.05359 3.21438 4.42875V7.28563C3.21438 7.48454 3.13536 7.6753 2.9947 7.81596L2.81066 8L2.9947 8.18404C3.13536 8.3247 3.21438 8.51546 3.21438 8.71437V11.5713C3.21438 11.9464 3.51796 12.25 3.89313 12.25C4.30734 12.25 4.64313 12.5858 4.64313 13C4.64313 13.4142 4.30734 13.75 3.89313 13.75C2.68954 13.75 1.71438 12.7748 1.71438 11.5713V9.02503L1.21967 8.53033C1.07902 8.38968 1 8.19891 1 8C1 7.80109 1.07902 7.61032 1.21967 7.46967L1.71438 6.97497V4.42875ZM11.3568 3C11.3568 2.58579 11.6925 2.25 12.1068 2.25C13.3103 2.25 14.2855 3.22516 14.2855 4.42875V6.97497L14.7802 7.46967C14.9209 7.61032 14.9999 7.80109 14.9999 8C14.9999 8.19891 14.9209 8.38968 14.7802 8.53033L14.2855 9.02503V11.5713C14.2855 12.7751 13.3095 13.75 12.1068 13.75C11.6925 13.75 11.3568 13.4142 11.3568 13C11.3568 12.5858 11.6925 12.25 12.1068 12.25C12.4815 12.25 12.7855 11.9462 12.7855 11.5713V8.71437C12.7855 8.51546 12.8645 8.3247 13.0052 8.18404L13.1892 8L13.0052 7.81596C12.8645 7.6753 12.7855 7.48454 12.7855 7.28563V4.42875C12.7855 4.05359 12.4819 3.75 12.1068 3.75C11.6925 3.75 11.3568 3.41421 11.3568 3Z", + "fill": "currentColor" + }, + "children": [] + }, + { + "type": "element", + "name": "path", + "attributes": { + "fill-rule": "evenodd", + "clip-rule": "evenodd", + "d": "M5.25 6C5.25 5.58579 5.58579 5.25 6 5.25H10C10.4142 5.25 10.75 5.58579 10.75 6C10.75 6.41421 10.4142 6.75 10 6.75H6C5.58579 6.75 5.25 6.41421 5.25 6Z", + "fill": "currentColor" + }, + "children": [] + }, + { + "type": "element", + "name": "path", + "attributes": { + "fill-rule": "evenodd", + "clip-rule": "evenodd", + "d": "M5.25 10C5.25 9.58579 5.58579 9.25 6 9.25H10C10.4142 9.25 10.75 9.58579 10.75 10C10.75 10.4142 10.4142 10.75 10 10.75H6C5.58579 10.75 5.25 10.4142 5.25 10Z", + "fill": "currentColor" + }, + "children": [] + } + ] + } + ] + } + ] + }, + "name": "Assigner" +} \ No newline at end of file diff --git a/web/app/components/base/icons/src/vender/workflow/Assigner.tsx b/web/app/components/base/icons/src/vender/workflow/Assigner.tsx new file mode 100644 index 000000000..1cb7d692d --- /dev/null +++ b/web/app/components/base/icons/src/vender/workflow/Assigner.tsx @@ -0,0 +1,16 @@ +// GENERATE BY script +// DON NOT EDIT IT MANUALLY + +import * as React from 'react' +import data from './Assigner.json' +import IconBase from '@/app/components/base/icons/IconBase' +import type { IconBaseProps, IconData } from '@/app/components/base/icons/IconBase' + +const Icon = React.forwardRef, Omit>(( + props, + ref, +) => ) + +Icon.displayName = 'Assigner' + +export default Icon diff --git a/web/app/components/base/icons/src/vender/workflow/index.ts b/web/app/components/base/icons/src/vender/workflow/index.ts index 94e20ae6a..a2563a6a3 100644 --- a/web/app/components/base/icons/src/vender/workflow/index.ts +++ b/web/app/components/base/icons/src/vender/workflow/index.ts @@ -1,4 +1,5 @@ export { default as Answer } from './Answer' +export { default as Assigner } from './Assigner' export { default as Code } from './Code' export { default as End } from './End' export { default as Home } from './Home' diff --git a/web/app/components/base/input/index.tsx b/web/app/components/base/input/index.tsx index 0fb34de2e..5ab824944 100644 --- a/web/app/components/base/input/index.tsx +++ b/web/app/components/base/input/index.tsx @@ -2,7 +2,7 @@ import type { SVGProps } from 'react' import React, { useState } from 'react' import { useTranslation } from 'react-i18next' -import s from './style.module.css' +import cn from 'classnames' type InputProps = { placeholder?: string @@ -27,10 +27,10 @@ const Input = ({ value, defaultValue, onChange, className = '', wrapperClassName const { t } = useTranslation() return (
- {showPrefix && {prefixIcon ?? }} + {showPrefix && {prefixIcon ?? }} { diff --git a/web/app/components/base/input/style.module.css b/web/app/components/base/input/style.module.css deleted file mode 100644 index 5f2782777..000000000 --- a/web/app/components/base/input/style.module.css +++ /dev/null @@ -1,7 +0,0 @@ -.input { - @apply inline-flex h-7 w-full py-1 px-2 rounded-lg text-xs leading-normal; - @apply bg-gray-100 caret-primary-600 hover:bg-gray-100 focus:ring-1 focus:ring-inset focus:ring-gray-200 focus-visible:outline-none focus:bg-white placeholder:text-gray-400; -} -.prefix { - @apply whitespace-nowrap absolute left-2 self-center -} diff --git a/web/app/components/base/prompt-editor/index.tsx b/web/app/components/base/prompt-editor/index.tsx index da70d04ac..deae6833c 100644 --- a/web/app/components/base/prompt-editor/index.tsx +++ b/web/app/components/base/prompt-editor/index.tsx @@ -144,7 +144,7 @@ const PromptEditor: FC = ({ return ( -
+
} placeholder={} diff --git a/web/app/components/base/prompt-editor/plugins/workflow-variable-block/component.tsx b/web/app/components/base/prompt-editor/plugins/workflow-variable-block/component.tsx index e149f5b75..39193fc31 100644 --- a/web/app/components/base/prompt-editor/plugins/workflow-variable-block/component.tsx +++ b/web/app/components/base/prompt-editor/plugins/workflow-variable-block/component.tsx @@ -21,10 +21,10 @@ import { } from './index' import cn from '@/utils/classnames' import { Variable02 } from '@/app/components/base/icons/src/vender/solid/development' -import { Env } from '@/app/components/base/icons/src/vender/line/others' +import { BubbleX, Env } from '@/app/components/base/icons/src/vender/line/others' import { VarBlockIcon } from '@/app/components/workflow/block-icon' import { Line3 } from '@/app/components/base/icons/src/public/common' -import { isENV, isSystemVar } from '@/app/components/workflow/nodes/_base/components/variable/utils' +import { isConversationVar, isENV, isSystemVar } from '@/app/components/workflow/nodes/_base/components/variable/utils' import TooltipPlus from '@/app/components/base/tooltip-plus' type WorkflowVariableBlockComponentProps = { @@ -52,6 +52,7 @@ const WorkflowVariableBlockComponent = ({ const [localWorkflowNodesMap, setLocalWorkflowNodesMap] = useState(workflowNodesMap) const node = localWorkflowNodesMap![variables[0]] const isEnv = isENV(variables) + const isChatVar = isConversationVar(variables) useEffect(() => { if (!editor.hasNodes([WorkflowVariableBlockNode])) @@ -75,11 +76,11 @@ const WorkflowVariableBlockComponent = ({ className={cn( 'mx-0.5 relative group/wrap flex items-center h-[18px] pl-0.5 pr-[3px] rounded-[5px] border select-none', isSelected ? ' border-[#84ADFF] bg-[#F5F8FF]' : ' border-black/5 bg-white', - !node && !isEnv && '!border-[#F04438] !bg-[#FEF3F2]', + !node && !isEnv && !isChatVar && '!border-[#F04438] !bg-[#FEF3F2]', )} ref={ref} > - {!isEnv && ( + {!isEnv && !isChatVar && (
{ node?.type && ( @@ -97,11 +98,12 @@ const WorkflowVariableBlockComponent = ({
)}
- {!isEnv && } + {!isEnv && !isChatVar && } {isEnv && } -
{varName}
+ {isChatVar && } +
{varName}
{ - !node && !isEnv && ( + !node && !isEnv && !isChatVar && ( ) } @@ -109,7 +111,7 @@ const WorkflowVariableBlockComponent = ({
) - if (!node && !isEnv) { + if (!node && !isEnv && !isChatVar) { return ( {Item} diff --git a/web/app/components/workflow/block-icon.tsx b/web/app/components/workflow/block-icon.tsx index 6bec70449..a7e89ad6c 100644 --- a/web/app/components/workflow/block-icon.tsx +++ b/web/app/components/workflow/block-icon.tsx @@ -3,6 +3,7 @@ import { memo } from 'react' import { BlockEnum } from './types' import { Answer, + Assigner, Code, End, Home, @@ -43,6 +44,7 @@ const getIcon = (type: BlockEnum, className: string) => { [BlockEnum.TemplateTransform]: , [BlockEnum.VariableAssigner]: , [BlockEnum.VariableAggregator]: , + [BlockEnum.Assigner]: , [BlockEnum.Tool]: , [BlockEnum.Iteration]: , [BlockEnum.ParameterExtractor]: , @@ -62,6 +64,7 @@ const ICON_CONTAINER_BG_COLOR_MAP: Record = { [BlockEnum.TemplateTransform]: 'bg-[#2E90FA]', [BlockEnum.VariableAssigner]: 'bg-[#2E90FA]', [BlockEnum.VariableAggregator]: 'bg-[#2E90FA]', + [BlockEnum.Assigner]: 'bg-[#2E90FA]', [BlockEnum.ParameterExtractor]: 'bg-[#2E90FA]', } const BlockIcon: FC = ({ diff --git a/web/app/components/workflow/block-selector/constants.tsx b/web/app/components/workflow/block-selector/constants.tsx index 517d9356f..fbe0a9a8a 100644 --- a/web/app/components/workflow/block-selector/constants.tsx +++ b/web/app/components/workflow/block-selector/constants.tsx @@ -59,6 +59,11 @@ export const BLOCKS: Block[] = [ type: BlockEnum.VariableAggregator, title: 'Variable Aggregator', }, + { + classification: BlockClassificationEnum.Transform, + type: BlockEnum.Assigner, + title: 'Variable Assigner', + }, { classification: BlockClassificationEnum.Transform, type: BlockEnum.ParameterExtractor, diff --git a/web/app/components/workflow/constants.ts b/web/app/components/workflow/constants.ts index a77e09296..070748bab 100644 --- a/web/app/components/workflow/constants.ts +++ b/web/app/components/workflow/constants.ts @@ -12,6 +12,7 @@ import HttpRequestDefault from './nodes/http/default' import ParameterExtractorDefault from './nodes/parameter-extractor/default' import ToolDefault from './nodes/tool/default' import VariableAssignerDefault from './nodes/variable-assigner/default' +import AssignerDefault from './nodes/assigner/default' import EndNodeDefault from './nodes/end/default' import IterationDefault from './nodes/iteration/default' @@ -133,6 +134,15 @@ export const NODES_EXTRA_DATA: Record = { getAvailableNextNodes: VariableAssignerDefault.getAvailableNextNodes, checkValid: VariableAssignerDefault.checkValid, }, + [BlockEnum.Assigner]: { + author: 'Dify', + about: '', + availablePrevNodes: [], + availableNextNodes: [], + getAvailablePrevNodes: AssignerDefault.getAvailablePrevNodes, + getAvailableNextNodes: AssignerDefault.getAvailableNextNodes, + checkValid: AssignerDefault.checkValid, + }, [BlockEnum.VariableAggregator]: { author: 'Dify', about: '', @@ -268,6 +278,12 @@ export const NODES_INITIAL_DATA = { output_type: '', ...VariableAssignerDefault.defaultValue, }, + [BlockEnum.Assigner]: { + type: BlockEnum.Assigner, + title: '', + desc: '', + ...AssignerDefault.defaultValue, + }, [BlockEnum.Tool]: { type: BlockEnum.Tool, title: '', diff --git a/web/app/components/workflow/header/chat-variable-button.tsx b/web/app/components/workflow/header/chat-variable-button.tsx new file mode 100644 index 000000000..39745d4fb --- /dev/null +++ b/web/app/components/workflow/header/chat-variable-button.tsx @@ -0,0 +1,24 @@ +import { memo } from 'react' +import Button from '@/app/components/base/button' +import { BubbleX } from '@/app/components/base/icons/src/vender/line/others' +import { useStore } from '@/app/components/workflow/store' + +const ChatVariableButton = ({ disabled }: { disabled: boolean }) => { + const setShowChatVariablePanel = useStore(s => s.setShowChatVariablePanel) + const setShowEnvPanel = useStore(s => s.setShowEnvPanel) + const setShowDebugAndPreviewPanel = useStore(s => s.setShowDebugAndPreviewPanel) + + const handleClick = () => { + setShowChatVariablePanel(true) + setShowEnvPanel(false) + setShowDebugAndPreviewPanel(false) + } + + return ( + + ) +} + +export default memo(ChatVariableButton) diff --git a/web/app/components/workflow/header/env-button.tsx b/web/app/components/workflow/header/env-button.tsx index f93273971..71598776d 100644 --- a/web/app/components/workflow/header/env-button.tsx +++ b/web/app/components/workflow/header/env-button.tsx @@ -1,21 +1,23 @@ import { memo } from 'react' +import Button from '@/app/components/base/button' import { Env } from '@/app/components/base/icons/src/vender/line/others' import { useStore } from '@/app/components/workflow/store' -import cn from '@/utils/classnames' -const EnvButton = () => { +const EnvButton = ({ disabled }: { disabled: boolean }) => { + const setShowChatVariablePanel = useStore(s => s.setShowChatVariablePanel) const setShowEnvPanel = useStore(s => s.setShowEnvPanel) const setShowDebugAndPreviewPanel = useStore(s => s.setShowDebugAndPreviewPanel) const handleClick = () => { setShowEnvPanel(true) + setShowChatVariablePanel(false) setShowDebugAndPreviewPanel(false) } return ( -
+
+ ) } diff --git a/web/app/components/workflow/header/index.tsx b/web/app/components/workflow/header/index.tsx index 75d5b29a8..58624d816 100644 --- a/web/app/components/workflow/header/index.tsx +++ b/web/app/components/workflow/header/index.tsx @@ -19,6 +19,7 @@ import { import type { StartNodeType } from '../nodes/start/types' import { useChecklistBeforePublish, + useIsChatMode, useNodesReadOnly, useNodesSyncDraft, useWorkflowMode, @@ -31,6 +32,7 @@ import EditingTitle from './editing-title' import RunningTitle from './running-title' import RestoringTitle from './restoring-title' import ViewHistory from './view-history' +import ChatVariableButton from './chat-variable-button' import EnvButton from './env-button' import Button from '@/app/components/base/button' import { useStore as useAppStore } from '@/app/components/app/store' @@ -44,7 +46,8 @@ const Header: FC = () => { const appDetail = useAppStore(s => s.appDetail) const appSidebarExpand = useAppStore(s => s.appSidebarExpand) const appID = appDetail?.id - const { getNodesReadOnly } = useNodesReadOnly() + const isChatMode = useIsChatMode() + const { nodesReadOnly, getNodesReadOnly } = useNodesReadOnly() const publishedAt = useStore(s => s.publishedAt) const draftUpdatedAt = useStore(s => s.draftUpdatedAt) const toolPublished = useStore(s => s.toolPublished) @@ -165,7 +168,8 @@ const Header: FC = () => { { normal && (
- + {isChatMode && } +
+
+ ) +} +export default React.memo(ArrayValueList) diff --git a/web/app/components/workflow/panel/chat-variable-panel/components/object-value-item.tsx b/web/app/components/workflow/panel/chat-variable-panel/components/object-value-item.tsx new file mode 100644 index 000000000..6bbdeae08 --- /dev/null +++ b/web/app/components/workflow/panel/chat-variable-panel/components/object-value-item.tsx @@ -0,0 +1,135 @@ +'use client' +import type { FC } from 'react' +import React, { useCallback, useState } from 'react' +import { useTranslation } from 'react-i18next' +import produce from 'immer' +import { useContext } from 'use-context-selector' +import { ToastContext } from '@/app/components/base/toast' +import VariableTypeSelector from '@/app/components/workflow/panel/chat-variable-panel/components/variable-type-select' +import RemoveButton from '@/app/components/workflow/nodes/_base/components/remove-button' +import { ChatVarType } from '@/app/components/workflow/panel/chat-variable-panel/type' + +type Props = { + index: number + list: any[] + onChange: (list: any[]) => void +} + +const typeList = [ + ChatVarType.String, + ChatVarType.Number, +] + +export const DEFAULT_OBJECT_VALUE = { + key: '', + type: ChatVarType.String, + value: undefined, +} + +const ObjectValueItem: FC = ({ + index, + list, + onChange, +}) => { + const { t } = useTranslation() + const { notify } = useContext(ToastContext) + const [isFocus, setIsFocus] = useState(false) + + const handleKeyChange = useCallback((index: number) => { + return (e: React.ChangeEvent) => { + const newList = produce(list, (draft: any[]) => { + if (!/^[a-zA-Z0-9_]+$/.test(e.target.value)) + return notify({ type: 'error', message: 'key is can only contain letters, numbers and underscores' }) + draft[index].key = e.target.value + }) + onChange(newList) + } + }, [list, notify, onChange]) + + const handleTypeChange = useCallback((index: number) => { + return (type: ChatVarType) => { + const newList = produce(list, (draft) => { + draft[index].type = type + if (type === ChatVarType.Number) + draft[index].value = isNaN(Number(draft[index].value)) ? undefined : Number(draft[index].value) + else + draft[index].value = draft[index].value ? String(draft[index].value) : undefined + }) + onChange(newList) + } + }, [list, onChange]) + + const handleValueChange = useCallback((index: number) => { + return (e: React.ChangeEvent) => { + const newList = produce(list, (draft: any[]) => { + draft[index].value = draft[index].type === ChatVarType.String ? e.target.value : isNaN(Number(e.target.value)) ? undefined : Number(e.target.value) + }) + onChange(newList) + } + }, [list, onChange]) + + const handleItemRemove = useCallback((index: number) => { + return () => { + const newList = produce(list, (draft) => { + draft.splice(index, 1) + }) + onChange(newList) + } + }, [list, onChange]) + + const handleItemAdd = useCallback(() => { + const newList = produce(list, (draft: any[]) => { + draft.push(DEFAULT_OBJECT_VALUE) + }) + onChange(newList) + }, [list, onChange]) + + const handleFocusChange = useCallback(() => { + setIsFocus(true) + if (index === list.length - 1) + handleItemAdd() + }, [handleItemAdd, index, list.length]) + + return ( +
+ {/* Key */} +
+ +
+ {/* Type */} +
+ +
+ {/* Value */} +
+ handleFocusChange()} + onBlur={() => setIsFocus(false)} + type={list[index].type === ChatVarType.Number ? 'number' : 'text'} + /> + {list.length > 1 && !isFocus && ( + + )} +
+
+ ) +} +export default React.memo(ObjectValueItem) diff --git a/web/app/components/workflow/panel/chat-variable-panel/components/object-value-list.tsx b/web/app/components/workflow/panel/chat-variable-panel/components/object-value-list.tsx new file mode 100644 index 000000000..ec287accb --- /dev/null +++ b/web/app/components/workflow/panel/chat-variable-panel/components/object-value-list.tsx @@ -0,0 +1,36 @@ +'use client' +import type { FC } from 'react' +import React from 'react' +import { useTranslation } from 'react-i18next' +import ObjectValueItem from '@/app/components/workflow/panel/chat-variable-panel/components/object-value-item' + +type Props = { + list: any[] + onChange: (list: any[]) => void +} + +const ObjectValueList: FC = ({ + list, + onChange, +}) => { + const { t } = useTranslation() + + return ( +
+
+
{t('workflow.chatVariable.modal.objectKey')}
+
{t('workflow.chatVariable.modal.objectType')}
+
{t('workflow.chatVariable.modal.objectValue')}
+
+ {list.map((item, index) => ( + + ))} +
+ ) +} +export default React.memo(ObjectValueList) diff --git a/web/app/components/workflow/panel/chat-variable-panel/components/variable-item.tsx b/web/app/components/workflow/panel/chat-variable-panel/components/variable-item.tsx new file mode 100644 index 000000000..a1a7c9dc3 --- /dev/null +++ b/web/app/components/workflow/panel/chat-variable-panel/components/variable-item.tsx @@ -0,0 +1,49 @@ +import { memo, useState } from 'react' +import { capitalize } from 'lodash-es' +import { RiDeleteBinLine, RiEditLine } from '@remixicon/react' +import { BubbleX } from '@/app/components/base/icons/src/vender/line/others' +import type { ConversationVariable } from '@/app/components/workflow/types' +import cn from '@/utils/classnames' + +type VariableItemProps = { + item: ConversationVariable + onEdit: (item: ConversationVariable) => void + onDelete: (item: ConversationVariable) => void +} + +const VariableItem = ({ + item, + onEdit, + onDelete, +}: VariableItemProps) => { + const [destructive, setDestructive] = useState(false) + return ( +
+
+
+ +
{item.name}
+
{capitalize(item.value_type)}
+
+
+
+ onEdit(item)}/> +
+
setDestructive(true)} + onMouseOut={() => setDestructive(false)} + > + onDelete(item)}/> +
+
+
+
{item.description}
+
+ ) +} + +export default memo(VariableItem) diff --git a/web/app/components/workflow/panel/chat-variable-panel/components/variable-modal-trigger.tsx b/web/app/components/workflow/panel/chat-variable-panel/components/variable-modal-trigger.tsx new file mode 100644 index 000000000..35d525432 --- /dev/null +++ b/web/app/components/workflow/panel/chat-variable-panel/components/variable-modal-trigger.tsx @@ -0,0 +1,69 @@ +'use client' +import React from 'react' +import { useTranslation } from 'react-i18next' +import { RiAddLine } from '@remixicon/react' +import Button from '@/app/components/base/button' +import VariableModal from '@/app/components/workflow/panel/chat-variable-panel/components/variable-modal' +import { + PortalToFollowElem, + PortalToFollowElemContent, + PortalToFollowElemTrigger, +} from '@/app/components/base/portal-to-follow-elem' +import type { ConversationVariable } from '@/app/components/workflow/types' + +type Props = { + open: boolean + setOpen: (value: React.SetStateAction) => void + showTip: boolean + chatVar?: ConversationVariable + onClose: () => void + onSave: (env: ConversationVariable) => void +} + +const VariableModalTrigger = ({ + open, + setOpen, + showTip, + chatVar, + onClose, + onSave, +}: Props) => { + const { t } = useTranslation() + + return ( + { + setOpen(v => !v) + open && onClose() + }} + placement='left-start' + offset={{ + mainAxis: 8, + alignmentAxis: showTip ? -278 : -48, + }} + > + { + setOpen(v => !v) + open && onClose() + }}> + + + + { + onClose() + setOpen(false) + }} + /> + + + ) +} + +export default VariableModalTrigger diff --git a/web/app/components/workflow/panel/chat-variable-panel/components/variable-modal.tsx b/web/app/components/workflow/panel/chat-variable-panel/components/variable-modal.tsx new file mode 100644 index 000000000..135ee4349 --- /dev/null +++ b/web/app/components/workflow/panel/chat-variable-panel/components/variable-modal.tsx @@ -0,0 +1,388 @@ +import React, { useCallback, useEffect, useMemo } from 'react' +import { useTranslation } from 'react-i18next' +import { useContext } from 'use-context-selector' +import { v4 as uuid4 } from 'uuid' +import { RiCloseLine, RiDraftLine, RiInputField } from '@remixicon/react' +import VariableTypeSelector from '@/app/components/workflow/panel/chat-variable-panel/components/variable-type-select' +import ObjectValueList from '@/app/components/workflow/panel/chat-variable-panel/components/object-value-list' +import { DEFAULT_OBJECT_VALUE } from '@/app/components/workflow/panel/chat-variable-panel/components/object-value-item' +import ArrayValueList from '@/app/components/workflow/panel/chat-variable-panel/components/array-value-list' +import Button from '@/app/components/base/button' +import CodeEditor from '@/app/components/workflow/nodes/_base/components/editor/code-editor' +import { ToastContext } from '@/app/components/base/toast' +import { useStore } from '@/app/components/workflow/store' +import type { ConversationVariable } from '@/app/components/workflow/types' +import { CodeLanguage } from '@/app/components/workflow/nodes/code/types' +import { ChatVarType } from '@/app/components/workflow/panel/chat-variable-panel/type' +import cn from '@/utils/classnames' + +export type ModalPropsType = { + chatVar?: ConversationVariable + onClose: () => void + onSave: (chatVar: ConversationVariable) => void +} + +type ObjectValueItem = { + key: string + type: ChatVarType + value: string | number | undefined +} + +const typeList = [ + ChatVarType.String, + ChatVarType.Number, + ChatVarType.Object, + ChatVarType.ArrayString, + ChatVarType.ArrayNumber, + ChatVarType.ArrayObject, +] + +const objectPlaceholder = `# example +# { +# "name": "ray", +# "age": 20 +# }` +const arrayStringPlaceholder = `# example +# [ +# "value1", +# "value2" +# ]` +const arrayNumberPlaceholder = `# example +# [ +# 100, +# 200 +# ]` +const arrayObjectPlaceholder = `# example +# [ +# { +# "name": "ray", +# "age": 20 +# }, +# { +# "name": "lily", +# "age": 18 +# } +# ]` + +const ChatVariableModal = ({ + chatVar, + onClose, + onSave, +}: ModalPropsType) => { + const { t } = useTranslation() + const { notify } = useContext(ToastContext) + const varList = useStore(s => s.conversationVariables) + const [name, setName] = React.useState('') + const [type, setType] = React.useState(ChatVarType.String) + const [value, setValue] = React.useState() + const [objectValue, setObjectValue] = React.useState([DEFAULT_OBJECT_VALUE]) + const [editorContent, setEditorContent] = React.useState() + const [editInJSON, setEditInJSON] = React.useState(false) + const [des, setDes] = React.useState('') + + const editorMinHeight = useMemo(() => { + if (type === ChatVarType.ArrayObject) + return '240px' + return '120px' + }, [type]) + const placeholder = useMemo(() => { + if (type === ChatVarType.ArrayString) + return arrayStringPlaceholder + if (type === ChatVarType.ArrayNumber) + return arrayNumberPlaceholder + if (type === ChatVarType.ArrayObject) + return arrayObjectPlaceholder + return objectPlaceholder + }, [type]) + const getObjectValue = useCallback(() => { + if (!chatVar) + return [DEFAULT_OBJECT_VALUE] + return Object.keys(chatVar.value).map((key) => { + return { + key, + type: typeof chatVar.value[key] === 'string' ? ChatVarType.String : ChatVarType.Number, + value: chatVar.value[key], + } + }) + }, [chatVar]) + const formatValueFromObject = useCallback((list: ObjectValueItem[]) => { + return list.reduce((acc: any, curr) => { + if (curr.key) + acc[curr.key] = curr.value || null + return acc + }, {}) + }, []) + + const formatValue = (value: any) => { + switch (type) { + case ChatVarType.String: + return value || '' + case ChatVarType.Number: + return value || 0 + case ChatVarType.Object: + return formatValueFromObject(objectValue) + case ChatVarType.ArrayString: + case ChatVarType.ArrayNumber: + case ChatVarType.ArrayObject: + return value?.filter(Boolean) || [] + } + } + + const handleNameChange = (v: string) => { + if (!v) + return setName('') + if (!/^[a-zA-Z0-9_]+$/.test(v)) + return notify({ type: 'error', message: 'name is can only contain letters, numbers and underscores' }) + if (/^[0-9]/.test(v)) + return notify({ type: 'error', message: 'name can not start with a number' }) + setName(v) + } + + const handleTypeChange = (v: ChatVarType) => { + setValue(undefined) + setEditorContent(undefined) + if (v === ChatVarType.ArrayObject) + setEditInJSON(true) + if (v === ChatVarType.String || v === ChatVarType.Number || v === ChatVarType.Object) + setEditInJSON(false) + setType(v) + } + + const handleEditorChange = (editInJSON: boolean) => { + if (type === ChatVarType.Object) { + if (editInJSON) { + const newValue = !objectValue[0].key ? undefined : formatValueFromObject(objectValue) + setValue(newValue) + setEditorContent(JSON.stringify(newValue)) + } + else { + if (!editorContent) { + setValue(undefined) + setObjectValue([DEFAULT_OBJECT_VALUE]) + } + else { + try { + const newValue = JSON.parse(editorContent) + setValue(newValue) + const newObjectValue = Object.keys(newValue).map((key) => { + return { + key, + type: typeof newValue[key] === 'string' ? ChatVarType.String : ChatVarType.Number, + value: newValue[key], + } + }) + setObjectValue(newObjectValue) + } + catch (e) { + // ignore JSON.parse errors + } + } + } + } + if (type === ChatVarType.ArrayString || type === ChatVarType.ArrayNumber) { + if (editInJSON) { + const newValue = (value?.length && value.filter(Boolean).length) ? value.filter(Boolean) : undefined + setValue(newValue) + if (!editorContent) + setEditorContent(JSON.stringify(newValue)) + } + else { + setValue(value?.length ? value : [undefined]) + } + } + setEditInJSON(editInJSON) + } + + const handleEditorValueChange = (content: string) => { + if (!content) { + setEditorContent(content) + return setValue(undefined) + } + else { + setEditorContent(content) + try { + const newValue = JSON.parse(content) + setValue(newValue) + } + catch (e) { + // ignore JSON.parse errors + } + } + } + + const handleSave = () => { + if (!name) + return notify({ type: 'error', message: 'name can not be empty' }) + if (!chatVar && varList.some(chatVar => chatVar.name === name)) + return notify({ type: 'error', message: 'name is existed' }) + // if (type !== ChatVarType.Object && !value) + // return notify({ type: 'error', message: 'value can not be empty' }) + if (type === ChatVarType.Object && objectValue.some(item => !item.key && !!item.value)) + return notify({ type: 'error', message: 'object key can not be empty' }) + + onSave({ + id: chatVar ? chatVar.id : uuid4(), + name, + value_type: type, + value: formatValue(value), + description: des, + }) + onClose() + } + + useEffect(() => { + if (chatVar) { + setName(chatVar.name) + setType(chatVar.value_type) + setValue(chatVar.value) + setDes(chatVar.description) + setEditInJSON(false) + setObjectValue(getObjectValue()) + } + }, [chatVar, getObjectValue]) + + return ( +
+
+ {!chatVar ? t('workflow.chatVariable.modal.title') : t('workflow.chatVariable.modal.editTitle')} +
+
+ +
+
+
+
+ {/* name */} +
+
{t('workflow.chatVariable.modal.name')}
+
+ handleNameChange(e.target.value)} + type='text' + /> +
+
+ {/* type */} +
+
{t('workflow.chatVariable.modal.type')}
+
+ +
+
+ {/* default value */} +
+
+
{t('workflow.chatVariable.modal.value')}
+ {(type === ChatVarType.ArrayString || type === ChatVarType.ArrayNumber) && ( + + )} + {type === ChatVarType.Object && ( + + )} +
+
+ {type === ChatVarType.String && ( + setValue(e.target.value)} + /> + )} + {type === ChatVarType.Number && ( + setValue(Number(e.target.value))} + type='number' + /> + )} + {type === ChatVarType.Object && !editInJSON && ( + + )} + {type === ChatVarType.ArrayString && !editInJSON && ( + + )} + {type === ChatVarType.ArrayNumber && !editInJSON && ( + + )} + {editInJSON && ( +
+ {placeholder}
} + onChange={handleEditorValueChange} + /> +
+ )} +
+
+ {/* description */} +
+
{t('workflow.chatVariable.modal.description')}
+
+