feat(api/repo): Allow to config repository implementation (#21458)
Signed-off-by: -LAN- <laipz8200@outlook.com> Co-authored-by: QuantumGhost <obelisk.reg+git@gmail.com>
This commit is contained in:
@@ -449,6 +449,19 @@ MAX_VARIABLE_SIZE=204800
|
||||
# hybrid: Save new data to object storage, read from both object storage and RDBMS
|
||||
WORKFLOW_NODE_EXECUTION_STORAGE=rdbms
|
||||
|
||||
# Repository configuration
|
||||
# Core workflow execution repository implementation
|
||||
CORE_WORKFLOW_EXECUTION_REPOSITORY=core.repositories.sqlalchemy_workflow_execution_repository.SQLAlchemyWorkflowExecutionRepository
|
||||
|
||||
# Core workflow node execution repository implementation
|
||||
CORE_WORKFLOW_NODE_EXECUTION_REPOSITORY=core.repositories.sqlalchemy_workflow_node_execution_repository.SQLAlchemyWorkflowNodeExecutionRepository
|
||||
|
||||
# API workflow node execution repository implementation
|
||||
API_WORKFLOW_NODE_EXECUTION_REPOSITORY=repositories.sqlalchemy_api_workflow_node_execution_repository.DifyAPISQLAlchemyWorkflowNodeExecutionRepository
|
||||
|
||||
# API workflow run repository implementation
|
||||
API_WORKFLOW_RUN_REPOSITORY=repositories.sqlalchemy_api_workflow_run_repository.DifyAPISQLAlchemyWorkflowRunRepository
|
||||
|
||||
# App configuration
|
||||
APP_MAX_EXECUTION_TIME=1200
|
||||
APP_MAX_ACTIVE_REQUESTS=0
|
||||
|
@@ -537,6 +537,33 @@ class WorkflowNodeExecutionConfig(BaseSettings):
|
||||
)
|
||||
|
||||
|
||||
class RepositoryConfig(BaseSettings):
|
||||
"""
|
||||
Configuration for repository implementations
|
||||
"""
|
||||
|
||||
CORE_WORKFLOW_EXECUTION_REPOSITORY: str = Field(
|
||||
description="Repository implementation for WorkflowExecution. Specify as a module path",
|
||||
default="core.repositories.sqlalchemy_workflow_execution_repository.SQLAlchemyWorkflowExecutionRepository",
|
||||
)
|
||||
|
||||
CORE_WORKFLOW_NODE_EXECUTION_REPOSITORY: str = Field(
|
||||
description="Repository implementation for WorkflowNodeExecution. Specify as a module path",
|
||||
default="core.repositories.sqlalchemy_workflow_node_execution_repository.SQLAlchemyWorkflowNodeExecutionRepository",
|
||||
)
|
||||
|
||||
API_WORKFLOW_NODE_EXECUTION_REPOSITORY: str = Field(
|
||||
description="Service-layer repository implementation for WorkflowNodeExecutionModel operations. "
|
||||
"Specify as a module path",
|
||||
default="repositories.sqlalchemy_api_workflow_node_execution_repository.DifyAPISQLAlchemyWorkflowNodeExecutionRepository",
|
||||
)
|
||||
|
||||
API_WORKFLOW_RUN_REPOSITORY: str = Field(
|
||||
description="Service-layer repository implementation for WorkflowRun operations. Specify as a module path",
|
||||
default="repositories.sqlalchemy_api_workflow_run_repository.DifyAPISQLAlchemyWorkflowRunRepository",
|
||||
)
|
||||
|
||||
|
||||
class AuthConfig(BaseSettings):
|
||||
"""
|
||||
Configuration for authentication and OAuth
|
||||
@@ -903,6 +930,7 @@ class FeatureConfig(
|
||||
MultiModalTransferConfig,
|
||||
PositionConfig,
|
||||
RagEtlConfig,
|
||||
RepositoryConfig,
|
||||
SecurityConfig,
|
||||
ToolConfig,
|
||||
UpdateConfig,
|
||||
|
@@ -35,8 +35,6 @@ def get_app_model(view: Optional[Callable] = None, *, mode: Union[AppMode, list[
|
||||
raise AppNotFoundError()
|
||||
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode == AppMode.CHANNEL:
|
||||
raise AppNotFoundError()
|
||||
|
||||
if mode is not None:
|
||||
if isinstance(mode, list):
|
||||
|
@@ -3,7 +3,7 @@ import logging
|
||||
from dateutil.parser import isoparse
|
||||
from flask_restful import Resource, fields, marshal_with, reqparse
|
||||
from flask_restful.inputs import int_range
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
from werkzeug.exceptions import InternalServerError
|
||||
|
||||
from controllers.service_api import api
|
||||
@@ -30,7 +30,7 @@ from fields.workflow_app_log_fields import workflow_app_log_pagination_fields
|
||||
from libs import helper
|
||||
from libs.helper import TimestampField
|
||||
from models.model import App, AppMode, EndUser
|
||||
from models.workflow import WorkflowRun
|
||||
from repositories.factory import DifyAPIRepositoryFactory
|
||||
from services.app_generate_service import AppGenerateService
|
||||
from services.errors.llm import InvokeRateLimitError
|
||||
from services.workflow_app_service import WorkflowAppService
|
||||
@@ -63,7 +63,15 @@ class WorkflowRunDetailApi(Resource):
|
||||
if app_mode not in [AppMode.WORKFLOW, AppMode.ADVANCED_CHAT]:
|
||||
raise NotWorkflowAppError()
|
||||
|
||||
workflow_run = db.session.query(WorkflowRun).filter(WorkflowRun.id == workflow_run_id).first()
|
||||
# Use repository to get workflow run
|
||||
session_maker = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||
workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
|
||||
|
||||
workflow_run = workflow_run_repo.get_workflow_run_by_id(
|
||||
tenant_id=app_model.tenant_id,
|
||||
app_id=app_model.id,
|
||||
run_id=workflow_run_id,
|
||||
)
|
||||
return workflow_run
|
||||
|
||||
|
||||
|
@@ -3,6 +3,8 @@ import logging
|
||||
import uuid
|
||||
from typing import Optional, Union, cast
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.agent.entities import AgentEntity, AgentToolEntity
|
||||
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
||||
from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfig
|
||||
@@ -417,12 +419,15 @@ class BaseAgentRunner(AppRunner):
|
||||
if isinstance(prompt_message, SystemPromptMessage):
|
||||
result.append(prompt_message)
|
||||
|
||||
messages: list[Message] = (
|
||||
db.session.query(Message)
|
||||
.filter(
|
||||
Message.conversation_id == self.message.conversation_id,
|
||||
)
|
||||
messages = (
|
||||
(
|
||||
db.session.execute(
|
||||
select(Message)
|
||||
.where(Message.conversation_id == self.message.conversation_id)
|
||||
.order_by(Message.created_at.desc())
|
||||
)
|
||||
)
|
||||
.scalars()
|
||||
.all()
|
||||
)
|
||||
|
||||
|
@@ -25,8 +25,7 @@ from core.app.entities.task_entities import ChatbotAppBlockingResponse, ChatbotA
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from core.prompt.utils.get_thread_messages_length import get_thread_messages_length
|
||||
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
|
||||
from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository
|
||||
from core.repositories import DifyCoreRepositoryFactory
|
||||
from core.workflow.repositories.draft_variable_repository import (
|
||||
DraftVariableSaverFactory,
|
||||
)
|
||||
@@ -183,14 +182,14 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
workflow_triggered_from = WorkflowRunTriggeredFrom.DEBUGGING
|
||||
else:
|
||||
workflow_triggered_from = WorkflowRunTriggeredFrom.APP_RUN
|
||||
workflow_execution_repository = SQLAlchemyWorkflowExecutionRepository(
|
||||
workflow_execution_repository = DifyCoreRepositoryFactory.create_workflow_execution_repository(
|
||||
session_factory=session_factory,
|
||||
user=user,
|
||||
app_id=application_generate_entity.app_config.app_id,
|
||||
triggered_from=workflow_triggered_from,
|
||||
)
|
||||
# Create workflow node execution repository
|
||||
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
|
||||
session_factory=session_factory,
|
||||
user=user,
|
||||
app_id=application_generate_entity.app_config.app_id,
|
||||
@@ -260,14 +259,14 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
# Create session factory
|
||||
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||
# Create workflow execution(aka workflow run) repository
|
||||
workflow_execution_repository = SQLAlchemyWorkflowExecutionRepository(
|
||||
workflow_execution_repository = DifyCoreRepositoryFactory.create_workflow_execution_repository(
|
||||
session_factory=session_factory,
|
||||
user=user,
|
||||
app_id=application_generate_entity.app_config.app_id,
|
||||
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING,
|
||||
)
|
||||
# Create workflow node execution repository
|
||||
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
|
||||
session_factory=session_factory,
|
||||
user=user,
|
||||
app_id=application_generate_entity.app_config.app_id,
|
||||
@@ -343,14 +342,14 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
# Create session factory
|
||||
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||
# Create workflow execution(aka workflow run) repository
|
||||
workflow_execution_repository = SQLAlchemyWorkflowExecutionRepository(
|
||||
workflow_execution_repository = DifyCoreRepositoryFactory.create_workflow_execution_repository(
|
||||
session_factory=session_factory,
|
||||
user=user,
|
||||
app_id=application_generate_entity.app_config.app_id,
|
||||
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING,
|
||||
)
|
||||
# Create workflow node execution repository
|
||||
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
|
||||
session_factory=session_factory,
|
||||
user=user,
|
||||
app_id=application_generate_entity.app_config.app_id,
|
||||
|
@@ -23,8 +23,7 @@ from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerat
|
||||
from core.app.entities.task_entities import WorkflowAppBlockingResponse, WorkflowAppStreamResponse
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
|
||||
from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository
|
||||
from core.repositories import DifyCoreRepositoryFactory
|
||||
from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory
|
||||
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
|
||||
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||
@@ -156,14 +155,14 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
workflow_triggered_from = WorkflowRunTriggeredFrom.DEBUGGING
|
||||
else:
|
||||
workflow_triggered_from = WorkflowRunTriggeredFrom.APP_RUN
|
||||
workflow_execution_repository = SQLAlchemyWorkflowExecutionRepository(
|
||||
workflow_execution_repository = DifyCoreRepositoryFactory.create_workflow_execution_repository(
|
||||
session_factory=session_factory,
|
||||
user=user,
|
||||
app_id=application_generate_entity.app_config.app_id,
|
||||
triggered_from=workflow_triggered_from,
|
||||
)
|
||||
# Create workflow node execution repository
|
||||
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
|
||||
session_factory=session_factory,
|
||||
user=user,
|
||||
app_id=application_generate_entity.app_config.app_id,
|
||||
@@ -306,16 +305,14 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
# Create session factory
|
||||
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||
# Create workflow execution(aka workflow run) repository
|
||||
workflow_execution_repository = SQLAlchemyWorkflowExecutionRepository(
|
||||
workflow_execution_repository = DifyCoreRepositoryFactory.create_workflow_execution_repository(
|
||||
session_factory=session_factory,
|
||||
user=user,
|
||||
app_id=application_generate_entity.app_config.app_id,
|
||||
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING,
|
||||
)
|
||||
# Create workflow node execution repository
|
||||
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||
|
||||
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
|
||||
session_factory=session_factory,
|
||||
user=user,
|
||||
app_id=application_generate_entity.app_config.app_id,
|
||||
@@ -390,16 +387,14 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
# Create session factory
|
||||
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||
# Create workflow execution(aka workflow run) repository
|
||||
workflow_execution_repository = SQLAlchemyWorkflowExecutionRepository(
|
||||
workflow_execution_repository = DifyCoreRepositoryFactory.create_workflow_execution_repository(
|
||||
session_factory=session_factory,
|
||||
user=user,
|
||||
app_id=application_generate_entity.app_config.app_id,
|
||||
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING,
|
||||
)
|
||||
# Create workflow node execution repository
|
||||
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||
|
||||
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
|
||||
session_factory=session_factory,
|
||||
user=user,
|
||||
app_id=application_generate_entity.app_config.app_id,
|
||||
|
@@ -3,7 +3,6 @@ import time
|
||||
from collections.abc import Generator
|
||||
from typing import Optional, Union
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME
|
||||
@@ -68,7 +67,6 @@ from models.workflow import (
|
||||
Workflow,
|
||||
WorkflowAppLog,
|
||||
WorkflowAppLogCreatedFrom,
|
||||
WorkflowRun,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -562,8 +560,6 @@ class WorkflowAppGenerateTaskPipeline:
|
||||
tts_publisher.publish(None)
|
||||
|
||||
def _save_workflow_app_log(self, *, session: Session, workflow_execution: WorkflowExecution) -> None:
|
||||
workflow_run = session.scalar(select(WorkflowRun).where(WorkflowRun.id == workflow_execution.id_))
|
||||
assert workflow_run is not None
|
||||
invoke_from = self._application_generate_entity.invoke_from
|
||||
if invoke_from == InvokeFrom.SERVICE_API:
|
||||
created_from = WorkflowAppLogCreatedFrom.SERVICE_API
|
||||
@@ -576,10 +572,10 @@ class WorkflowAppGenerateTaskPipeline:
|
||||
return
|
||||
|
||||
workflow_app_log = WorkflowAppLog()
|
||||
workflow_app_log.tenant_id = workflow_run.tenant_id
|
||||
workflow_app_log.app_id = workflow_run.app_id
|
||||
workflow_app_log.workflow_id = workflow_run.workflow_id
|
||||
workflow_app_log.workflow_run_id = workflow_run.id
|
||||
workflow_app_log.tenant_id = self._application_generate_entity.app_config.tenant_id
|
||||
workflow_app_log.app_id = self._application_generate_entity.app_config.app_id
|
||||
workflow_app_log.workflow_id = workflow_execution.workflow_id
|
||||
workflow_app_log.workflow_run_id = workflow_execution.id_
|
||||
workflow_app_log.created_from = created_from.value
|
||||
workflow_app_log.created_by_role = self._created_by_role
|
||||
workflow_app_log.created_by = self._user_id
|
||||
|
@@ -1,6 +1,8 @@
|
||||
from collections.abc import Sequence
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
||||
from core.file import file_manager
|
||||
from core.model_manager import ModelInstance
|
||||
@@ -17,11 +19,15 @@ from core.prompt.utils.extract_thread_messages import extract_thread_messages
|
||||
from extensions.ext_database import db
|
||||
from factories import file_factory
|
||||
from models.model import AppMode, Conversation, Message, MessageFile
|
||||
from models.workflow import WorkflowRun
|
||||
from models.workflow import Workflow, WorkflowRun
|
||||
|
||||
|
||||
class TokenBufferMemory:
|
||||
def __init__(self, conversation: Conversation, model_instance: ModelInstance) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
conversation: Conversation,
|
||||
model_instance: ModelInstance,
|
||||
) -> None:
|
||||
self.conversation = conversation
|
||||
self.model_instance = model_instance
|
||||
|
||||
@@ -36,20 +42,8 @@ class TokenBufferMemory:
|
||||
app_record = self.conversation.app
|
||||
|
||||
# fetch limited messages, and return reversed
|
||||
query = (
|
||||
db.session.query(
|
||||
Message.id,
|
||||
Message.query,
|
||||
Message.answer,
|
||||
Message.created_at,
|
||||
Message.workflow_run_id,
|
||||
Message.parent_message_id,
|
||||
Message.answer_tokens,
|
||||
)
|
||||
.filter(
|
||||
Message.conversation_id == self.conversation.id,
|
||||
)
|
||||
.order_by(Message.created_at.desc())
|
||||
stmt = (
|
||||
select(Message).where(Message.conversation_id == self.conversation.id).order_by(Message.created_at.desc())
|
||||
)
|
||||
|
||||
if message_limit and message_limit > 0:
|
||||
@@ -57,7 +51,9 @@ class TokenBufferMemory:
|
||||
else:
|
||||
message_limit = 500
|
||||
|
||||
messages = query.limit(message_limit).all()
|
||||
stmt = stmt.limit(message_limit)
|
||||
|
||||
messages = db.session.scalars(stmt).all()
|
||||
|
||||
# instead of all messages from the conversation, we only need to extract messages
|
||||
# that belong to the thread of last message
|
||||
@@ -74,18 +70,20 @@ class TokenBufferMemory:
|
||||
files = db.session.query(MessageFile).filter(MessageFile.message_id == message.id).all()
|
||||
if files:
|
||||
file_extra_config = None
|
||||
if self.conversation.mode not in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
|
||||
if self.conversation.mode in {AppMode.AGENT_CHAT, AppMode.COMPLETION, AppMode.CHAT}:
|
||||
file_extra_config = FileUploadConfigManager.convert(self.conversation.model_config)
|
||||
elif self.conversation.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
|
||||
workflow_run = db.session.scalar(
|
||||
select(WorkflowRun).where(WorkflowRun.id == message.workflow_run_id)
|
||||
)
|
||||
if not workflow_run:
|
||||
raise ValueError(f"Workflow run not found: {message.workflow_run_id}")
|
||||
workflow = db.session.scalar(select(Workflow).where(Workflow.id == workflow_run.workflow_id))
|
||||
if not workflow:
|
||||
raise ValueError(f"Workflow not found: {workflow_run.workflow_id}")
|
||||
file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False)
|
||||
else:
|
||||
if message.workflow_run_id:
|
||||
workflow_run = (
|
||||
db.session.query(WorkflowRun).filter(WorkflowRun.id == message.workflow_run_id).first()
|
||||
)
|
||||
|
||||
if workflow_run and workflow_run.workflow:
|
||||
file_extra_config = FileUploadConfigManager.convert(
|
||||
workflow_run.workflow.features_dict, is_vision=False
|
||||
)
|
||||
raise AssertionError(f"Invalid app mode: {self.conversation.mode}")
|
||||
|
||||
detail = ImagePromptMessageContent.DETAIL.LOW
|
||||
if file_extra_config and app_record:
|
||||
|
@@ -28,7 +28,7 @@ from core.ops.langfuse_trace.entities.langfuse_trace_entity import (
|
||||
UnitEnum,
|
||||
)
|
||||
from core.ops.utils import filter_none_values
|
||||
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
|
||||
from core.repositories import DifyCoreRepositoryFactory
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from extensions.ext_database import db
|
||||
from models import EndUser, WorkflowNodeExecutionTriggeredFrom
|
||||
@@ -123,10 +123,10 @@ class LangFuseDataTrace(BaseTraceInstance):
|
||||
|
||||
service_account = self.get_service_account_with_tenant(app_id)
|
||||
|
||||
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
|
||||
session_factory=session_factory,
|
||||
user=service_account,
|
||||
app_id=trace_info.metadata.get("app_id"),
|
||||
app_id=app_id,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
|
||||
|
@@ -27,7 +27,7 @@ from core.ops.langsmith_trace.entities.langsmith_trace_entity import (
|
||||
LangSmithRunUpdateModel,
|
||||
)
|
||||
from core.ops.utils import filter_none_values, generate_dotted_order
|
||||
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
|
||||
from core.repositories import DifyCoreRepositoryFactory
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from extensions.ext_database import db
|
||||
@@ -145,10 +145,10 @@ class LangSmithDataTrace(BaseTraceInstance):
|
||||
|
||||
service_account = self.get_service_account_with_tenant(app_id)
|
||||
|
||||
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
|
||||
session_factory=session_factory,
|
||||
user=service_account,
|
||||
app_id=trace_info.metadata.get("app_id"),
|
||||
app_id=app_id,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
|
||||
|
@@ -21,7 +21,7 @@ from core.ops.entities.trace_entity import (
|
||||
TraceTaskName,
|
||||
WorkflowTraceInfo,
|
||||
)
|
||||
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
|
||||
from core.repositories import DifyCoreRepositoryFactory
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from extensions.ext_database import db
|
||||
@@ -160,10 +160,10 @@ class OpikDataTrace(BaseTraceInstance):
|
||||
|
||||
service_account = self.get_service_account_with_tenant(app_id)
|
||||
|
||||
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
|
||||
session_factory=session_factory,
|
||||
user=service_account,
|
||||
app_id=trace_info.metadata.get("app_id"),
|
||||
app_id=app_id,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
|
||||
|
@@ -22,7 +22,7 @@ from core.ops.entities.trace_entity import (
|
||||
WorkflowTraceInfo,
|
||||
)
|
||||
from core.ops.weave_trace.entities.weave_trace_entity import WeaveTraceModel
|
||||
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
|
||||
from core.repositories import DifyCoreRepositoryFactory
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from extensions.ext_database import db
|
||||
@@ -144,10 +144,10 @@ class WeaveDataTrace(BaseTraceInstance):
|
||||
|
||||
service_account = self.get_service_account_with_tenant(app_id)
|
||||
|
||||
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
|
||||
session_factory=session_factory,
|
||||
user=service_account,
|
||||
app_id=trace_info.metadata.get("app_id"),
|
||||
app_id=app_id,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
|
||||
|
@@ -1,10 +1,11 @@
|
||||
from typing import Any
|
||||
from collections.abc import Sequence
|
||||
|
||||
from constants import UUID_NIL
|
||||
from models import Message
|
||||
|
||||
|
||||
def extract_thread_messages(messages: list[Any]):
|
||||
thread_messages = []
|
||||
def extract_thread_messages(messages: Sequence[Message]):
|
||||
thread_messages: list[Message] = []
|
||||
next_message = None
|
||||
|
||||
for message in messages:
|
||||
|
@@ -1,3 +1,5 @@
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.prompt.utils.extract_thread_messages import extract_thread_messages
|
||||
from extensions.ext_database import db
|
||||
from models.model import Message
|
||||
@@ -8,19 +10,9 @@ def get_thread_messages_length(conversation_id: str) -> int:
|
||||
Get the number of thread messages based on the parent message id.
|
||||
"""
|
||||
# Fetch all messages related to the conversation
|
||||
query = (
|
||||
db.session.query(
|
||||
Message.id,
|
||||
Message.parent_message_id,
|
||||
Message.answer,
|
||||
)
|
||||
.filter(
|
||||
Message.conversation_id == conversation_id,
|
||||
)
|
||||
.order_by(Message.created_at.desc())
|
||||
)
|
||||
stmt = select(Message).where(Message.conversation_id == conversation_id).order_by(Message.created_at.desc())
|
||||
|
||||
messages = query.all()
|
||||
messages = db.session.scalars(stmt).all()
|
||||
|
||||
# Extract thread messages
|
||||
thread_messages = extract_thread_messages(messages)
|
||||
|
@@ -5,8 +5,11 @@ This package contains concrete implementations of the repository interfaces
|
||||
defined in the core.workflow.repository package.
|
||||
"""
|
||||
|
||||
from core.repositories.factory import DifyCoreRepositoryFactory, RepositoryImportError
|
||||
from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository
|
||||
|
||||
__all__ = [
|
||||
"DifyCoreRepositoryFactory",
|
||||
"RepositoryImportError",
|
||||
"SQLAlchemyWorkflowNodeExecutionRepository",
|
||||
]
|
||||
|
224
api/core/repositories/factory.py
Normal file
224
api/core/repositories/factory.py
Normal file
@@ -0,0 +1,224 @@
|
||||
"""
|
||||
Repository factory for dynamically creating repository instances based on configuration.
|
||||
|
||||
This module provides a Django-like settings system for repository implementations,
|
||||
allowing users to configure different repository backends through string paths.
|
||||
"""
|
||||
|
||||
import importlib
|
||||
import inspect
|
||||
import logging
|
||||
from typing import Protocol, Union
|
||||
|
||||
from sqlalchemy.engine import Engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from configs import dify_config
|
||||
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
|
||||
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||
from models import Account, EndUser
|
||||
from models.enums import WorkflowRunTriggeredFrom
|
||||
from models.workflow import WorkflowNodeExecutionTriggeredFrom
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RepositoryImportError(Exception):
|
||||
"""Raised when a repository implementation cannot be imported or instantiated."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class DifyCoreRepositoryFactory:
|
||||
"""
|
||||
Factory for creating repository instances based on configuration.
|
||||
|
||||
This factory supports Django-like settings where repository implementations
|
||||
are specified as module paths (e.g., 'module.submodule.ClassName').
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _import_class(class_path: str) -> type:
|
||||
"""
|
||||
Import a class from a module path string.
|
||||
|
||||
Args:
|
||||
class_path: Full module path to the class (e.g., 'module.submodule.ClassName')
|
||||
|
||||
Returns:
|
||||
The imported class
|
||||
|
||||
Raises:
|
||||
RepositoryImportError: If the class cannot be imported
|
||||
"""
|
||||
try:
|
||||
module_path, class_name = class_path.rsplit(".", 1)
|
||||
module = importlib.import_module(module_path)
|
||||
repo_class = getattr(module, class_name)
|
||||
assert isinstance(repo_class, type)
|
||||
return repo_class
|
||||
except (ValueError, ImportError, AttributeError) as e:
|
||||
raise RepositoryImportError(f"Cannot import repository class '{class_path}': {e}") from e
|
||||
|
||||
@staticmethod
|
||||
def _validate_repository_interface(repository_class: type, expected_interface: type[Protocol]) -> None: # type: ignore
|
||||
"""
|
||||
Validate that a class implements the expected repository interface.
|
||||
|
||||
Args:
|
||||
repository_class: The class to validate
|
||||
expected_interface: The expected interface/protocol
|
||||
|
||||
Raises:
|
||||
RepositoryImportError: If the class doesn't implement the interface
|
||||
"""
|
||||
# Check if the class has all required methods from the protocol
|
||||
required_methods = [
|
||||
method
|
||||
for method in dir(expected_interface)
|
||||
if not method.startswith("_") and callable(getattr(expected_interface, method, None))
|
||||
]
|
||||
|
||||
missing_methods = []
|
||||
for method_name in required_methods:
|
||||
if not hasattr(repository_class, method_name):
|
||||
missing_methods.append(method_name)
|
||||
|
||||
if missing_methods:
|
||||
raise RepositoryImportError(
|
||||
f"Repository class '{repository_class.__name__}' does not implement required methods "
|
||||
f"{missing_methods} from interface '{expected_interface.__name__}'"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _validate_constructor_signature(repository_class: type, required_params: list[str]) -> None:
|
||||
"""
|
||||
Validate that a repository class constructor accepts required parameters.
|
||||
|
||||
Args:
|
||||
repository_class: The class to validate
|
||||
required_params: List of required parameter names
|
||||
|
||||
Raises:
|
||||
RepositoryImportError: If the constructor doesn't accept required parameters
|
||||
"""
|
||||
|
||||
try:
|
||||
# MyPy may flag the line below with the following error:
|
||||
#
|
||||
# > Accessing "__init__" on an instance is unsound, since
|
||||
# > instance.__init__ could be from an incompatible subclass.
|
||||
#
|
||||
# Despite this, we need to ensure that the constructor of `repository_class`
|
||||
# has a compatible signature.
|
||||
signature = inspect.signature(repository_class.__init__) # type: ignore[misc]
|
||||
param_names = list(signature.parameters.keys())
|
||||
|
||||
# Remove 'self' parameter
|
||||
if "self" in param_names:
|
||||
param_names.remove("self")
|
||||
|
||||
missing_params = [param for param in required_params if param not in param_names]
|
||||
if missing_params:
|
||||
raise RepositoryImportError(
|
||||
f"Repository class '{repository_class.__name__}' constructor does not accept required parameters: "
|
||||
f"{missing_params}. Expected parameters: {required_params}"
|
||||
)
|
||||
except Exception as e:
|
||||
raise RepositoryImportError(
|
||||
f"Failed to validate constructor signature for '{repository_class.__name__}': {e}"
|
||||
) from e
|
||||
|
||||
@classmethod
|
||||
def create_workflow_execution_repository(
|
||||
cls,
|
||||
session_factory: Union[sessionmaker, Engine],
|
||||
user: Union[Account, EndUser],
|
||||
app_id: str,
|
||||
triggered_from: WorkflowRunTriggeredFrom,
|
||||
) -> WorkflowExecutionRepository:
|
||||
"""
|
||||
Create a WorkflowExecutionRepository instance based on configuration.
|
||||
|
||||
Args:
|
||||
session_factory: SQLAlchemy sessionmaker or engine
|
||||
user: Account or EndUser object
|
||||
app_id: Application ID
|
||||
triggered_from: Source of the execution trigger
|
||||
|
||||
Returns:
|
||||
Configured WorkflowExecutionRepository instance
|
||||
|
||||
Raises:
|
||||
RepositoryImportError: If the configured repository cannot be created
|
||||
"""
|
||||
class_path = dify_config.CORE_WORKFLOW_EXECUTION_REPOSITORY
|
||||
logger.debug(f"Creating WorkflowExecutionRepository from: {class_path}")
|
||||
|
||||
try:
|
||||
repository_class = cls._import_class(class_path)
|
||||
cls._validate_repository_interface(repository_class, WorkflowExecutionRepository)
|
||||
cls._validate_constructor_signature(
|
||||
repository_class, ["session_factory", "user", "app_id", "triggered_from"]
|
||||
)
|
||||
|
||||
return repository_class( # type: ignore[no-any-return]
|
||||
session_factory=session_factory,
|
||||
user=user,
|
||||
app_id=app_id,
|
||||
triggered_from=triggered_from,
|
||||
)
|
||||
except RepositoryImportError:
|
||||
# Re-raise our custom errors as-is
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception("Failed to create WorkflowExecutionRepository")
|
||||
raise RepositoryImportError(f"Failed to create WorkflowExecutionRepository from '{class_path}': {e}") from e
|
||||
|
||||
@classmethod
|
||||
def create_workflow_node_execution_repository(
|
||||
cls,
|
||||
session_factory: Union[sessionmaker, Engine],
|
||||
user: Union[Account, EndUser],
|
||||
app_id: str,
|
||||
triggered_from: WorkflowNodeExecutionTriggeredFrom,
|
||||
) -> WorkflowNodeExecutionRepository:
|
||||
"""
|
||||
Create a WorkflowNodeExecutionRepository instance based on configuration.
|
||||
|
||||
Args:
|
||||
session_factory: SQLAlchemy sessionmaker or engine
|
||||
user: Account or EndUser object
|
||||
app_id: Application ID
|
||||
triggered_from: Source of the execution trigger
|
||||
|
||||
Returns:
|
||||
Configured WorkflowNodeExecutionRepository instance
|
||||
|
||||
Raises:
|
||||
RepositoryImportError: If the configured repository cannot be created
|
||||
"""
|
||||
class_path = dify_config.CORE_WORKFLOW_NODE_EXECUTION_REPOSITORY
|
||||
logger.debug(f"Creating WorkflowNodeExecutionRepository from: {class_path}")
|
||||
|
||||
try:
|
||||
repository_class = cls._import_class(class_path)
|
||||
cls._validate_repository_interface(repository_class, WorkflowNodeExecutionRepository)
|
||||
cls._validate_constructor_signature(
|
||||
repository_class, ["session_factory", "user", "app_id", "triggered_from"]
|
||||
)
|
||||
|
||||
return repository_class( # type: ignore[no-any-return]
|
||||
session_factory=session_factory,
|
||||
user=user,
|
||||
app_id=app_id,
|
||||
triggered_from=triggered_from,
|
||||
)
|
||||
except RepositoryImportError:
|
||||
# Re-raise our custom errors as-is
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception("Failed to create WorkflowNodeExecutionRepository")
|
||||
raise RepositoryImportError(
|
||||
f"Failed to create WorkflowNodeExecutionRepository from '{class_path}': {e}"
|
||||
) from e
|
@@ -50,7 +50,6 @@ class AppMode(StrEnum):
|
||||
CHAT = "chat"
|
||||
ADVANCED_CHAT = "advanced-chat"
|
||||
AGENT_CHAT = "agent-chat"
|
||||
CHANNEL = "channel"
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: str) -> "AppMode":
|
||||
@@ -934,7 +933,7 @@ class Message(Base):
|
||||
created_at: Mapped[datetime] = mapped_column(db.DateTime, server_default=func.current_timestamp())
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
agent_based = db.Column(db.Boolean, nullable=False, server_default=db.text("false"))
|
||||
workflow_run_id = db.Column(StringUUID)
|
||||
workflow_run_id: Mapped[str] = db.Column(StringUUID)
|
||||
|
||||
@property
|
||||
def inputs(self):
|
||||
|
0
api/repositories/__init__.py
Normal file
0
api/repositories/__init__.py
Normal file
197
api/repositories/api_workflow_node_execution_repository.py
Normal file
197
api/repositories/api_workflow_node_execution_repository.py
Normal file
@@ -0,0 +1,197 @@
|
||||
"""
|
||||
Service-layer repository protocol for WorkflowNodeExecutionModel operations.
|
||||
|
||||
This module provides a protocol interface for service-layer operations on WorkflowNodeExecutionModel
|
||||
that abstracts database queries currently done directly in service classes. This repository is
|
||||
specifically designed for service-layer needs and is separate from the core domain repository.
|
||||
|
||||
The service repository handles operations that require access to database-specific fields like
|
||||
tenant_id, app_id, triggered_from, etc., which are not part of the core domain model.
|
||||
"""
|
||||
|
||||
from collections.abc import Sequence
|
||||
from datetime import datetime
|
||||
from typing import Optional, Protocol
|
||||
|
||||
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||
from models.workflow import WorkflowNodeExecutionModel
|
||||
|
||||
|
||||
class DifyAPIWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository, Protocol):
|
||||
"""
|
||||
Protocol for service-layer operations on WorkflowNodeExecutionModel.
|
||||
|
||||
This repository provides database access patterns specifically needed by service classes,
|
||||
handling queries that involve database-specific fields and multi-tenancy concerns.
|
||||
|
||||
Key responsibilities:
|
||||
- Manages database operations for workflow node executions
|
||||
- Handles multi-tenant data isolation
|
||||
- Provides batch processing capabilities
|
||||
- Supports execution lifecycle management
|
||||
|
||||
Implementation notes:
|
||||
- Returns database models directly (WorkflowNodeExecutionModel)
|
||||
- Handles tenant/app filtering automatically
|
||||
- Provides service-specific query patterns
|
||||
- Focuses on database operations without domain logic
|
||||
- Supports cleanup and maintenance operations
|
||||
"""
|
||||
|
||||
def get_node_last_execution(
|
||||
self,
|
||||
tenant_id: str,
|
||||
app_id: str,
|
||||
workflow_id: str,
|
||||
node_id: str,
|
||||
) -> Optional[WorkflowNodeExecutionModel]:
|
||||
"""
|
||||
Get the most recent execution for a specific node.
|
||||
|
||||
This method finds the latest execution of a specific node within a workflow,
|
||||
ordered by creation time. Used primarily for debugging and inspection purposes.
|
||||
|
||||
Args:
|
||||
tenant_id: The tenant identifier
|
||||
app_id: The application identifier
|
||||
workflow_id: The workflow identifier
|
||||
node_id: The node identifier
|
||||
|
||||
Returns:
|
||||
The most recent WorkflowNodeExecutionModel for the node, or None if not found
|
||||
"""
|
||||
...
|
||||
|
||||
def get_executions_by_workflow_run(
|
||||
self,
|
||||
tenant_id: str,
|
||||
app_id: str,
|
||||
workflow_run_id: str,
|
||||
) -> Sequence[WorkflowNodeExecutionModel]:
|
||||
"""
|
||||
Get all node executions for a specific workflow run.
|
||||
|
||||
This method retrieves all node executions that belong to a specific workflow run,
|
||||
ordered by index in descending order for proper trace visualization.
|
||||
|
||||
Args:
|
||||
tenant_id: The tenant identifier
|
||||
app_id: The application identifier
|
||||
workflow_run_id: The workflow run identifier
|
||||
|
||||
Returns:
|
||||
A sequence of WorkflowNodeExecutionModel instances ordered by index (desc)
|
||||
"""
|
||||
...
|
||||
|
||||
def get_execution_by_id(
|
||||
self,
|
||||
execution_id: str,
|
||||
tenant_id: Optional[str] = None,
|
||||
) -> Optional[WorkflowNodeExecutionModel]:
|
||||
"""
|
||||
Get a workflow node execution by its ID.
|
||||
|
||||
This method retrieves a specific execution by its unique identifier.
|
||||
Tenant filtering is optional for cases where the execution ID is globally unique.
|
||||
|
||||
When `tenant_id` is None, it's the caller's responsibility to ensure proper data isolation between tenants.
|
||||
If the `execution_id` comes from untrusted sources (e.g., retrieved from an API request), the caller should
|
||||
set `tenant_id` to prevent horizontal privilege escalation.
|
||||
|
||||
Args:
|
||||
execution_id: The execution identifier
|
||||
tenant_id: Optional tenant identifier for additional filtering
|
||||
|
||||
Returns:
|
||||
The WorkflowNodeExecutionModel if found, or None if not found
|
||||
"""
|
||||
...
|
||||
|
||||
def delete_expired_executions(
|
||||
self,
|
||||
tenant_id: str,
|
||||
before_date: datetime,
|
||||
batch_size: int = 1000,
|
||||
) -> int:
|
||||
"""
|
||||
Delete workflow node executions that are older than the specified date.
|
||||
|
||||
This method is used for cleanup operations to remove expired executions
|
||||
in batches to avoid overwhelming the database.
|
||||
|
||||
Args:
|
||||
tenant_id: The tenant identifier
|
||||
before_date: Delete executions created before this date
|
||||
batch_size: Maximum number of executions to delete in one batch
|
||||
|
||||
Returns:
|
||||
The number of executions deleted
|
||||
"""
|
||||
...
|
||||
|
||||
def delete_executions_by_app(
|
||||
self,
|
||||
tenant_id: str,
|
||||
app_id: str,
|
||||
batch_size: int = 1000,
|
||||
) -> int:
|
||||
"""
|
||||
Delete all workflow node executions for a specific app.
|
||||
|
||||
This method is used when removing an app and all its related data.
|
||||
Executions are deleted in batches to avoid overwhelming the database.
|
||||
|
||||
Args:
|
||||
tenant_id: The tenant identifier
|
||||
app_id: The application identifier
|
||||
batch_size: Maximum number of executions to delete in one batch
|
||||
|
||||
Returns:
|
||||
The total number of executions deleted
|
||||
"""
|
||||
...
|
||||
|
||||
def get_expired_executions_batch(
|
||||
self,
|
||||
tenant_id: str,
|
||||
before_date: datetime,
|
||||
batch_size: int = 1000,
|
||||
) -> Sequence[WorkflowNodeExecutionModel]:
|
||||
"""
|
||||
Get a batch of expired workflow node executions for backup purposes.
|
||||
|
||||
This method retrieves expired executions without deleting them,
|
||||
allowing the caller to backup the data before deletion.
|
||||
|
||||
Args:
|
||||
tenant_id: The tenant identifier
|
||||
before_date: Get executions created before this date
|
||||
batch_size: Maximum number of executions to retrieve
|
||||
|
||||
Returns:
|
||||
A sequence of WorkflowNodeExecutionModel instances
|
||||
"""
|
||||
...
|
||||
|
||||
def delete_executions_by_ids(
|
||||
self,
|
||||
execution_ids: Sequence[str],
|
||||
) -> int:
|
||||
"""
|
||||
Delete workflow node executions by their IDs.
|
||||
|
||||
This method deletes specific executions by their IDs,
|
||||
typically used after backing up the data.
|
||||
|
||||
This method does not perform tenant isolation checks. The caller is responsible for ensuring proper
|
||||
data isolation between tenants. When execution IDs come from untrusted sources (e.g., API requests),
|
||||
additional tenant validation should be implemented to prevent unauthorized access.
|
||||
|
||||
Args:
|
||||
execution_ids: List of execution IDs to delete
|
||||
|
||||
Returns:
|
||||
The number of executions deleted
|
||||
"""
|
||||
...
|
181
api/repositories/api_workflow_run_repository.py
Normal file
181
api/repositories/api_workflow_run_repository.py
Normal file
@@ -0,0 +1,181 @@
|
||||
"""
|
||||
API WorkflowRun Repository Protocol
|
||||
|
||||
This module defines the protocol for service-layer WorkflowRun operations.
|
||||
The repository provides an abstraction layer for WorkflowRun database operations
|
||||
used by service classes, separating service-layer concerns from core domain logic.
|
||||
|
||||
Key Features:
|
||||
- Paginated workflow run queries with filtering
|
||||
- Bulk deletion operations with OSS backup support
|
||||
- Multi-tenant data isolation
|
||||
- Expired record cleanup with data retention
|
||||
- Service-layer specific query patterns
|
||||
|
||||
Usage:
|
||||
This protocol should be used by service classes that need to perform
|
||||
WorkflowRun database operations. It provides a clean interface that
|
||||
hides implementation details and supports dependency injection.
|
||||
|
||||
Example:
|
||||
```python
|
||||
from repositories.dify_api_repository_factory import DifyAPIRepositoryFactory
|
||||
|
||||
session_maker = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||
repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
|
||||
|
||||
# Get paginated workflow runs
|
||||
runs = repo.get_paginated_workflow_runs(
|
||||
tenant_id="tenant-123",
|
||||
app_id="app-456",
|
||||
triggered_from="debugging",
|
||||
limit=20
|
||||
)
|
||||
```
|
||||
"""
|
||||
|
||||
from collections.abc import Sequence
|
||||
from datetime import datetime
|
||||
from typing import Optional, Protocol
|
||||
|
||||
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
|
||||
from libs.infinite_scroll_pagination import InfiniteScrollPagination
|
||||
from models.workflow import WorkflowRun
|
||||
|
||||
|
||||
class APIWorkflowRunRepository(WorkflowExecutionRepository, Protocol):
|
||||
"""
|
||||
Protocol for service-layer WorkflowRun repository operations.
|
||||
|
||||
This protocol defines the interface for WorkflowRun database operations
|
||||
that are specific to service-layer needs, including pagination, filtering,
|
||||
and bulk operations with data backup support.
|
||||
"""
|
||||
|
||||
def get_paginated_workflow_runs(
|
||||
self,
|
||||
tenant_id: str,
|
||||
app_id: str,
|
||||
triggered_from: str,
|
||||
limit: int = 20,
|
||||
last_id: Optional[str] = None,
|
||||
) -> InfiniteScrollPagination:
|
||||
"""
|
||||
Get paginated workflow runs with filtering.
|
||||
|
||||
Retrieves workflow runs for a specific app and trigger source with
|
||||
cursor-based pagination support. Used primarily for debugging and
|
||||
workflow run listing in the UI.
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant identifier for multi-tenant isolation
|
||||
app_id: Application identifier
|
||||
triggered_from: Filter by trigger source (e.g., "debugging", "app-run")
|
||||
limit: Maximum number of records to return (default: 20)
|
||||
last_id: Cursor for pagination - ID of the last record from previous page
|
||||
|
||||
Returns:
|
||||
InfiniteScrollPagination object containing:
|
||||
- data: List of WorkflowRun objects
|
||||
- limit: Applied limit
|
||||
- has_more: Boolean indicating if more records exist
|
||||
|
||||
Raises:
|
||||
ValueError: If last_id is provided but the corresponding record doesn't exist
|
||||
"""
|
||||
...
|
||||
|
||||
def get_workflow_run_by_id(
|
||||
self,
|
||||
tenant_id: str,
|
||||
app_id: str,
|
||||
run_id: str,
|
||||
) -> Optional[WorkflowRun]:
|
||||
"""
|
||||
Get a specific workflow run by ID.
|
||||
|
||||
Retrieves a single workflow run with tenant and app isolation.
|
||||
Used for workflow run detail views and execution tracking.
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant identifier for multi-tenant isolation
|
||||
app_id: Application identifier
|
||||
run_id: Workflow run identifier
|
||||
|
||||
Returns:
|
||||
WorkflowRun object if found, None otherwise
|
||||
"""
|
||||
...
|
||||
|
||||
def get_expired_runs_batch(
|
||||
self,
|
||||
tenant_id: str,
|
||||
before_date: datetime,
|
||||
batch_size: int = 1000,
|
||||
) -> Sequence[WorkflowRun]:
|
||||
"""
|
||||
Get a batch of expired workflow runs for cleanup.
|
||||
|
||||
Retrieves workflow runs created before the specified date for
|
||||
cleanup operations. Used by scheduled tasks to remove old data
|
||||
while maintaining data retention policies.
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant identifier for multi-tenant isolation
|
||||
before_date: Only return runs created before this date
|
||||
batch_size: Maximum number of records to return
|
||||
|
||||
Returns:
|
||||
Sequence of WorkflowRun objects to be processed for cleanup
|
||||
"""
|
||||
...
|
||||
|
||||
def delete_runs_by_ids(
|
||||
self,
|
||||
run_ids: Sequence[str],
|
||||
) -> int:
|
||||
"""
|
||||
Delete workflow runs by their IDs.
|
||||
|
||||
Performs bulk deletion of workflow runs by ID. This method should
|
||||
be used after backing up the data to OSS storage for retention.
|
||||
|
||||
Args:
|
||||
run_ids: Sequence of workflow run IDs to delete
|
||||
|
||||
Returns:
|
||||
Number of records actually deleted
|
||||
|
||||
Note:
|
||||
This method performs hard deletion. Ensure data is backed up
|
||||
to OSS storage before calling this method for compliance with
|
||||
data retention policies.
|
||||
"""
|
||||
...
|
||||
|
||||
def delete_runs_by_app(
|
||||
self,
|
||||
tenant_id: str,
|
||||
app_id: str,
|
||||
batch_size: int = 1000,
|
||||
) -> int:
|
||||
"""
|
||||
Delete all workflow runs for a specific app.
|
||||
|
||||
Performs bulk deletion of all workflow runs associated with an app.
|
||||
Used during app cleanup operations. Processes records in batches
|
||||
to avoid memory issues and long-running transactions.
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant identifier for multi-tenant isolation
|
||||
app_id: Application identifier
|
||||
batch_size: Number of records to process in each batch
|
||||
|
||||
Returns:
|
||||
Total number of records deleted across all batches
|
||||
|
||||
Note:
|
||||
This method performs hard deletion without backup. Use with caution
|
||||
and ensure proper data retention policies are followed.
|
||||
"""
|
||||
...
|
103
api/repositories/factory.py
Normal file
103
api/repositories/factory.py
Normal file
@@ -0,0 +1,103 @@
|
||||
"""
|
||||
DifyAPI Repository Factory for creating repository instances.
|
||||
|
||||
This factory is specifically designed for DifyAPI repositories that handle
|
||||
service-layer operations with dependency injection patterns.
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from configs import dify_config
|
||||
from core.repositories import DifyCoreRepositoryFactory, RepositoryImportError
|
||||
from repositories.api_workflow_node_execution_repository import DifyAPIWorkflowNodeExecutionRepository
|
||||
from repositories.api_workflow_run_repository import APIWorkflowRunRepository
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DifyAPIRepositoryFactory(DifyCoreRepositoryFactory):
|
||||
"""
|
||||
Factory for creating DifyAPI repository instances based on configuration.
|
||||
|
||||
This factory handles the creation of repositories that are specifically designed
|
||||
for service-layer operations and use dependency injection with sessionmaker
|
||||
for better testability and separation of concerns.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def create_api_workflow_node_execution_repository(
|
||||
cls, session_maker: sessionmaker
|
||||
) -> DifyAPIWorkflowNodeExecutionRepository:
|
||||
"""
|
||||
Create a DifyAPIWorkflowNodeExecutionRepository instance based on configuration.
|
||||
|
||||
This repository is designed for service-layer operations and uses dependency injection
|
||||
with a sessionmaker for better testability and separation of concerns. It provides
|
||||
database access patterns specifically needed by service classes, handling queries
|
||||
that involve database-specific fields and multi-tenancy concerns.
|
||||
|
||||
Args:
|
||||
session_maker: SQLAlchemy sessionmaker to inject for database session management.
|
||||
|
||||
Returns:
|
||||
Configured DifyAPIWorkflowNodeExecutionRepository instance
|
||||
|
||||
Raises:
|
||||
RepositoryImportError: If the configured repository cannot be imported or instantiated
|
||||
"""
|
||||
class_path = dify_config.API_WORKFLOW_NODE_EXECUTION_REPOSITORY
|
||||
logger.debug(f"Creating DifyAPIWorkflowNodeExecutionRepository from: {class_path}")
|
||||
|
||||
try:
|
||||
repository_class = cls._import_class(class_path)
|
||||
cls._validate_repository_interface(repository_class, DifyAPIWorkflowNodeExecutionRepository)
|
||||
# Service repository requires session_maker parameter
|
||||
cls._validate_constructor_signature(repository_class, ["session_maker"])
|
||||
|
||||
return repository_class(session_maker=session_maker) # type: ignore[no-any-return]
|
||||
except RepositoryImportError:
|
||||
# Re-raise our custom errors as-is
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception("Failed to create DifyAPIWorkflowNodeExecutionRepository")
|
||||
raise RepositoryImportError(
|
||||
f"Failed to create DifyAPIWorkflowNodeExecutionRepository from '{class_path}': {e}"
|
||||
) from e
|
||||
|
||||
@classmethod
|
||||
def create_api_workflow_run_repository(cls, session_maker: sessionmaker) -> APIWorkflowRunRepository:
|
||||
"""
|
||||
Create an APIWorkflowRunRepository instance based on configuration.
|
||||
|
||||
This repository is designed for service-layer WorkflowRun operations and uses dependency
|
||||
injection with a sessionmaker for better testability and separation of concerns. It provides
|
||||
database access patterns specifically needed by service classes for workflow run management,
|
||||
including pagination, filtering, and bulk operations.
|
||||
|
||||
Args:
|
||||
session_maker: SQLAlchemy sessionmaker to inject for database session management.
|
||||
|
||||
Returns:
|
||||
Configured APIWorkflowRunRepository instance
|
||||
|
||||
Raises:
|
||||
RepositoryImportError: If the configured repository cannot be imported or instantiated
|
||||
"""
|
||||
class_path = dify_config.API_WORKFLOW_RUN_REPOSITORY
|
||||
logger.debug(f"Creating APIWorkflowRunRepository from: {class_path}")
|
||||
|
||||
try:
|
||||
repository_class = cls._import_class(class_path)
|
||||
cls._validate_repository_interface(repository_class, APIWorkflowRunRepository)
|
||||
# Service repository requires session_maker parameter
|
||||
cls._validate_constructor_signature(repository_class, ["session_maker"])
|
||||
|
||||
return repository_class(session_maker=session_maker) # type: ignore[no-any-return]
|
||||
except RepositoryImportError:
|
||||
# Re-raise our custom errors as-is
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception("Failed to create APIWorkflowRunRepository")
|
||||
raise RepositoryImportError(f"Failed to create APIWorkflowRunRepository from '{class_path}': {e}") from e
|
@@ -0,0 +1,290 @@
|
||||
"""
|
||||
SQLAlchemy implementation of WorkflowNodeExecutionServiceRepository.
|
||||
|
||||
This module provides a concrete implementation of the service repository protocol
|
||||
using SQLAlchemy 2.0 style queries for WorkflowNodeExecutionModel operations.
|
||||
"""
|
||||
|
||||
from collections.abc import Sequence
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import delete, desc, select
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from models.workflow import WorkflowNodeExecutionModel
|
||||
from repositories.api_workflow_node_execution_repository import DifyAPIWorkflowNodeExecutionRepository
|
||||
|
||||
|
||||
class DifyAPISQLAlchemyWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecutionRepository):
|
||||
"""
|
||||
SQLAlchemy implementation of DifyAPIWorkflowNodeExecutionRepository.
|
||||
|
||||
This repository provides service-layer database operations for WorkflowNodeExecutionModel
|
||||
using SQLAlchemy 2.0 style queries. It implements the DifyAPIWorkflowNodeExecutionRepository
|
||||
protocol with the following features:
|
||||
|
||||
- Multi-tenancy data isolation through tenant_id filtering
|
||||
- Direct database model operations without domain conversion
|
||||
- Batch processing for efficient large-scale operations
|
||||
- Optimized query patterns for common access patterns
|
||||
- Dependency injection for better testability and maintainability
|
||||
- Session management and transaction handling with proper cleanup
|
||||
- Maintenance operations for data lifecycle management
|
||||
- Thread-safe database operations using session-per-request pattern
|
||||
"""
|
||||
|
||||
def __init__(self, session_maker: sessionmaker[Session]):
|
||||
"""
|
||||
Initialize the repository with a sessionmaker.
|
||||
|
||||
Args:
|
||||
session_maker: SQLAlchemy sessionmaker for creating database sessions
|
||||
"""
|
||||
self._session_maker = session_maker
|
||||
|
||||
def get_node_last_execution(
|
||||
self,
|
||||
tenant_id: str,
|
||||
app_id: str,
|
||||
workflow_id: str,
|
||||
node_id: str,
|
||||
) -> Optional[WorkflowNodeExecutionModel]:
|
||||
"""
|
||||
Get the most recent execution for a specific node.
|
||||
|
||||
This method replicates the query pattern from WorkflowService.get_node_last_run()
|
||||
using SQLAlchemy 2.0 style syntax.
|
||||
|
||||
Args:
|
||||
tenant_id: The tenant identifier
|
||||
app_id: The application identifier
|
||||
workflow_id: The workflow identifier
|
||||
node_id: The node identifier
|
||||
|
||||
Returns:
|
||||
The most recent WorkflowNodeExecutionModel for the node, or None if not found
|
||||
"""
|
||||
stmt = (
|
||||
select(WorkflowNodeExecutionModel)
|
||||
.where(
|
||||
WorkflowNodeExecutionModel.tenant_id == tenant_id,
|
||||
WorkflowNodeExecutionModel.app_id == app_id,
|
||||
WorkflowNodeExecutionModel.workflow_id == workflow_id,
|
||||
WorkflowNodeExecutionModel.node_id == node_id,
|
||||
)
|
||||
.order_by(desc(WorkflowNodeExecutionModel.created_at))
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
with self._session_maker() as session:
|
||||
return session.scalar(stmt)
|
||||
|
||||
def get_executions_by_workflow_run(
|
||||
self,
|
||||
tenant_id: str,
|
||||
app_id: str,
|
||||
workflow_run_id: str,
|
||||
) -> Sequence[WorkflowNodeExecutionModel]:
|
||||
"""
|
||||
Get all node executions for a specific workflow run.
|
||||
|
||||
This method replicates the query pattern from WorkflowRunService.get_workflow_run_node_executions()
|
||||
using SQLAlchemy 2.0 style syntax.
|
||||
|
||||
Args:
|
||||
tenant_id: The tenant identifier
|
||||
app_id: The application identifier
|
||||
workflow_run_id: The workflow run identifier
|
||||
|
||||
Returns:
|
||||
A sequence of WorkflowNodeExecutionModel instances ordered by index (desc)
|
||||
"""
|
||||
stmt = (
|
||||
select(WorkflowNodeExecutionModel)
|
||||
.where(
|
||||
WorkflowNodeExecutionModel.tenant_id == tenant_id,
|
||||
WorkflowNodeExecutionModel.app_id == app_id,
|
||||
WorkflowNodeExecutionModel.workflow_run_id == workflow_run_id,
|
||||
)
|
||||
.order_by(desc(WorkflowNodeExecutionModel.index))
|
||||
)
|
||||
|
||||
with self._session_maker() as session:
|
||||
return session.execute(stmt).scalars().all()
|
||||
|
||||
def get_execution_by_id(
|
||||
self,
|
||||
execution_id: str,
|
||||
tenant_id: Optional[str] = None,
|
||||
) -> Optional[WorkflowNodeExecutionModel]:
|
||||
"""
|
||||
Get a workflow node execution by its ID.
|
||||
|
||||
This method replicates the query pattern from WorkflowDraftVariableService
|
||||
and WorkflowService.single_step_run_workflow_node() using SQLAlchemy 2.0 style syntax.
|
||||
|
||||
When `tenant_id` is None, it's the caller's responsibility to ensure proper data isolation between tenants.
|
||||
If the `execution_id` comes from untrusted sources (e.g., retrieved from an API request), the caller should
|
||||
set `tenant_id` to prevent horizontal privilege escalation.
|
||||
|
||||
Args:
|
||||
execution_id: The execution identifier
|
||||
tenant_id: Optional tenant identifier for additional filtering
|
||||
|
||||
Returns:
|
||||
The WorkflowNodeExecutionModel if found, or None if not found
|
||||
"""
|
||||
stmt = select(WorkflowNodeExecutionModel).where(WorkflowNodeExecutionModel.id == execution_id)
|
||||
|
||||
# Add tenant filtering if provided
|
||||
if tenant_id is not None:
|
||||
stmt = stmt.where(WorkflowNodeExecutionModel.tenant_id == tenant_id)
|
||||
|
||||
with self._session_maker() as session:
|
||||
return session.scalar(stmt)
|
||||
|
||||
def delete_expired_executions(
|
||||
self,
|
||||
tenant_id: str,
|
||||
before_date: datetime,
|
||||
batch_size: int = 1000,
|
||||
) -> int:
|
||||
"""
|
||||
Delete workflow node executions that are older than the specified date.
|
||||
|
||||
Args:
|
||||
tenant_id: The tenant identifier
|
||||
before_date: Delete executions created before this date
|
||||
batch_size: Maximum number of executions to delete in one batch
|
||||
|
||||
Returns:
|
||||
The number of executions deleted
|
||||
"""
|
||||
total_deleted = 0
|
||||
|
||||
while True:
|
||||
with self._session_maker() as session:
|
||||
# Find executions to delete in batches
|
||||
stmt = (
|
||||
select(WorkflowNodeExecutionModel.id)
|
||||
.where(
|
||||
WorkflowNodeExecutionModel.tenant_id == tenant_id,
|
||||
WorkflowNodeExecutionModel.created_at < before_date,
|
||||
)
|
||||
.limit(batch_size)
|
||||
)
|
||||
|
||||
execution_ids = session.execute(stmt).scalars().all()
|
||||
if not execution_ids:
|
||||
break
|
||||
|
||||
# Delete the batch
|
||||
delete_stmt = delete(WorkflowNodeExecutionModel).where(WorkflowNodeExecutionModel.id.in_(execution_ids))
|
||||
result = session.execute(delete_stmt)
|
||||
session.commit()
|
||||
total_deleted += result.rowcount
|
||||
|
||||
# If we deleted fewer than the batch size, we're done
|
||||
if len(execution_ids) < batch_size:
|
||||
break
|
||||
|
||||
return total_deleted
|
||||
|
||||
def delete_executions_by_app(
|
||||
self,
|
||||
tenant_id: str,
|
||||
app_id: str,
|
||||
batch_size: int = 1000,
|
||||
) -> int:
|
||||
"""
|
||||
Delete all workflow node executions for a specific app.
|
||||
|
||||
Args:
|
||||
tenant_id: The tenant identifier
|
||||
app_id: The application identifier
|
||||
batch_size: Maximum number of executions to delete in one batch
|
||||
|
||||
Returns:
|
||||
The total number of executions deleted
|
||||
"""
|
||||
total_deleted = 0
|
||||
|
||||
while True:
|
||||
with self._session_maker() as session:
|
||||
# Find executions to delete in batches
|
||||
stmt = (
|
||||
select(WorkflowNodeExecutionModel.id)
|
||||
.where(
|
||||
WorkflowNodeExecutionModel.tenant_id == tenant_id,
|
||||
WorkflowNodeExecutionModel.app_id == app_id,
|
||||
)
|
||||
.limit(batch_size)
|
||||
)
|
||||
|
||||
execution_ids = session.execute(stmt).scalars().all()
|
||||
if not execution_ids:
|
||||
break
|
||||
|
||||
# Delete the batch
|
||||
delete_stmt = delete(WorkflowNodeExecutionModel).where(WorkflowNodeExecutionModel.id.in_(execution_ids))
|
||||
result = session.execute(delete_stmt)
|
||||
session.commit()
|
||||
total_deleted += result.rowcount
|
||||
|
||||
# If we deleted fewer than the batch size, we're done
|
||||
if len(execution_ids) < batch_size:
|
||||
break
|
||||
|
||||
return total_deleted
|
||||
|
||||
def get_expired_executions_batch(
|
||||
self,
|
||||
tenant_id: str,
|
||||
before_date: datetime,
|
||||
batch_size: int = 1000,
|
||||
) -> Sequence[WorkflowNodeExecutionModel]:
|
||||
"""
|
||||
Get a batch of expired workflow node executions for backup purposes.
|
||||
|
||||
Args:
|
||||
tenant_id: The tenant identifier
|
||||
before_date: Get executions created before this date
|
||||
batch_size: Maximum number of executions to retrieve
|
||||
|
||||
Returns:
|
||||
A sequence of WorkflowNodeExecutionModel instances
|
||||
"""
|
||||
stmt = (
|
||||
select(WorkflowNodeExecutionModel)
|
||||
.where(
|
||||
WorkflowNodeExecutionModel.tenant_id == tenant_id,
|
||||
WorkflowNodeExecutionModel.created_at < before_date,
|
||||
)
|
||||
.limit(batch_size)
|
||||
)
|
||||
|
||||
with self._session_maker() as session:
|
||||
return session.execute(stmt).scalars().all()
|
||||
|
||||
def delete_executions_by_ids(
|
||||
self,
|
||||
execution_ids: Sequence[str],
|
||||
) -> int:
|
||||
"""
|
||||
Delete workflow node executions by their IDs.
|
||||
|
||||
Args:
|
||||
execution_ids: List of execution IDs to delete
|
||||
|
||||
Returns:
|
||||
The number of executions deleted
|
||||
"""
|
||||
if not execution_ids:
|
||||
return 0
|
||||
|
||||
with self._session_maker() as session:
|
||||
stmt = delete(WorkflowNodeExecutionModel).where(WorkflowNodeExecutionModel.id.in_(execution_ids))
|
||||
result = session.execute(stmt)
|
||||
session.commit()
|
||||
return result.rowcount
|
202
api/repositories/sqlalchemy_api_workflow_run_repository.py
Normal file
202
api/repositories/sqlalchemy_api_workflow_run_repository.py
Normal file
@@ -0,0 +1,202 @@
|
||||
"""
|
||||
SQLAlchemy API WorkflowRun Repository Implementation
|
||||
|
||||
This module provides the SQLAlchemy-based implementation of the APIWorkflowRunRepository
|
||||
protocol. It handles service-layer WorkflowRun database operations using SQLAlchemy 2.0
|
||||
style queries with proper session management and multi-tenant data isolation.
|
||||
|
||||
Key Features:
|
||||
- SQLAlchemy 2.0 style queries for modern database operations
|
||||
- Cursor-based pagination for efficient large dataset handling
|
||||
- Bulk operations with batch processing for performance
|
||||
- Multi-tenant data isolation and security
|
||||
- Proper session management with dependency injection
|
||||
|
||||
Implementation Notes:
|
||||
- Uses sessionmaker for consistent session management
|
||||
- Implements cursor-based pagination using created_at timestamps
|
||||
- Provides efficient bulk deletion with batch processing
|
||||
- Maintains data consistency with proper transaction handling
|
||||
"""
|
||||
|
||||
import logging
|
||||
from collections.abc import Sequence
|
||||
from datetime import datetime
|
||||
from typing import Optional, cast
|
||||
|
||||
from sqlalchemy import delete, select
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from libs.infinite_scroll_pagination import InfiniteScrollPagination
|
||||
from models.workflow import WorkflowRun
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DifyAPISQLAlchemyWorkflowRunRepository:
|
||||
"""
|
||||
SQLAlchemy implementation of APIWorkflowRunRepository.
|
||||
|
||||
Provides service-layer WorkflowRun database operations using SQLAlchemy 2.0
|
||||
style queries. Supports dependency injection through sessionmaker and
|
||||
maintains proper multi-tenant data isolation.
|
||||
|
||||
Args:
|
||||
session_maker: SQLAlchemy sessionmaker instance for database connections
|
||||
"""
|
||||
|
||||
def __init__(self, session_maker: sessionmaker[Session]) -> None:
|
||||
"""
|
||||
Initialize the repository with a sessionmaker.
|
||||
|
||||
Args:
|
||||
session_maker: SQLAlchemy sessionmaker for database connections
|
||||
"""
|
||||
self._session_maker = session_maker
|
||||
|
||||
def get_paginated_workflow_runs(
|
||||
self,
|
||||
tenant_id: str,
|
||||
app_id: str,
|
||||
triggered_from: str,
|
||||
limit: int = 20,
|
||||
last_id: Optional[str] = None,
|
||||
) -> InfiniteScrollPagination:
|
||||
"""
|
||||
Get paginated workflow runs with filtering.
|
||||
|
||||
Implements cursor-based pagination using created_at timestamps for
|
||||
efficient handling of large datasets. Filters by tenant, app, and
|
||||
trigger source for proper data isolation.
|
||||
"""
|
||||
with self._session_maker() as session:
|
||||
# Build base query with filters
|
||||
base_stmt = select(WorkflowRun).where(
|
||||
WorkflowRun.tenant_id == tenant_id,
|
||||
WorkflowRun.app_id == app_id,
|
||||
WorkflowRun.triggered_from == triggered_from,
|
||||
)
|
||||
|
||||
if last_id:
|
||||
# Get the last workflow run for cursor-based pagination
|
||||
last_run_stmt = base_stmt.where(WorkflowRun.id == last_id)
|
||||
last_workflow_run = session.scalar(last_run_stmt)
|
||||
|
||||
if not last_workflow_run:
|
||||
raise ValueError("Last workflow run not exists")
|
||||
|
||||
# Get records created before the last run's timestamp
|
||||
base_stmt = base_stmt.where(
|
||||
WorkflowRun.created_at < last_workflow_run.created_at,
|
||||
WorkflowRun.id != last_workflow_run.id,
|
||||
)
|
||||
|
||||
# First page - get most recent records
|
||||
workflow_runs = session.scalars(base_stmt.order_by(WorkflowRun.created_at.desc()).limit(limit + 1)).all()
|
||||
|
||||
# Check if there are more records for pagination
|
||||
has_more = len(workflow_runs) > limit
|
||||
if has_more:
|
||||
workflow_runs = workflow_runs[:-1]
|
||||
|
||||
return InfiniteScrollPagination(data=workflow_runs, limit=limit, has_more=has_more)
|
||||
|
||||
def get_workflow_run_by_id(
|
||||
self,
|
||||
tenant_id: str,
|
||||
app_id: str,
|
||||
run_id: str,
|
||||
) -> Optional[WorkflowRun]:
|
||||
"""
|
||||
Get a specific workflow run by ID with tenant and app isolation.
|
||||
"""
|
||||
with self._session_maker() as session:
|
||||
stmt = select(WorkflowRun).where(
|
||||
WorkflowRun.tenant_id == tenant_id,
|
||||
WorkflowRun.app_id == app_id,
|
||||
WorkflowRun.id == run_id,
|
||||
)
|
||||
return cast(Optional[WorkflowRun], session.scalar(stmt))
|
||||
|
||||
def get_expired_runs_batch(
|
||||
self,
|
||||
tenant_id: str,
|
||||
before_date: datetime,
|
||||
batch_size: int = 1000,
|
||||
) -> Sequence[WorkflowRun]:
|
||||
"""
|
||||
Get a batch of expired workflow runs for cleanup operations.
|
||||
"""
|
||||
with self._session_maker() as session:
|
||||
stmt = (
|
||||
select(WorkflowRun)
|
||||
.where(
|
||||
WorkflowRun.tenant_id == tenant_id,
|
||||
WorkflowRun.created_at < before_date,
|
||||
)
|
||||
.limit(batch_size)
|
||||
)
|
||||
return cast(Sequence[WorkflowRun], session.scalars(stmt).all())
|
||||
|
||||
def delete_runs_by_ids(
|
||||
self,
|
||||
run_ids: Sequence[str],
|
||||
) -> int:
|
||||
"""
|
||||
Delete workflow runs by their IDs using bulk deletion.
|
||||
"""
|
||||
if not run_ids:
|
||||
return 0
|
||||
|
||||
with self._session_maker() as session:
|
||||
stmt = delete(WorkflowRun).where(WorkflowRun.id.in_(run_ids))
|
||||
result = session.execute(stmt)
|
||||
session.commit()
|
||||
|
||||
deleted_count = cast(int, result.rowcount)
|
||||
logger.info(f"Deleted {deleted_count} workflow runs by IDs")
|
||||
return deleted_count
|
||||
|
||||
def delete_runs_by_app(
|
||||
self,
|
||||
tenant_id: str,
|
||||
app_id: str,
|
||||
batch_size: int = 1000,
|
||||
) -> int:
|
||||
"""
|
||||
Delete all workflow runs for a specific app in batches.
|
||||
"""
|
||||
total_deleted = 0
|
||||
|
||||
while True:
|
||||
with self._session_maker() as session:
|
||||
# Get a batch of run IDs to delete
|
||||
stmt = (
|
||||
select(WorkflowRun.id)
|
||||
.where(
|
||||
WorkflowRun.tenant_id == tenant_id,
|
||||
WorkflowRun.app_id == app_id,
|
||||
)
|
||||
.limit(batch_size)
|
||||
)
|
||||
run_ids = session.scalars(stmt).all()
|
||||
|
||||
if not run_ids:
|
||||
break
|
||||
|
||||
# Delete the batch
|
||||
delete_stmt = delete(WorkflowRun).where(WorkflowRun.id.in_(run_ids))
|
||||
result = session.execute(delete_stmt)
|
||||
session.commit()
|
||||
|
||||
batch_deleted = result.rowcount
|
||||
total_deleted += batch_deleted
|
||||
|
||||
logger.info(f"Deleted batch of {batch_deleted} workflow runs for app {app_id}")
|
||||
|
||||
# If we deleted fewer records than the batch size, we're done
|
||||
if batch_deleted < batch_size:
|
||||
break
|
||||
|
||||
logger.info(f"Total deleted {total_deleted} workflow runs for app {app_id}")
|
||||
return total_deleted
|
@@ -47,8 +47,6 @@ class AppService:
|
||||
filters.append(App.mode == AppMode.ADVANCED_CHAT.value)
|
||||
elif args["mode"] == "agent-chat":
|
||||
filters.append(App.mode == AppMode.AGENT_CHAT.value)
|
||||
elif args["mode"] == "channel":
|
||||
filters.append(App.mode == AppMode.CHANNEL.value)
|
||||
|
||||
if args.get("is_created_by_me", False):
|
||||
filters.append(App.created_by == user_id)
|
||||
|
@@ -6,7 +6,7 @@ from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
import click
|
||||
from flask import Flask, current_app
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from configs import dify_config
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
@@ -14,7 +14,7 @@ from extensions.ext_database import db
|
||||
from extensions.ext_storage import storage
|
||||
from models.account import Tenant
|
||||
from models.model import App, Conversation, Message
|
||||
from models.workflow import WorkflowNodeExecutionModel, WorkflowRun
|
||||
from repositories.factory import DifyAPIRepositoryFactory
|
||||
from services.billing_service import BillingService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -105,23 +105,24 @@ class ClearFreePlanTenantExpiredLogs:
|
||||
)
|
||||
)
|
||||
|
||||
# Process expired workflow node executions with backup
|
||||
session_maker = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||
node_execution_repo = DifyAPIRepositoryFactory.create_api_workflow_node_execution_repository(session_maker)
|
||||
before_date = datetime.datetime.now() - datetime.timedelta(days=days)
|
||||
total_deleted = 0
|
||||
|
||||
while True:
|
||||
with Session(db.engine).no_autoflush as session:
|
||||
workflow_node_executions = (
|
||||
session.query(WorkflowNodeExecutionModel)
|
||||
.filter(
|
||||
WorkflowNodeExecutionModel.tenant_id == tenant_id,
|
||||
WorkflowNodeExecutionModel.created_at
|
||||
< datetime.datetime.now() - datetime.timedelta(days=days),
|
||||
)
|
||||
.limit(batch)
|
||||
.all()
|
||||
# Get a batch of expired executions for backup
|
||||
workflow_node_executions = node_execution_repo.get_expired_executions_batch(
|
||||
tenant_id=tenant_id,
|
||||
before_date=before_date,
|
||||
batch_size=batch,
|
||||
)
|
||||
|
||||
if len(workflow_node_executions) == 0:
|
||||
break
|
||||
|
||||
# save workflow node executions
|
||||
# Save workflow node executions to storage
|
||||
storage.save(
|
||||
f"free_plan_tenant_expired_logs/"
|
||||
f"{tenant_id}/workflow_node_executions/{datetime.datetime.now().strftime('%Y-%m-%d')}"
|
||||
@@ -131,15 +132,14 @@ class ClearFreePlanTenantExpiredLogs:
|
||||
).encode("utf-8"),
|
||||
)
|
||||
|
||||
# Extract IDs for deletion
|
||||
workflow_node_execution_ids = [
|
||||
workflow_node_execution.id for workflow_node_execution in workflow_node_executions
|
||||
]
|
||||
|
||||
# delete workflow node executions
|
||||
session.query(WorkflowNodeExecutionModel).filter(
|
||||
WorkflowNodeExecutionModel.id.in_(workflow_node_execution_ids),
|
||||
).delete(synchronize_session=False)
|
||||
session.commit()
|
||||
# Delete the backed up executions
|
||||
deleted_count = node_execution_repo.delete_executions_by_ids(workflow_node_execution_ids)
|
||||
total_deleted += deleted_count
|
||||
|
||||
click.echo(
|
||||
click.style(
|
||||
@@ -148,23 +148,28 @@ class ClearFreePlanTenantExpiredLogs:
|
||||
)
|
||||
)
|
||||
|
||||
# If we got fewer than the batch size, we're done
|
||||
if len(workflow_node_executions) < batch:
|
||||
break
|
||||
|
||||
# Process expired workflow runs with backup
|
||||
session_maker = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||
workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
|
||||
before_date = datetime.datetime.now() - datetime.timedelta(days=days)
|
||||
total_deleted = 0
|
||||
|
||||
while True:
|
||||
with Session(db.engine).no_autoflush as session:
|
||||
workflow_runs = (
|
||||
session.query(WorkflowRun)
|
||||
.filter(
|
||||
WorkflowRun.tenant_id == tenant_id,
|
||||
WorkflowRun.created_at < datetime.datetime.now() - datetime.timedelta(days=days),
|
||||
)
|
||||
.limit(batch)
|
||||
.all()
|
||||
# Get a batch of expired workflow runs for backup
|
||||
workflow_runs = workflow_run_repo.get_expired_runs_batch(
|
||||
tenant_id=tenant_id,
|
||||
before_date=before_date,
|
||||
batch_size=batch,
|
||||
)
|
||||
|
||||
if len(workflow_runs) == 0:
|
||||
break
|
||||
|
||||
# save workflow runs
|
||||
|
||||
# Save workflow runs to storage
|
||||
storage.save(
|
||||
f"free_plan_tenant_expired_logs/"
|
||||
f"{tenant_id}/workflow_runs/{datetime.datetime.now().strftime('%Y-%m-%d')}"
|
||||
@@ -176,13 +181,23 @@ class ClearFreePlanTenantExpiredLogs:
|
||||
).encode("utf-8"),
|
||||
)
|
||||
|
||||
# Extract IDs for deletion
|
||||
workflow_run_ids = [workflow_run.id for workflow_run in workflow_runs]
|
||||
|
||||
# delete workflow runs
|
||||
session.query(WorkflowRun).filter(
|
||||
WorkflowRun.id.in_(workflow_run_ids),
|
||||
).delete(synchronize_session=False)
|
||||
session.commit()
|
||||
# Delete the backed up workflow runs
|
||||
deleted_count = workflow_run_repo.delete_runs_by_ids(workflow_run_ids)
|
||||
total_deleted += deleted_count
|
||||
|
||||
click.echo(
|
||||
click.style(
|
||||
f"[{datetime.datetime.now()}] Processed {len(workflow_run_ids)}"
|
||||
f" workflow runs for tenant {tenant_id}"
|
||||
)
|
||||
)
|
||||
|
||||
# If we got fewer than the batch size, we're done
|
||||
if len(workflow_runs) < batch:
|
||||
break
|
||||
|
||||
@classmethod
|
||||
def process(cls, days: int, batch: int, tenant_ids: list[str]):
|
||||
|
@@ -5,9 +5,9 @@ from collections.abc import Mapping, Sequence
|
||||
from enum import StrEnum
|
||||
from typing import Any, ClassVar
|
||||
|
||||
from sqlalchemy import Engine, orm, select
|
||||
from sqlalchemy import Engine, orm
|
||||
from sqlalchemy.dialects.postgresql import insert
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
from sqlalchemy.sql.expression import and_, or_
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
@@ -25,7 +25,8 @@ from factories.file_factory import StorageKeyLoader
|
||||
from factories.variable_factory import build_segment, segment_to_variable
|
||||
from models import App, Conversation
|
||||
from models.enums import DraftVariableType
|
||||
from models.workflow import Workflow, WorkflowDraftVariable, WorkflowNodeExecutionModel, is_system_variable_editable
|
||||
from models.workflow import Workflow, WorkflowDraftVariable, is_system_variable_editable
|
||||
from repositories.factory import DifyAPIRepositoryFactory
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -117,7 +118,24 @@ class WorkflowDraftVariableService:
|
||||
_session: Session
|
||||
|
||||
def __init__(self, session: Session) -> None:
|
||||
"""
|
||||
Initialize the WorkflowDraftVariableService with a SQLAlchemy session.
|
||||
|
||||
Args:
|
||||
session (Session): The SQLAlchemy session used to execute database queries.
|
||||
The provided session must be bound to an `Engine` object, not a specific `Connection`.
|
||||
|
||||
Raises:
|
||||
AssertionError: If the provided session is not bound to an `Engine` object.
|
||||
"""
|
||||
self._session = session
|
||||
engine = session.get_bind()
|
||||
# Ensure the session is bound to a engine.
|
||||
assert isinstance(engine, Engine)
|
||||
session_maker = sessionmaker(bind=engine, expire_on_commit=False)
|
||||
self._api_node_execution_repo = DifyAPIRepositoryFactory.create_api_workflow_node_execution_repository(
|
||||
session_maker
|
||||
)
|
||||
|
||||
def get_variable(self, variable_id: str) -> WorkflowDraftVariable | None:
|
||||
return self._session.query(WorkflowDraftVariable).filter(WorkflowDraftVariable.id == variable_id).first()
|
||||
@@ -248,8 +266,7 @@ class WorkflowDraftVariableService:
|
||||
_logger.warning("draft variable has no node_execution_id, id=%s, name=%s", variable.id, variable.name)
|
||||
return None
|
||||
|
||||
query = select(WorkflowNodeExecutionModel).where(WorkflowNodeExecutionModel.id == variable.node_execution_id)
|
||||
node_exec = self._session.scalars(query).first()
|
||||
node_exec = self._api_node_execution_repo.get_execution_by_id(variable.node_execution_id)
|
||||
if node_exec is None:
|
||||
_logger.warning(
|
||||
"Node exectution not found for draft variable, id=%s, name=%s, node_execution_id=%s",
|
||||
@@ -298,6 +315,8 @@ class WorkflowDraftVariableService:
|
||||
|
||||
def reset_variable(self, workflow: Workflow, variable: WorkflowDraftVariable) -> WorkflowDraftVariable | None:
|
||||
variable_type = variable.get_variable_type()
|
||||
if variable_type == DraftVariableType.SYS and not is_system_variable_editable(variable.name):
|
||||
raise VariableResetError(f"cannot reset system variable, variable_id={variable.id}")
|
||||
if variable_type == DraftVariableType.CONVERSATION:
|
||||
return self._reset_conv_var(workflow, variable)
|
||||
else:
|
||||
|
@@ -2,9 +2,9 @@ import threading
|
||||
from collections.abc import Sequence
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
import contexts
|
||||
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
|
||||
from core.workflow.repositories.workflow_node_execution_repository import OrderConfig
|
||||
from extensions.ext_database import db
|
||||
from libs.infinite_scroll_pagination import InfiniteScrollPagination
|
||||
from models import (
|
||||
@@ -15,10 +15,18 @@ from models import (
|
||||
WorkflowRun,
|
||||
WorkflowRunTriggeredFrom,
|
||||
)
|
||||
from models.workflow import WorkflowNodeExecutionTriggeredFrom
|
||||
from repositories.factory import DifyAPIRepositoryFactory
|
||||
|
||||
|
||||
class WorkflowRunService:
|
||||
def __init__(self):
|
||||
"""Initialize WorkflowRunService with repository dependencies."""
|
||||
session_maker = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||
self._node_execution_service_repo = DifyAPIRepositoryFactory.create_api_workflow_node_execution_repository(
|
||||
session_maker
|
||||
)
|
||||
self._workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
|
||||
|
||||
def get_paginate_advanced_chat_workflow_runs(self, app_model: App, args: dict) -> InfiniteScrollPagination:
|
||||
"""
|
||||
Get advanced chat app workflow run list
|
||||
@@ -62,45 +70,16 @@ class WorkflowRunService:
|
||||
:param args: request args
|
||||
"""
|
||||
limit = int(args.get("limit", 20))
|
||||
last_id = args.get("last_id")
|
||||
|
||||
base_query = db.session.query(WorkflowRun).filter(
|
||||
WorkflowRun.tenant_id == app_model.tenant_id,
|
||||
WorkflowRun.app_id == app_model.id,
|
||||
WorkflowRun.triggered_from == WorkflowRunTriggeredFrom.DEBUGGING.value,
|
||||
return self._workflow_run_repo.get_paginated_workflow_runs(
|
||||
tenant_id=app_model.tenant_id,
|
||||
app_id=app_model.id,
|
||||
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING.value,
|
||||
limit=limit,
|
||||
last_id=last_id,
|
||||
)
|
||||
|
||||
if args.get("last_id"):
|
||||
last_workflow_run = base_query.filter(
|
||||
WorkflowRun.id == args.get("last_id"),
|
||||
).first()
|
||||
|
||||
if not last_workflow_run:
|
||||
raise ValueError("Last workflow run not exists")
|
||||
|
||||
workflow_runs = (
|
||||
base_query.filter(
|
||||
WorkflowRun.created_at < last_workflow_run.created_at, WorkflowRun.id != last_workflow_run.id
|
||||
)
|
||||
.order_by(WorkflowRun.created_at.desc())
|
||||
.limit(limit)
|
||||
.all()
|
||||
)
|
||||
else:
|
||||
workflow_runs = base_query.order_by(WorkflowRun.created_at.desc()).limit(limit).all()
|
||||
|
||||
has_more = False
|
||||
if len(workflow_runs) == limit:
|
||||
current_page_first_workflow_run = workflow_runs[-1]
|
||||
rest_count = base_query.filter(
|
||||
WorkflowRun.created_at < current_page_first_workflow_run.created_at,
|
||||
WorkflowRun.id != current_page_first_workflow_run.id,
|
||||
).count()
|
||||
|
||||
if rest_count > 0:
|
||||
has_more = True
|
||||
|
||||
return InfiniteScrollPagination(data=workflow_runs, limit=limit, has_more=has_more)
|
||||
|
||||
def get_workflow_run(self, app_model: App, run_id: str) -> Optional[WorkflowRun]:
|
||||
"""
|
||||
Get workflow run detail
|
||||
@@ -108,17 +87,11 @@ class WorkflowRunService:
|
||||
:param app_model: app model
|
||||
:param run_id: workflow run id
|
||||
"""
|
||||
workflow_run = (
|
||||
db.session.query(WorkflowRun)
|
||||
.filter(
|
||||
WorkflowRun.tenant_id == app_model.tenant_id,
|
||||
WorkflowRun.app_id == app_model.id,
|
||||
WorkflowRun.id == run_id,
|
||||
return self._workflow_run_repo.get_workflow_run_by_id(
|
||||
tenant_id=app_model.tenant_id,
|
||||
app_id=app_model.id,
|
||||
run_id=run_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
return workflow_run
|
||||
|
||||
def get_workflow_run_node_executions(
|
||||
self,
|
||||
@@ -137,17 +110,13 @@ class WorkflowRunService:
|
||||
if not workflow_run:
|
||||
return []
|
||||
|
||||
repository = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
session_factory=db.engine,
|
||||
user=user,
|
||||
# Get tenant_id from user
|
||||
tenant_id = user.tenant_id if isinstance(user, EndUser) else user.current_tenant_id
|
||||
if tenant_id is None:
|
||||
raise ValueError("User tenant_id cannot be None")
|
||||
|
||||
return self._node_execution_service_repo.get_executions_by_workflow_run(
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_model.id,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
workflow_run_id=run_id,
|
||||
)
|
||||
|
||||
# Use the repository to get the database models directly
|
||||
order_config = OrderConfig(order_by=["index"], order_direction="desc")
|
||||
workflow_node_executions = repository.get_db_models_by_workflow_run(
|
||||
workflow_run_id=run_id, order_config=order_config
|
||||
)
|
||||
|
||||
return workflow_node_executions
|
||||
|
@@ -7,13 +7,13 @@ from typing import Any, Optional
|
||||
from uuid import uuid4
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from core.app.app_config.entities import VariableEntityType
|
||||
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager
|
||||
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
|
||||
from core.file import File
|
||||
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
|
||||
from core.repositories import DifyCoreRepositoryFactory
|
||||
from core.variables import Variable
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
@@ -41,6 +41,7 @@ from models.workflow import (
|
||||
WorkflowNodeExecutionTriggeredFrom,
|
||||
WorkflowType,
|
||||
)
|
||||
from repositories.factory import DifyAPIRepositoryFactory
|
||||
from services.errors.app import IsDraftWorkflowError, WorkflowHashNotEqualError
|
||||
from services.workflow.workflow_converter import WorkflowConverter
|
||||
|
||||
@@ -57,21 +58,32 @@ class WorkflowService:
|
||||
Workflow Service
|
||||
"""
|
||||
|
||||
def __init__(self, session_maker: sessionmaker | None = None):
|
||||
"""Initialize WorkflowService with repository dependencies."""
|
||||
if session_maker is None:
|
||||
session_maker = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||
self._node_execution_service_repo = DifyAPIRepositoryFactory.create_api_workflow_node_execution_repository(
|
||||
session_maker
|
||||
)
|
||||
|
||||
def get_node_last_run(self, app_model: App, workflow: Workflow, node_id: str) -> WorkflowNodeExecutionModel | None:
|
||||
# TODO(QuantumGhost): This query is not fully covered by index.
|
||||
criteria = (
|
||||
WorkflowNodeExecutionModel.tenant_id == app_model.tenant_id,
|
||||
WorkflowNodeExecutionModel.app_id == app_model.id,
|
||||
WorkflowNodeExecutionModel.workflow_id == workflow.id,
|
||||
WorkflowNodeExecutionModel.node_id == node_id,
|
||||
"""
|
||||
Get the most recent execution for a specific node.
|
||||
|
||||
Args:
|
||||
app_model: The application model
|
||||
workflow: The workflow model
|
||||
node_id: The node identifier
|
||||
|
||||
Returns:
|
||||
The most recent WorkflowNodeExecutionModel for the node, or None if not found
|
||||
"""
|
||||
return self._node_execution_service_repo.get_node_last_execution(
|
||||
tenant_id=app_model.tenant_id,
|
||||
app_id=app_model.id,
|
||||
workflow_id=workflow.id,
|
||||
node_id=node_id,
|
||||
)
|
||||
node_exec = (
|
||||
db.session.query(WorkflowNodeExecutionModel)
|
||||
.filter(*criteria)
|
||||
.order_by(WorkflowNodeExecutionModel.created_at.desc())
|
||||
.first()
|
||||
)
|
||||
return node_exec
|
||||
|
||||
def is_workflow_exist(self, app_model: App) -> bool:
|
||||
return (
|
||||
@@ -396,7 +408,7 @@ class WorkflowService:
|
||||
node_execution.workflow_id = draft_workflow.id
|
||||
|
||||
# Create repository and save the node execution
|
||||
repository = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
|
||||
session_factory=db.engine,
|
||||
user=account,
|
||||
app_id=app_model.id,
|
||||
@@ -404,8 +416,9 @@ class WorkflowService:
|
||||
)
|
||||
repository.save(node_execution)
|
||||
|
||||
# Convert node_execution to WorkflowNodeExecution after save
|
||||
workflow_node_execution = repository.to_db_model(node_execution)
|
||||
workflow_node_execution = self._node_execution_service_repo.get_execution_by_id(node_execution.id)
|
||||
if workflow_node_execution is None:
|
||||
raise ValueError(f"WorkflowNodeExecution with id {node_execution.id} not found after saving")
|
||||
|
||||
with Session(bind=db.engine) as session, session.begin():
|
||||
draft_var_saver = DraftVariableSaver(
|
||||
@@ -418,6 +431,7 @@ class WorkflowService:
|
||||
)
|
||||
draft_var_saver.save(process_data=node_execution.process_data, outputs=node_execution.outputs)
|
||||
session.commit()
|
||||
|
||||
return workflow_node_execution
|
||||
|
||||
def run_free_workflow_node(
|
||||
@@ -429,7 +443,7 @@ class WorkflowService:
|
||||
# run draft workflow node
|
||||
start_at = time.perf_counter()
|
||||
|
||||
workflow_node_execution = self._handle_node_run_result(
|
||||
node_execution = self._handle_node_run_result(
|
||||
invoke_node_fn=lambda: WorkflowEntry.run_free_node(
|
||||
node_id=node_id,
|
||||
node_data=node_data,
|
||||
@@ -441,7 +455,7 @@ class WorkflowService:
|
||||
node_id=node_id,
|
||||
)
|
||||
|
||||
return workflow_node_execution
|
||||
return node_execution
|
||||
|
||||
def _handle_node_run_result(
|
||||
self,
|
||||
|
@@ -6,6 +6,7 @@ import click
|
||||
from celery import shared_task # type: ignore
|
||||
from sqlalchemy import delete
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from extensions.ext_database import db
|
||||
from models import (
|
||||
@@ -31,7 +32,8 @@ from models import (
|
||||
)
|
||||
from models.tools import WorkflowToolProvider
|
||||
from models.web import PinnedConversation, SavedMessage
|
||||
from models.workflow import ConversationVariable, Workflow, WorkflowAppLog, WorkflowNodeExecutionModel, WorkflowRun
|
||||
from models.workflow import ConversationVariable, Workflow, WorkflowAppLog
|
||||
from repositories.factory import DifyAPIRepositoryFactory
|
||||
|
||||
|
||||
@shared_task(queue="app_deletion", bind=True, max_retries=3)
|
||||
@@ -189,30 +191,32 @@ def _delete_app_workflows(tenant_id: str, app_id: str):
|
||||
|
||||
|
||||
def _delete_app_workflow_runs(tenant_id: str, app_id: str):
|
||||
def del_workflow_run(workflow_run_id: str):
|
||||
db.session.query(WorkflowRun).filter(WorkflowRun.id == workflow_run_id).delete(synchronize_session=False)
|
||||
"""Delete all workflow runs for an app using the service repository."""
|
||||
session_maker = sessionmaker(bind=db.engine)
|
||||
workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
|
||||
|
||||
_delete_records(
|
||||
"""select id from workflow_runs where tenant_id=:tenant_id and app_id=:app_id limit 1000""",
|
||||
{"tenant_id": tenant_id, "app_id": app_id},
|
||||
del_workflow_run,
|
||||
"workflow run",
|
||||
deleted_count = workflow_run_repo.delete_runs_by_app(
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
batch_size=1000,
|
||||
)
|
||||
|
||||
logging.info(f"Deleted {deleted_count} workflow runs for app {app_id}")
|
||||
|
||||
|
||||
def _delete_app_workflow_node_executions(tenant_id: str, app_id: str):
|
||||
def del_workflow_node_execution(workflow_node_execution_id: str):
|
||||
db.session.query(WorkflowNodeExecutionModel).filter(
|
||||
WorkflowNodeExecutionModel.id == workflow_node_execution_id
|
||||
).delete(synchronize_session=False)
|
||||
"""Delete all workflow node executions for an app using the service repository."""
|
||||
session_maker = sessionmaker(bind=db.engine)
|
||||
node_execution_repo = DifyAPIRepositoryFactory.create_api_workflow_node_execution_repository(session_maker)
|
||||
|
||||
_delete_records(
|
||||
"""select id from workflow_node_executions where tenant_id=:tenant_id and app_id=:app_id limit 1000""",
|
||||
{"tenant_id": tenant_id, "app_id": app_id},
|
||||
del_workflow_node_execution,
|
||||
"workflow node execution",
|
||||
deleted_count = node_execution_repo.delete_executions_by_app(
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
batch_size=1000,
|
||||
)
|
||||
|
||||
logging.info(f"Deleted {deleted_count} workflow node executions for app {app_id}")
|
||||
|
||||
|
||||
def _delete_app_workflow_app_logs(tenant_id: str, app_id: str):
|
||||
def del_workflow_app_log(workflow_app_log_id: str):
|
||||
|
1
api/tests/unit_tests/core/repositories/__init__.py
Normal file
1
api/tests/unit_tests/core/repositories/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# Unit tests for core repositories module
|
455
api/tests/unit_tests/core/repositories/test_factory.py
Normal file
455
api/tests/unit_tests/core/repositories/test_factory.py
Normal file
@@ -0,0 +1,455 @@
|
||||
"""
|
||||
Unit tests for the RepositoryFactory.
|
||||
|
||||
This module tests the factory pattern implementation for creating repository instances
|
||||
based on configuration, including error handling and validation.
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from pytest_mock import MockerFixture
|
||||
from sqlalchemy.engine import Engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from core.repositories.factory import DifyCoreRepositoryFactory, RepositoryImportError
|
||||
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
|
||||
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||
from models import Account, EndUser
|
||||
from models.enums import WorkflowRunTriggeredFrom
|
||||
from models.workflow import WorkflowNodeExecutionTriggeredFrom
|
||||
|
||||
|
||||
class TestRepositoryFactory:
|
||||
"""Test cases for RepositoryFactory."""
|
||||
|
||||
def test_import_class_success(self):
|
||||
"""Test successful class import."""
|
||||
# Test importing a real class
|
||||
class_path = "unittest.mock.MagicMock"
|
||||
result = DifyCoreRepositoryFactory._import_class(class_path)
|
||||
assert result is MagicMock
|
||||
|
||||
def test_import_class_invalid_path(self):
|
||||
"""Test import with invalid module path."""
|
||||
with pytest.raises(RepositoryImportError) as exc_info:
|
||||
DifyCoreRepositoryFactory._import_class("invalid.module.path")
|
||||
assert "Cannot import repository class" in str(exc_info.value)
|
||||
|
||||
def test_import_class_invalid_class_name(self):
|
||||
"""Test import with invalid class name."""
|
||||
with pytest.raises(RepositoryImportError) as exc_info:
|
||||
DifyCoreRepositoryFactory._import_class("unittest.mock.NonExistentClass")
|
||||
assert "Cannot import repository class" in str(exc_info.value)
|
||||
|
||||
def test_import_class_malformed_path(self):
|
||||
"""Test import with malformed path (no dots)."""
|
||||
with pytest.raises(RepositoryImportError) as exc_info:
|
||||
DifyCoreRepositoryFactory._import_class("invalidpath")
|
||||
assert "Cannot import repository class" in str(exc_info.value)
|
||||
|
||||
def test_validate_repository_interface_success(self):
|
||||
"""Test successful interface validation."""
|
||||
|
||||
# Create a mock class that implements the required methods
|
||||
class MockRepository:
|
||||
def save(self):
|
||||
pass
|
||||
|
||||
def get_by_id(self):
|
||||
pass
|
||||
|
||||
# Create a mock interface with the same methods
|
||||
class MockInterface:
|
||||
def save(self):
|
||||
pass
|
||||
|
||||
def get_by_id(self):
|
||||
pass
|
||||
|
||||
# Should not raise an exception
|
||||
DifyCoreRepositoryFactory._validate_repository_interface(MockRepository, MockInterface)
|
||||
|
||||
def test_validate_repository_interface_missing_methods(self):
|
||||
"""Test interface validation with missing methods."""
|
||||
|
||||
# Create a mock class that doesn't implement all required methods
|
||||
class IncompleteRepository:
|
||||
def save(self):
|
||||
pass
|
||||
|
||||
# Missing get_by_id method
|
||||
|
||||
# Create a mock interface with required methods
|
||||
class MockInterface:
|
||||
def save(self):
|
||||
pass
|
||||
|
||||
def get_by_id(self):
|
||||
pass
|
||||
|
||||
with pytest.raises(RepositoryImportError) as exc_info:
|
||||
DifyCoreRepositoryFactory._validate_repository_interface(IncompleteRepository, MockInterface)
|
||||
assert "does not implement required methods" in str(exc_info.value)
|
||||
assert "get_by_id" in str(exc_info.value)
|
||||
|
||||
def test_validate_constructor_signature_success(self):
|
||||
"""Test successful constructor signature validation."""
|
||||
|
||||
class MockRepository:
|
||||
def __init__(self, session_factory, user, app_id, triggered_from):
|
||||
pass
|
||||
|
||||
# Should not raise an exception
|
||||
DifyCoreRepositoryFactory._validate_constructor_signature(
|
||||
MockRepository, ["session_factory", "user", "app_id", "triggered_from"]
|
||||
)
|
||||
|
||||
def test_validate_constructor_signature_missing_params(self):
|
||||
"""Test constructor validation with missing parameters."""
|
||||
|
||||
class IncompleteRepository:
|
||||
def __init__(self, session_factory, user):
|
||||
# Missing app_id and triggered_from parameters
|
||||
pass
|
||||
|
||||
with pytest.raises(RepositoryImportError) as exc_info:
|
||||
DifyCoreRepositoryFactory._validate_constructor_signature(
|
||||
IncompleteRepository, ["session_factory", "user", "app_id", "triggered_from"]
|
||||
)
|
||||
assert "does not accept required parameters" in str(exc_info.value)
|
||||
assert "app_id" in str(exc_info.value)
|
||||
assert "triggered_from" in str(exc_info.value)
|
||||
|
||||
def test_validate_constructor_signature_inspection_error(self, mocker: MockerFixture):
|
||||
"""Test constructor validation when inspection fails."""
|
||||
# Mock inspect.signature to raise an exception
|
||||
mocker.patch("inspect.signature", side_effect=Exception("Inspection failed"))
|
||||
|
||||
class MockRepository:
|
||||
def __init__(self, session_factory):
|
||||
pass
|
||||
|
||||
with pytest.raises(RepositoryImportError) as exc_info:
|
||||
DifyCoreRepositoryFactory._validate_constructor_signature(MockRepository, ["session_factory"])
|
||||
assert "Failed to validate constructor signature" in str(exc_info.value)
|
||||
|
||||
@patch("core.repositories.factory.dify_config")
|
||||
def test_create_workflow_execution_repository_success(self, mock_config, mocker: MockerFixture):
|
||||
"""Test successful creation of WorkflowExecutionRepository."""
|
||||
# Setup mock configuration
|
||||
mock_config.WORKFLOW_EXECUTION_REPOSITORY = "unittest.mock.MagicMock"
|
||||
|
||||
# Create mock dependencies
|
||||
mock_session_factory = MagicMock(spec=sessionmaker)
|
||||
mock_user = MagicMock(spec=Account)
|
||||
app_id = "test-app-id"
|
||||
triggered_from = WorkflowRunTriggeredFrom.APP_RUN
|
||||
|
||||
# Mock the imported class to be a valid repository
|
||||
mock_repository_class = MagicMock()
|
||||
mock_repository_instance = MagicMock(spec=WorkflowExecutionRepository)
|
||||
mock_repository_class.return_value = mock_repository_instance
|
||||
|
||||
# Mock the validation methods
|
||||
with (
|
||||
patch.object(DifyCoreRepositoryFactory, "_import_class", return_value=mock_repository_class),
|
||||
patch.object(DifyCoreRepositoryFactory, "_validate_repository_interface"),
|
||||
patch.object(DifyCoreRepositoryFactory, "_validate_constructor_signature"),
|
||||
):
|
||||
result = DifyCoreRepositoryFactory.create_workflow_execution_repository(
|
||||
session_factory=mock_session_factory,
|
||||
user=mock_user,
|
||||
app_id=app_id,
|
||||
triggered_from=triggered_from,
|
||||
)
|
||||
|
||||
# Verify the repository was created with correct parameters
|
||||
mock_repository_class.assert_called_once_with(
|
||||
session_factory=mock_session_factory,
|
||||
user=mock_user,
|
||||
app_id=app_id,
|
||||
triggered_from=triggered_from,
|
||||
)
|
||||
assert result is mock_repository_instance
|
||||
|
||||
@patch("core.repositories.factory.dify_config")
|
||||
def test_create_workflow_execution_repository_import_error(self, mock_config):
|
||||
"""Test WorkflowExecutionRepository creation with import error."""
|
||||
# Setup mock configuration with invalid class path
|
||||
mock_config.WORKFLOW_EXECUTION_REPOSITORY = "invalid.module.InvalidClass"
|
||||
|
||||
mock_session_factory = MagicMock(spec=sessionmaker)
|
||||
mock_user = MagicMock(spec=Account)
|
||||
|
||||
with pytest.raises(RepositoryImportError) as exc_info:
|
||||
DifyCoreRepositoryFactory.create_workflow_execution_repository(
|
||||
session_factory=mock_session_factory,
|
||||
user=mock_user,
|
||||
app_id="test-app-id",
|
||||
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
|
||||
)
|
||||
assert "Cannot import repository class" in str(exc_info.value)
|
||||
|
||||
@patch("core.repositories.factory.dify_config")
|
||||
def test_create_workflow_execution_repository_validation_error(self, mock_config, mocker: MockerFixture):
|
||||
"""Test WorkflowExecutionRepository creation with validation error."""
|
||||
# Setup mock configuration
|
||||
mock_config.WORKFLOW_EXECUTION_REPOSITORY = "unittest.mock.MagicMock"
|
||||
|
||||
mock_session_factory = MagicMock(spec=sessionmaker)
|
||||
mock_user = MagicMock(spec=Account)
|
||||
|
||||
# Mock import to succeed but validation to fail
|
||||
mock_repository_class = MagicMock()
|
||||
with (
|
||||
patch.object(DifyCoreRepositoryFactory, "_import_class", return_value=mock_repository_class),
|
||||
patch.object(
|
||||
DifyCoreRepositoryFactory,
|
||||
"_validate_repository_interface",
|
||||
side_effect=RepositoryImportError("Interface validation failed"),
|
||||
),
|
||||
):
|
||||
with pytest.raises(RepositoryImportError) as exc_info:
|
||||
DifyCoreRepositoryFactory.create_workflow_execution_repository(
|
||||
session_factory=mock_session_factory,
|
||||
user=mock_user,
|
||||
app_id="test-app-id",
|
||||
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
|
||||
)
|
||||
assert "Interface validation failed" in str(exc_info.value)
|
||||
|
||||
@patch("core.repositories.factory.dify_config")
|
||||
def test_create_workflow_execution_repository_instantiation_error(self, mock_config, mocker: MockerFixture):
|
||||
"""Test WorkflowExecutionRepository creation with instantiation error."""
|
||||
# Setup mock configuration
|
||||
mock_config.WORKFLOW_EXECUTION_REPOSITORY = "unittest.mock.MagicMock"
|
||||
|
||||
mock_session_factory = MagicMock(spec=sessionmaker)
|
||||
mock_user = MagicMock(spec=Account)
|
||||
|
||||
# Mock import and validation to succeed but instantiation to fail
|
||||
mock_repository_class = MagicMock(side_effect=Exception("Instantiation failed"))
|
||||
with (
|
||||
patch.object(DifyCoreRepositoryFactory, "_import_class", return_value=mock_repository_class),
|
||||
patch.object(DifyCoreRepositoryFactory, "_validate_repository_interface"),
|
||||
patch.object(DifyCoreRepositoryFactory, "_validate_constructor_signature"),
|
||||
):
|
||||
with pytest.raises(RepositoryImportError) as exc_info:
|
||||
DifyCoreRepositoryFactory.create_workflow_execution_repository(
|
||||
session_factory=mock_session_factory,
|
||||
user=mock_user,
|
||||
app_id="test-app-id",
|
||||
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
|
||||
)
|
||||
assert "Failed to create WorkflowExecutionRepository" in str(exc_info.value)
|
||||
|
||||
@patch("core.repositories.factory.dify_config")
|
||||
def test_create_workflow_node_execution_repository_success(self, mock_config, mocker: MockerFixture):
|
||||
"""Test successful creation of WorkflowNodeExecutionRepository."""
|
||||
# Setup mock configuration
|
||||
mock_config.WORKFLOW_NODE_EXECUTION_REPOSITORY = "unittest.mock.MagicMock"
|
||||
|
||||
# Create mock dependencies
|
||||
mock_session_factory = MagicMock(spec=sessionmaker)
|
||||
mock_user = MagicMock(spec=EndUser)
|
||||
app_id = "test-app-id"
|
||||
triggered_from = WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN
|
||||
|
||||
# Mock the imported class to be a valid repository
|
||||
mock_repository_class = MagicMock()
|
||||
mock_repository_instance = MagicMock(spec=WorkflowNodeExecutionRepository)
|
||||
mock_repository_class.return_value = mock_repository_instance
|
||||
|
||||
# Mock the validation methods
|
||||
with (
|
||||
patch.object(DifyCoreRepositoryFactory, "_import_class", return_value=mock_repository_class),
|
||||
patch.object(DifyCoreRepositoryFactory, "_validate_repository_interface"),
|
||||
patch.object(DifyCoreRepositoryFactory, "_validate_constructor_signature"),
|
||||
):
|
||||
result = DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
|
||||
session_factory=mock_session_factory,
|
||||
user=mock_user,
|
||||
app_id=app_id,
|
||||
triggered_from=triggered_from,
|
||||
)
|
||||
|
||||
# Verify the repository was created with correct parameters
|
||||
mock_repository_class.assert_called_once_with(
|
||||
session_factory=mock_session_factory,
|
||||
user=mock_user,
|
||||
app_id=app_id,
|
||||
triggered_from=triggered_from,
|
||||
)
|
||||
assert result is mock_repository_instance
|
||||
|
||||
@patch("core.repositories.factory.dify_config")
|
||||
def test_create_workflow_node_execution_repository_import_error(self, mock_config):
|
||||
"""Test WorkflowNodeExecutionRepository creation with import error."""
|
||||
# Setup mock configuration with invalid class path
|
||||
mock_config.WORKFLOW_NODE_EXECUTION_REPOSITORY = "invalid.module.InvalidClass"
|
||||
|
||||
mock_session_factory = MagicMock(spec=sessionmaker)
|
||||
mock_user = MagicMock(spec=EndUser)
|
||||
|
||||
with pytest.raises(RepositoryImportError) as exc_info:
|
||||
DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
|
||||
session_factory=mock_session_factory,
|
||||
user=mock_user,
|
||||
app_id="test-app-id",
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
assert "Cannot import repository class" in str(exc_info.value)
|
||||
|
||||
def test_repository_import_error_exception(self):
|
||||
"""Test RepositoryImportError exception."""
|
||||
error_message = "Test error message"
|
||||
exception = RepositoryImportError(error_message)
|
||||
assert str(exception) == error_message
|
||||
assert isinstance(exception, Exception)
|
||||
|
||||
@patch("core.repositories.factory.dify_config")
|
||||
def test_create_with_engine_instead_of_sessionmaker(self, mock_config, mocker: MockerFixture):
|
||||
"""Test repository creation with Engine instead of sessionmaker."""
|
||||
# Setup mock configuration
|
||||
mock_config.WORKFLOW_EXECUTION_REPOSITORY = "unittest.mock.MagicMock"
|
||||
|
||||
# Create mock dependencies with Engine instead of sessionmaker
|
||||
mock_engine = MagicMock(spec=Engine)
|
||||
mock_user = MagicMock(spec=Account)
|
||||
|
||||
# Mock the imported class to be a valid repository
|
||||
mock_repository_class = MagicMock()
|
||||
mock_repository_instance = MagicMock(spec=WorkflowExecutionRepository)
|
||||
mock_repository_class.return_value = mock_repository_instance
|
||||
|
||||
# Mock the validation methods
|
||||
with (
|
||||
patch.object(DifyCoreRepositoryFactory, "_import_class", return_value=mock_repository_class),
|
||||
patch.object(DifyCoreRepositoryFactory, "_validate_repository_interface"),
|
||||
patch.object(DifyCoreRepositoryFactory, "_validate_constructor_signature"),
|
||||
):
|
||||
result = DifyCoreRepositoryFactory.create_workflow_execution_repository(
|
||||
session_factory=mock_engine, # Using Engine instead of sessionmaker
|
||||
user=mock_user,
|
||||
app_id="test-app-id",
|
||||
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
|
||||
)
|
||||
|
||||
# Verify the repository was created with the Engine
|
||||
mock_repository_class.assert_called_once_with(
|
||||
session_factory=mock_engine,
|
||||
user=mock_user,
|
||||
app_id="test-app-id",
|
||||
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
|
||||
)
|
||||
assert result is mock_repository_instance
|
||||
|
||||
@patch("core.repositories.factory.dify_config")
|
||||
def test_create_workflow_node_execution_repository_validation_error(self, mock_config):
|
||||
"""Test WorkflowNodeExecutionRepository creation with validation error."""
|
||||
# Setup mock configuration
|
||||
mock_config.WORKFLOW_NODE_EXECUTION_REPOSITORY = "unittest.mock.MagicMock"
|
||||
|
||||
mock_session_factory = MagicMock(spec=sessionmaker)
|
||||
mock_user = MagicMock(spec=EndUser)
|
||||
|
||||
# Mock import to succeed but validation to fail
|
||||
mock_repository_class = MagicMock()
|
||||
with (
|
||||
patch.object(DifyCoreRepositoryFactory, "_import_class", return_value=mock_repository_class),
|
||||
patch.object(
|
||||
DifyCoreRepositoryFactory,
|
||||
"_validate_repository_interface",
|
||||
side_effect=RepositoryImportError("Interface validation failed"),
|
||||
),
|
||||
):
|
||||
with pytest.raises(RepositoryImportError) as exc_info:
|
||||
DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
|
||||
session_factory=mock_session_factory,
|
||||
user=mock_user,
|
||||
app_id="test-app-id",
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
assert "Interface validation failed" in str(exc_info.value)
|
||||
|
||||
@patch("core.repositories.factory.dify_config")
|
||||
def test_create_workflow_node_execution_repository_instantiation_error(self, mock_config):
|
||||
"""Test WorkflowNodeExecutionRepository creation with instantiation error."""
|
||||
# Setup mock configuration
|
||||
mock_config.WORKFLOW_NODE_EXECUTION_REPOSITORY = "unittest.mock.MagicMock"
|
||||
|
||||
mock_session_factory = MagicMock(spec=sessionmaker)
|
||||
mock_user = MagicMock(spec=EndUser)
|
||||
|
||||
# Mock import and validation to succeed but instantiation to fail
|
||||
mock_repository_class = MagicMock(side_effect=Exception("Instantiation failed"))
|
||||
with (
|
||||
patch.object(DifyCoreRepositoryFactory, "_import_class", return_value=mock_repository_class),
|
||||
patch.object(DifyCoreRepositoryFactory, "_validate_repository_interface"),
|
||||
patch.object(DifyCoreRepositoryFactory, "_validate_constructor_signature"),
|
||||
):
|
||||
with pytest.raises(RepositoryImportError) as exc_info:
|
||||
DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
|
||||
session_factory=mock_session_factory,
|
||||
user=mock_user,
|
||||
app_id="test-app-id",
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
assert "Failed to create WorkflowNodeExecutionRepository" in str(exc_info.value)
|
||||
|
||||
def test_validate_repository_interface_with_private_methods(self):
|
||||
"""Test interface validation ignores private methods."""
|
||||
|
||||
# Create a mock class with private methods
|
||||
class MockRepository:
|
||||
def save(self):
|
||||
pass
|
||||
|
||||
def get_by_id(self):
|
||||
pass
|
||||
|
||||
def _private_method(self):
|
||||
pass
|
||||
|
||||
# Create a mock interface with private methods
|
||||
class MockInterface:
|
||||
def save(self):
|
||||
pass
|
||||
|
||||
def get_by_id(self):
|
||||
pass
|
||||
|
||||
def _private_method(self):
|
||||
pass
|
||||
|
||||
# Should not raise an exception (private methods are ignored)
|
||||
DifyCoreRepositoryFactory._validate_repository_interface(MockRepository, MockInterface)
|
||||
|
||||
def test_validate_constructor_signature_with_extra_params(self):
|
||||
"""Test constructor validation with extra parameters (should pass)."""
|
||||
|
||||
class MockRepository:
|
||||
def __init__(self, session_factory, user, app_id, triggered_from, extra_param=None):
|
||||
pass
|
||||
|
||||
# Should not raise an exception (extra parameters are allowed)
|
||||
DifyCoreRepositoryFactory._validate_constructor_signature(
|
||||
MockRepository, ["session_factory", "user", "app_id", "triggered_from"]
|
||||
)
|
||||
|
||||
def test_validate_constructor_signature_with_kwargs(self):
|
||||
"""Test constructor validation with **kwargs (current implementation doesn't support this)."""
|
||||
|
||||
class MockRepository:
|
||||
def __init__(self, session_factory, user, **kwargs):
|
||||
pass
|
||||
|
||||
# Current implementation doesn't handle **kwargs, so this should raise an exception
|
||||
with pytest.raises(RepositoryImportError) as exc_info:
|
||||
DifyCoreRepositoryFactory._validate_constructor_signature(
|
||||
MockRepository, ["session_factory", "user", "app_id", "triggered_from"]
|
||||
)
|
||||
assert "does not accept required parameters" in str(exc_info.value)
|
||||
assert "app_id" in str(exc_info.value)
|
||||
assert "triggered_from" in str(exc_info.value)
|
@@ -10,7 +10,8 @@ from services.workflow_service import DraftWorkflowDeletionError, WorkflowInUseE
|
||||
|
||||
@pytest.fixture
|
||||
def workflow_setup():
|
||||
workflow_service = WorkflowService()
|
||||
mock_session_maker = MagicMock()
|
||||
workflow_service = WorkflowService(mock_session_maker)
|
||||
session = MagicMock(spec=Session)
|
||||
tenant_id = "test-tenant-id"
|
||||
workflow_id = "test-workflow-id"
|
||||
|
@@ -1,14 +1,14 @@
|
||||
import dataclasses
|
||||
import secrets
|
||||
from unittest import mock
|
||||
from unittest.mock import Mock, patch
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import Engine
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.variables import StringSegment
|
||||
from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
|
||||
from core.workflow.nodes import NodeType
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from models.enums import DraftVariableType
|
||||
from models.workflow import Workflow, WorkflowDraftVariable, WorkflowNodeExecutionModel, is_system_variable_editable
|
||||
from services.workflow_draft_variable_service import (
|
||||
@@ -18,13 +18,25 @@ from services.workflow_draft_variable_service import (
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_engine() -> Engine:
|
||||
return Mock(spec=Engine)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session(mock_engine) -> Session:
|
||||
mock_session = Mock(spec=Session)
|
||||
mock_session.get_bind.return_value = mock_engine
|
||||
return mock_session
|
||||
|
||||
|
||||
class TestDraftVariableSaver:
|
||||
def _get_test_app_id(self):
|
||||
suffix = secrets.token_hex(6)
|
||||
return f"test_app_id_{suffix}"
|
||||
|
||||
def test__should_variable_be_visible(self):
|
||||
mock_session = mock.MagicMock(spec=Session)
|
||||
mock_session = MagicMock(spec=Session)
|
||||
test_app_id = self._get_test_app_id()
|
||||
saver = DraftVariableSaver(
|
||||
session=mock_session,
|
||||
@@ -70,7 +82,7 @@ class TestDraftVariableSaver:
|
||||
),
|
||||
]
|
||||
|
||||
mock_session = mock.MagicMock(spec=Session)
|
||||
mock_session = MagicMock(spec=Session)
|
||||
test_app_id = self._get_test_app_id()
|
||||
saver = DraftVariableSaver(
|
||||
session=mock_session,
|
||||
@@ -105,9 +117,8 @@ class TestWorkflowDraftVariableService:
|
||||
conversation_variables=[],
|
||||
)
|
||||
|
||||
def test_reset_conversation_variable(self):
|
||||
def test_reset_conversation_variable(self, mock_session):
|
||||
"""Test resetting a conversation variable"""
|
||||
mock_session = Mock(spec=Session)
|
||||
service = WorkflowDraftVariableService(mock_session)
|
||||
|
||||
test_app_id = self._get_test_app_id()
|
||||
@@ -131,9 +142,8 @@ class TestWorkflowDraftVariableService:
|
||||
mock_reset_conv.assert_called_once_with(workflow, variable)
|
||||
assert result == expected_result
|
||||
|
||||
def test_reset_node_variable_with_no_execution_id(self):
|
||||
def test_reset_node_variable_with_no_execution_id(self, mock_session):
|
||||
"""Test resetting a node variable with no execution ID - should delete variable"""
|
||||
mock_session = Mock(spec=Session)
|
||||
service = WorkflowDraftVariableService(mock_session)
|
||||
|
||||
test_app_id = self._get_test_app_id()
|
||||
@@ -158,11 +168,26 @@ class TestWorkflowDraftVariableService:
|
||||
mock_session.flush.assert_called_once()
|
||||
assert result is None
|
||||
|
||||
def test_reset_node_variable_with_missing_execution_record(self):
|
||||
def test_reset_node_variable_with_missing_execution_record(
|
||||
self,
|
||||
mock_engine,
|
||||
mock_session,
|
||||
monkeypatch,
|
||||
):
|
||||
"""Test resetting a node variable when execution record doesn't exist"""
|
||||
mock_session = Mock(spec=Session)
|
||||
mock_repo_session = Mock(spec=Session)
|
||||
|
||||
mock_session_maker = MagicMock()
|
||||
# Mock the context manager protocol for sessionmaker
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_repo_session
|
||||
mock_session_maker.return_value.__exit__.return_value = None
|
||||
monkeypatch.setattr("services.workflow_draft_variable_service.sessionmaker", mock_session_maker)
|
||||
service = WorkflowDraftVariableService(mock_session)
|
||||
|
||||
# Mock the repository to return None (no execution record found)
|
||||
service._api_node_execution_repo = Mock()
|
||||
service._api_node_execution_repo.get_execution_by_id.return_value = None
|
||||
|
||||
test_app_id = self._get_test_app_id()
|
||||
workflow = self._create_test_workflow(test_app_id)
|
||||
|
||||
@@ -171,24 +196,41 @@ class TestWorkflowDraftVariableService:
|
||||
variable = WorkflowDraftVariable.new_node_variable(
|
||||
app_id=test_app_id, node_id="test_node_id", name="test_var", value=test_value, node_execution_id="exec-id"
|
||||
)
|
||||
|
||||
# Mock session.scalars to return None (no execution record found)
|
||||
mock_scalars = Mock()
|
||||
mock_scalars.first.return_value = None
|
||||
mock_session.scalars.return_value = mock_scalars
|
||||
# Variable is editable by default from factory method
|
||||
|
||||
result = service._reset_node_var_or_sys_var(workflow, variable)
|
||||
|
||||
mock_session_maker.assert_called_once_with(bind=mock_engine, expire_on_commit=False)
|
||||
# Should delete the variable and return None
|
||||
mock_session.delete.assert_called_once_with(instance=variable)
|
||||
mock_session.flush.assert_called_once()
|
||||
assert result is None
|
||||
|
||||
def test_reset_node_variable_with_valid_execution_record(self):
|
||||
def test_reset_node_variable_with_valid_execution_record(
|
||||
self,
|
||||
mock_session,
|
||||
monkeypatch,
|
||||
):
|
||||
"""Test resetting a node variable with valid execution record - should restore from execution"""
|
||||
mock_session = Mock(spec=Session)
|
||||
mock_repo_session = Mock(spec=Session)
|
||||
|
||||
mock_session_maker = MagicMock()
|
||||
# Mock the context manager protocol for sessionmaker
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_repo_session
|
||||
mock_session_maker.return_value.__exit__.return_value = None
|
||||
mock_session_maker = monkeypatch.setattr(
|
||||
"services.workflow_draft_variable_service.sessionmaker", mock_session_maker
|
||||
)
|
||||
service = WorkflowDraftVariableService(mock_session)
|
||||
|
||||
# Create mock execution record
|
||||
mock_execution = Mock(spec=WorkflowNodeExecutionModel)
|
||||
mock_execution.outputs_dict = {"test_var": "output_value"}
|
||||
|
||||
# Mock the repository to return the execution record
|
||||
service._api_node_execution_repo = Mock()
|
||||
service._api_node_execution_repo.get_execution_by_id.return_value = mock_execution
|
||||
|
||||
test_app_id = self._get_test_app_id()
|
||||
workflow = self._create_test_workflow(test_app_id)
|
||||
|
||||
@@ -197,16 +239,7 @@ class TestWorkflowDraftVariableService:
|
||||
variable = WorkflowDraftVariable.new_node_variable(
|
||||
app_id=test_app_id, node_id="test_node_id", name="test_var", value=test_value, node_execution_id="exec-id"
|
||||
)
|
||||
|
||||
# Create mock execution record
|
||||
mock_execution = Mock(spec=WorkflowNodeExecutionModel)
|
||||
mock_execution.process_data_dict = {"test_var": "process_value"}
|
||||
mock_execution.outputs_dict = {"test_var": "output_value"}
|
||||
|
||||
# Mock session.scalars to return the execution record
|
||||
mock_scalars = Mock()
|
||||
mock_scalars.first.return_value = mock_execution
|
||||
mock_session.scalars.return_value = mock_scalars
|
||||
# Variable is editable by default from factory method
|
||||
|
||||
# Mock workflow methods
|
||||
mock_node_config = {"type": "test_node"}
|
||||
@@ -224,9 +257,8 @@ class TestWorkflowDraftVariableService:
|
||||
# Should return the updated variable
|
||||
assert result == variable
|
||||
|
||||
def test_reset_non_editable_system_variable_raises_error(self):
|
||||
def test_reset_non_editable_system_variable_raises_error(self, mock_session):
|
||||
"""Test that resetting a non-editable system variable raises an error"""
|
||||
mock_session = Mock(spec=Session)
|
||||
service = WorkflowDraftVariableService(mock_session)
|
||||
|
||||
test_app_id = self._get_test_app_id()
|
||||
@@ -242,24 +274,13 @@ class TestWorkflowDraftVariableService:
|
||||
editable=False, # Non-editable system variable
|
||||
)
|
||||
|
||||
# Mock the service to properly check system variable editability
|
||||
with patch.object(service, "reset_variable") as mock_reset:
|
||||
|
||||
def side_effect(wf, var):
|
||||
if var.get_variable_type() == DraftVariableType.SYS and not is_system_variable_editable(var.name):
|
||||
raise VariableResetError(f"cannot reset system variable, variable_id={var.id}")
|
||||
return var
|
||||
|
||||
mock_reset.side_effect = side_effect
|
||||
|
||||
with pytest.raises(VariableResetError) as exc_info:
|
||||
service.reset_variable(workflow, variable)
|
||||
assert "cannot reset system variable" in str(exc_info.value)
|
||||
assert f"variable_id={variable.id}" in str(exc_info.value)
|
||||
|
||||
def test_reset_editable_system_variable_succeeds(self):
|
||||
def test_reset_editable_system_variable_succeeds(self, mock_session):
|
||||
"""Test that resetting an editable system variable succeeds"""
|
||||
mock_session = Mock(spec=Session)
|
||||
service = WorkflowDraftVariableService(mock_session)
|
||||
|
||||
test_app_id = self._get_test_app_id()
|
||||
@@ -279,10 +300,9 @@ class TestWorkflowDraftVariableService:
|
||||
mock_execution = Mock(spec=WorkflowNodeExecutionModel)
|
||||
mock_execution.outputs_dict = {"sys.files": "[]"}
|
||||
|
||||
# Mock session.scalars to return the execution record
|
||||
mock_scalars = Mock()
|
||||
mock_scalars.first.return_value = mock_execution
|
||||
mock_session.scalars.return_value = mock_scalars
|
||||
# Mock the repository to return the execution record
|
||||
service._api_node_execution_repo = Mock()
|
||||
service._api_node_execution_repo.get_execution_by_id.return_value = mock_execution
|
||||
|
||||
result = service._reset_node_var_or_sys_var(workflow, variable)
|
||||
|
||||
@@ -291,9 +311,8 @@ class TestWorkflowDraftVariableService:
|
||||
assert variable.last_edited_at is None
|
||||
mock_session.flush.assert_called()
|
||||
|
||||
def test_reset_query_system_variable_succeeds(self):
|
||||
def test_reset_query_system_variable_succeeds(self, mock_session):
|
||||
"""Test that resetting query system variable (another editable one) succeeds"""
|
||||
mock_session = Mock(spec=Session)
|
||||
service = WorkflowDraftVariableService(mock_session)
|
||||
|
||||
test_app_id = self._get_test_app_id()
|
||||
@@ -313,10 +332,9 @@ class TestWorkflowDraftVariableService:
|
||||
mock_execution = Mock(spec=WorkflowNodeExecutionModel)
|
||||
mock_execution.outputs_dict = {"sys.query": "reset query"}
|
||||
|
||||
# Mock session.scalars to return the execution record
|
||||
mock_scalars = Mock()
|
||||
mock_scalars.first.return_value = mock_execution
|
||||
mock_session.scalars.return_value = mock_scalars
|
||||
# Mock the repository to return the execution record
|
||||
service._api_node_execution_repo = Mock()
|
||||
service._api_node_execution_repo.get_execution_by_id.return_value = mock_execution
|
||||
|
||||
result = service._reset_node_var_or_sys_var(workflow, variable)
|
||||
|
||||
|
@@ -0,0 +1,288 @@
|
||||
from datetime import datetime
|
||||
from unittest.mock import MagicMock
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from models.workflow import WorkflowNodeExecutionModel
|
||||
from repositories.sqlalchemy_api_workflow_node_execution_repository import (
|
||||
DifyAPISQLAlchemyWorkflowNodeExecutionRepository,
|
||||
)
|
||||
|
||||
|
||||
class TestSQLAlchemyWorkflowNodeExecutionServiceRepository:
|
||||
@pytest.fixture
|
||||
def repository(self):
|
||||
mock_session_maker = MagicMock()
|
||||
return DifyAPISQLAlchemyWorkflowNodeExecutionRepository(session_maker=mock_session_maker)
|
||||
|
||||
@pytest.fixture
|
||||
def mock_execution(self):
|
||||
execution = MagicMock(spec=WorkflowNodeExecutionModel)
|
||||
execution.id = str(uuid4())
|
||||
execution.tenant_id = "tenant-123"
|
||||
execution.app_id = "app-456"
|
||||
execution.workflow_id = "workflow-789"
|
||||
execution.workflow_run_id = "run-101"
|
||||
execution.node_id = "node-202"
|
||||
execution.index = 1
|
||||
execution.created_at = "2023-01-01T00:00:00Z"
|
||||
return execution
|
||||
|
||||
def test_get_node_last_execution_found(self, repository, mock_execution):
|
||||
"""Test getting the last execution for a node when it exists."""
|
||||
# Arrange
|
||||
mock_session = MagicMock(spec=Session)
|
||||
repository._session_maker.return_value.__enter__.return_value = mock_session
|
||||
mock_session.scalar.return_value = mock_execution
|
||||
|
||||
# Act
|
||||
result = repository.get_node_last_execution(
|
||||
tenant_id="tenant-123",
|
||||
app_id="app-456",
|
||||
workflow_id="workflow-789",
|
||||
node_id="node-202",
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result == mock_execution
|
||||
mock_session.scalar.assert_called_once()
|
||||
# Verify the query was constructed correctly
|
||||
call_args = mock_session.scalar.call_args[0][0]
|
||||
assert hasattr(call_args, "compile") # It's a SQLAlchemy statement
|
||||
|
||||
def test_get_node_last_execution_not_found(self, repository):
|
||||
"""Test getting the last execution for a node when it doesn't exist."""
|
||||
# Arrange
|
||||
mock_session = MagicMock(spec=Session)
|
||||
repository._session_maker.return_value.__enter__.return_value = mock_session
|
||||
mock_session.scalar.return_value = None
|
||||
|
||||
# Act
|
||||
result = repository.get_node_last_execution(
|
||||
tenant_id="tenant-123",
|
||||
app_id="app-456",
|
||||
workflow_id="workflow-789",
|
||||
node_id="node-202",
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
mock_session.scalar.assert_called_once()
|
||||
|
||||
def test_get_executions_by_workflow_run(self, repository, mock_execution):
|
||||
"""Test getting all executions for a workflow run."""
|
||||
# Arrange
|
||||
mock_session = MagicMock(spec=Session)
|
||||
repository._session_maker.return_value.__enter__.return_value = mock_session
|
||||
executions = [mock_execution]
|
||||
mock_session.execute.return_value.scalars.return_value.all.return_value = executions
|
||||
|
||||
# Act
|
||||
result = repository.get_executions_by_workflow_run(
|
||||
tenant_id="tenant-123",
|
||||
app_id="app-456",
|
||||
workflow_run_id="run-101",
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result == executions
|
||||
mock_session.execute.assert_called_once()
|
||||
# Verify the query was constructed correctly
|
||||
call_args = mock_session.execute.call_args[0][0]
|
||||
assert hasattr(call_args, "compile") # It's a SQLAlchemy statement
|
||||
|
||||
def test_get_executions_by_workflow_run_empty(self, repository):
|
||||
"""Test getting executions for a workflow run when none exist."""
|
||||
# Arrange
|
||||
mock_session = MagicMock(spec=Session)
|
||||
repository._session_maker.return_value.__enter__.return_value = mock_session
|
||||
mock_session.execute.return_value.scalars.return_value.all.return_value = []
|
||||
|
||||
# Act
|
||||
result = repository.get_executions_by_workflow_run(
|
||||
tenant_id="tenant-123",
|
||||
app_id="app-456",
|
||||
workflow_run_id="run-101",
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result == []
|
||||
mock_session.execute.assert_called_once()
|
||||
|
||||
def test_get_execution_by_id_found(self, repository, mock_execution):
|
||||
"""Test getting execution by ID when it exists."""
|
||||
# Arrange
|
||||
mock_session = MagicMock(spec=Session)
|
||||
repository._session_maker.return_value.__enter__.return_value = mock_session
|
||||
mock_session.scalar.return_value = mock_execution
|
||||
|
||||
# Act
|
||||
result = repository.get_execution_by_id(mock_execution.id)
|
||||
|
||||
# Assert
|
||||
assert result == mock_execution
|
||||
mock_session.scalar.assert_called_once()
|
||||
|
||||
def test_get_execution_by_id_not_found(self, repository):
|
||||
"""Test getting execution by ID when it doesn't exist."""
|
||||
# Arrange
|
||||
mock_session = MagicMock(spec=Session)
|
||||
repository._session_maker.return_value.__enter__.return_value = mock_session
|
||||
mock_session.scalar.return_value = None
|
||||
|
||||
# Act
|
||||
result = repository.get_execution_by_id("non-existent-id")
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
mock_session.scalar.assert_called_once()
|
||||
|
||||
def test_repository_implements_protocol(self, repository):
|
||||
"""Test that the repository implements the required protocol methods."""
|
||||
# Verify all protocol methods are implemented
|
||||
assert hasattr(repository, "get_node_last_execution")
|
||||
assert hasattr(repository, "get_executions_by_workflow_run")
|
||||
assert hasattr(repository, "get_execution_by_id")
|
||||
|
||||
# Verify methods are callable
|
||||
assert callable(repository.get_node_last_execution)
|
||||
assert callable(repository.get_executions_by_workflow_run)
|
||||
assert callable(repository.get_execution_by_id)
|
||||
assert callable(repository.delete_expired_executions)
|
||||
assert callable(repository.delete_executions_by_app)
|
||||
assert callable(repository.get_expired_executions_batch)
|
||||
assert callable(repository.delete_executions_by_ids)
|
||||
|
||||
def test_delete_expired_executions(self, repository):
|
||||
"""Test deleting expired executions."""
|
||||
# Arrange
|
||||
mock_session = MagicMock(spec=Session)
|
||||
repository._session_maker.return_value.__enter__.return_value = mock_session
|
||||
|
||||
# Mock the select query to return some IDs first time, then empty to stop loop
|
||||
execution_ids = ["id1", "id2"] # Less than batch_size to trigger break
|
||||
|
||||
# Mock execute method to handle both select and delete statements
|
||||
def mock_execute(stmt):
|
||||
mock_result = MagicMock()
|
||||
# For select statements, return execution IDs
|
||||
if hasattr(stmt, "limit"): # This is our select statement
|
||||
mock_result.scalars.return_value.all.return_value = execution_ids
|
||||
else: # This is our delete statement
|
||||
mock_result.rowcount = 2
|
||||
return mock_result
|
||||
|
||||
mock_session.execute.side_effect = mock_execute
|
||||
|
||||
before_date = datetime(2023, 1, 1)
|
||||
|
||||
# Act
|
||||
result = repository.delete_expired_executions(
|
||||
tenant_id="tenant-123",
|
||||
before_date=before_date,
|
||||
batch_size=1000,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result == 2
|
||||
assert mock_session.execute.call_count == 2 # One select call, one delete call
|
||||
mock_session.commit.assert_called_once()
|
||||
|
||||
def test_delete_executions_by_app(self, repository):
|
||||
"""Test deleting executions by app."""
|
||||
# Arrange
|
||||
mock_session = MagicMock(spec=Session)
|
||||
repository._session_maker.return_value.__enter__.return_value = mock_session
|
||||
|
||||
# Mock the select query to return some IDs first time, then empty to stop loop
|
||||
execution_ids = ["id1", "id2"]
|
||||
|
||||
# Mock execute method to handle both select and delete statements
|
||||
def mock_execute(stmt):
|
||||
mock_result = MagicMock()
|
||||
# For select statements, return execution IDs
|
||||
if hasattr(stmt, "limit"): # This is our select statement
|
||||
mock_result.scalars.return_value.all.return_value = execution_ids
|
||||
else: # This is our delete statement
|
||||
mock_result.rowcount = 2
|
||||
return mock_result
|
||||
|
||||
mock_session.execute.side_effect = mock_execute
|
||||
|
||||
# Act
|
||||
result = repository.delete_executions_by_app(
|
||||
tenant_id="tenant-123",
|
||||
app_id="app-456",
|
||||
batch_size=1000,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result == 2
|
||||
assert mock_session.execute.call_count == 2 # One select call, one delete call
|
||||
mock_session.commit.assert_called_once()
|
||||
|
||||
def test_get_expired_executions_batch(self, repository):
|
||||
"""Test getting expired executions batch for backup."""
|
||||
# Arrange
|
||||
mock_session = MagicMock(spec=Session)
|
||||
repository._session_maker.return_value.__enter__.return_value = mock_session
|
||||
|
||||
# Create mock execution objects
|
||||
mock_execution1 = MagicMock()
|
||||
mock_execution1.id = "exec-1"
|
||||
mock_execution2 = MagicMock()
|
||||
mock_execution2.id = "exec-2"
|
||||
|
||||
mock_session.execute.return_value.scalars.return_value.all.return_value = [mock_execution1, mock_execution2]
|
||||
|
||||
before_date = datetime(2023, 1, 1)
|
||||
|
||||
# Act
|
||||
result = repository.get_expired_executions_batch(
|
||||
tenant_id="tenant-123",
|
||||
before_date=before_date,
|
||||
batch_size=1000,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert len(result) == 2
|
||||
assert result[0].id == "exec-1"
|
||||
assert result[1].id == "exec-2"
|
||||
mock_session.execute.assert_called_once()
|
||||
|
||||
def test_delete_executions_by_ids(self, repository):
|
||||
"""Test deleting executions by IDs."""
|
||||
# Arrange
|
||||
mock_session = MagicMock(spec=Session)
|
||||
repository._session_maker.return_value.__enter__.return_value = mock_session
|
||||
|
||||
# Mock the delete query result
|
||||
mock_result = MagicMock()
|
||||
mock_result.rowcount = 3
|
||||
mock_session.execute.return_value = mock_result
|
||||
|
||||
execution_ids = ["id1", "id2", "id3"]
|
||||
|
||||
# Act
|
||||
result = repository.delete_executions_by_ids(execution_ids)
|
||||
|
||||
# Assert
|
||||
assert result == 3
|
||||
mock_session.execute.assert_called_once()
|
||||
mock_session.commit.assert_called_once()
|
||||
|
||||
def test_delete_executions_by_ids_empty_list(self, repository):
|
||||
"""Test deleting executions with empty ID list."""
|
||||
# Arrange
|
||||
mock_session = MagicMock(spec=Session)
|
||||
repository._session_maker.return_value.__enter__.return_value = mock_session
|
||||
|
||||
# Act
|
||||
result = repository.delete_executions_by_ids([])
|
||||
|
||||
# Assert
|
||||
assert result == 0
|
||||
mock_session.query.assert_not_called()
|
||||
mock_session.commit.assert_not_called()
|
@@ -10,7 +10,8 @@ from services.workflow_service import WorkflowService
|
||||
class TestWorkflowService:
|
||||
@pytest.fixture
|
||||
def workflow_service(self):
|
||||
return WorkflowService()
|
||||
mock_session_maker = MagicMock()
|
||||
return WorkflowService(mock_session_maker)
|
||||
|
||||
@pytest.fixture
|
||||
def mock_app(self):
|
||||
|
@@ -799,6 +799,19 @@ WORKFLOW_FILE_UPLOAD_LIMIT=10
|
||||
# hybrid: Save new data to object storage, read from both object storage and RDBMS
|
||||
WORKFLOW_NODE_EXECUTION_STORAGE=rdbms
|
||||
|
||||
# Repository configuration
|
||||
# Core workflow execution repository implementation
|
||||
CORE_WORKFLOW_EXECUTION_REPOSITORY=core.repositories.sqlalchemy_workflow_execution_repository.SQLAlchemyWorkflowExecutionRepository
|
||||
|
||||
# Core workflow node execution repository implementation
|
||||
CORE_WORKFLOW_NODE_EXECUTION_REPOSITORY=core.repositories.sqlalchemy_workflow_node_execution_repository.SQLAlchemyWorkflowNodeExecutionRepository
|
||||
|
||||
# API workflow node execution repository implementation
|
||||
API_WORKFLOW_NODE_EXECUTION_REPOSITORY=repositories.sqlalchemy_api_workflow_node_execution_repository.DifyAPISQLAlchemyWorkflowNodeExecutionRepository
|
||||
|
||||
# API workflow run repository implementation
|
||||
API_WORKFLOW_RUN_REPOSITORY=repositories.sqlalchemy_api_workflow_run_repository.DifyAPISQLAlchemyWorkflowRunRepository
|
||||
|
||||
# HTTP request node in workflow configuration
|
||||
HTTP_REQUEST_NODE_MAX_BINARY_SIZE=10485760
|
||||
HTTP_REQUEST_NODE_MAX_TEXT_SIZE=1048576
|
||||
|
@@ -354,6 +354,10 @@ x-shared-env: &shared-api-worker-env
|
||||
WORKFLOW_PARALLEL_DEPTH_LIMIT: ${WORKFLOW_PARALLEL_DEPTH_LIMIT:-3}
|
||||
WORKFLOW_FILE_UPLOAD_LIMIT: ${WORKFLOW_FILE_UPLOAD_LIMIT:-10}
|
||||
WORKFLOW_NODE_EXECUTION_STORAGE: ${WORKFLOW_NODE_EXECUTION_STORAGE:-rdbms}
|
||||
CORE_WORKFLOW_EXECUTION_REPOSITORY: ${CORE_WORKFLOW_EXECUTION_REPOSITORY:-core.repositories.sqlalchemy_workflow_execution_repository.SQLAlchemyWorkflowExecutionRepository}
|
||||
CORE_WORKFLOW_NODE_EXECUTION_REPOSITORY: ${CORE_WORKFLOW_NODE_EXECUTION_REPOSITORY:-core.repositories.sqlalchemy_workflow_node_execution_repository.SQLAlchemyWorkflowNodeExecutionRepository}
|
||||
API_WORKFLOW_NODE_EXECUTION_REPOSITORY: ${API_WORKFLOW_NODE_EXECUTION_REPOSITORY:-repositories.sqlalchemy_api_workflow_node_execution_repository.DifyAPISQLAlchemyWorkflowNodeExecutionRepository}
|
||||
API_WORKFLOW_RUN_REPOSITORY: ${API_WORKFLOW_RUN_REPOSITORY:-repositories.sqlalchemy_api_workflow_run_repository.DifyAPISQLAlchemyWorkflowRunRepository}
|
||||
HTTP_REQUEST_NODE_MAX_BINARY_SIZE: ${HTTP_REQUEST_NODE_MAX_BINARY_SIZE:-10485760}
|
||||
HTTP_REQUEST_NODE_MAX_TEXT_SIZE: ${HTTP_REQUEST_NODE_MAX_TEXT_SIZE:-1048576}
|
||||
HTTP_REQUEST_NODE_SSL_VERIFY: ${HTTP_REQUEST_NODE_SSL_VERIFY:-True}
|
||||
|
Reference in New Issue
Block a user