refactor(core): Remove extra_config from File. (#10203)

This commit is contained in:
-LAN-
2024-11-08 18:13:24 +08:00
committed by GitHub
parent 78a380bcc4
commit 25ca0278dd
28 changed files with 263 additions and 344 deletions

View File

@@ -1,23 +1,21 @@
import mimetypes
from collections.abc import Mapping, Sequence
from collections.abc import Callable, Mapping, Sequence
from typing import Any
import httpx
from sqlalchemy import select
from constants import AUDIO_EXTENSIONS, DOCUMENT_EXTENSIONS, IMAGE_EXTENSIONS, VIDEO_EXTENSIONS
from core.file import File, FileBelongsTo, FileExtraConfig, FileTransferMethod, FileType
from core.file import File, FileBelongsTo, FileTransferMethod, FileType, FileUploadConfig
from core.helper import ssrf_proxy
from extensions.ext_database import db
from models import MessageFile, ToolFile, UploadFile
from models.enums import CreatedByRole
def build_from_message_files(
*,
message_files: Sequence["MessageFile"],
tenant_id: str,
config: FileExtraConfig,
config: FileUploadConfig,
) -> Sequence[File]:
results = [
build_from_message_file(message_file=file, tenant_id=tenant_id, config=config)
@@ -31,7 +29,7 @@ def build_from_message_file(
*,
message_file: "MessageFile",
tenant_id: str,
config: FileExtraConfig,
config: FileUploadConfig,
):
mapping = {
"transfer_method": message_file.transfer_method,
@@ -43,8 +41,6 @@ def build_from_message_file(
return build_from_mapping(
mapping=mapping,
tenant_id=tenant_id,
user_id=message_file.created_by,
role=CreatedByRole(message_file.created_by_role),
config=config,
)
@@ -53,38 +49,30 @@ def build_from_mapping(
*,
mapping: Mapping[str, Any],
tenant_id: str,
user_id: str,
role: "CreatedByRole",
config: FileExtraConfig,
):
config: FileUploadConfig | None = None,
) -> File:
config = config or FileUploadConfig()
transfer_method = FileTransferMethod.value_of(mapping.get("transfer_method"))
match transfer_method:
case FileTransferMethod.REMOTE_URL:
file = _build_from_remote_url(
mapping=mapping,
tenant_id=tenant_id,
config=config,
transfer_method=transfer_method,
)
case FileTransferMethod.LOCAL_FILE:
file = _build_from_local_file(
mapping=mapping,
tenant_id=tenant_id,
user_id=user_id,
role=role,
config=config,
transfer_method=transfer_method,
)
case FileTransferMethod.TOOL_FILE:
file = _build_from_tool_file(
mapping=mapping,
tenant_id=tenant_id,
user_id=user_id,
config=config,
transfer_method=transfer_method,
)
case _:
raise ValueError(f"Invalid file transfer method: {transfer_method}")
build_functions: dict[FileTransferMethod, Callable] = {
FileTransferMethod.LOCAL_FILE: _build_from_local_file,
FileTransferMethod.REMOTE_URL: _build_from_remote_url,
FileTransferMethod.TOOL_FILE: _build_from_tool_file,
}
build_func = build_functions.get(transfer_method)
if not build_func:
raise ValueError(f"Invalid file transfer method: {transfer_method}")
file = build_func(
mapping=mapping,
tenant_id=tenant_id,
transfer_method=transfer_method,
)
if not _is_file_valid_with_config(file=file, config=config):
raise ValueError(f"File validation failed for file: {file.filename}")
return file
@@ -92,10 +80,8 @@ def build_from_mapping(
def build_from_mappings(
*,
mappings: Sequence[Mapping[str, Any]],
config: FileExtraConfig | None,
config: FileUploadConfig | None,
tenant_id: str,
user_id: str,
role: "CreatedByRole",
) -> Sequence[File]:
if not config:
return []
@@ -104,8 +90,6 @@ def build_from_mappings(
build_from_mapping(
mapping=mapping,
tenant_id=tenant_id,
user_id=user_id,
role=role,
config=config,
)
for mapping in mappings
@@ -128,31 +112,20 @@ def _build_from_local_file(
*,
mapping: Mapping[str, Any],
tenant_id: str,
user_id: str,
role: "CreatedByRole",
config: FileExtraConfig,
transfer_method: FileTransferMethod,
):
# check if the upload file exists.
) -> File:
file_type = FileType.value_of(mapping.get("type"))
stmt = select(UploadFile).where(
UploadFile.id == mapping.get("upload_file_id"),
UploadFile.tenant_id == tenant_id,
UploadFile.created_by == user_id,
UploadFile.created_by_role == role,
)
if file_type == FileType.IMAGE:
stmt = stmt.where(UploadFile.extension.in_(IMAGE_EXTENSIONS))
elif file_type == FileType.VIDEO:
stmt = stmt.where(UploadFile.extension.in_(VIDEO_EXTENSIONS))
elif file_type == FileType.AUDIO:
stmt = stmt.where(UploadFile.extension.in_(AUDIO_EXTENSIONS))
elif file_type == FileType.DOCUMENT:
stmt = stmt.where(UploadFile.extension.in_(DOCUMENT_EXTENSIONS))
row = db.session.scalar(stmt)
if row is None:
raise ValueError("Invalid upload file")
file = File(
return File(
id=mapping.get("id"),
filename=row.name,
extension="." + row.extension,
@@ -162,23 +135,37 @@ def _build_from_local_file(
transfer_method=transfer_method,
remote_url=row.source_url,
related_id=mapping.get("upload_file_id"),
_extra_config=config,
size=row.size,
)
return file
def _build_from_remote_url(
*,
mapping: Mapping[str, Any],
tenant_id: str,
config: FileExtraConfig,
transfer_method: FileTransferMethod,
):
) -> File:
url = mapping.get("url")
if not url:
raise ValueError("Invalid file url")
mime_type, filename, file_size = _get_remote_file_info(url)
extension = mimetypes.guess_extension(mime_type) or "." + filename.split(".")[-1] if "." in filename else ".bin"
return File(
id=mapping.get("id"),
filename=filename,
tenant_id=tenant_id,
type=FileType.value_of(mapping.get("type")),
transfer_method=transfer_method,
remote_url=url,
mime_type=mime_type,
extension=extension,
size=file_size,
)
def _get_remote_file_info(url: str):
mime_type = mimetypes.guess_type(url)[0] or ""
file_size = -1
filename = url.split("/")[-1].split("?")[0] or "unknown_file"
@@ -186,56 +173,34 @@ def _build_from_remote_url(
resp = ssrf_proxy.head(url, follow_redirects=True)
if resp.status_code == httpx.codes.OK:
if content_disposition := resp.headers.get("Content-Disposition"):
filename = content_disposition.split("filename=")[-1].strip('"')
filename = str(content_disposition.split("filename=")[-1].strip('"'))
file_size = int(resp.headers.get("Content-Length", file_size))
mime_type = mime_type or str(resp.headers.get("Content-Type", ""))
# Determine file extension
extension = mimetypes.guess_extension(mime_type) or "." + filename.split(".")[-1] if "." in filename else ".bin"
if not mime_type:
mime_type, _ = mimetypes.guess_type(url)
file = File(
id=mapping.get("id"),
filename=filename,
tenant_id=tenant_id,
type=FileType.value_of(mapping.get("type")),
transfer_method=transfer_method,
remote_url=url,
_extra_config=config,
mime_type=mime_type,
extension=extension,
size=file_size,
)
return file
return mime_type, filename, file_size
def _build_from_tool_file(
*,
mapping: Mapping[str, Any],
tenant_id: str,
user_id: str,
config: FileExtraConfig,
transfer_method: FileTransferMethod,
):
) -> File:
tool_file = (
db.session.query(ToolFile)
.filter(
ToolFile.id == mapping.get("tool_file_id"),
ToolFile.tenant_id == tenant_id,
ToolFile.user_id == user_id,
)
.first()
)
if tool_file is None:
raise ValueError(f"ToolFile {mapping.get('tool_file_id')} not found")
path = tool_file.file_key
if "." in path:
extension = "." + path.split("/")[-1].split(".")[-1]
else:
extension = ".bin"
file = File(
extension = "." + tool_file.file_key.split(".")[-1] if "." in tool_file.file_key else ".bin"
return File(
id=mapping.get("id"),
tenant_id=tenant_id,
filename=tool_file.name,
@@ -246,6 +211,21 @@ def _build_from_tool_file(
extension=extension,
mime_type=tool_file.mimetype,
size=tool_file.size,
_extra_config=config,
)
return file
def _is_file_valid_with_config(*, file: File, config: FileUploadConfig) -> bool:
if config.allowed_file_types and file.type not in config.allowed_file_types and file.type != FileType.CUSTOM:
return False
if config.allowed_extensions and file.extension not in config.allowed_extensions:
return False
if config.allowed_upload_methods and file.transfer_method not in config.allowed_upload_methods:
return False
if file.type == FileType.IMAGE and config.image_config:
if config.image_config.transfer_methods and file.transfer_method not in config.image_config.transfer_methods:
return False
return True