feat/enhance the multi-modal support (#8818)
This commit is contained in:
@@ -1,24 +1,37 @@
|
||||
import json
|
||||
import re
|
||||
import uuid
|
||||
from collections.abc import Mapping, Sequence
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
from typing import Any, Literal, Optional
|
||||
|
||||
from flask import request
|
||||
from flask_login import UserMixin
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import Float, func, text
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from configs import dify_config
|
||||
from core.file import FILE_MODEL_IDENTITY, File, FileExtraConfig, FileTransferMethod, FileType
|
||||
from core.file import helpers as file_helpers
|
||||
from core.file.tool_file_parser import ToolFileParser
|
||||
from core.file.upload_file_parser import UploadFileParser
|
||||
from extensions.ext_database import db
|
||||
from libs.helper import generate_string
|
||||
from models.enums import CreatedByRole
|
||||
|
||||
from .account import Account, Tenant
|
||||
from .types import StringUUID
|
||||
|
||||
|
||||
class FileUploadConfig(BaseModel):
|
||||
enabled: bool = Field(default=False)
|
||||
allowed_file_types: Sequence[FileType] = Field(default_factory=list)
|
||||
allowed_extensions: Sequence[str] = Field(default_factory=list)
|
||||
allowed_upload_methods: Sequence[FileTransferMethod] = Field(default_factory=list)
|
||||
number_limits: int = Field(default=0, gt=0, le=10)
|
||||
|
||||
|
||||
class DifySetup(db.Model):
|
||||
__tablename__ = "dify_setups"
|
||||
__table_args__ = (db.PrimaryKeyConstraint("version", name="dify_setup_pkey"),)
|
||||
@@ -27,7 +40,7 @@ class DifySetup(db.Model):
|
||||
setup_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
||||
|
||||
|
||||
class AppMode(Enum):
|
||||
class AppMode(str, Enum):
|
||||
COMPLETION = "completion"
|
||||
WORKFLOW = "workflow"
|
||||
CHAT = "chat"
|
||||
@@ -59,7 +72,7 @@ class App(db.Model):
|
||||
__table_args__ = (db.PrimaryKeyConstraint("id", name="app_pkey"), db.Index("app_tenant_id_idx", "tenant_id"))
|
||||
|
||||
id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
|
||||
tenant_id = db.Column(StringUUID, nullable=False)
|
||||
tenant_id: Mapped[str] = db.Column(StringUUID, nullable=False)
|
||||
name = db.Column(db.String(255), nullable=False)
|
||||
description = db.Column(db.Text, nullable=False, server_default=db.text("''::character varying"))
|
||||
mode = db.Column(db.String(255), nullable=False)
|
||||
@@ -530,7 +543,7 @@ class Conversation(db.Model):
|
||||
mode = db.Column(db.String(255), nullable=False)
|
||||
name = db.Column(db.String(255), nullable=False)
|
||||
summary = db.Column(db.Text)
|
||||
inputs = db.Column(db.JSON)
|
||||
_inputs: Mapped[dict] = mapped_column("inputs", db.JSON)
|
||||
introduction = db.Column(db.Text)
|
||||
system_instruction = db.Column(db.Text)
|
||||
system_instruction_tokens = db.Column(db.Integer, nullable=False, server_default=db.text("0"))
|
||||
@@ -552,6 +565,28 @@ class Conversation(db.Model):
|
||||
|
||||
is_deleted = db.Column(db.Boolean, nullable=False, server_default=db.text("false"))
|
||||
|
||||
@property
|
||||
def inputs(self):
|
||||
inputs = self._inputs.copy()
|
||||
for key, value in inputs.items():
|
||||
if isinstance(value, dict) and value.get("dify_model_identity") == FILE_MODEL_IDENTITY:
|
||||
inputs[key] = File.model_validate(value)
|
||||
elif isinstance(value, list) and all(
|
||||
isinstance(item, dict) and item.get("dify_model_identity") == FILE_MODEL_IDENTITY for item in value
|
||||
):
|
||||
inputs[key] = [File.model_validate(item) for item in value]
|
||||
return inputs
|
||||
|
||||
@inputs.setter
|
||||
def inputs(self, value: Mapping[str, Any]):
|
||||
inputs = dict(value)
|
||||
for k, v in inputs.items():
|
||||
if isinstance(v, File):
|
||||
inputs[k] = v.model_dump()
|
||||
elif isinstance(v, list) and all(isinstance(item, File) for item in v):
|
||||
inputs[k] = [item.model_dump() for item in v]
|
||||
self._inputs = inputs
|
||||
|
||||
@property
|
||||
def model_config(self):
|
||||
model_config = {}
|
||||
@@ -700,13 +735,13 @@ class Message(db.Model):
|
||||
model_id = db.Column(db.String(255), nullable=True)
|
||||
override_model_configs = db.Column(db.Text)
|
||||
conversation_id = db.Column(StringUUID, db.ForeignKey("conversations.id"), nullable=False)
|
||||
inputs = db.Column(db.JSON)
|
||||
query = db.Column(db.Text, nullable=False)
|
||||
_inputs: Mapped[dict] = mapped_column("inputs", db.JSON)
|
||||
query: Mapped[str] = db.Column(db.Text, nullable=False)
|
||||
message = db.Column(db.JSON, nullable=False)
|
||||
message_tokens = db.Column(db.Integer, nullable=False, server_default=db.text("0"))
|
||||
message_unit_price = db.Column(db.Numeric(10, 4), nullable=False)
|
||||
message_price_unit = db.Column(db.Numeric(10, 7), nullable=False, server_default=db.text("0.001"))
|
||||
answer = db.Column(db.Text, nullable=False)
|
||||
answer: Mapped[str] = db.Column(db.Text, nullable=False)
|
||||
answer_tokens = db.Column(db.Integer, nullable=False, server_default=db.text("0"))
|
||||
answer_unit_price = db.Column(db.Numeric(10, 4), nullable=False)
|
||||
answer_price_unit = db.Column(db.Numeric(10, 7), nullable=False, server_default=db.text("0.001"))
|
||||
@@ -717,15 +752,37 @@ class Message(db.Model):
|
||||
status = db.Column(db.String(255), nullable=False, server_default=db.text("'normal'::character varying"))
|
||||
error = db.Column(db.Text)
|
||||
message_metadata = db.Column(db.Text)
|
||||
invoke_from = db.Column(db.String(255), nullable=True)
|
||||
invoke_from: Mapped[Optional[str]] = db.Column(db.String(255), nullable=True)
|
||||
from_source = db.Column(db.String(255), nullable=False)
|
||||
from_end_user_id = db.Column(StringUUID)
|
||||
from_account_id = db.Column(StringUUID)
|
||||
from_end_user_id: Mapped[Optional[str]] = db.Column(StringUUID)
|
||||
from_account_id: Mapped[Optional[str]] = db.Column(StringUUID)
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
||||
agent_based = db.Column(db.Boolean, nullable=False, server_default=db.text("false"))
|
||||
workflow_run_id = db.Column(StringUUID)
|
||||
|
||||
@property
|
||||
def inputs(self):
|
||||
inputs = self._inputs.copy()
|
||||
for key, value in inputs.items():
|
||||
if isinstance(value, dict) and value.get("dify_model_identity") == FILE_MODEL_IDENTITY:
|
||||
inputs[key] = File.model_validate(value)
|
||||
elif isinstance(value, list) and all(
|
||||
isinstance(item, dict) and item.get("dify_model_identity") == FILE_MODEL_IDENTITY for item in value
|
||||
):
|
||||
inputs[key] = [File.model_validate(item) for item in value]
|
||||
return inputs
|
||||
|
||||
@inputs.setter
|
||||
def inputs(self, value: Mapping[str, Any]):
|
||||
inputs = dict(value)
|
||||
for k, v in inputs.items():
|
||||
if isinstance(v, File):
|
||||
inputs[k] = v.model_dump()
|
||||
elif isinstance(v, list) and all(isinstance(item, File) for item in v):
|
||||
inputs[k] = [item.model_dump() for item in v]
|
||||
self._inputs = inputs
|
||||
|
||||
@property
|
||||
def re_sign_file_url_answer(self) -> str:
|
||||
if not self.answer:
|
||||
@@ -772,19 +829,29 @@ class Message(db.Model):
|
||||
sign_url = ToolFileParser.get_tool_file_manager().sign_file(
|
||||
tool_file_id=tool_file_id, extension=extension
|
||||
)
|
||||
else:
|
||||
elif "file-preview" in url:
|
||||
# get upload file id
|
||||
upload_file_id_pattern = r"\/files\/([\w-]+)\/image-preview?\?timestamp="
|
||||
upload_file_id_pattern = r"\/files\/([\w-]+)\/file-preview?\?timestamp="
|
||||
result = re.search(upload_file_id_pattern, url)
|
||||
if not result:
|
||||
continue
|
||||
|
||||
upload_file_id = result.group(1)
|
||||
|
||||
if not upload_file_id:
|
||||
continue
|
||||
|
||||
sign_url = UploadFileParser.get_signed_temp_image_url(upload_file_id)
|
||||
sign_url = file_helpers.get_signed_file_url(upload_file_id)
|
||||
elif "image-preview" in url:
|
||||
# image-preview is deprecated, use file-preview instead
|
||||
upload_file_id_pattern = r"\/files\/([\w-]+)\/image-preview?\?timestamp="
|
||||
result = re.search(upload_file_id_pattern, url)
|
||||
if not result:
|
||||
continue
|
||||
upload_file_id = result.group(1)
|
||||
if not upload_file_id:
|
||||
continue
|
||||
sign_url = file_helpers.get_signed_file_url(upload_file_id)
|
||||
else:
|
||||
continue
|
||||
|
||||
re_sign_file_url_answer = re_sign_file_url_answer.replace(url, sign_url)
|
||||
|
||||
@@ -870,50 +937,71 @@ class Message(db.Model):
|
||||
|
||||
@property
|
||||
def message_files(self):
|
||||
return db.session.query(MessageFile).filter(MessageFile.message_id == self.id).all()
|
||||
from factories import file_factory
|
||||
|
||||
@property
|
||||
def files(self):
|
||||
message_files = self.message_files
|
||||
message_files = db.session.query(MessageFile).filter(MessageFile.message_id == self.id).all()
|
||||
current_app = db.session.query(App).filter(App.id == self.app_id).first()
|
||||
if not current_app:
|
||||
raise ValueError(f"App {self.app_id} not found")
|
||||
|
||||
files = []
|
||||
files: list[File] = []
|
||||
for message_file in message_files:
|
||||
url = message_file.url
|
||||
if message_file.type == "image":
|
||||
if message_file.transfer_method == "local_file":
|
||||
upload_file = (
|
||||
db.session.query(UploadFile).filter(UploadFile.id == message_file.upload_file_id).first()
|
||||
)
|
||||
|
||||
url = UploadFileParser.get_image_data(upload_file=upload_file, force_url=True)
|
||||
if message_file.transfer_method == "tool_file":
|
||||
# get tool file id
|
||||
tool_file_id = message_file.url.split("/")[-1]
|
||||
# trim extension
|
||||
tool_file_id = tool_file_id.split(".")[0]
|
||||
|
||||
# get extension
|
||||
if "." in message_file.url:
|
||||
extension = f'.{message_file.url.split(".")[-1]}'
|
||||
if len(extension) > 10:
|
||||
extension = ".bin"
|
||||
else:
|
||||
extension = ".bin"
|
||||
# add sign url
|
||||
url = ToolFileParser.get_tool_file_manager().sign_file(
|
||||
tool_file_id=tool_file_id, extension=extension
|
||||
)
|
||||
|
||||
files.append(
|
||||
{
|
||||
if message_file.transfer_method == "local_file":
|
||||
if message_file.upload_file_id is None:
|
||||
raise ValueError(f"MessageFile {message_file.id} is a local file but has no upload_file_id")
|
||||
file = file_factory.build_from_mapping(
|
||||
mapping={
|
||||
"id": message_file.id,
|
||||
"upload_file_id": message_file.upload_file_id,
|
||||
"transfer_method": message_file.transfer_method,
|
||||
"type": message_file.type,
|
||||
},
|
||||
tenant_id=current_app.tenant_id,
|
||||
user_id=self.from_account_id or self.from_end_user_id or "",
|
||||
role=CreatedByRole(message_file.created_by_role),
|
||||
config=FileExtraConfig(),
|
||||
)
|
||||
elif message_file.transfer_method == "remote_url":
|
||||
if message_file.url is None:
|
||||
raise ValueError(f"MessageFile {message_file.id} is a remote url but has no url")
|
||||
file = file_factory.build_from_mapping(
|
||||
mapping={
|
||||
"id": message_file.id,
|
||||
"type": message_file.type,
|
||||
"transfer_method": message_file.transfer_method,
|
||||
"url": message_file.url,
|
||||
},
|
||||
tenant_id=current_app.tenant_id,
|
||||
user_id=self.from_account_id or self.from_end_user_id or "",
|
||||
role=CreatedByRole(message_file.created_by_role),
|
||||
config=FileExtraConfig(),
|
||||
)
|
||||
elif message_file.transfer_method == "tool_file":
|
||||
mapping = {
|
||||
"id": message_file.id,
|
||||
"type": message_file.type,
|
||||
"url": url,
|
||||
"belongs_to": message_file.belongs_to or "user",
|
||||
"transfer_method": message_file.transfer_method,
|
||||
"tool_file_id": message_file.upload_file_id,
|
||||
}
|
||||
)
|
||||
file = file_factory.build_from_mapping(
|
||||
mapping=mapping,
|
||||
tenant_id=current_app.tenant_id,
|
||||
user_id=self.from_account_id or self.from_end_user_id or "",
|
||||
role=CreatedByRole(message_file.created_by_role),
|
||||
config=FileExtraConfig(),
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"MessageFile {message_file.id} has an invalid transfer_method {message_file.transfer_method}"
|
||||
)
|
||||
files.append(file)
|
||||
|
||||
return files
|
||||
result = [
|
||||
{"belongs_to": message_file.belongs_to, **file.to_dict()}
|
||||
for (file, message_file) in zip(files, message_files)
|
||||
]
|
||||
|
||||
return result
|
||||
|
||||
@property
|
||||
def workflow_run(self):
|
||||
@@ -1003,16 +1091,39 @@ class MessageFile(db.Model):
|
||||
db.Index("message_file_created_by_idx", "created_by"),
|
||||
)
|
||||
|
||||
id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
|
||||
message_id = db.Column(StringUUID, nullable=False)
|
||||
type = db.Column(db.String(255), nullable=False)
|
||||
transfer_method = db.Column(db.String(255), nullable=False)
|
||||
url = db.Column(db.Text, nullable=True)
|
||||
belongs_to = db.Column(db.String(255), nullable=True)
|
||||
upload_file_id = db.Column(StringUUID, nullable=True)
|
||||
created_by_role = db.Column(db.String(255), nullable=False)
|
||||
created_by = db.Column(StringUUID, nullable=False)
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
message_id: str,
|
||||
type: FileType,
|
||||
transfer_method: FileTransferMethod,
|
||||
url: str | None = None,
|
||||
belongs_to: Literal["user", "assistant"] | None = None,
|
||||
upload_file_id: str | None = None,
|
||||
created_by_role: CreatedByRole,
|
||||
created_by: str,
|
||||
):
|
||||
self.message_id = message_id
|
||||
self.type = type
|
||||
self.transfer_method = transfer_method
|
||||
self.url = url
|
||||
self.belongs_to = belongs_to
|
||||
self.upload_file_id = upload_file_id
|
||||
self.created_by_role = created_by_role
|
||||
self.created_by = created_by
|
||||
|
||||
id: Mapped[str] = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
|
||||
message_id: Mapped[str] = db.Column(StringUUID, nullable=False)
|
||||
type: Mapped[str] = db.Column(db.String(255), nullable=False)
|
||||
transfer_method: Mapped[str] = db.Column(db.String(255), nullable=False)
|
||||
url: Mapped[Optional[str]] = db.Column(db.Text, nullable=True)
|
||||
belongs_to: Mapped[Optional[str]] = db.Column(db.String(255), nullable=True)
|
||||
upload_file_id: Mapped[Optional[str]] = db.Column(StringUUID, nullable=True)
|
||||
created_by_role: Mapped[str] = db.Column(db.String(255), nullable=False)
|
||||
created_by: Mapped[str] = db.Column(StringUUID, nullable=False)
|
||||
created_at: Mapped[datetime] = db.Column(
|
||||
db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")
|
||||
)
|
||||
|
||||
|
||||
class MessageAnnotation(db.Model):
|
||||
@@ -1250,21 +1361,58 @@ class UploadFile(db.Model):
|
||||
db.Index("upload_file_tenant_idx", "tenant_id"),
|
||||
)
|
||||
|
||||
id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
|
||||
tenant_id = db.Column(StringUUID, nullable=False)
|
||||
storage_type = db.Column(db.String(255), nullable=False)
|
||||
key = db.Column(db.String(255), nullable=False)
|
||||
name = db.Column(db.String(255), nullable=False)
|
||||
size = db.Column(db.Integer, nullable=False)
|
||||
extension = db.Column(db.String(255), nullable=False)
|
||||
mime_type = db.Column(db.String(255), nullable=True)
|
||||
created_by_role = db.Column(db.String(255), nullable=False, server_default=db.text("'account'::character varying"))
|
||||
created_by = db.Column(StringUUID, nullable=False)
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
||||
used = db.Column(db.Boolean, nullable=False, server_default=db.text("false"))
|
||||
used_by = db.Column(StringUUID, nullable=True)
|
||||
used_at = db.Column(db.DateTime, nullable=True)
|
||||
hash = db.Column(db.String(255), nullable=True)
|
||||
id: Mapped[str] = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
|
||||
tenant_id: Mapped[str] = db.Column(StringUUID, nullable=False)
|
||||
storage_type: Mapped[str] = db.Column(db.String(255), nullable=False)
|
||||
key: Mapped[str] = db.Column(db.String(255), nullable=False)
|
||||
name: Mapped[str] = db.Column(db.String(255), nullable=False)
|
||||
size: Mapped[int] = db.Column(db.Integer, nullable=False)
|
||||
extension: Mapped[str] = db.Column(db.String(255), nullable=False)
|
||||
mime_type: Mapped[str] = db.Column(db.String(255), nullable=True)
|
||||
created_by_role: Mapped[str] = db.Column(
|
||||
db.String(255), nullable=False, server_default=db.text("'account'::character varying")
|
||||
)
|
||||
created_by: Mapped[str] = db.Column(StringUUID, nullable=False)
|
||||
created_at: Mapped[datetime] = db.Column(
|
||||
db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")
|
||||
)
|
||||
used: Mapped[bool] = db.Column(db.Boolean, nullable=False, server_default=db.text("false"))
|
||||
used_by: Mapped[str | None] = db.Column(StringUUID, nullable=True)
|
||||
used_at: Mapped[datetime | None] = db.Column(db.DateTime, nullable=True)
|
||||
hash: Mapped[str | None] = db.Column(db.String(255), nullable=True)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
tenant_id: str,
|
||||
storage_type: str,
|
||||
key: str,
|
||||
name: str,
|
||||
size: int,
|
||||
extension: str,
|
||||
mime_type: str,
|
||||
created_by_role: str,
|
||||
created_by: str,
|
||||
created_at: datetime,
|
||||
used: bool,
|
||||
used_by: str | None = None,
|
||||
used_at: datetime | None = None,
|
||||
hash: str | None = None,
|
||||
) -> None:
|
||||
self.tenant_id = tenant_id
|
||||
self.storage_type = storage_type
|
||||
self.key = key
|
||||
self.name = name
|
||||
self.size = size
|
||||
self.extension = extension
|
||||
self.mime_type = mime_type
|
||||
self.created_by_role = created_by_role
|
||||
self.created_by = created_by
|
||||
self.created_at = created_at
|
||||
self.used = used
|
||||
self.used_by = used_by
|
||||
self.used_at = used_at
|
||||
self.hash = hash
|
||||
|
||||
|
||||
class ApiRequest(db.Model):
|
||||
|
Reference in New Issue
Block a user