add more dataclass (#25039)

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
Asuka Minato
2025-09-06 04:20:13 +09:00
committed by GitHub
parent 917d60a1cb
commit 2b0695bdde
5 changed files with 34 additions and 33 deletions

View File

@@ -98,6 +98,7 @@ class ToolFileManager:
mimetype=mimetype, mimetype=mimetype,
name=present_filename, name=present_filename,
size=len(file_binary), size=len(file_binary),
original_url=None,
) )
session.add(tool_file) session.add(tool_file)
@@ -131,7 +132,6 @@ class ToolFileManager:
filename = f"{unique_name}{extension}" filename = f"{unique_name}{extension}"
filepath = f"tools/{tenant_id}/{filename}" filepath = f"tools/{tenant_id}/{filename}"
storage.save(filepath, blob) storage.save(filepath, blob)
with Session(self._engine, expire_on_commit=False) as session: with Session(self._engine, expire_on_commit=False) as session:
tool_file = ToolFile( tool_file = ToolFile(
user_id=user_id, user_id=user_id,

View File

@@ -1,6 +1,6 @@
import json import json
from datetime import datetime from datetime import datetime
from typing import Any, cast from typing import Any, Optional, cast
from urllib.parse import urlparse from urllib.parse import urlparse
import sqlalchemy as sa import sqlalchemy as sa
@@ -22,15 +22,15 @@ from .types import StringUUID
# system level tool oauth client params (client_id, client_secret, etc.) # system level tool oauth client params (client_id, client_secret, etc.)
class ToolOAuthSystemClient(Base): class ToolOAuthSystemClient(TypeBase):
__tablename__ = "tool_oauth_system_clients" __tablename__ = "tool_oauth_system_clients"
__table_args__ = ( __table_args__ = (
sa.PrimaryKeyConstraint("id", name="tool_oauth_system_client_pkey"), sa.PrimaryKeyConstraint("id", name="tool_oauth_system_client_pkey"),
sa.UniqueConstraint("plugin_id", "provider", name="tool_oauth_system_client_plugin_id_provider_idx"), sa.UniqueConstraint("plugin_id", "provider", name="tool_oauth_system_client_plugin_id_provider_idx"),
) )
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"), init=False)
plugin_id = mapped_column(String(512), nullable=False) plugin_id: Mapped[str] = mapped_column(String(512), nullable=False)
provider: Mapped[str] = mapped_column(String(255), nullable=False) provider: Mapped[str] = mapped_column(String(255), nullable=False)
# oauth params of the tool provider # oauth params of the tool provider
encrypted_oauth_params: Mapped[str] = mapped_column(sa.Text, nullable=False) encrypted_oauth_params: Mapped[str] = mapped_column(sa.Text, nullable=False)
@@ -412,7 +412,7 @@ class ToolConversationVariables(Base):
return json.loads(self.variables_str) return json.loads(self.variables_str)
class ToolFile(Base): class ToolFile(TypeBase):
"""This table stores file metadata generated in workflows, """This table stores file metadata generated in workflows,
not only files created by agent. not only files created by agent.
""" """
@@ -423,19 +423,19 @@ class ToolFile(Base):
sa.Index("tool_file_conversation_id_idx", "conversation_id"), sa.Index("tool_file_conversation_id_idx", "conversation_id"),
) )
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"), init=False)
# conversation user id # conversation user id
user_id: Mapped[str] = mapped_column(StringUUID) user_id: Mapped[str] = mapped_column(StringUUID)
# tenant id # tenant id
tenant_id: Mapped[str] = mapped_column(StringUUID) tenant_id: Mapped[str] = mapped_column(StringUUID)
# conversation id # conversation id
conversation_id: Mapped[str] = mapped_column(StringUUID, nullable=True) conversation_id: Mapped[Optional[str]] = mapped_column(StringUUID, nullable=True)
# file key # file key
file_key: Mapped[str] = mapped_column(String(255), nullable=False) file_key: Mapped[str] = mapped_column(String(255), nullable=False)
# mime type # mime type
mimetype: Mapped[str] = mapped_column(String(255), nullable=False) mimetype: Mapped[str] = mapped_column(String(255), nullable=False)
# original url # original url
original_url: Mapped[str] = mapped_column(String(2048), nullable=True) original_url: Mapped[Optional[str]] = mapped_column(String(2048), nullable=True, default=None)
# name # name
name: Mapped[str] = mapped_column(default="") name: Mapped[str] = mapped_column(default="")
# size # size

View File

@@ -84,17 +84,17 @@ class TestStorageKeyLoader(unittest.TestCase):
if tenant_id is None: if tenant_id is None:
tenant_id = self.tenant_id tenant_id = self.tenant_id
tool_file = ToolFile() tool_file = ToolFile(
user_id=self.user_id,
tenant_id=tenant_id,
conversation_id=self.conversation_id,
file_key=file_key,
mimetype="text/plain",
original_url="http://example.com/file.txt",
name="test_tool_file.txt",
size=2048,
)
tool_file.id = file_id tool_file.id = file_id
tool_file.user_id = self.user_id
tool_file.tenant_id = tenant_id
tool_file.conversation_id = self.conversation_id
tool_file.file_key = file_key
tool_file.mimetype = "text/plain"
tool_file.original_url = "http://example.com/file.txt"
tool_file.name = "test_tool_file.txt"
tool_file.size = 2048
self.session.add(tool_file) self.session.add(tool_file)
self.session.flush() self.session.flush()
self.test_tool_files.append(tool_file) self.test_tool_files.append(tool_file)

