feat: integrate flask-orjson for improved JSON serialization performance (#23935)

This commit is contained in:
-LAN-
2025-08-14 19:50:59 +08:00
committed by GitHub
parent 4a2e6af9b5
commit e340fccafb
8 changed files with 64 additions and 27 deletions

View File

@@ -51,6 +51,7 @@ def initialize_extensions(app: DifyApp):
ext_login, ext_login,
ext_mail, ext_mail,
ext_migrate, ext_migrate,
ext_orjson,
ext_otel, ext_otel,
ext_proxy_fix, ext_proxy_fix,
ext_redis, ext_redis,
@@ -67,6 +68,7 @@ def initialize_extensions(app: DifyApp):
ext_logging, ext_logging,
ext_warnings, ext_warnings,
ext_import_modules, ext_import_modules,
ext_orjson,
ext_set_secretkey, ext_set_secretkey,
ext_compress, ext_compress,
ext_code_based_extension, ext_code_based_extension,

View File

@@ -5,7 +5,7 @@ from base64 import b64encode
from collections.abc import Mapping from collections.abc import Mapping
from typing import Any from typing import Any
from core.variables.utils import SegmentJSONEncoder from core.variables.utils import dumps_with_segments
class TemplateTransformer(ABC): class TemplateTransformer(ABC):
@@ -93,7 +93,7 @@ class TemplateTransformer(ABC):
@classmethod @classmethod
def serialize_inputs(cls, inputs: Mapping[str, Any]) -> str: def serialize_inputs(cls, inputs: Mapping[str, Any]) -> str:
inputs_json_str = json.dumps(inputs, ensure_ascii=False, cls=SegmentJSONEncoder).encode() inputs_json_str = dumps_with_segments(inputs, ensure_ascii=False).encode()
input_base64_encoded = b64encode(inputs_json_str).decode("utf-8") input_base64_encoded = b64encode(inputs_json_str).decode("utf-8")
return input_base64_encoded return input_base64_encoded

View File

@@ -1,7 +1,7 @@
import json
from collections import defaultdict from collections import defaultdict
from typing import Any, Optional from typing import Any, Optional
import orjson
from pydantic import BaseModel from pydantic import BaseModel
from configs import dify_config from configs import dify_config
@@ -134,13 +134,13 @@ class Jieba(BaseKeyword):
dataset_keyword_table = self.dataset.dataset_keyword_table dataset_keyword_table = self.dataset.dataset_keyword_table
keyword_data_source_type = dataset_keyword_table.data_source_type keyword_data_source_type = dataset_keyword_table.data_source_type
if keyword_data_source_type == "database": if keyword_data_source_type == "database":
dataset_keyword_table.keyword_table = json.dumps(keyword_table_dict, cls=SetEncoder) dataset_keyword_table.keyword_table = dumps_with_sets(keyword_table_dict)
db.session.commit() db.session.commit()
else: else:
file_key = "keyword_files/" + self.dataset.tenant_id + "/" + self.dataset.id + ".txt" file_key = "keyword_files/" + self.dataset.tenant_id + "/" + self.dataset.id + ".txt"
if storage.exists(file_key): if storage.exists(file_key):
storage.delete(file_key) storage.delete(file_key)
storage.save(file_key, json.dumps(keyword_table_dict, cls=SetEncoder).encode("utf-8")) storage.save(file_key, dumps_with_sets(keyword_table_dict).encode("utf-8"))
def _get_dataset_keyword_table(self) -> Optional[dict]: def _get_dataset_keyword_table(self) -> Optional[dict]:
dataset_keyword_table = self.dataset.dataset_keyword_table dataset_keyword_table = self.dataset.dataset_keyword_table
@@ -156,12 +156,11 @@ class Jieba(BaseKeyword):
data_source_type=keyword_data_source_type, data_source_type=keyword_data_source_type,
) )
if keyword_data_source_type == "database": if keyword_data_source_type == "database":
dataset_keyword_table.keyword_table = json.dumps( dataset_keyword_table.keyword_table = dumps_with_sets(
{ {
"__type__": "keyword_table", "__type__": "keyword_table",
"__data__": {"index_id": self.dataset.id, "summary": None, "table": {}}, "__data__": {"index_id": self.dataset.id, "summary": None, "table": {}},
}, }
cls=SetEncoder,
) )
db.session.add(dataset_keyword_table) db.session.add(dataset_keyword_table)
db.session.commit() db.session.commit()
@@ -252,8 +251,13 @@ class Jieba(BaseKeyword):
self._save_dataset_keyword_table(keyword_table) self._save_dataset_keyword_table(keyword_table)
class SetEncoder(json.JSONEncoder): def set_orjson_default(obj: Any) -> Any:
def default(self, obj): """Default function for orjson serialization of set types"""
if isinstance(obj, set): if isinstance(obj, set):
return list(obj) return list(obj)
return super().default(obj) raise TypeError(f"Object of type {type(obj).__name__} is not JSON serializable")
def dumps_with_sets(obj: Any) -> str:
"""JSON dumps with set support using orjson"""
return orjson.dumps(obj, default=set_orjson_default).decode("utf-8")

View File

