feat/enhance the multi-modal support (#8818)
This commit is contained in:
254
api/factories/file_factory.py
Normal file
254
api/factories/file_factory.py
Normal file
@@ -0,0 +1,254 @@
|
||||
import mimetypes
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from constants import AUDIO_EXTENSIONS, DOCUMENT_EXTENSIONS, IMAGE_EXTENSIONS, VIDEO_EXTENSIONS
|
||||
from core.file import File, FileBelongsTo, FileExtraConfig, FileTransferMethod, FileType
|
||||
from core.helper import ssrf_proxy
|
||||
from extensions.ext_database import db
|
||||
from models import MessageFile, ToolFile, UploadFile
|
||||
from models.enums import CreatedByRole
|
||||
|
||||
|
||||
def build_from_message_files(
|
||||
*,
|
||||
message_files: Sequence["MessageFile"],
|
||||
tenant_id: str,
|
||||
config: FileExtraConfig,
|
||||
) -> Sequence[File]:
|
||||
results = [
|
||||
build_from_message_file(message_file=file, tenant_id=tenant_id, config=config)
|
||||
for file in message_files
|
||||
if file.belongs_to != FileBelongsTo.ASSISTANT
|
||||
]
|
||||
return results
|
||||
|
||||
|
||||
def build_from_message_file(
|
||||
*,
|
||||
message_file: "MessageFile",
|
||||
tenant_id: str,
|
||||
config: FileExtraConfig,
|
||||
):
|
||||
mapping = {
|
||||
"transfer_method": message_file.transfer_method,
|
||||
"url": message_file.url,
|
||||
"id": message_file.id,
|
||||
"type": message_file.type,
|
||||
"upload_file_id": message_file.upload_file_id,
|
||||
}
|
||||
return build_from_mapping(
|
||||
mapping=mapping,
|
||||
tenant_id=tenant_id,
|
||||
user_id=message_file.created_by,
|
||||
role=CreatedByRole(message_file.created_by_role),
|
||||
config=config,
|
||||
)
|
||||
|
||||
|
||||
def build_from_mapping(
|
||||
*,
|
||||
mapping: Mapping[str, Any],
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
role: "CreatedByRole",
|
||||
config: FileExtraConfig,
|
||||
):
|
||||
transfer_method = FileTransferMethod.value_of(mapping.get("transfer_method"))
|
||||
match transfer_method:
|
||||
case FileTransferMethod.REMOTE_URL:
|
||||
file = _build_from_remote_url(
|
||||
mapping=mapping,
|
||||
tenant_id=tenant_id,
|
||||
config=config,
|
||||
transfer_method=transfer_method,
|
||||
)
|
||||
case FileTransferMethod.LOCAL_FILE:
|
||||
file = _build_from_local_file(
|
||||
mapping=mapping,
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
role=role,
|
||||
config=config,
|
||||
transfer_method=transfer_method,
|
||||
)
|
||||
case FileTransferMethod.TOOL_FILE:
|
||||
file = _build_from_tool_file(
|
||||
mapping=mapping,
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
config=config,
|
||||
transfer_method=transfer_method,
|
||||
)
|
||||
case _:
|
||||
raise ValueError(f"Invalid file transfer method: {transfer_method}")
|
||||
|
||||
return file
|
||||
|
||||
|
||||
def build_from_mappings(
|
||||
*,
|
||||
mappings: Sequence[Mapping[str, Any]],
|
||||
config: FileExtraConfig | None,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
role: "CreatedByRole",
|
||||
) -> Sequence[File]:
|
||||
if not config:
|
||||
return []
|
||||
|
||||
files = [
|
||||
build_from_mapping(
|
||||
mapping=mapping,
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
role=role,
|
||||
config=config,
|
||||
)
|
||||
for mapping in mappings
|
||||
]
|
||||
|
||||
if (
|
||||
# If image config is set.
|
||||
config.image_config
|
||||
# And the number of image files exceeds the maximum limit
|
||||
and sum(1 for _ in (filter(lambda x: x.type == FileType.IMAGE, files))) > config.image_config.number_limits
|
||||
):
|
||||
raise ValueError(f"Number of image files exceeds the maximum limit {config.image_config.number_limits}")
|
||||
if config.number_limits and len(files) > config.number_limits:
|
||||
raise ValueError(f"Number of files exceeds the maximum limit {config.number_limits}")
|
||||
|
||||
return files
|
||||
|
||||
|
||||
def _build_from_local_file(
|
||||
*,
|
||||
mapping: Mapping[str, Any],
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
role: "CreatedByRole",
|
||||
config: FileExtraConfig,
|
||||
transfer_method: FileTransferMethod,
|
||||
):
|
||||
# check if the upload file exists.
|
||||
file_type = FileType.value_of(mapping.get("type"))
|
||||
stmt = select(UploadFile).where(
|
||||
UploadFile.id == mapping.get("upload_file_id"),
|
||||
UploadFile.tenant_id == tenant_id,
|
||||
UploadFile.created_by == user_id,
|
||||
UploadFile.created_by_role == role,
|
||||
)
|
||||
if file_type == FileType.IMAGE:
|
||||
stmt = stmt.where(UploadFile.extension.in_(IMAGE_EXTENSIONS))
|
||||
elif file_type == FileType.VIDEO:
|
||||
stmt = stmt.where(UploadFile.extension.in_(VIDEO_EXTENSIONS))
|
||||
elif file_type == FileType.AUDIO:
|
||||
stmt = stmt.where(UploadFile.extension.in_(AUDIO_EXTENSIONS))
|
||||
elif file_type == FileType.DOCUMENT:
|
||||
stmt = stmt.where(UploadFile.extension.in_(DOCUMENT_EXTENSIONS))
|
||||
row = db.session.scalar(stmt)
|
||||
if row is None:
|
||||
raise ValueError("Invalid upload file")
|
||||
file = File(
|
||||
id=mapping.get("id"),
|
||||
filename=row.name,
|
||||
extension=row.extension,
|
||||
mime_type=row.mime_type,
|
||||
tenant_id=tenant_id,
|
||||
type=file_type,
|
||||
transfer_method=transfer_method,
|
||||
remote_url=None,
|
||||
related_id=mapping.get("upload_file_id"),
|
||||
_extra_config=config,
|
||||
size=row.size,
|
||||
)
|
||||
return file
|
||||
|
||||
|
||||
def _build_from_remote_url(
|
||||
*,
|
||||
mapping: Mapping[str, Any],
|
||||
tenant_id: str,
|
||||
config: FileExtraConfig,
|
||||
transfer_method: FileTransferMethod,
|
||||
):
|
||||
url = mapping.get("url")
|
||||
if not url:
|
||||
raise ValueError("Invalid file url")
|
||||
resp = ssrf_proxy.head(url)
|
||||
resp.raise_for_status()
|
||||
|
||||
# Try to extract filename from response headers or URL
|
||||
content_disposition = resp.headers.get("Content-Disposition")
|
||||
if content_disposition:
|
||||
filename = content_disposition.split("filename=")[-1].strip('"')
|
||||
else:
|
||||
filename = url.split("/")[-1].split("?")[0]
|
||||
# If filename is empty, set a default one
|
||||
if not filename:
|
||||
filename = "unknown_file"
|
||||
|
||||
# Determine file extension
|
||||
extension = "." + filename.split(".")[-1] if "." in filename else ".bin"
|
||||
|
||||
# Create the File object
|
||||
file_size = int(resp.headers.get("Content-Length", -1))
|
||||
mime_type = str(resp.headers.get("Content-Type", ""))
|
||||
if not mime_type:
|
||||
mime_type, _ = mimetypes.guess_type(url)
|
||||
file = File(
|
||||
id=mapping.get("id"),
|
||||
filename=filename,
|
||||
tenant_id=tenant_id,
|
||||
type=FileType.value_of(mapping.get("type")),
|
||||
transfer_method=transfer_method,
|
||||
remote_url=url,
|
||||
_extra_config=config,
|
||||
mime_type=mime_type,
|
||||
extension=extension,
|
||||
size=file_size,
|
||||
)
|
||||
return file
|
||||
|
||||
|
||||
def _build_from_tool_file(
|
||||
*,
|
||||
mapping: Mapping[str, Any],
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
config: FileExtraConfig,
|
||||
transfer_method: FileTransferMethod,
|
||||
):
|
||||
tool_file = (
|
||||
db.session.query(ToolFile)
|
||||
.filter(
|
||||
ToolFile.id == mapping.get("tool_file_id"),
|
||||
ToolFile.tenant_id == tenant_id,
|
||||
ToolFile.user_id == user_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if tool_file is None:
|
||||
raise ValueError(f"ToolFile {mapping.get('tool_file_id')} not found")
|
||||
|
||||
path = tool_file.file_key
|
||||
if "." in path:
|
||||
extension = "." + path.split("/")[-1].split(".")[-1]
|
||||
else:
|
||||
extension = ".bin"
|
||||
file = File(
|
||||
id=mapping.get("id"),
|
||||
tenant_id=tenant_id,
|
||||
filename=tool_file.name,
|
||||
type=FileType.value_of(mapping.get("type")),
|
||||
transfer_method=transfer_method,
|
||||
remote_url=tool_file.original_url,
|
||||
related_id=tool_file.id,
|
||||
extension=extension,
|
||||
mime_type=tool_file.mimetype,
|
||||
size=tool_file.size,
|
||||
_extra_config=config,
|
||||
)
|
||||
return file
|
95
api/factories/variable_factory.py
Normal file
95
api/factories/variable_factory.py
Normal file
@@ -0,0 +1,95 @@
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from configs import dify_config
|
||||
from core.file import File
|
||||
from core.variables import (
|
||||
ArrayAnySegment,
|
||||
ArrayFileSegment,
|
||||
ArrayNumberSegment,
|
||||
ArrayNumberVariable,
|
||||
ArrayObjectSegment,
|
||||
ArrayObjectVariable,
|
||||
ArrayStringSegment,
|
||||
ArrayStringVariable,
|
||||
FileSegment,
|
||||
FloatSegment,
|
||||
FloatVariable,
|
||||
IntegerSegment,
|
||||
IntegerVariable,
|
||||
NoneSegment,
|
||||
ObjectSegment,
|
||||
ObjectVariable,
|
||||
SecretVariable,
|
||||
Segment,
|
||||
SegmentType,
|
||||
StringSegment,
|
||||
StringVariable,
|
||||
Variable,
|
||||
)
|
||||
from core.variables.exc import VariableError
|
||||
|
||||
|
||||
def build_variable_from_mapping(mapping: Mapping[str, Any], /) -> Variable:
|
||||
if (value_type := mapping.get("value_type")) is None:
|
||||
raise VariableError("missing value type")
|
||||
if not mapping.get("name"):
|
||||
raise VariableError("missing name")
|
||||
if (value := mapping.get("value")) is None:
|
||||
raise VariableError("missing value")
|
||||
match value_type:
|
||||
case SegmentType.STRING:
|
||||
result = StringVariable.model_validate(mapping)
|
||||
case SegmentType.SECRET:
|
||||
result = SecretVariable.model_validate(mapping)
|
||||
case SegmentType.NUMBER if isinstance(value, int):
|
||||
result = IntegerVariable.model_validate(mapping)
|
||||
case SegmentType.NUMBER if isinstance(value, float):
|
||||
result = FloatVariable.model_validate(mapping)
|
||||
case SegmentType.NUMBER if not isinstance(value, float | int):
|
||||
raise VariableError(f"invalid number value {value}")
|
||||
case SegmentType.OBJECT if isinstance(value, dict):
|
||||
result = ObjectVariable.model_validate(mapping)
|
||||
case SegmentType.ARRAY_STRING if isinstance(value, list):
|
||||
result = ArrayStringVariable.model_validate(mapping)
|
||||
case SegmentType.ARRAY_NUMBER if isinstance(value, list):
|
||||
result = ArrayNumberVariable.model_validate(mapping)
|
||||
case SegmentType.ARRAY_OBJECT if isinstance(value, list):
|
||||
result = ArrayObjectVariable.model_validate(mapping)
|
||||
case _:
|
||||
raise VariableError(f"not supported value type {value_type}")
|
||||
if result.size > dify_config.MAX_VARIABLE_SIZE:
|
||||
raise VariableError(f"variable size {result.size} exceeds limit {dify_config.MAX_VARIABLE_SIZE}")
|
||||
return result
|
||||
|
||||
|
||||
def build_segment(value: Any, /) -> Segment:
|
||||
if value is None:
|
||||
return NoneSegment()
|
||||
if isinstance(value, str):
|
||||
return StringSegment(value=value)
|
||||
if isinstance(value, int):
|
||||
return IntegerSegment(value=value)
|
||||
if isinstance(value, float):
|
||||
return FloatSegment(value=value)
|
||||
if isinstance(value, dict):
|
||||
return ObjectSegment(value=value)
|
||||
if isinstance(value, File):
|
||||
return FileSegment(value=value)
|
||||
if isinstance(value, list):
|
||||
items = [build_segment(item) for item in value]
|
||||
types = {item.value_type for item in items}
|
||||
if len(types) != 1:
|
||||
return ArrayAnySegment(value=value)
|
||||
match types.pop():
|
||||
case SegmentType.STRING:
|
||||
return ArrayStringSegment(value=value)
|
||||
case SegmentType.NUMBER:
|
||||
return ArrayNumberSegment(value=value)
|
||||
case SegmentType.OBJECT:
|
||||
return ArrayObjectSegment(value=value)
|
||||
case SegmentType.FILE:
|
||||
return ArrayFileSegment(value=value)
|
||||
case _:
|
||||
raise ValueError(f"not supported value {value}")
|
||||
raise ValueError(f"not supported value {value}")
|
Reference in New Issue
Block a user