refactor(api): Decouple ParameterExtractorNode
from LLMNode
(#20843)
- Extract methods used by `ParameterExtractorNode` from `LLMNode` into a separate file. - Convert `ParameterExtractorNode` into a subclass of `BaseNode`. - Refactor code referencing the extracted methods to ensure functionality and clarity. - Fixes the issue that `ParameterExtractorNode` returns error when executed. - Fix relevant test cases. Closes #20840.
This commit is contained in:
@@ -21,7 +21,7 @@ from core.plugin.entities.request import (
|
|||||||
)
|
)
|
||||||
from core.tools.entities.tool_entities import ToolProviderType
|
from core.tools.entities.tool_entities import ToolProviderType
|
||||||
from core.tools.utils.model_invocation_utils import ModelInvocationUtils
|
from core.tools.utils.model_invocation_utils import ModelInvocationUtils
|
||||||
from core.workflow.nodes.llm.node import LLMNode
|
from core.workflow.nodes.llm import llm_utils
|
||||||
from models.account import Tenant
|
from models.account import Tenant
|
||||||
|
|
||||||
|
|
||||||
@@ -55,7 +55,7 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
|
|||||||
def handle() -> Generator[LLMResultChunk, None, None]:
|
def handle() -> Generator[LLMResultChunk, None, None]:
|
||||||
for chunk in response:
|
for chunk in response:
|
||||||
if chunk.delta.usage:
|
if chunk.delta.usage:
|
||||||
LLMNode.deduct_llm_quota(
|
llm_utils.deduct_llm_quota(
|
||||||
tenant_id=tenant.id, model_instance=model_instance, usage=chunk.delta.usage
|
tenant_id=tenant.id, model_instance=model_instance, usage=chunk.delta.usage
|
||||||
)
|
)
|
||||||
chunk.prompt_messages = []
|
chunk.prompt_messages = []
|
||||||
@@ -64,7 +64,7 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
|
|||||||
return handle()
|
return handle()
|
||||||
else:
|
else:
|
||||||
if response.usage:
|
if response.usage:
|
||||||
LLMNode.deduct_llm_quota(tenant_id=tenant.id, model_instance=model_instance, usage=response.usage)
|
llm_utils.deduct_llm_quota(tenant_id=tenant.id, model_instance=model_instance, usage=response.usage)
|
||||||
|
|
||||||
def handle_non_streaming(response: LLMResult) -> Generator[LLMResultChunk, None, None]:
|
def handle_non_streaming(response: LLMResult) -> Generator[LLMResultChunk, None, None]:
|
||||||
yield LLMResultChunk(
|
yield LLMResultChunk(
|
||||||
|
@@ -9,7 +9,7 @@ from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
|
|||||||
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate
|
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate
|
||||||
from core.rag.retrieval.output_parser.react_output import ReactAction
|
from core.rag.retrieval.output_parser.react_output import ReactAction
|
||||||
from core.rag.retrieval.output_parser.structured_chat import StructuredChatOutputParser
|
from core.rag.retrieval.output_parser.structured_chat import StructuredChatOutputParser
|
||||||
from core.workflow.nodes.llm import LLMNode
|
from core.workflow.nodes.llm import llm_utils
|
||||||
|
|
||||||
PREFIX = """Respond to the human as helpfully and accurately as possible. You have access to the following tools:"""
|
PREFIX = """Respond to the human as helpfully and accurately as possible. You have access to the following tools:"""
|
||||||
|
|
||||||
@@ -165,7 +165,7 @@ class ReactMultiDatasetRouter:
|
|||||||
text, usage = self._handle_invoke_result(invoke_result=invoke_result)
|
text, usage = self._handle_invoke_result(invoke_result=invoke_result)
|
||||||
|
|
||||||
# deduct quota
|
# deduct quota
|
||||||
LLMNode.deduct_llm_quota(tenant_id=tenant_id, model_instance=model_instance, usage=usage)
|
llm_utils.deduct_llm_quota(tenant_id=tenant_id, model_instance=model_instance, usage=usage)
|
||||||
|
|
||||||
return text, usage
|
return text, usage
|
||||||
|
|
||||||
|
156
api/core/workflow/nodes/llm/llm_utils.py
Normal file
156
api/core/workflow/nodes/llm/llm_utils.py
Normal file
@@ -0,0 +1,156 @@
|
|||||||
|
from collections.abc import Sequence
|
||||||
|
from datetime import UTC, datetime
|
||||||
|
from typing import Optional, cast
|
||||||
|
|
||||||
|
from sqlalchemy import select, update
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from configs import dify_config
|
||||||
|
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||||
|
from core.entities.provider_entities import QuotaUnit
|
||||||
|
from core.file.models import File
|
||||||
|
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||||
|
from core.model_manager import ModelInstance, ModelManager
|
||||||
|
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||||
|
from core.model_runtime.entities.model_entities import ModelType
|
||||||
|
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||||
|
from core.plugin.entities.plugin import ModelProviderID
|
||||||
|
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
|
||||||
|
from core.variables.segments import ArrayAnySegment, ArrayFileSegment, FileSegment, NoneSegment, StringSegment
|
||||||
|
from core.workflow.entities.variable_pool import VariablePool
|
||||||
|
from core.workflow.enums import SystemVariableKey
|
||||||
|
from core.workflow.nodes.llm.entities import ModelConfig
|
||||||
|
from models import db
|
||||||
|
from models.model import Conversation
|
||||||
|
from models.provider import Provider, ProviderType
|
||||||
|
|
||||||
|
from .exc import InvalidVariableTypeError, LLMModeRequiredError, ModelNotExistError
|
||||||
|
|
||||||
|
|
||||||
|
def fetch_model_config(
|
||||||
|
tenant_id: str, node_data_model: ModelConfig
|
||||||
|
) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]:
|
||||||
|
if not node_data_model.mode:
|
||||||
|
raise LLMModeRequiredError("LLM mode is required.")
|
||||||
|
|
||||||
|
model = ModelManager().get_model_instance(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
model_type=ModelType.LLM,
|
||||||
|
provider=node_data_model.provider,
|
||||||
|
model=node_data_model.name,
|
||||||
|
)
|
||||||
|
|
||||||
|
model.model_type_instance = cast(LargeLanguageModel, model.model_type_instance)
|
||||||
|
|
||||||
|
# check model
|
||||||
|
provider_model = model.provider_model_bundle.configuration.get_provider_model(
|
||||||
|
model=node_data_model.name, model_type=ModelType.LLM
|
||||||
|
)
|
||||||
|
|
||||||
|
if provider_model is None:
|
||||||
|
raise ModelNotExistError(f"Model {node_data_model.name} not exist.")
|
||||||
|
provider_model.raise_for_status()
|
||||||
|
|
||||||
|
# model config
|
||||||
|
stop: list[str] = []
|
||||||
|
if "stop" in node_data_model.completion_params:
|
||||||
|
stop = node_data_model.completion_params.pop("stop")
|
||||||
|
|
||||||
|
model_schema = model.model_type_instance.get_model_schema(node_data_model.name, model.credentials)
|
||||||
|
if not model_schema:
|
||||||
|
raise ModelNotExistError(f"Model {node_data_model.name} not exist.")
|
||||||
|
|
||||||
|
return model, ModelConfigWithCredentialsEntity(
|
||||||
|
provider=node_data_model.provider,
|
||||||
|
model=node_data_model.name,
|
||||||
|
model_schema=model_schema,
|
||||||
|
mode=node_data_model.mode,
|
||||||
|
provider_model_bundle=model.provider_model_bundle,
|
||||||
|
credentials=model.credentials,
|
||||||
|
parameters=node_data_model.completion_params,
|
||||||
|
stop=stop,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def fetch_files(variable_pool: VariablePool, selector: Sequence[str]) -> Sequence["File"]:
|
||||||
|
variable = variable_pool.get(selector)
|
||||||
|
if variable is None:
|
||||||
|
return []
|
||||||
|
elif isinstance(variable, FileSegment):
|
||||||
|
return [variable.value]
|
||||||
|
elif isinstance(variable, ArrayFileSegment):
|
||||||
|
return variable.value
|
||||||
|
elif isinstance(variable, NoneSegment | ArrayAnySegment):
|
||||||
|
return []
|
||||||
|
raise InvalidVariableTypeError(f"Invalid variable type: {type(variable)}")
|
||||||
|
|
||||||
|
|
||||||
|
def fetch_memory(
|
||||||
|
variable_pool: VariablePool, app_id: str, node_data_memory: Optional[MemoryConfig], model_instance: ModelInstance
|
||||||
|
) -> Optional[TokenBufferMemory]:
|
||||||
|
if not node_data_memory:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# get conversation id
|
||||||
|
conversation_id_variable = variable_pool.get(["sys", SystemVariableKey.CONVERSATION_ID.value])
|
||||||
|
if not isinstance(conversation_id_variable, StringSegment):
|
||||||
|
return None
|
||||||
|
conversation_id = conversation_id_variable.value
|
||||||
|
|
||||||
|
with Session(db.engine, expire_on_commit=False) as session:
|
||||||
|
stmt = select(Conversation).where(Conversation.app_id == app_id, Conversation.id == conversation_id)
|
||||||
|
conversation = session.scalar(stmt)
|
||||||
|
if not conversation:
|
||||||
|
return None
|
||||||
|
|
||||||
|
memory = TokenBufferMemory(conversation=conversation, model_instance=model_instance)
|
||||||
|
return memory
|
||||||
|
|
||||||
|
|
||||||
|
def deduct_llm_quota(tenant_id: str, model_instance: ModelInstance, usage: LLMUsage) -> None:
|
||||||
|
provider_model_bundle = model_instance.provider_model_bundle
|
||||||
|
provider_configuration = provider_model_bundle.configuration
|
||||||
|
|
||||||
|
if provider_configuration.using_provider_type != ProviderType.SYSTEM:
|
||||||
|
return
|
||||||
|
|
||||||
|
system_configuration = provider_configuration.system_configuration
|
||||||
|
|
||||||
|
quota_unit = None
|
||||||
|
for quota_configuration in system_configuration.quota_configurations:
|
||||||
|
if quota_configuration.quota_type == system_configuration.current_quota_type:
|
||||||
|
quota_unit = quota_configuration.quota_unit
|
||||||
|
|
||||||
|
if quota_configuration.quota_limit == -1:
|
||||||
|
return
|
||||||
|
|
||||||
|
break
|
||||||
|
|
||||||
|
used_quota = None
|
||||||
|
if quota_unit:
|
||||||
|
if quota_unit == QuotaUnit.TOKENS:
|
||||||
|
used_quota = usage.total_tokens
|
||||||
|
elif quota_unit == QuotaUnit.CREDITS:
|
||||||
|
used_quota = dify_config.get_model_credits(model_instance.model)
|
||||||
|
else:
|
||||||
|
used_quota = 1
|
||||||
|
|
||||||
|
if used_quota is not None and system_configuration.current_quota_type is not None:
|
||||||
|
with Session(db.engine) as session:
|
||||||
|
stmt = (
|
||||||
|
update(Provider)
|
||||||
|
.where(
|
||||||
|
Provider.tenant_id == tenant_id,
|
||||||
|
# TODO: Use provider name with prefix after the data migration.
|
||||||
|
Provider.provider_name == ModelProviderID(model_instance.provider).provider_name,
|
||||||
|
Provider.provider_type == ProviderType.SYSTEM.value,
|
||||||
|
Provider.quota_type == system_configuration.current_quota_type.value,
|
||||||
|
Provider.quota_limit > Provider.quota_used,
|
||||||
|
)
|
||||||
|
.values(
|
||||||
|
quota_used=Provider.quota_used + used_quota,
|
||||||
|
last_used=datetime.now(tz=UTC).replace(tzinfo=None),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
session.execute(stmt)
|
||||||
|
session.commit()
|
@@ -3,16 +3,11 @@ import io
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from collections.abc import Generator, Mapping, Sequence
|
from collections.abc import Generator, Mapping, Sequence
|
||||||
from datetime import UTC, datetime
|
|
||||||
from typing import TYPE_CHECKING, Any, Optional, cast
|
from typing import TYPE_CHECKING, Any, Optional, cast
|
||||||
|
|
||||||
import json_repair
|
import json_repair
|
||||||
from sqlalchemy import select, update
|
|
||||||
from sqlalchemy.orm import Session
|
|
||||||
|
|
||||||
from configs import dify_config
|
|
||||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||||
from core.entities.provider_entities import QuotaUnit
|
|
||||||
from core.file import FileType, file_manager
|
from core.file import FileType, file_manager
|
||||||
from core.helper.code_executor import CodeExecutor, CodeLanguage
|
from core.helper.code_executor import CodeExecutor, CodeLanguage
|
||||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||||
@@ -40,12 +35,10 @@ from core.model_runtime.entities.model_entities import (
|
|||||||
)
|
)
|
||||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||||
from core.plugin.entities.plugin import ModelProviderID
|
|
||||||
from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig
|
from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig
|
||||||
from core.prompt.utils.prompt_message_util import PromptMessageUtil
|
from core.prompt.utils.prompt_message_util import PromptMessageUtil
|
||||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||||
from core.variables import (
|
from core.variables import (
|
||||||
ArrayAnySegment,
|
|
||||||
ArrayFileSegment,
|
ArrayFileSegment,
|
||||||
ArraySegment,
|
ArraySegment,
|
||||||
FileSegment,
|
FileSegment,
|
||||||
@@ -75,10 +68,8 @@ from core.workflow.utils.structured_output.entities import (
|
|||||||
)
|
)
|
||||||
from core.workflow.utils.structured_output.prompt import STRUCTURED_OUTPUT_PROMPT
|
from core.workflow.utils.structured_output.prompt import STRUCTURED_OUTPUT_PROMPT
|
||||||
from core.workflow.utils.variable_template_parser import VariableTemplateParser
|
from core.workflow.utils.variable_template_parser import VariableTemplateParser
|
||||||
from extensions.ext_database import db
|
|
||||||
from models.model import Conversation
|
|
||||||
from models.provider import Provider, ProviderType
|
|
||||||
|
|
||||||
|
from . import llm_utils
|
||||||
from .entities import (
|
from .entities import (
|
||||||
LLMNodeChatModelMessage,
|
LLMNodeChatModelMessage,
|
||||||
LLMNodeCompletionModelPromptTemplate,
|
LLMNodeCompletionModelPromptTemplate,
|
||||||
@@ -88,7 +79,6 @@ from .entities import (
|
|||||||
from .exc import (
|
from .exc import (
|
||||||
InvalidContextStructureError,
|
InvalidContextStructureError,
|
||||||
InvalidVariableTypeError,
|
InvalidVariableTypeError,
|
||||||
LLMModeRequiredError,
|
|
||||||
LLMNodeError,
|
LLMNodeError,
|
||||||
MemoryRolePrefixRequiredError,
|
MemoryRolePrefixRequiredError,
|
||||||
ModelNotExistError,
|
ModelNotExistError,
|
||||||
@@ -160,6 +150,7 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|||||||
result_text = ""
|
result_text = ""
|
||||||
usage = LLMUsage.empty_usage()
|
usage = LLMUsage.empty_usage()
|
||||||
finish_reason = None
|
finish_reason = None
|
||||||
|
variable_pool = self.graph_runtime_state.variable_pool
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# init messages template
|
# init messages template
|
||||||
@@ -178,7 +169,10 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|||||||
|
|
||||||
# fetch files
|
# fetch files
|
||||||
files = (
|
files = (
|
||||||
self._fetch_files(selector=self.node_data.vision.configs.variable_selector)
|
llm_utils.fetch_files(
|
||||||
|
variable_pool=variable_pool,
|
||||||
|
selector=self.node_data.vision.configs.variable_selector,
|
||||||
|
)
|
||||||
if self.node_data.vision.enabled
|
if self.node_data.vision.enabled
|
||||||
else []
|
else []
|
||||||
)
|
)
|
||||||
@@ -200,15 +194,18 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|||||||
model_instance, model_config = self._fetch_model_config(self.node_data.model)
|
model_instance, model_config = self._fetch_model_config(self.node_data.model)
|
||||||
|
|
||||||
# fetch memory
|
# fetch memory
|
||||||
memory = self._fetch_memory(node_data_memory=self.node_data.memory, model_instance=model_instance)
|
memory = llm_utils.fetch_memory(
|
||||||
|
variable_pool=variable_pool,
|
||||||
|
app_id=self.app_id,
|
||||||
|
node_data_memory=self.node_data.memory,
|
||||||
|
model_instance=model_instance,
|
||||||
|
)
|
||||||
|
|
||||||
query = None
|
query = None
|
||||||
if self.node_data.memory:
|
if self.node_data.memory:
|
||||||
query = self.node_data.memory.query_prompt_template
|
query = self.node_data.memory.query_prompt_template
|
||||||
if not query and (
|
if not query and (
|
||||||
query_variable := self.graph_runtime_state.variable_pool.get(
|
query_variable := variable_pool.get((SYSTEM_VARIABLE_NODE_ID, SystemVariableKey.QUERY))
|
||||||
(SYSTEM_VARIABLE_NODE_ID, SystemVariableKey.QUERY)
|
|
||||||
)
|
|
||||||
):
|
):
|
||||||
query = query_variable.text
|
query = query_variable.text
|
||||||
|
|
||||||
@@ -222,7 +219,7 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|||||||
memory_config=self.node_data.memory,
|
memory_config=self.node_data.memory,
|
||||||
vision_enabled=self.node_data.vision.enabled,
|
vision_enabled=self.node_data.vision.enabled,
|
||||||
vision_detail=self.node_data.vision.configs.detail,
|
vision_detail=self.node_data.vision.configs.detail,
|
||||||
variable_pool=self.graph_runtime_state.variable_pool,
|
variable_pool=variable_pool,
|
||||||
jinja2_variables=self.node_data.prompt_config.jinja2_variables,
|
jinja2_variables=self.node_data.prompt_config.jinja2_variables,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -251,7 +248,7 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|||||||
usage = event.usage
|
usage = event.usage
|
||||||
finish_reason = event.finish_reason
|
finish_reason = event.finish_reason
|
||||||
# deduct quota
|
# deduct quota
|
||||||
self.deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage)
|
llm_utils.deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage)
|
||||||
break
|
break
|
||||||
outputs = {"text": result_text, "usage": jsonable_encoder(usage), "finish_reason": finish_reason}
|
outputs = {"text": result_text, "usage": jsonable_encoder(usage), "finish_reason": finish_reason}
|
||||||
structured_output = process_structured_output(result_text)
|
structured_output = process_structured_output(result_text)
|
||||||
@@ -447,18 +444,6 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|||||||
|
|
||||||
return inputs
|
return inputs
|
||||||
|
|
||||||
def _fetch_files(self, *, selector: Sequence[str]) -> Sequence["File"]:
|
|
||||||
variable = self.graph_runtime_state.variable_pool.get(selector)
|
|
||||||
if variable is None:
|
|
||||||
return []
|
|
||||||
elif isinstance(variable, FileSegment):
|
|
||||||
return [variable.value]
|
|
||||||
elif isinstance(variable, ArrayFileSegment):
|
|
||||||
return variable.value
|
|
||||||
elif isinstance(variable, NoneSegment | ArrayAnySegment):
|
|
||||||
return []
|
|
||||||
raise InvalidVariableTypeError(f"Invalid variable type: {type(variable)}")
|
|
||||||
|
|
||||||
def _fetch_context(self, node_data: LLMNodeData):
|
def _fetch_context(self, node_data: LLMNodeData):
|
||||||
if not node_data.context.enabled:
|
if not node_data.context.enabled:
|
||||||
return
|
return
|
||||||
@@ -524,31 +509,10 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|||||||
def _fetch_model_config(
|
def _fetch_model_config(
|
||||||
self, node_data_model: ModelConfig
|
self, node_data_model: ModelConfig
|
||||||
) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]:
|
) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]:
|
||||||
if not node_data_model.mode:
|
model, model_config_with_cred = llm_utils.fetch_model_config(
|
||||||
raise LLMModeRequiredError("LLM mode is required.")
|
tenant_id=self.tenant_id, node_data_model=node_data_model
|
||||||
|
|
||||||
model = ModelManager().get_model_instance(
|
|
||||||
tenant_id=self.tenant_id,
|
|
||||||
model_type=ModelType.LLM,
|
|
||||||
provider=node_data_model.provider,
|
|
||||||
model=node_data_model.name,
|
|
||||||
)
|
)
|
||||||
|
completion_params = model_config_with_cred.parameters
|
||||||
model.model_type_instance = cast(LargeLanguageModel, model.model_type_instance)
|
|
||||||
|
|
||||||
# check model
|
|
||||||
provider_model = model.provider_model_bundle.configuration.get_provider_model(
|
|
||||||
model=node_data_model.name, model_type=ModelType.LLM
|
|
||||||
)
|
|
||||||
|
|
||||||
if provider_model is None:
|
|
||||||
raise ModelNotExistError(f"Model {node_data_model.name} not exist.")
|
|
||||||
provider_model.raise_for_status()
|
|
||||||
|
|
||||||
# model config
|
|
||||||
stop: list[str] = []
|
|
||||||
if "stop" in node_data_model.completion_params:
|
|
||||||
stop = node_data_model.completion_params.pop("stop")
|
|
||||||
|
|
||||||
model_schema = model.model_type_instance.get_model_schema(node_data_model.name, model.credentials)
|
model_schema = model.model_type_instance.get_model_schema(node_data_model.name, model.credentials)
|
||||||
if not model_schema:
|
if not model_schema:
|
||||||
@@ -556,47 +520,12 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|||||||
|
|
||||||
if self.node_data.structured_output_enabled:
|
if self.node_data.structured_output_enabled:
|
||||||
if model_schema.support_structure_output:
|
if model_schema.support_structure_output:
|
||||||
node_data_model.completion_params = self._handle_native_json_schema(
|
completion_params = self._handle_native_json_schema(completion_params, model_schema.parameter_rules)
|
||||||
node_data_model.completion_params, model_schema.parameter_rules
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
# Set appropriate response format based on model capabilities
|
# Set appropriate response format based on model capabilities
|
||||||
self._set_response_format(node_data_model.completion_params, model_schema.parameter_rules)
|
self._set_response_format(completion_params, model_schema.parameter_rules)
|
||||||
|
model_config_with_cred.parameters = completion_params
|
||||||
return model, ModelConfigWithCredentialsEntity(
|
return model, model_config_with_cred
|
||||||
provider=node_data_model.provider,
|
|
||||||
model=node_data_model.name,
|
|
||||||
model_schema=model_schema,
|
|
||||||
mode=node_data_model.mode,
|
|
||||||
provider_model_bundle=model.provider_model_bundle,
|
|
||||||
credentials=model.credentials,
|
|
||||||
parameters=node_data_model.completion_params,
|
|
||||||
stop=stop,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _fetch_memory(
|
|
||||||
self, node_data_memory: Optional[MemoryConfig], model_instance: ModelInstance
|
|
||||||
) -> Optional[TokenBufferMemory]:
|
|
||||||
if not node_data_memory:
|
|
||||||
return None
|
|
||||||
|
|
||||||
# get conversation id
|
|
||||||
conversation_id_variable = self.graph_runtime_state.variable_pool.get(
|
|
||||||
["sys", SystemVariableKey.CONVERSATION_ID.value]
|
|
||||||
)
|
|
||||||
if not isinstance(conversation_id_variable, StringSegment):
|
|
||||||
return None
|
|
||||||
conversation_id = conversation_id_variable.value
|
|
||||||
|
|
||||||
with Session(db.engine, expire_on_commit=False) as session:
|
|
||||||
stmt = select(Conversation).where(Conversation.app_id == self.app_id, Conversation.id == conversation_id)
|
|
||||||
conversation = session.scalar(stmt)
|
|
||||||
if not conversation:
|
|
||||||
return None
|
|
||||||
|
|
||||||
memory = TokenBufferMemory(conversation=conversation, model_instance=model_instance)
|
|
||||||
|
|
||||||
return memory
|
|
||||||
|
|
||||||
def _fetch_prompt_messages(
|
def _fetch_prompt_messages(
|
||||||
self,
|
self,
|
||||||
@@ -810,55 +739,6 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|||||||
structured_output = parsed
|
structured_output = parsed
|
||||||
return structured_output
|
return structured_output
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def deduct_llm_quota(cls, tenant_id: str, model_instance: ModelInstance, usage: LLMUsage) -> None:
|
|
||||||
provider_model_bundle = model_instance.provider_model_bundle
|
|
||||||
provider_configuration = provider_model_bundle.configuration
|
|
||||||
|
|
||||||
if provider_configuration.using_provider_type != ProviderType.SYSTEM:
|
|
||||||
return
|
|
||||||
|
|
||||||
system_configuration = provider_configuration.system_configuration
|
|
||||||
|
|
||||||
quota_unit = None
|
|
||||||
for quota_configuration in system_configuration.quota_configurations:
|
|
||||||
if quota_configuration.quota_type == system_configuration.current_quota_type:
|
|
||||||
quota_unit = quota_configuration.quota_unit
|
|
||||||
|
|
||||||
if quota_configuration.quota_limit == -1:
|
|
||||||
return
|
|
||||||
|
|
||||||
break
|
|
||||||
|
|
||||||
used_quota = None
|
|
||||||
if quota_unit:
|
|
||||||
if quota_unit == QuotaUnit.TOKENS:
|
|
||||||
used_quota = usage.total_tokens
|
|
||||||
elif quota_unit == QuotaUnit.CREDITS:
|
|
||||||
used_quota = dify_config.get_model_credits(model_instance.model)
|
|
||||||
else:
|
|
||||||
used_quota = 1
|
|
||||||
|
|
||||||
if used_quota is not None and system_configuration.current_quota_type is not None:
|
|
||||||
with Session(db.engine) as session:
|
|
||||||
stmt = (
|
|
||||||
update(Provider)
|
|
||||||
.where(
|
|
||||||
Provider.tenant_id == tenant_id,
|
|
||||||
# TODO: Use provider name with prefix after the data migration.
|
|
||||||
Provider.provider_name == ModelProviderID(model_instance.provider).provider_name,
|
|
||||||
Provider.provider_type == ProviderType.SYSTEM.value,
|
|
||||||
Provider.quota_type == system_configuration.current_quota_type.value,
|
|
||||||
Provider.quota_limit > Provider.quota_used,
|
|
||||||
)
|
|
||||||
.values(
|
|
||||||
quota_used=Provider.quota_used + used_quota,
|
|
||||||
last_used=datetime.now(tz=UTC).replace(tzinfo=None),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
session.execute(stmt)
|
|
||||||
session.commit()
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _extract_variable_selector_to_variable_mapping(
|
def _extract_variable_selector_to_variable_mapping(
|
||||||
cls,
|
cls,
|
||||||
|
@@ -28,8 +28,9 @@ from core.prompt.utils.prompt_message_util import PromptMessageUtil
|
|||||||
from core.workflow.entities.node_entities import NodeRunResult
|
from core.workflow.entities.node_entities import NodeRunResult
|
||||||
from core.workflow.entities.variable_pool import VariablePool
|
from core.workflow.entities.variable_pool import VariablePool
|
||||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||||
|
from core.workflow.nodes.base.node import BaseNode
|
||||||
from core.workflow.nodes.enums import NodeType
|
from core.workflow.nodes.enums import NodeType
|
||||||
from core.workflow.nodes.llm import LLMNode, ModelConfig
|
from core.workflow.nodes.llm import ModelConfig, llm_utils
|
||||||
from core.workflow.utils import variable_template_parser
|
from core.workflow.utils import variable_template_parser
|
||||||
|
|
||||||
from .entities import ParameterExtractorNodeData
|
from .entities import ParameterExtractorNodeData
|
||||||
@@ -83,7 +84,7 @@ def extract_json(text):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
class ParameterExtractorNode(LLMNode):
|
class ParameterExtractorNode(BaseNode):
|
||||||
"""
|
"""
|
||||||
Parameter Extractor Node.
|
Parameter Extractor Node.
|
||||||
"""
|
"""
|
||||||
@@ -116,8 +117,11 @@ class ParameterExtractorNode(LLMNode):
|
|||||||
variable = self.graph_runtime_state.variable_pool.get(node_data.query)
|
variable = self.graph_runtime_state.variable_pool.get(node_data.query)
|
||||||
query = variable.text if variable else ""
|
query = variable.text if variable else ""
|
||||||
|
|
||||||
|
variable_pool = self.graph_runtime_state.variable_pool
|
||||||
|
|
||||||
files = (
|
files = (
|
||||||
self._fetch_files(
|
llm_utils.fetch_files(
|
||||||
|
variable_pool=variable_pool,
|
||||||
selector=node_data.vision.configs.variable_selector,
|
selector=node_data.vision.configs.variable_selector,
|
||||||
)
|
)
|
||||||
if node_data.vision.enabled
|
if node_data.vision.enabled
|
||||||
@@ -137,7 +141,9 @@ class ParameterExtractorNode(LLMNode):
|
|||||||
raise ModelSchemaNotFoundError("Model schema not found")
|
raise ModelSchemaNotFoundError("Model schema not found")
|
||||||
|
|
||||||
# fetch memory
|
# fetch memory
|
||||||
memory = self._fetch_memory(
|
memory = llm_utils.fetch_memory(
|
||||||
|
variable_pool=variable_pool,
|
||||||
|
app_id=self.app_id,
|
||||||
node_data_memory=node_data.memory,
|
node_data_memory=node_data.memory,
|
||||||
model_instance=model_instance,
|
model_instance=model_instance,
|
||||||
)
|
)
|
||||||
@@ -279,7 +285,7 @@ class ParameterExtractorNode(LLMNode):
|
|||||||
tool_call = invoke_result.message.tool_calls[0] if invoke_result.message.tool_calls else None
|
tool_call = invoke_result.message.tool_calls[0] if invoke_result.message.tool_calls else None
|
||||||
|
|
||||||
# deduct quota
|
# deduct quota
|
||||||
self.deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage)
|
llm_utils.deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage)
|
||||||
|
|
||||||
if text is None:
|
if text is None:
|
||||||
text = ""
|
text = ""
|
||||||
@@ -794,7 +800,9 @@ class ParameterExtractorNode(LLMNode):
|
|||||||
Fetch model config.
|
Fetch model config.
|
||||||
"""
|
"""
|
||||||
if not self._model_instance or not self._model_config:
|
if not self._model_instance or not self._model_config:
|
||||||
self._model_instance, self._model_config = super()._fetch_model_config(node_data_model)
|
self._model_instance, self._model_config = llm_utils.fetch_model_config(
|
||||||
|
tenant_id=self.tenant_id, node_data_model=node_data_model
|
||||||
|
)
|
||||||
|
|
||||||
return self._model_instance, self._model_config
|
return self._model_instance, self._model_config
|
||||||
|
|
||||||
|
@@ -18,6 +18,7 @@ from core.workflow.nodes.llm import (
|
|||||||
LLMNode,
|
LLMNode,
|
||||||
LLMNodeChatModelMessage,
|
LLMNodeChatModelMessage,
|
||||||
LLMNodeCompletionModelPromptTemplate,
|
LLMNodeCompletionModelPromptTemplate,
|
||||||
|
llm_utils,
|
||||||
)
|
)
|
||||||
from core.workflow.utils.variable_template_parser import VariableTemplateParser
|
from core.workflow.utils.variable_template_parser import VariableTemplateParser
|
||||||
from libs.json_in_md_parser import parse_and_check_json_markdown
|
from libs.json_in_md_parser import parse_and_check_json_markdown
|
||||||
@@ -50,7 +51,9 @@ class QuestionClassifierNode(LLMNode):
|
|||||||
# fetch model config
|
# fetch model config
|
||||||
model_instance, model_config = self._fetch_model_config(node_data.model)
|
model_instance, model_config = self._fetch_model_config(node_data.model)
|
||||||
# fetch memory
|
# fetch memory
|
||||||
memory = self._fetch_memory(
|
memory = llm_utils.fetch_memory(
|
||||||
|
variable_pool=variable_pool,
|
||||||
|
app_id=self.app_id,
|
||||||
node_data_memory=node_data.memory,
|
node_data_memory=node_data.memory,
|
||||||
model_instance=model_instance,
|
model_instance=model_instance,
|
||||||
)
|
)
|
||||||
@@ -59,7 +62,8 @@ class QuestionClassifierNode(LLMNode):
|
|||||||
node_data.instruction = variable_pool.convert_template(node_data.instruction).text
|
node_data.instruction = variable_pool.convert_template(node_data.instruction).text
|
||||||
|
|
||||||
files = (
|
files = (
|
||||||
self._fetch_files(
|
llm_utils.fetch_files(
|
||||||
|
variable_pool=variable_pool,
|
||||||
selector=node_data.vision.configs.variable_selector,
|
selector=node_data.vision.configs.variable_selector,
|
||||||
)
|
)
|
||||||
if node_data.vision.enabled
|
if node_data.vision.enabled
|
||||||
|
@@ -353,7 +353,7 @@ def test_extract_json_from_tool_call():
|
|||||||
assert result["location"] == "kawaii"
|
assert result["location"] == "kawaii"
|
||||||
|
|
||||||
|
|
||||||
def test_chat_parameter_extractor_with_memory(setup_model_mock):
|
def test_chat_parameter_extractor_with_memory(setup_model_mock, monkeypatch):
|
||||||
"""
|
"""
|
||||||
Test chat parameter extractor with memory.
|
Test chat parameter extractor with memory.
|
||||||
"""
|
"""
|
||||||
@@ -384,7 +384,8 @@ def test_chat_parameter_extractor_with_memory(setup_model_mock):
|
|||||||
mode="chat",
|
mode="chat",
|
||||||
credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")},
|
credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")},
|
||||||
)
|
)
|
||||||
node._fetch_memory = get_mocked_fetch_memory("customized memory")
|
# Test the mock before running the actual test
|
||||||
|
monkeypatch.setattr("core.workflow.nodes.llm.llm_utils.fetch_memory", get_mocked_fetch_memory("customized memory"))
|
||||||
db.session.close = MagicMock()
|
db.session.close = MagicMock()
|
||||||
|
|
||||||
result = node._run()
|
result = node._run()
|
||||||
|
@@ -25,6 +25,7 @@ from core.workflow.entities.variable_pool import VariablePool
|
|||||||
from core.workflow.graph_engine import Graph, GraphInitParams, GraphRuntimeState
|
from core.workflow.graph_engine import Graph, GraphInitParams, GraphRuntimeState
|
||||||
from core.workflow.nodes.answer import AnswerStreamGenerateRoute
|
from core.workflow.nodes.answer import AnswerStreamGenerateRoute
|
||||||
from core.workflow.nodes.end import EndStreamParam
|
from core.workflow.nodes.end import EndStreamParam
|
||||||
|
from core.workflow.nodes.llm import llm_utils
|
||||||
from core.workflow.nodes.llm.entities import (
|
from core.workflow.nodes.llm.entities import (
|
||||||
ContextConfig,
|
ContextConfig,
|
||||||
LLMNodeChatModelMessage,
|
LLMNodeChatModelMessage,
|
||||||
@@ -170,7 +171,7 @@ def model_config():
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_fetch_files_with_file_segment(llm_node):
|
def test_fetch_files_with_file_segment():
|
||||||
file = File(
|
file = File(
|
||||||
id="1",
|
id="1",
|
||||||
tenant_id="test",
|
tenant_id="test",
|
||||||
@@ -180,13 +181,14 @@ def test_fetch_files_with_file_segment(llm_node):
|
|||||||
related_id="1",
|
related_id="1",
|
||||||
storage_key="",
|
storage_key="",
|
||||||
)
|
)
|
||||||
llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], file)
|
variable_pool = VariablePool()
|
||||||
|
variable_pool.add(["sys", "files"], file)
|
||||||
|
|
||||||
result = llm_node._fetch_files(selector=["sys", "files"])
|
result = llm_utils.fetch_files(variable_pool=variable_pool, selector=["sys", "files"])
|
||||||
assert result == [file]
|
assert result == [file]
|
||||||
|
|
||||||
|
|
||||||
def test_fetch_files_with_array_file_segment(llm_node):
|
def test_fetch_files_with_array_file_segment():
|
||||||
files = [
|
files = [
|
||||||
File(
|
File(
|
||||||
id="1",
|
id="1",
|
||||||
@@ -207,28 +209,32 @@ def test_fetch_files_with_array_file_segment(llm_node):
|
|||||||
storage_key="",
|
storage_key="",
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], ArrayFileSegment(value=files))
|
variable_pool = VariablePool()
|
||||||
|
variable_pool.add(["sys", "files"], ArrayFileSegment(value=files))
|
||||||
|
|
||||||
result = llm_node._fetch_files(selector=["sys", "files"])
|
result = llm_utils.fetch_files(variable_pool=variable_pool, selector=["sys", "files"])
|
||||||
assert result == files
|
assert result == files
|
||||||
|
|
||||||
|
|
||||||
def test_fetch_files_with_none_segment(llm_node):
|
def test_fetch_files_with_none_segment():
|
||||||
llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], NoneSegment())
|
variable_pool = VariablePool()
|
||||||
|
variable_pool.add(["sys", "files"], NoneSegment())
|
||||||
|
|
||||||
result = llm_node._fetch_files(selector=["sys", "files"])
|
result = llm_utils.fetch_files(variable_pool=variable_pool, selector=["sys", "files"])
|
||||||
assert result == []
|
assert result == []
|
||||||
|
|
||||||
|
|
||||||
def test_fetch_files_with_array_any_segment(llm_node):
|
def test_fetch_files_with_array_any_segment():
|
||||||
llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], ArrayAnySegment(value=[]))
|
variable_pool = VariablePool()
|
||||||
|
variable_pool.add(["sys", "files"], ArrayAnySegment(value=[]))
|
||||||
|
|
||||||
result = llm_node._fetch_files(selector=["sys", "files"])
|
result = llm_utils.fetch_files(variable_pool=variable_pool, selector=["sys", "files"])
|
||||||
assert result == []
|
assert result == []
|
||||||
|
|
||||||
|
|
||||||
def test_fetch_files_with_non_existent_variable(llm_node):
|
def test_fetch_files_with_non_existent_variable():
|
||||||
result = llm_node._fetch_files(selector=["sys", "files"])
|
variable_pool = VariablePool()
|
||||||
|
result = llm_utils.fetch_files(variable_pool=variable_pool, selector=["sys", "files"])
|
||||||
assert result == []
|
assert result == []
|
||||||
|
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user