feat/enhance the multi-modal support (#8818)

This commit is contained in:
-LAN-
2024-10-21 10:43:49 +08:00
committed by GitHub
parent 7a1d6fe509
commit e61752bd3a
267 changed files with 6263 additions and 3523 deletions

View File

@@ -4,7 +4,6 @@ import hmac
import logging
import os
import time
from collections.abc import Generator
from mimetypes import guess_extension, guess_type
from typing import Optional, Union
from uuid import uuid4
@@ -57,22 +56,32 @@ class ToolFileManager:
@staticmethod
def create_file_by_raw(
user_id: str, tenant_id: str, conversation_id: Optional[str], file_binary: bytes, mimetype: str
*,
user_id: str,
tenant_id: str,
conversation_id: Optional[str],
file_binary: bytes,
mimetype: str,
) -> ToolFile:
"""
create file
"""
extension = guess_extension(mimetype) or ".bin"
unique_name = uuid4().hex
filename = f"tools/{tenant_id}/{unique_name}{extension}"
storage.save(filename, file_binary)
filename = f"{unique_name}{extension}"
filepath = f"tools/{tenant_id}/{filename}"
storage.save(filepath, file_binary)
tool_file = ToolFile(
user_id=user_id, tenant_id=tenant_id, conversation_id=conversation_id, file_key=filename, mimetype=mimetype
user_id=user_id,
tenant_id=tenant_id,
conversation_id=conversation_id,
file_key=filepath,
mimetype=mimetype,
name=filename,
size=len(file_binary),
)
db.session.add(tool_file)
db.session.commit()
db.session.refresh(tool_file)
return tool_file
@@ -80,29 +89,34 @@ class ToolFileManager:
def create_file_by_url(
user_id: str,
tenant_id: str,
conversation_id: str,
conversation_id: str | None,
file_url: str,
) -> ToolFile:
"""
create file
"""
# try to download image
response = get(file_url)
response.raise_for_status()
blob = response.content
try:
response = get(file_url)
response.raise_for_status()
blob = response.content
except Exception as e:
logger.error(f"Failed to download file from {file_url}: {e}")
raise
mimetype = guess_type(file_url)[0] or "octet/stream"
extension = guess_extension(mimetype) or ".bin"
unique_name = uuid4().hex
filename = f"tools/{tenant_id}/{unique_name}{extension}"
storage.save(filename, blob)
filename = f"{unique_name}{extension}"
filepath = f"tools/{tenant_id}/{filename}"
storage.save(filepath, blob)
tool_file = ToolFile(
user_id=user_id,
tenant_id=tenant_id,
conversation_id=conversation_id,
file_key=filename,
file_key=filepath,
mimetype=mimetype,
original_url=file_url,
name=filename,
size=len(blob),
)
db.session.add(tool_file)
@@ -110,18 +124,6 @@ class ToolFileManager:
return tool_file
@staticmethod
def create_file_by_key(
user_id: str, tenant_id: str, conversation_id: str, file_key: str, mimetype: str
) -> ToolFile:
"""
create file
"""
tool_file = ToolFile(
user_id=user_id, tenant_id=tenant_id, conversation_id=conversation_id, file_key=file_key, mimetype=mimetype
)
return tool_file
@staticmethod
def get_file_binary(id: str) -> Union[tuple[bytes, str], None]:
"""
@@ -131,7 +133,7 @@ class ToolFileManager:
:return: the binary of the file, mime type
"""
tool_file: ToolFile = (
tool_file = (
db.session.query(ToolFile)
.filter(
ToolFile.id == id,
@@ -155,7 +157,7 @@ class ToolFileManager:
:return: the binary of the file, mime type
"""
message_file: MessageFile = (
message_file = (
db.session.query(MessageFile)
.filter(
MessageFile.id == id,
@@ -166,13 +168,16 @@ class ToolFileManager:
# Check if message_file is not None
if message_file is not None:
# get tool file id
tool_file_id = message_file.url.split("/")[-1]
# trim extension
tool_file_id = tool_file_id.split(".")[0]
if message_file.url is not None:
tool_file_id = message_file.url.split("/")[-1]
# trim extension
tool_file_id = tool_file_id.split(".")[0]
else:
tool_file_id = None
else:
tool_file_id = None
tool_file: ToolFile = (
tool_file = (
db.session.query(ToolFile)
.filter(
ToolFile.id == tool_file_id,
@@ -188,7 +193,7 @@ class ToolFileManager:
return blob, tool_file.mimetype
@staticmethod
def get_file_generator_by_tool_file_id(tool_file_id: str) -> Union[tuple[Generator, str], None]:
def get_file_generator_by_tool_file_id(tool_file_id: str):
"""
get file binary
@@ -196,7 +201,7 @@ class ToolFileManager:
:return: the binary of the file, mime type
"""
tool_file: ToolFile = (
tool_file = (
db.session.query(ToolFile)
.filter(
ToolFile.id == tool_file_id,
@@ -205,11 +210,11 @@ class ToolFileManager:
)
if not tool_file:
return None
return None, None
generator = storage.load_stream(tool_file.file_key)
stream = storage.load_stream(tool_file.file_key)
return generator, tool_file.mimetype
return stream, tool_file
# init tool_file_parser