feat(llm_node): support order in text and files (#11837)
Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
@@ -1,15 +1,14 @@
|
||||
import base64
|
||||
|
||||
from configs import dify_config
|
||||
from core.file import file_repository
|
||||
from core.helper import ssrf_proxy
|
||||
from core.model_runtime.entities import (
|
||||
AudioPromptMessageContent,
|
||||
DocumentPromptMessageContent,
|
||||
ImagePromptMessageContent,
|
||||
MultiModalPromptMessageContent,
|
||||
VideoPromptMessageContent,
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_storage import storage
|
||||
|
||||
from . import helpers
|
||||
@@ -41,7 +40,7 @@ def to_prompt_message_content(
|
||||
/,
|
||||
*,
|
||||
image_detail_config: ImagePromptMessageContent.DETAIL | None = None,
|
||||
):
|
||||
) -> MultiModalPromptMessageContent:
|
||||
if f.extension is None:
|
||||
raise ValueError("Missing file extension")
|
||||
if f.mime_type is None:
|
||||
@@ -70,16 +69,13 @@ def to_prompt_message_content(
|
||||
|
||||
|
||||
def download(f: File, /):
|
||||
if f.transfer_method == FileTransferMethod.TOOL_FILE:
|
||||
tool_file = file_repository.get_tool_file(session=db.session(), file=f)
|
||||
return _download_file_content(tool_file.file_key)
|
||||
elif f.transfer_method == FileTransferMethod.LOCAL_FILE:
|
||||
upload_file = file_repository.get_upload_file(session=db.session(), file=f)
|
||||
return _download_file_content(upload_file.key)
|
||||
# remote file
|
||||
response = ssrf_proxy.get(f.remote_url, follow_redirects=True)
|
||||
response.raise_for_status()
|
||||
return response.content
|
||||
if f.transfer_method in (FileTransferMethod.TOOL_FILE, FileTransferMethod.LOCAL_FILE):
|
||||
return _download_file_content(f._storage_key)
|
||||
elif f.transfer_method == FileTransferMethod.REMOTE_URL:
|
||||
response = ssrf_proxy.get(f.remote_url, follow_redirects=True)
|
||||
response.raise_for_status()
|
||||
return response.content
|
||||
raise ValueError(f"unsupported transfer method: {f.transfer_method}")
|
||||
|
||||
|
||||
def _download_file_content(path: str, /):
|
||||
@@ -110,11 +106,9 @@ def _get_encoded_string(f: File, /):
|
||||
response.raise_for_status()
|
||||
data = response.content
|
||||
case FileTransferMethod.LOCAL_FILE:
|
||||
upload_file = file_repository.get_upload_file(session=db.session(), file=f)
|
||||
data = _download_file_content(upload_file.key)
|
||||
data = _download_file_content(f._storage_key)
|
||||
case FileTransferMethod.TOOL_FILE:
|
||||
tool_file = file_repository.get_tool_file(session=db.session(), file=f)
|
||||
data = _download_file_content(tool_file.file_key)
|
||||
data = _download_file_content(f._storage_key)
|
||||
|
||||
encoded_string = base64.b64encode(data).decode("utf-8")
|
||||
return encoded_string
|
||||
|
@@ -1,32 +0,0 @@
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from models import ToolFile, UploadFile
|
||||
|
||||
from .models import File
|
||||
|
||||
|
||||
def get_upload_file(*, session: Session, file: File):
|
||||
if file.related_id is None:
|
||||
raise ValueError("Missing file related_id")
|
||||
stmt = select(UploadFile).filter(
|
||||
UploadFile.id == file.related_id,
|
||||
UploadFile.tenant_id == file.tenant_id,
|
||||
)
|
||||
record = session.scalar(stmt)
|
||||
if not record:
|
||||
raise ValueError(f"upload file {file.related_id} not found")
|
||||
return record
|
||||
|
||||
|
||||
def get_tool_file(*, session: Session, file: File):
|
||||
if file.related_id is None:
|
||||
raise ValueError("Missing file related_id")
|
||||
stmt = select(ToolFile).filter(
|
||||
ToolFile.id == file.related_id,
|
||||
ToolFile.tenant_id == file.tenant_id,
|
||||
)
|
||||
record = session.scalar(stmt)
|
||||
if not record:
|
||||
raise ValueError(f"tool file {file.related_id} not found")
|
||||
return record
|
@@ -47,6 +47,38 @@ class File(BaseModel):
|
||||
mime_type: Optional[str] = None
|
||||
size: int = -1
|
||||
|
||||
# Those properties are private, should not be exposed to the outside.
|
||||
_storage_key: str
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
id: Optional[str] = None,
|
||||
tenant_id: str,
|
||||
type: FileType,
|
||||
transfer_method: FileTransferMethod,
|
||||
remote_url: Optional[str] = None,
|
||||
related_id: Optional[str] = None,
|
||||
filename: Optional[str] = None,
|
||||
extension: Optional[str] = None,
|
||||
mime_type: Optional[str] = None,
|
||||
size: int = -1,
|
||||
storage_key: str,
|
||||
):
|
||||
super().__init__(
|
||||
id=id,
|
||||
tenant_id=tenant_id,
|
||||
type=type,
|
||||
transfer_method=transfer_method,
|
||||
remote_url=remote_url,
|
||||
related_id=related_id,
|
||||
filename=filename,
|
||||
extension=extension,
|
||||
mime_type=mime_type,
|
||||
size=size,
|
||||
)
|
||||
self._storage_key = storage_key
|
||||
|
||||
def to_dict(self) -> Mapping[str, str | int | None]:
|
||||
data = self.model_dump(mode="json")
|
||||
return {
|
||||
|
Reference in New Issue
Block a user