diff --git a/api/core/tools/tool_engine.py b/api/core/tools/tool_engine.py index 178f2b968..83444c02d 100644 --- a/api/core/tools/tool_engine.py +++ b/api/core/tools/tool_engine.py @@ -29,7 +29,7 @@ from core.tools.errors import ( ToolProviderCredentialValidationError, ToolProviderNotFoundError, ) -from core.tools.utils.message_transformer import ToolFileMessageTransformer +from core.tools.utils.message_transformer import ToolFileMessageTransformer, safe_json_value from core.tools.workflow_as_tool.tool import WorkflowTool from extensions.ext_database import db from models.enums import CreatorUserRole @@ -247,7 +247,8 @@ class ToolEngine: ) elif response.type == ToolInvokeMessage.MessageType.JSON: result += json.dumps( - cast(ToolInvokeMessage.JsonMessage, response.message).json_object, ensure_ascii=False + safe_json_value(cast(ToolInvokeMessage.JsonMessage, response.message).json_object), + ensure_ascii=False, ) else: result += str(response.message) diff --git a/api/core/tools/utils/message_transformer.py b/api/core/tools/utils/message_transformer.py index 9998de046..ac12d83ef 100644 --- a/api/core/tools/utils/message_transformer.py +++ b/api/core/tools/utils/message_transformer.py @@ -1,7 +1,14 @@ import logging from collections.abc import Generator +from datetime import date, datetime +from decimal import Decimal from mimetypes import guess_extension -from typing import Optional +from typing import Optional, cast +from uuid import UUID + +import numpy as np +import pytz +from flask_login import current_user from core.file import File, FileTransferMethod, FileType from core.tools.entities.tool_entities import ToolInvokeMessage @@ -10,6 +17,41 @@ from core.tools.tool_file_manager import ToolFileManager logger = logging.getLogger(__name__) +def safe_json_value(v): + if isinstance(v, datetime): + tz_name = getattr(current_user, "timezone", None) if current_user is not None else None + if not tz_name: + tz_name = "UTC" + return v.astimezone(pytz.timezone(tz_name)).isoformat() + elif isinstance(v, date): + return v.isoformat() + elif isinstance(v, UUID): + return str(v) + elif isinstance(v, Decimal): + return float(v) + elif isinstance(v, bytes): + try: + return v.decode("utf-8") + except UnicodeDecodeError: + return v.hex() + elif isinstance(v, memoryview): + return v.tobytes().hex() + elif isinstance(v, np.ndarray): + return v.tolist() + elif isinstance(v, dict): + return safe_json_dict(v) + elif isinstance(v, list | tuple | set): + return [safe_json_value(i) for i in v] + else: + return v + + +def safe_json_dict(d): + if not isinstance(d, dict): + raise TypeError("safe_json_dict() expects a dictionary (dict) as input") + return {k: safe_json_value(v) for k, v in d.items()} + + class ToolFileMessageTransformer: @classmethod def transform_tool_invoke_messages( @@ -113,6 +155,12 @@ class ToolFileMessageTransformer: ) else: yield message + + elif message.type == ToolInvokeMessage.MessageType.JSON: + if isinstance(message.message, ToolInvokeMessage.JsonMessage): + json_msg = cast(ToolInvokeMessage.JsonMessage, message.message) + json_msg.json_object = safe_json_value(json_msg.json_object) + yield message else: yield message