feat/enhance the multi-modal support (#8818)
This commit is contained in:
@@ -4,7 +4,6 @@ import hmac
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from mimetypes import guess_extension, guess_type
|
||||
from typing import Optional, Union
|
||||
from uuid import uuid4
|
||||
@@ -57,22 +56,32 @@ class ToolFileManager:
|
||||
|
||||
@staticmethod
|
||||
def create_file_by_raw(
|
||||
user_id: str, tenant_id: str, conversation_id: Optional[str], file_binary: bytes, mimetype: str
|
||||
*,
|
||||
user_id: str,
|
||||
tenant_id: str,
|
||||
conversation_id: Optional[str],
|
||||
file_binary: bytes,
|
||||
mimetype: str,
|
||||
) -> ToolFile:
|
||||
"""
|
||||
create file
|
||||
"""
|
||||
extension = guess_extension(mimetype) or ".bin"
|
||||
unique_name = uuid4().hex
|
||||
filename = f"tools/{tenant_id}/{unique_name}{extension}"
|
||||
storage.save(filename, file_binary)
|
||||
filename = f"{unique_name}{extension}"
|
||||
filepath = f"tools/{tenant_id}/{filename}"
|
||||
storage.save(filepath, file_binary)
|
||||
|
||||
tool_file = ToolFile(
|
||||
user_id=user_id, tenant_id=tenant_id, conversation_id=conversation_id, file_key=filename, mimetype=mimetype
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
conversation_id=conversation_id,
|
||||
file_key=filepath,
|
||||
mimetype=mimetype,
|
||||
name=filename,
|
||||
size=len(file_binary),
|
||||
)
|
||||
|
||||
db.session.add(tool_file)
|
||||
db.session.commit()
|
||||
db.session.refresh(tool_file)
|
||||
|
||||
return tool_file
|
||||
|
||||
@@ -80,29 +89,34 @@ class ToolFileManager:
|
||||
def create_file_by_url(
|
||||
user_id: str,
|
||||
tenant_id: str,
|
||||
conversation_id: str,
|
||||
conversation_id: str | None,
|
||||
file_url: str,
|
||||
) -> ToolFile:
|
||||
"""
|
||||
create file
|
||||
"""
|
||||
# try to download image
|
||||
response = get(file_url)
|
||||
response.raise_for_status()
|
||||
blob = response.content
|
||||
try:
|
||||
response = get(file_url)
|
||||
response.raise_for_status()
|
||||
blob = response.content
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to download file from {file_url}: {e}")
|
||||
raise
|
||||
|
||||
mimetype = guess_type(file_url)[0] or "octet/stream"
|
||||
extension = guess_extension(mimetype) or ".bin"
|
||||
unique_name = uuid4().hex
|
||||
filename = f"tools/{tenant_id}/{unique_name}{extension}"
|
||||
storage.save(filename, blob)
|
||||
filename = f"{unique_name}{extension}"
|
||||
filepath = f"tools/{tenant_id}/{filename}"
|
||||
storage.save(filepath, blob)
|
||||
|
||||
tool_file = ToolFile(
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
conversation_id=conversation_id,
|
||||
file_key=filename,
|
||||
file_key=filepath,
|
||||
mimetype=mimetype,
|
||||
original_url=file_url,
|
||||
name=filename,
|
||||
size=len(blob),
|
||||
)
|
||||
|
||||
db.session.add(tool_file)
|
||||
@@ -110,18 +124,6 @@ class ToolFileManager:
|
||||
|
||||
return tool_file
|
||||
|
||||
@staticmethod
|
||||
def create_file_by_key(
|
||||
user_id: str, tenant_id: str, conversation_id: str, file_key: str, mimetype: str
|
||||
) -> ToolFile:
|
||||
"""
|
||||
create file
|
||||
"""
|
||||
tool_file = ToolFile(
|
||||
user_id=user_id, tenant_id=tenant_id, conversation_id=conversation_id, file_key=file_key, mimetype=mimetype
|
||||
)
|
||||
return tool_file
|
||||
|
||||
@staticmethod
|
||||
def get_file_binary(id: str) -> Union[tuple[bytes, str], None]:
|
||||
"""
|
||||
@@ -131,7 +133,7 @@ class ToolFileManager:
|
||||
|
||||
:return: the binary of the file, mime type
|
||||
"""
|
||||
tool_file: ToolFile = (
|
||||
tool_file = (
|
||||
db.session.query(ToolFile)
|
||||
.filter(
|
||||
ToolFile.id == id,
|
||||
@@ -155,7 +157,7 @@ class ToolFileManager:
|
||||
|
||||
:return: the binary of the file, mime type
|
||||
"""
|
||||
message_file: MessageFile = (
|
||||
message_file = (
|
||||
db.session.query(MessageFile)
|
||||
.filter(
|
||||
MessageFile.id == id,
|
||||
@@ -166,13 +168,16 @@ class ToolFileManager:
|
||||
# Check if message_file is not None
|
||||
if message_file is not None:
|
||||
# get tool file id
|
||||
tool_file_id = message_file.url.split("/")[-1]
|
||||
# trim extension
|
||||
tool_file_id = tool_file_id.split(".")[0]
|
||||
if message_file.url is not None:
|
||||
tool_file_id = message_file.url.split("/")[-1]
|
||||
# trim extension
|
||||
tool_file_id = tool_file_id.split(".")[0]
|
||||
else:
|
||||
tool_file_id = None
|
||||
else:
|
||||
tool_file_id = None
|
||||
|
||||
tool_file: ToolFile = (
|
||||
tool_file = (
|
||||
db.session.query(ToolFile)
|
||||
.filter(
|
||||
ToolFile.id == tool_file_id,
|
||||
@@ -188,7 +193,7 @@ class ToolFileManager:
|
||||
return blob, tool_file.mimetype
|
||||
|
||||
@staticmethod
|
||||
def get_file_generator_by_tool_file_id(tool_file_id: str) -> Union[tuple[Generator, str], None]:
|
||||
def get_file_generator_by_tool_file_id(tool_file_id: str):
|
||||
"""
|
||||
get file binary
|
||||
|
||||
@@ -196,7 +201,7 @@ class ToolFileManager:
|
||||
|
||||
:return: the binary of the file, mime type
|
||||
"""
|
||||
tool_file: ToolFile = (
|
||||
tool_file = (
|
||||
db.session.query(ToolFile)
|
||||
.filter(
|
||||
ToolFile.id == tool_file_id,
|
||||
@@ -205,11 +210,11 @@ class ToolFileManager:
|
||||
)
|
||||
|
||||
if not tool_file:
|
||||
return None
|
||||
return None, None
|
||||
|
||||
generator = storage.load_stream(tool_file.file_key)
|
||||
stream = storage.load_stream(tool_file.file_key)
|
||||
|
||||
return generator, tool_file.mimetype
|
||||
return stream, tool_file
|
||||
|
||||
|
||||
# init tool_file_parser
|
||||
|
Reference in New Issue
Block a user