@@ -1,5 +1,7 @@
import json
from collections.abc import Iterable, Sequence from collections.abc import Iterable, Sequence
from typing import Any
import orjson
from .segment_group import SegmentGroup from .segment_group import SegmentGroup
from .segments import ArrayFileSegment, FileSegment, Segment from .segments import ArrayFileSegment, FileSegment, Segment
@@ -12,15 +14,20 @@ def to_selector(node_id: str, name: str, paths: Iterable[str] = ()) -> Sequence[
return selectors return selectors
class SegmentJSONEncoder(json.JSONEncoder): def segment_orjson_default(o: Any) -> Any:
def default(self, o): """Default function for orjson serialization of Segment types"""
if isinstance(o, ArrayFileSegment): if isinstance(o, ArrayFileSegment):
return [v.model_dump() for v in o.value] return [v.model_dump() for v in o.value]
elif isinstance(o, FileSegment): elif isinstance(o, FileSegment):
return o.value.model_dump() return o.value.model_dump()
elif isinstance(o, SegmentGroup): elif isinstance(o, SegmentGroup):
return [self.default(seg) for seg in o.value] return [segment_orjson_default(seg) for seg in o.value]
elif isinstance(o, Segment): elif isinstance(o, Segment):
return o.value return o.value
else: raise TypeError(f"Object of type {type(o).__name__} is not JSON serializable")
super().default(o)
def dumps_with_segments(obj: Any, ensure_ascii: bool = False) -> str:
"""JSON dumps with segment support using orjson"""
option = orjson.OPT_NON_STR_KEYS
return orjson.dumps(obj, default=segment_orjson_default, option=option).decode("utf-8")

View File

@@ -0,0 +1,8 @@
from flask_orjson import OrjsonProvider
from dify_app import DifyApp
def init_app(app: DifyApp) -> None:
"""Initialize Flask-Orjson extension for faster JSON serialization"""
app.json = OrjsonProvider(app)

View File

@@ -1153,7 +1153,7 @@ class WorkflowDraftVariable(Base):
value: The Segment object to store as the variable's value. value: The Segment object to store as the variable's value.
""" """
self.__value = value self.__value = value
self.value = json.dumps(value, cls=variable_utils.SegmentJSONEncoder) self.value = variable_utils.dumps_with_segments(value)
self.value_type = value.value_type self.value_type = value.value_type
def get_node_id(self) -> str | None: def get_node_id(self) -> str | None:

View File

@@ -18,6 +18,7 @@ dependencies = [
"flask-cors~=6.0.0", "flask-cors~=6.0.0",
"flask-login~=0.6.3", "flask-login~=0.6.3",
"flask-migrate~=4.0.7", "flask-migrate~=4.0.7",
"flask-orjson~=2.0.0",
"flask-restful~=0.3.10", "flask-restful~=0.3.10",
"flask-sqlalchemy~=3.1.1", "flask-sqlalchemy~=3.1.1",
"gevent~=24.11.1", "gevent~=24.11.1",

15
api/uv.lock generated
View File

@@ -1253,6 +1253,7 @@ dependencies = [
{ name = "flask-cors" }, { name = "flask-cors" },
{ name = "flask-login" }, { name = "flask-login" },
{ name = "flask-migrate" }, { name = "flask-migrate" },
{ name = "flask-orjson" },
{ name = "flask-restful" }, { name = "flask-restful" },
{ name = "flask-sqlalchemy" }, { name = "flask-sqlalchemy" },
{ name = "gevent" }, { name = "gevent" },
@@ -1440,6 +1441,7 @@ requires-dist = [
{ name = "flask-cors", specifier = "~=6.0.0" }, { name = "flask-cors", specifier = "~=6.0.0" },
{ name = "flask-login", specifier = "~=0.6.3" }, { name = "flask-login", specifier = "~=0.6.3" },
{ name = "flask-migrate", specifier = "~=4.0.7" }, { name = "flask-migrate", specifier = "~=4.0.7" },
{ name = "flask-orjson", specifier = "~=2.0.0" },
{ name = "flask-restful", specifier = "~=0.3.10" }, { name = "flask-restful", specifier = "~=0.3.10" },
{ name = "flask-sqlalchemy", specifier = "~=3.1.1" }, { name = "flask-sqlalchemy", specifier = "~=3.1.1" },
{ name = "gevent", specifier = "~=24.11.1" }, { name = "gevent", specifier = "~=24.11.1" },
@@ -1859,6 +1861,19 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/93/01/587023575286236f95d2ab8a826c320375ed5ea2102bb103ed89704ffa6b/Flask_Migrate-4.0.7-py3-none-any.whl", hash = "sha256:5c532be17e7b43a223b7500d620edae33795df27c75811ddf32560f7d48ec617", size = 21127, upload-time = "2024-03-11T18:42:59.462Z" }, { url = "https://files.pythonhosted.org/packages/93/01/587023575286236f95d2ab8a826c320375ed5ea2102bb103ed89704ffa6b/Flask_Migrate-4.0.7-py3-none-any.whl", hash = "sha256:5c532be17e7b43a223b7500d620edae33795df27c75811ddf32560f7d48ec617", size = 21127, upload-time = "2024-03-11T18:42:59.462Z" },
] ]
[[package]]
name = "flask-orjson"
version = "2.0.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "flask" },
{ name = "orjson" },
]
sdist = { url = "https://files.pythonhosted.org/packages/a3/49/575796f6ddca171d82dbb12762e33166c8b8f8616c946f0a6dfbb9bc3cd6/flask_orjson-2.0.0.tar.gz", hash = "sha256:6df6631437f9bc52cf9821735f896efa5583b5f80712f7d29d9ef69a79986a9c", size = 2974, upload-time = "2024-01-15T00:03:22.236Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/f3/ca/53e14be018a2284acf799830e8cd8e0b263c0fd3dff1ad7b35f8417e7067/flask_orjson-2.0.0-py3-none-any.whl", hash = "sha256:5d15f2ba94b8d6c02aee88fc156045016e83db9eda2c30545fabd640aebaec9d", size = 3622, upload-time = "2024-01-15T00:03:17.511Z" },
]
[[package]] [[package]]
name = "flask-restful" name = "flask-restful"
version = "0.3.10" version = "0.3.10"