fix: implement robust file type checks to align with existing logic (#17557)

Co-authored-by: Bowen Liang <liangbowen@gf.com.cn>
This commit is contained in:
Arcaner
2025-04-16 19:21:50 +08:00
committed by GitHub
parent 18f98f4fe1
commit cac0d3c33e
4 changed files with 243 additions and 6 deletions

View File

@@ -52,6 +52,7 @@ def build_from_mapping(
mapping: Mapping[str, Any],
tenant_id: str,
config: FileUploadConfig | None = None,
strict_type_validation: bool = False,
) -> File:
transfer_method = FileTransferMethod.value_of(mapping.get("transfer_method"))
@@ -69,6 +70,7 @@ def build_from_mapping(
mapping=mapping,
tenant_id=tenant_id,
transfer_method=transfer_method,
strict_type_validation=strict_type_validation,
)
if config and not _is_file_valid_with_config(
@@ -87,12 +89,14 @@ def build_from_mappings(
mappings: Sequence[Mapping[str, Any]],
config: FileUploadConfig | None = None,
tenant_id: str,
strict_type_validation: bool = False,
) -> Sequence[File]:
files = [
build_from_mapping(
mapping=mapping,
tenant_id=tenant_id,
config=config,
strict_type_validation=strict_type_validation,
)
for mapping in mappings
]
@@ -116,6 +120,7 @@ def _build_from_local_file(
mapping: Mapping[str, Any],
tenant_id: str,
transfer_method: FileTransferMethod,
strict_type_validation: bool = False,
) -> File:
upload_file_id = mapping.get("upload_file_id")
if not upload_file_id:
@@ -134,10 +139,16 @@ def _build_from_local_file(
if row is None:
raise ValueError("Invalid upload file")
file_type = _standardize_file_type(extension="." + row.extension, mime_type=row.mime_type)
if file_type.value != mapping.get("type", "custom"):
detected_file_type = _standardize_file_type(extension="." + row.extension, mime_type=row.mime_type)
specified_type = mapping.get("type", "custom")
if strict_type_validation and detected_file_type.value != specified_type:
raise ValueError("Detected file type does not match the specified type. Please verify the file.")
file_type = (
FileType(specified_type) if specified_type and specified_type != FileType.CUSTOM.value else detected_file_type
)
return File(
id=mapping.get("id"),
filename=row.name,
@@ -158,6 +169,7 @@ def _build_from_remote_url(
mapping: Mapping[str, Any],
tenant_id: str,
transfer_method: FileTransferMethod,
strict_type_validation: bool = False,
) -> File:
upload_file_id = mapping.get("upload_file_id")
if upload_file_id:
@@ -174,10 +186,21 @@ def _build_from_remote_url(
if upload_file is None:
raise ValueError("Invalid upload file")
file_type = _standardize_file_type(extension="." + upload_file.extension, mime_type=upload_file.mime_type)
if file_type.value != mapping.get("type", "custom"):
detected_file_type = _standardize_file_type(
extension="." + upload_file.extension, mime_type=upload_file.mime_type
)
specified_type = mapping.get("type")
if strict_type_validation and specified_type and detected_file_type.value != specified_type:
raise ValueError("Detected file type does not match the specified type. Please verify the file.")
file_type = (
FileType(specified_type)
if specified_type and specified_type != FileType.CUSTOM.value
else detected_file_type
)
return File(
id=mapping.get("id"),
filename=upload_file.name,
@@ -237,6 +260,7 @@ def _build_from_tool_file(
mapping: Mapping[str, Any],
tenant_id: str,
transfer_method: FileTransferMethod,
strict_type_validation: bool = False,
) -> File:
tool_file = (
db.session.query(ToolFile)
@@ -252,7 +276,16 @@ def _build_from_tool_file(
extension = "." + tool_file.file_key.split(".")[-1] if "." in tool_file.file_key else ".bin"
file_type = _standardize_file_type(extension=extension, mime_type=tool_file.mimetype)
detected_file_type = _standardize_file_type(extension="." + extension, mime_type=tool_file.mimetype)
specified_type = mapping.get("type")
if strict_type_validation and specified_type and detected_file_type.value != specified_type:
raise ValueError("Detected file type does not match the specified type. Please verify the file.")
file_type = (
FileType(specified_type) if specified_type and specified_type != FileType.CUSTOM.value else detected_file_type
)
return File(
id=mapping.get("id"),