feat(llm_node): support order in text and files (#11837)

Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
-LAN-
2024-12-20 14:12:50 +08:00
committed by GitHub
parent 3599751f93
commit 996a9135f6
18 changed files with 217 additions and 175 deletions

View File

@@ -1,15 +1,14 @@
import base64
from configs import dify_config
from core.file import file_repository
from core.helper import ssrf_proxy
from core.model_runtime.entities import (
AudioPromptMessageContent,
DocumentPromptMessageContent,
ImagePromptMessageContent,
MultiModalPromptMessageContent,
VideoPromptMessageContent,
)
from extensions.ext_database import db
from extensions.ext_storage import storage
from . import helpers
@@ -41,7 +40,7 @@ def to_prompt_message_content(
/,
*,
image_detail_config: ImagePromptMessageContent.DETAIL | None = None,
):
) -> MultiModalPromptMessageContent:
if f.extension is None:
raise ValueError("Missing file extension")
if f.mime_type is None:
@@ -70,16 +69,13 @@ def to_prompt_message_content(
def download(f: File, /):
if f.transfer_method == FileTransferMethod.TOOL_FILE:
tool_file = file_repository.get_tool_file(session=db.session(), file=f)
return _download_file_content(tool_file.file_key)
elif f.transfer_method == FileTransferMethod.LOCAL_FILE:
upload_file = file_repository.get_upload_file(session=db.session(), file=f)
return _download_file_content(upload_file.key)
# remote file
response = ssrf_proxy.get(f.remote_url, follow_redirects=True)
response.raise_for_status()
return response.content
if f.transfer_method in (FileTransferMethod.TOOL_FILE, FileTransferMethod.LOCAL_FILE):
return _download_file_content(f._storage_key)
elif f.transfer_method == FileTransferMethod.REMOTE_URL:
response = ssrf_proxy.get(f.remote_url, follow_redirects=True)
response.raise_for_status()
return response.content
raise ValueError(f"unsupported transfer method: {f.transfer_method}")
def _download_file_content(path: str, /):
@@ -110,11 +106,9 @@ def _get_encoded_string(f: File, /):
response.raise_for_status()
data = response.content
case FileTransferMethod.LOCAL_FILE:
upload_file = file_repository.get_upload_file(session=db.session(), file=f)
data = _download_file_content(upload_file.key)
data = _download_file_content(f._storage_key)
case FileTransferMethod.TOOL_FILE:
tool_file = file_repository.get_tool_file(session=db.session(), file=f)
data = _download_file_content(tool_file.file_key)
data = _download_file_content(f._storage_key)
encoded_string = base64.b64encode(data).decode("utf-8")
return encoded_string

View File

@@ -1,32 +0,0 @@
from sqlalchemy import select
from sqlalchemy.orm import Session
from models import ToolFile, UploadFile
from .models import File
def get_upload_file(*, session: Session, file: File):
if file.related_id is None:
raise ValueError("Missing file related_id")
stmt = select(UploadFile).filter(
UploadFile.id == file.related_id,
UploadFile.tenant_id == file.tenant_id,
)
record = session.scalar(stmt)
if not record:
raise ValueError(f"upload file {file.related_id} not found")
return record
def get_tool_file(*, session: Session, file: File):
if file.related_id is None:
raise ValueError("Missing file related_id")
stmt = select(ToolFile).filter(
ToolFile.id == file.related_id,
ToolFile.tenant_id == file.tenant_id,
)
record = session.scalar(stmt)
if not record:
raise ValueError(f"tool file {file.related_id} not found")
return record

View File

@@ -47,6 +47,38 @@ class File(BaseModel):
mime_type: Optional[str] = None
size: int = -1
# Those properties are private, should not be exposed to the outside.
_storage_key: str
def __init__(
self,
*,
id: Optional[str] = None,
tenant_id: str,
type: FileType,
transfer_method: FileTransferMethod,
remote_url: Optional[str] = None,
related_id: Optional[str] = None,
filename: Optional[str] = None,
extension: Optional[str] = None,
mime_type: Optional[str] = None,
size: int = -1,
storage_key: str,
):
super().__init__(
id=id,
tenant_id=tenant_id,
type=type,
transfer_method=transfer_method,
remote_url=remote_url,
related_id=related_id,
filename=filename,
extension=extension,
mime_type=mime_type,
size=size,
)
self._storage_key = storage_key
def to_dict(self) -> Mapping[str, str | int | None]:
data = self.model_dump(mode="json")
return {