From e340fccafb84bfe15b4bc2b905aca70c79386ad5 Mon Sep 17 00:00:00 2001 From: -LAN- Date: Thu, 14 Aug 2025 19:50:59 +0800 Subject: [PATCH] feat: integrate flask-orjson for improved JSON serialization performance (#23935) --- api/app_factory.py | 2 ++ .../code_executor/template_transformer.py | 4 +-- .../rag/datasource/keyword/jieba/jieba.py | 26 ++++++++------- api/core/variables/utils.py | 33 +++++++++++-------- api/extensions/ext_orjson.py | 8 +++++ api/models/workflow.py | 2 +- api/pyproject.toml | 1 + api/uv.lock | 15 +++++++++ 8 files changed, 64 insertions(+), 27 deletions(-) create mode 100644 api/extensions/ext_orjson.py diff --git a/api/app_factory.py b/api/app_factory.py index 81155cbac..032d6b17f 100644 --- a/api/app_factory.py +++ b/api/app_factory.py @@ -51,6 +51,7 @@ def initialize_extensions(app: DifyApp): ext_login, ext_mail, ext_migrate, + ext_orjson, ext_otel, ext_proxy_fix, ext_redis, @@ -67,6 +68,7 @@ def initialize_extensions(app: DifyApp): ext_logging, ext_warnings, ext_import_modules, + ext_orjson, ext_set_secretkey, ext_compress, ext_code_based_extension, diff --git a/api/core/helper/code_executor/template_transformer.py b/api/core/helper/code_executor/template_transformer.py index b416e48ce..3965f8cb3 100644 --- a/api/core/helper/code_executor/template_transformer.py +++ b/api/core/helper/code_executor/template_transformer.py @@ -5,7 +5,7 @@ from base64 import b64encode from collections.abc import Mapping from typing import Any -from core.variables.utils import SegmentJSONEncoder +from core.variables.utils import dumps_with_segments class TemplateTransformer(ABC): @@ -93,7 +93,7 @@ class TemplateTransformer(ABC): @classmethod def serialize_inputs(cls, inputs: Mapping[str, Any]) -> str: - inputs_json_str = json.dumps(inputs, ensure_ascii=False, cls=SegmentJSONEncoder).encode() + inputs_json_str = dumps_with_segments(inputs, ensure_ascii=False).encode() input_base64_encoded = b64encode(inputs_json_str).decode("utf-8") return input_base64_encoded diff --git a/api/core/rag/datasource/keyword/jieba/jieba.py b/api/core/rag/datasource/keyword/jieba/jieba.py index 7c5f47006..c98306ea4 100644 --- a/api/core/rag/datasource/keyword/jieba/jieba.py +++ b/api/core/rag/datasource/keyword/jieba/jieba.py @@ -1,7 +1,7 @@ -import json from collections import defaultdict from typing import Any, Optional +import orjson from pydantic import BaseModel from configs import dify_config @@ -134,13 +134,13 @@ class Jieba(BaseKeyword): dataset_keyword_table = self.dataset.dataset_keyword_table keyword_data_source_type = dataset_keyword_table.data_source_type if keyword_data_source_type == "database": - dataset_keyword_table.keyword_table = json.dumps(keyword_table_dict, cls=SetEncoder) + dataset_keyword_table.keyword_table = dumps_with_sets(keyword_table_dict) db.session.commit() else: file_key = "keyword_files/" + self.dataset.tenant_id + "/" + self.dataset.id + ".txt" if storage.exists(file_key): storage.delete(file_key) - storage.save(file_key, json.dumps(keyword_table_dict, cls=SetEncoder).encode("utf-8")) + storage.save(file_key, dumps_with_sets(keyword_table_dict).encode("utf-8")) def _get_dataset_keyword_table(self) -> Optional[dict]: dataset_keyword_table = self.dataset.dataset_keyword_table @@ -156,12 +156,11 @@ class Jieba(BaseKeyword): data_source_type=keyword_data_source_type, ) if keyword_data_source_type == "database": - dataset_keyword_table.keyword_table = json.dumps( + dataset_keyword_table.keyword_table = dumps_with_sets( { "__type__": "keyword_table", "__data__": {"index_id": self.dataset.id, "summary": None, "table": {}}, - }, - cls=SetEncoder, + } ) db.session.add(dataset_keyword_table) db.session.commit() @@ -252,8 +251,13 @@ class Jieba(BaseKeyword): self._save_dataset_keyword_table(keyword_table) -class SetEncoder(json.JSONEncoder): - def default(self, obj): - if isinstance(obj, set): - return list(obj) - return super().default(obj) +def set_orjson_default(obj: Any) -> Any: + """Default function for orjson serialization of set types""" + if isinstance(obj, set): + return list(obj) + raise TypeError(f"Object of type {type(obj).__name__} is not JSON serializable") + + +def dumps_with_sets(obj: Any) -> str: + """JSON dumps with set support using orjson""" + return orjson.dumps(obj, default=set_orjson_default).decode("utf-8") diff --git a/api/core/variables/utils.py b/api/core/variables/utils.py index 692db3502..7ebd29f86 100644 --- a/api/core/variables/utils.py +++ b/api/core/variables/utils.py @@ -1,5 +1,7 @@ -import json from collections.abc import Iterable, Sequence +from typing import Any + +import orjson from .segment_group import SegmentGroup from .segments import ArrayFileSegment, FileSegment, Segment @@ -12,15 +14,20 @@ def to_selector(node_id: str, name: str, paths: Iterable[str] = ()) -> Sequence[ return selectors -class SegmentJSONEncoder(json.JSONEncoder): - def default(self, o): - if isinstance(o, ArrayFileSegment): - return [v.model_dump() for v in o.value] - elif isinstance(o, FileSegment): - return o.value.model_dump() - elif isinstance(o, SegmentGroup): - return [self.default(seg) for seg in o.value] - elif isinstance(o, Segment): - return o.value - else: - super().default(o) +def segment_orjson_default(o: Any) -> Any: + """Default function for orjson serialization of Segment types""" + if isinstance(o, ArrayFileSegment): + return [v.model_dump() for v in o.value] + elif isinstance(o, FileSegment): + return o.value.model_dump() + elif isinstance(o, SegmentGroup): + return [segment_orjson_default(seg) for seg in o.value] + elif isinstance(o, Segment): + return o.value + raise TypeError(f"Object of type {type(o).__name__} is not JSON serializable") + + +def dumps_with_segments(obj: Any, ensure_ascii: bool = False) -> str: + """JSON dumps with segment support using orjson""" + option = orjson.OPT_NON_STR_KEYS + return orjson.dumps(obj, default=segment_orjson_default, option=option).decode("utf-8") diff --git a/api/extensions/ext_orjson.py b/api/extensions/ext_orjson.py new file mode 100644 index 000000000..659784a58 --- /dev/null +++ b/api/extensions/ext_orjson.py @@ -0,0 +1,8 @@ +from flask_orjson import OrjsonProvider + +from dify_app import DifyApp + + +def init_app(app: DifyApp) -> None: + """Initialize Flask-Orjson extension for faster JSON serialization""" + app.json = OrjsonProvider(app) diff --git a/api/models/workflow.py b/api/models/workflow.py index 453a650f8..7ff463e08 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -1153,7 +1153,7 @@ class WorkflowDraftVariable(Base): value: The Segment object to store as the variable's value. """ self.__value = value - self.value = json.dumps(value, cls=variable_utils.SegmentJSONEncoder) + self.value = variable_utils.dumps_with_segments(value) self.value_type = value.value_type def get_node_id(self) -> str | None: diff --git a/api/pyproject.toml b/api/pyproject.toml index de472c870..61a725a83 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -18,6 +18,7 @@ dependencies = [ "flask-cors~=6.0.0", "flask-login~=0.6.3", "flask-migrate~=4.0.7", + "flask-orjson~=2.0.0", "flask-restful~=0.3.10", "flask-sqlalchemy~=3.1.1", "gevent~=24.11.1", diff --git a/api/uv.lock b/api/uv.lock index 870975418..cecce2bc4 100644 --- a/api/uv.lock +++ b/api/uv.lock @@ -1253,6 +1253,7 @@ dependencies = [ { name = "flask-cors" }, { name = "flask-login" }, { name = "flask-migrate" }, + { name = "flask-orjson" }, { name = "flask-restful" }, { name = "flask-sqlalchemy" }, { name = "gevent" }, @@ -1440,6 +1441,7 @@ requires-dist = [ { name = "flask-cors", specifier = "~=6.0.0" }, { name = "flask-login", specifier = "~=0.6.3" }, { name = "flask-migrate", specifier = "~=4.0.7" }, + { name = "flask-orjson", specifier = "~=2.0.0" }, { name = "flask-restful", specifier = "~=0.3.10" }, { name = "flask-sqlalchemy", specifier = "~=3.1.1" }, { name = "gevent", specifier = "~=24.11.1" }, @@ -1859,6 +1861,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/93/01/587023575286236f95d2ab8a826c320375ed5ea2102bb103ed89704ffa6b/Flask_Migrate-4.0.7-py3-none-any.whl", hash = "sha256:5c532be17e7b43a223b7500d620edae33795df27c75811ddf32560f7d48ec617", size = 21127, upload-time = "2024-03-11T18:42:59.462Z" }, ] +[[package]] +name = "flask-orjson" +version = "2.0.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "flask" }, + { name = "orjson" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/a3/49/575796f6ddca171d82dbb12762e33166c8b8f8616c946f0a6dfbb9bc3cd6/flask_orjson-2.0.0.tar.gz", hash = "sha256:6df6631437f9bc52cf9821735f896efa5583b5f80712f7d29d9ef69a79986a9c", size = 2974, upload-time = "2024-01-15T00:03:22.236Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f3/ca/53e14be018a2284acf799830e8cd8e0b263c0fd3dff1ad7b35f8417e7067/flask_orjson-2.0.0-py3-none-any.whl", hash = "sha256:5d15f2ba94b8d6c02aee88fc156045016e83db9eda2c30545fabd640aebaec9d", size = 3622, upload-time = "2024-01-15T00:03:17.511Z" }, +] + [[package]] name = "flask-restful" version = "0.3.10"