View File

@@ -84,16 +84,17 @@ class TestStorageKeyLoader(unittest.TestCase):
if tenant_id is None: if tenant_id is None:
tenant_id = self.tenant_id tenant_id = self.tenant_id
tool_file = ToolFile() tool_file = ToolFile(
user_id=self.user_id,
tenant_id=tenant_id,
conversation_id=self.conversation_id,
file_key=file_key,
mimetype="text/plain",
original_url="http://example.com/file.txt",
name="test_tool_file.txt",
size=2048,
)
tool_file.id = file_id tool_file.id = file_id
tool_file.user_id = self.user_id
tool_file.tenant_id = tenant_id
tool_file.conversation_id = self.conversation_id
tool_file.file_key = file_key
tool_file.mimetype = "text/plain"
tool_file.original_url = "http://example.com/file.txt"
tool_file.name = "test_tool_file.txt"
tool_file.size = 2048
self.session.add(tool_file) self.session.add(tool_file)
self.session.flush() self.session.flush()

View File

@@ -26,14 +26,13 @@ def _gen_id():
class TestFileSaverImpl: class TestFileSaverImpl:
def test_save_binary_string(self, monkeypatch): def test_save_binary_string(self, monkeypatch: pytest.MonkeyPatch):
user_id = _gen_id() user_id = _gen_id()
tenant_id = _gen_id() tenant_id = _gen_id()
file_type = FileType.IMAGE file_type = FileType.IMAGE
mime_type = "image/png" mime_type = "image/png"
mock_signed_url = "https://example.com/image.png" mock_signed_url = "https://example.com/image.png"
mock_tool_file = ToolFile( mock_tool_file = ToolFile(
id=_gen_id(),
user_id=user_id, user_id=user_id,
tenant_id=tenant_id, tenant_id=tenant_id,
conversation_id=None, conversation_id=None,
@@ -43,6 +42,7 @@ class TestFileSaverImpl:
name=f"{_gen_id()}.png", name=f"{_gen_id()}.png",
size=len(_PNG_DATA), size=len(_PNG_DATA),
) )
mock_tool_file.id = _gen_id()
mocked_tool_file_manager = mock.MagicMock(spec=ToolFileManager) mocked_tool_file_manager = mock.MagicMock(spec=ToolFileManager)
mocked_engine = mock.MagicMock(spec=Engine) mocked_engine = mock.MagicMock(spec=Engine)
@@ -80,7 +80,7 @@ class TestFileSaverImpl:
) )
mocked_sign_file.assert_called_once_with(mock_tool_file.id, ".png") mocked_sign_file.assert_called_once_with(mock_tool_file.id, ".png")
def test_save_remote_url_request_failed(self, monkeypatch): def test_save_remote_url_request_failed(self, monkeypatch: pytest.MonkeyPatch):
_TEST_URL = "https://example.com/image.png" _TEST_URL = "https://example.com/image.png"
mock_request = httpx.Request("GET", _TEST_URL) mock_request = httpx.Request("GET", _TEST_URL)
mock_response = httpx.Response( mock_response = httpx.Response(
@@ -99,7 +99,7 @@ class TestFileSaverImpl:
mock_get.assert_called_once_with(_TEST_URL) mock_get.assert_called_once_with(_TEST_URL)
assert exc.value.response.status_code == 401 assert exc.value.response.status_code == 401
def test_save_remote_url_success(self, monkeypatch): def test_save_remote_url_success(self, monkeypatch: pytest.MonkeyPatch):
_TEST_URL = "https://example.com/image.png" _TEST_URL = "https://example.com/image.png"
mime_type = "image/png" mime_type = "image/png"
user_id = _gen_id() user_id = _gen_id()
@@ -115,7 +115,6 @@ class TestFileSaverImpl:
file_saver = FileSaverImpl(user_id=user_id, tenant_id=tenant_id) file_saver = FileSaverImpl(user_id=user_id, tenant_id=tenant_id)
mock_tool_file = ToolFile( mock_tool_file = ToolFile(
id=_gen_id(),
user_id=user_id, user_id=user_id,
tenant_id=tenant_id, tenant_id=tenant_id,
conversation_id=None, conversation_id=None,
@@ -125,6 +124,7 @@ class TestFileSaverImpl:
name=f"{_gen_id()}.png", name=f"{_gen_id()}.png",
size=len(_PNG_DATA), size=len(_PNG_DATA),
) )
mock_tool_file.id = _gen_id()
mock_get = mock.MagicMock(spec=ssrf_proxy.get, return_value=mock_response) mock_get = mock.MagicMock(spec=ssrf_proxy.get, return_value=mock_response)
monkeypatch.setattr(ssrf_proxy, "get", mock_get) monkeypatch.setattr(ssrf_proxy, "get", mock_get)
mock_save_binary_string = mock.MagicMock(spec=file_saver.save_binary_string, return_value=mock_tool_file) mock_save_binary_string = mock.MagicMock(spec=file_saver.save_binary_string, return_value=mock_tool_file)