feat: Allow using file variables directly in the LLM node and support more file types. (#10679)
Co-authored-by: Joel <iamjoel007@gmail.com>
This commit is contained in:
@@ -3,7 +3,12 @@ import base64
|
||||
from configs import dify_config
|
||||
from core.file import file_repository
|
||||
from core.helper import ssrf_proxy
|
||||
from core.model_runtime.entities import AudioPromptMessageContent, ImagePromptMessageContent, VideoPromptMessageContent
|
||||
from core.model_runtime.entities import (
|
||||
AudioPromptMessageContent,
|
||||
DocumentPromptMessageContent,
|
||||
ImagePromptMessageContent,
|
||||
VideoPromptMessageContent,
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_storage import storage
|
||||
|
||||
@@ -29,35 +34,17 @@ def get_attr(*, file: File, attr: FileAttribute):
|
||||
return file.remote_url
|
||||
case FileAttribute.EXTENSION:
|
||||
return file.extension
|
||||
case _:
|
||||
raise ValueError(f"Invalid file attribute: {attr}")
|
||||
|
||||
|
||||
def to_prompt_message_content(
|
||||
f: File,
|
||||
/,
|
||||
*,
|
||||
image_detail_config: ImagePromptMessageContent.DETAIL = ImagePromptMessageContent.DETAIL.LOW,
|
||||
image_detail_config: ImagePromptMessageContent.DETAIL | None = None,
|
||||
):
|
||||
"""
|
||||
Convert a File object to an ImagePromptMessageContent or AudioPromptMessageContent object.
|
||||
|
||||
This function takes a File object and converts it to an appropriate PromptMessageContent
|
||||
object, which can be used as a prompt for image or audio-based AI models.
|
||||
|
||||
Args:
|
||||
f (File): The File object to convert.
|
||||
detail (Optional[ImagePromptMessageContent.DETAIL]): The detail level for image prompts.
|
||||
If not provided, defaults to ImagePromptMessageContent.DETAIL.LOW.
|
||||
|
||||
Returns:
|
||||
Union[ImagePromptMessageContent, AudioPromptMessageContent]: An object containing the file data and detail level
|
||||
|
||||
Raises:
|
||||
ValueError: If the file type is not supported or if required data is missing.
|
||||
"""
|
||||
match f.type:
|
||||
case FileType.IMAGE:
|
||||
image_detail_config = image_detail_config or ImagePromptMessageContent.DETAIL.LOW
|
||||
if dify_config.MULTIMODAL_SEND_IMAGE_FORMAT == "url":
|
||||
data = _to_url(f)
|
||||
else:
|
||||
@@ -65,7 +52,7 @@ def to_prompt_message_content(
|
||||
|
||||
return ImagePromptMessageContent(data=data, detail=image_detail_config)
|
||||
case FileType.AUDIO:
|
||||
encoded_string = _file_to_encoded_string(f)
|
||||
encoded_string = _get_encoded_string(f)
|
||||
if f.extension is None:
|
||||
raise ValueError("Missing file extension")
|
||||
return AudioPromptMessageContent(data=encoded_string, format=f.extension.lstrip("."))
|
||||
@@ -74,9 +61,20 @@ def to_prompt_message_content(
|
||||
data = _to_url(f)
|
||||
else:
|
||||
data = _to_base64_data_string(f)
|
||||
if f.extension is None:
|
||||
raise ValueError("Missing file extension")
|
||||
return VideoPromptMessageContent(data=data, format=f.extension.lstrip("."))
|
||||
case FileType.DOCUMENT:
|
||||
data = _get_encoded_string(f)
|
||||
if f.mime_type is None:
|
||||
raise ValueError("Missing file mime_type")
|
||||
return DocumentPromptMessageContent(
|
||||
encode_format="base64",
|
||||
mime_type=f.mime_type,
|
||||
data=data,
|
||||
)
|
||||
case _:
|
||||
raise ValueError("file type f.type is not supported")
|
||||
raise ValueError(f"file type {f.type} is not supported")
|
||||
|
||||
|
||||
def download(f: File, /):
|
||||
@@ -118,21 +116,16 @@ def _get_encoded_string(f: File, /):
|
||||
case FileTransferMethod.REMOTE_URL:
|
||||
response = ssrf_proxy.get(f.remote_url, follow_redirects=True)
|
||||
response.raise_for_status()
|
||||
content = response.content
|
||||
encoded_string = base64.b64encode(content).decode("utf-8")
|
||||
return encoded_string
|
||||
data = response.content
|
||||
case FileTransferMethod.LOCAL_FILE:
|
||||
upload_file = file_repository.get_upload_file(session=db.session(), file=f)
|
||||
data = _download_file_content(upload_file.key)
|
||||
encoded_string = base64.b64encode(data).decode("utf-8")
|
||||
return encoded_string
|
||||
case FileTransferMethod.TOOL_FILE:
|
||||
tool_file = file_repository.get_tool_file(session=db.session(), file=f)
|
||||
data = _download_file_content(tool_file.file_key)
|
||||
encoded_string = base64.b64encode(data).decode("utf-8")
|
||||
return encoded_string
|
||||
case _:
|
||||
raise ValueError(f"Unsupported transfer method: {f.transfer_method}")
|
||||
|
||||
encoded_string = base64.b64encode(data).decode("utf-8")
|
||||
return encoded_string
|
||||
|
||||
|
||||
def _to_base64_data_string(f: File, /):
|
||||
@@ -140,18 +133,6 @@ def _to_base64_data_string(f: File, /):
|
||||
return f"data:{f.mime_type};base64,{encoded_string}"
|
||||
|
||||
|
||||
def _file_to_encoded_string(f: File, /):
|
||||
match f.type:
|
||||
case FileType.IMAGE:
|
||||
return _to_base64_data_string(f)
|
||||
case FileType.VIDEO:
|
||||
return _to_base64_data_string(f)
|
||||
case FileType.AUDIO:
|
||||
return _get_encoded_string(f)
|
||||
case _:
|
||||
raise ValueError(f"file type {f.type} is not supported")
|
||||
|
||||
|
||||
def _to_url(f: File, /):
|
||||
if f.transfer_method == FileTransferMethod.REMOTE_URL:
|
||||
if f.remote_url is None:
|
||||
|
Reference in New Issue
Block a user