|
|
|
@@ -1,32 +1,51 @@
|
|
|
|
|
import json
|
|
|
|
|
import logging
|
|
|
|
|
import time
|
|
|
|
|
from collections import defaultdict
|
|
|
|
|
from collections.abc import Mapping, Sequence
|
|
|
|
|
from typing import Any, cast
|
|
|
|
|
from typing import Any, Optional, cast
|
|
|
|
|
|
|
|
|
|
from sqlalchemy import func
|
|
|
|
|
from sqlalchemy import Integer, and_, func, or_, text
|
|
|
|
|
from sqlalchemy import cast as sqlalchemy_cast
|
|
|
|
|
|
|
|
|
|
from core.app.app_config.entities import DatasetRetrieveConfigEntity
|
|
|
|
|
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
|
|
|
|
from core.entities.agent_entities import PlanningStrategy
|
|
|
|
|
from core.entities.model_entities import ModelStatus
|
|
|
|
|
from core.model_manager import ModelInstance, ModelManager
|
|
|
|
|
from core.model_runtime.entities.message_entities import PromptMessageRole
|
|
|
|
|
from core.model_runtime.entities.model_entities import ModelFeature, ModelType
|
|
|
|
|
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
|
|
|
|
from core.prompt.simple_prompt_transform import ModelMode
|
|
|
|
|
from core.rag.datasource.retrieval_service import RetrievalService
|
|
|
|
|
from core.rag.entities.metadata_entities import Condition, MetadataCondition
|
|
|
|
|
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
|
|
|
|
|
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
|
|
|
|
from core.variables import StringSegment
|
|
|
|
|
from core.workflow.entities.node_entities import NodeRunResult
|
|
|
|
|
from core.workflow.nodes.base import BaseNode
|
|
|
|
|
from core.workflow.nodes.enums import NodeType
|
|
|
|
|
from core.workflow.nodes.event.event import ModelInvokeCompletedEvent
|
|
|
|
|
from core.workflow.nodes.knowledge_retrieval.template_prompts import (
|
|
|
|
|
METADATA_FILTER_ASSISTANT_PROMPT_1,
|
|
|
|
|
METADATA_FILTER_ASSISTANT_PROMPT_2,
|
|
|
|
|
METADATA_FILTER_COMPLETION_PROMPT,
|
|
|
|
|
METADATA_FILTER_SYSTEM_PROMPT,
|
|
|
|
|
METADATA_FILTER_USER_PROMPT_1,
|
|
|
|
|
METADATA_FILTER_USER_PROMPT_3,
|
|
|
|
|
)
|
|
|
|
|
from core.workflow.nodes.llm.entities import LLMNodeChatModelMessage, LLMNodeCompletionModelPromptTemplate
|
|
|
|
|
from core.workflow.nodes.llm.node import LLMNode
|
|
|
|
|
from core.workflow.nodes.question_classifier.template_prompts import QUESTION_CLASSIFIER_USER_PROMPT_2
|
|
|
|
|
from extensions.ext_database import db
|
|
|
|
|
from extensions.ext_redis import redis_client
|
|
|
|
|
from models.dataset import Dataset, Document, RateLimitLog
|
|
|
|
|
from libs.json_in_md_parser import parse_and_check_json_markdown
|
|
|
|
|
from models.dataset import Dataset, DatasetMetadata, Document, RateLimitLog
|
|
|
|
|
from models.workflow import WorkflowNodeExecutionStatus
|
|
|
|
|
from services.feature_service import FeatureService
|
|
|
|
|
|
|
|
|
|
from .entities import KnowledgeRetrievalNodeData
|
|
|
|
|
from .entities import KnowledgeRetrievalNodeData, ModelConfig
|
|
|
|
|
from .exc import (
|
|
|
|
|
InvalidModelTypeError,
|
|
|
|
|
KnowledgeRetrievalNodeError,
|
|
|
|
|
ModelCredentialsNotInitializedError,
|
|
|
|
|
ModelNotExistError,
|
|
|
|
@@ -45,13 +64,14 @@ default_retrieval_model = {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class KnowledgeRetrievalNode(BaseNode[KnowledgeRetrievalNodeData]):
|
|
|
|
|
_node_data_cls = KnowledgeRetrievalNodeData
|
|
|
|
|
class KnowledgeRetrievalNode(LLMNode):
|
|
|
|
|
_node_data_cls = KnowledgeRetrievalNodeData # type: ignore
|
|
|
|
|
_node_type = NodeType.KNOWLEDGE_RETRIEVAL
|
|
|
|
|
|
|
|
|
|
def _run(self) -> NodeRunResult:
|
|
|
|
|
def _run(self) -> NodeRunResult: # type: ignore
|
|
|
|
|
node_data = cast(KnowledgeRetrievalNodeData, self.node_data)
|
|
|
|
|
# extract variables
|
|
|
|
|
variable = self.graph_runtime_state.variable_pool.get(self.node_data.query_variable_selector)
|
|
|
|
|
variable = self.graph_runtime_state.variable_pool.get(node_data.query_variable_selector)
|
|
|
|
|
if not isinstance(variable, StringSegment):
|
|
|
|
|
return NodeRunResult(
|
|
|
|
|
status=WorkflowNodeExecutionStatus.FAILED,
|
|
|
|
@@ -91,7 +111,7 @@ class KnowledgeRetrievalNode(BaseNode[KnowledgeRetrievalNodeData]):
|
|
|
|
|
|
|
|
|
|
# retrieve knowledge
|
|
|
|
|
try:
|
|
|
|
|
results = self._fetch_dataset_retriever(node_data=self.node_data, query=query)
|
|
|
|
|
results = self._fetch_dataset_retriever(node_data=node_data, query=query)
|
|
|
|
|
outputs = {"result": results}
|
|
|
|
|
return NodeRunResult(
|
|
|
|
|
status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, process_data=None, outputs=outputs
|
|
|
|
@@ -145,11 +165,14 @@ class KnowledgeRetrievalNode(BaseNode[KnowledgeRetrievalNodeData]):
|
|
|
|
|
if not dataset:
|
|
|
|
|
continue
|
|
|
|
|
available_datasets.append(dataset)
|
|
|
|
|
metadata_filter_document_ids, metadata_condition = self._get_metadata_filter_condition(
|
|
|
|
|
[dataset.id for dataset in available_datasets], query, node_data
|
|
|
|
|
)
|
|
|
|
|
all_documents = []
|
|
|
|
|
dataset_retrieval = DatasetRetrieval()
|
|
|
|
|
if node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE.value:
|
|
|
|
|
# fetch model config
|
|
|
|
|
model_instance, model_config = self._fetch_model_config(node_data)
|
|
|
|
|
model_instance, model_config = self._fetch_model_config(node_data.single_retrieval_config.model) # type: ignore
|
|
|
|
|
# check model is support tool calling
|
|
|
|
|
model_type_instance = model_config.provider_model_bundle.model_type_instance
|
|
|
|
|
model_type_instance = cast(LargeLanguageModel, model_type_instance)
|
|
|
|
@@ -174,6 +197,8 @@ class KnowledgeRetrievalNode(BaseNode[KnowledgeRetrievalNodeData]):
|
|
|
|
|
model_config=model_config,
|
|
|
|
|
model_instance=model_instance,
|
|
|
|
|
planning_strategy=planning_strategy,
|
|
|
|
|
metadata_filter_document_ids=metadata_filter_document_ids,
|
|
|
|
|
metadata_condition=metadata_condition,
|
|
|
|
|
)
|
|
|
|
|
elif node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE.value:
|
|
|
|
|
if node_data.multiple_retrieval_config is None:
|
|
|
|
@@ -220,6 +245,8 @@ class KnowledgeRetrievalNode(BaseNode[KnowledgeRetrievalNodeData]):
|
|
|
|
|
reranking_model=reranking_model,
|
|
|
|
|
weights=weights,
|
|
|
|
|
reranking_enable=node_data.multiple_retrieval_config.reranking_enable,
|
|
|
|
|
metadata_filter_document_ids=metadata_filter_document_ids,
|
|
|
|
|
metadata_condition=metadata_condition,
|
|
|
|
|
)
|
|
|
|
|
dify_documents = [item for item in all_documents if item.provider == "dify"]
|
|
|
|
|
external_documents = [item for item in all_documents if item.provider == "external"]
|
|
|
|
@@ -287,13 +314,187 @@ class KnowledgeRetrievalNode(BaseNode[KnowledgeRetrievalNodeData]):
|
|
|
|
|
item["metadata"]["position"] = position
|
|
|
|
|
return retrieval_resource_list
|
|
|
|
|
|
|
|
|
|
def _get_metadata_filter_condition(
|
|
|
|
|
self, dataset_ids: list, query: str, node_data: KnowledgeRetrievalNodeData
|
|
|
|
|
) -> tuple[Optional[dict[str, list[str]]], Optional[MetadataCondition]]:
|
|
|
|
|
document_query = db.session.query(Document).filter(
|
|
|
|
|
Document.dataset_id.in_(dataset_ids),
|
|
|
|
|
Document.indexing_status == "completed",
|
|
|
|
|
Document.enabled == True,
|
|
|
|
|
Document.archived == False,
|
|
|
|
|
)
|
|
|
|
|
filters = [] # type: ignore
|
|
|
|
|
metadata_condition = None
|
|
|
|
|
if node_data.metadata_filtering_mode == "disabled":
|
|
|
|
|
return None, None
|
|
|
|
|
elif node_data.metadata_filtering_mode == "automatic":
|
|
|
|
|
automatic_metadata_filters = self._automatic_metadata_filter_func(dataset_ids, query, node_data)
|
|
|
|
|
if automatic_metadata_filters:
|
|
|
|
|
conditions = []
|
|
|
|
|
for filter in automatic_metadata_filters:
|
|
|
|
|
self._process_metadata_filter_func(
|
|
|
|
|
filter.get("condition", ""),
|
|
|
|
|
filter.get("metadata_name", ""),
|
|
|
|
|
filter.get("value"),
|
|
|
|
|
filters, # type: ignore
|
|
|
|
|
)
|
|
|
|
|
conditions.append(
|
|
|
|
|
Condition(
|
|
|
|
|
name=filter.get("metadata_name"), # type: ignore
|
|
|
|
|
comparison_operator=filter.get("condition"), # type: ignore
|
|
|
|
|
value=filter.get("value"),
|
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|
metadata_condition = MetadataCondition(
|
|
|
|
|
logical_operator=node_data.metadata_filtering_conditions.logical_operator, # type: ignore
|
|
|
|
|
conditions=conditions,
|
|
|
|
|
)
|
|
|
|
|
elif node_data.metadata_filtering_mode == "manual":
|
|
|
|
|
if node_data.metadata_filtering_conditions:
|
|
|
|
|
metadata_condition = MetadataCondition(**node_data.metadata_filtering_conditions.model_dump())
|
|
|
|
|
if node_data.metadata_filtering_conditions:
|
|
|
|
|
for condition in node_data.metadata_filtering_conditions.conditions: # type: ignore
|
|
|
|
|
metadata_name = condition.name
|
|
|
|
|
expected_value = condition.value
|
|
|
|
|
if expected_value or condition.comparison_operator in ("empty", "not empty"):
|
|
|
|
|
if isinstance(expected_value, str):
|
|
|
|
|
expected_value = self.graph_runtime_state.variable_pool.convert_template(
|
|
|
|
|
expected_value
|
|
|
|
|
).text
|
|
|
|
|
|
|
|
|
|
filters = self._process_metadata_filter_func(
|
|
|
|
|
condition.comparison_operator, metadata_name, expected_value, filters
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
raise ValueError("Invalid metadata filtering mode")
|
|
|
|
|
if filters:
|
|
|
|
|
if node_data.metadata_filtering_conditions.logical_operator == "and": # type: ignore
|
|
|
|
|
document_query = document_query.filter(and_(*filters))
|
|
|
|
|
else:
|
|
|
|
|
document_query = document_query.filter(or_(*filters))
|
|
|
|
|
documents = document_query.all()
|
|
|
|
|
# group by dataset_id
|
|
|
|
|
metadata_filter_document_ids = defaultdict(list) if documents else None # type: ignore
|
|
|
|
|
for document in documents:
|
|
|
|
|
metadata_filter_document_ids[document.dataset_id].append(document.id) # type: ignore
|
|
|
|
|
return metadata_filter_document_ids, metadata_condition
|
|
|
|
|
|
|
|
|
|
def _automatic_metadata_filter_func(
|
|
|
|
|
self, dataset_ids: list, query: str, node_data: KnowledgeRetrievalNodeData
|
|
|
|
|
) -> list[dict[str, Any]]:
|
|
|
|
|
# get all metadata field
|
|
|
|
|
metadata_fields = db.session.query(DatasetMetadata).filter(DatasetMetadata.dataset_id.in_(dataset_ids)).all()
|
|
|
|
|
all_metadata_fields = [metadata_field.field_name for metadata_field in metadata_fields]
|
|
|
|
|
# get metadata model config
|
|
|
|
|
metadata_model_config = node_data.metadata_model_config
|
|
|
|
|
if metadata_model_config is None:
|
|
|
|
|
raise ValueError("metadata_model_config is required")
|
|
|
|
|
# get metadata model instance
|
|
|
|
|
# fetch model config
|
|
|
|
|
model_instance, model_config = self._fetch_model_config(node_data.metadata_model_config) # type: ignore
|
|
|
|
|
# fetch prompt messages
|
|
|
|
|
prompt_template = self._get_prompt_template(
|
|
|
|
|
node_data=node_data,
|
|
|
|
|
metadata_fields=all_metadata_fields,
|
|
|
|
|
query=query or "",
|
|
|
|
|
)
|
|
|
|
|
prompt_messages, stop = self._fetch_prompt_messages(
|
|
|
|
|
prompt_template=prompt_template,
|
|
|
|
|
sys_query=query,
|
|
|
|
|
memory=None,
|
|
|
|
|
model_config=model_config,
|
|
|
|
|
sys_files=[],
|
|
|
|
|
vision_enabled=node_data.vision.enabled,
|
|
|
|
|
vision_detail=node_data.vision.configs.detail,
|
|
|
|
|
variable_pool=self.graph_runtime_state.variable_pool,
|
|
|
|
|
jinja2_variables=[],
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
result_text = ""
|
|
|
|
|
try:
|
|
|
|
|
# handle invoke result
|
|
|
|
|
generator = self._invoke_llm(
|
|
|
|
|
node_data_model=node_data.metadata_model_config, # type: ignore
|
|
|
|
|
model_instance=model_instance,
|
|
|
|
|
prompt_messages=prompt_messages,
|
|
|
|
|
stop=stop,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
for event in generator:
|
|
|
|
|
if isinstance(event, ModelInvokeCompletedEvent):
|
|
|
|
|
result_text = event.text
|
|
|
|
|
break
|
|
|
|
|
|
|
|
|
|
result_text_json = parse_and_check_json_markdown(result_text, [])
|
|
|
|
|
automatic_metadata_filters = []
|
|
|
|
|
if "metadata_map" in result_text_json:
|
|
|
|
|
metadata_map = result_text_json["metadata_map"]
|
|
|
|
|
for item in metadata_map:
|
|
|
|
|
if item.get("metadata_field_name") in all_metadata_fields:
|
|
|
|
|
automatic_metadata_filters.append(
|
|
|
|
|
{
|
|
|
|
|
"metadata_name": item.get("metadata_field_name"),
|
|
|
|
|
"value": item.get("metadata_field_value"),
|
|
|
|
|
"condition": item.get("comparison_operator"),
|
|
|
|
|
}
|
|
|
|
|
)
|
|
|
|
|
except Exception as e:
|
|
|
|
|
return []
|
|
|
|
|
return automatic_metadata_filters
|
|
|
|
|
|
|
|
|
|
def _process_metadata_filter_func(self, condition: str, metadata_name: str, value: Optional[str], filters: list):
|
|
|
|
|
match condition:
|
|
|
|
|
case "contains":
|
|
|
|
|
filters.append(
|
|
|
|
|
(text("documents.doc_metadata ->> :key LIKE :value")).params(key=metadata_name, value=f"%{value}%")
|
|
|
|
|
)
|
|
|
|
|
case "not contains":
|
|
|
|
|
filters.append(
|
|
|
|
|
(text("documents.doc_metadata ->> :key NOT LIKE :value")).params(
|
|
|
|
|
key=metadata_name, value=f"%{value}%"
|
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|
case "start with":
|
|
|
|
|
filters.append(
|
|
|
|
|
(text("documents.doc_metadata ->> :key LIKE :value")).params(key=metadata_name, value=f"{value}%")
|
|
|
|
|
)
|
|
|
|
|
case "end with":
|
|
|
|
|
filters.append(
|
|
|
|
|
(text("documents.doc_metadata ->> :key LIKE :value")).params(key=metadata_name, value=f"%{value}")
|
|
|
|
|
)
|
|
|
|
|
case "=" | "is":
|
|
|
|
|
if isinstance(value, str):
|
|
|
|
|
filters.append(Document.doc_metadata[metadata_name] == f'"{value}"')
|
|
|
|
|
else:
|
|
|
|
|
filters.append(sqlalchemy_cast(Document.doc_metadata[metadata_name].astext, Integer) == value)
|
|
|
|
|
case "is not" | "≠":
|
|
|
|
|
if isinstance(value, str):
|
|
|
|
|
filters.append(Document.doc_metadata[metadata_name] != f'"{value}"')
|
|
|
|
|
else:
|
|
|
|
|
filters.append(sqlalchemy_cast(Document.doc_metadata[metadata_name].astext, Integer) != value)
|
|
|
|
|
case "empty":
|
|
|
|
|
filters.append(Document.doc_metadata[metadata_name].is_(None))
|
|
|
|
|
case "not empty":
|
|
|
|
|
filters.append(Document.doc_metadata[metadata_name].isnot(None))
|
|
|
|
|
case "before" | "<":
|
|
|
|
|
filters.append(sqlalchemy_cast(Document.doc_metadata[metadata_name].astext, Integer) < value)
|
|
|
|
|
case "after" | ">":
|
|
|
|
|
filters.append(sqlalchemy_cast(Document.doc_metadata[metadata_name].astext, Integer) > value)
|
|
|
|
|
case "≤" | ">=":
|
|
|
|
|
filters.append(sqlalchemy_cast(Document.doc_metadata[metadata_name].astext, Integer) <= value)
|
|
|
|
|
case "≥" | ">=":
|
|
|
|
|
filters.append(sqlalchemy_cast(Document.doc_metadata[metadata_name].astext, Integer) >= value)
|
|
|
|
|
case _:
|
|
|
|
|
pass
|
|
|
|
|
return filters
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def _extract_variable_selector_to_variable_mapping(
|
|
|
|
|
cls,
|
|
|
|
|
*,
|
|
|
|
|
graph_config: Mapping[str, Any],
|
|
|
|
|
node_id: str,
|
|
|
|
|
node_data: KnowledgeRetrievalNodeData,
|
|
|
|
|
node_data: KnowledgeRetrievalNodeData, # type: ignore
|
|
|
|
|
) -> Mapping[str, Sequence[str]]:
|
|
|
|
|
"""
|
|
|
|
|
Extract variable selector to variable mapping
|
|
|
|
@@ -306,18 +507,16 @@ class KnowledgeRetrievalNode(BaseNode[KnowledgeRetrievalNodeData]):
|
|
|
|
|
variable_mapping[node_id + ".query"] = node_data.query_variable_selector
|
|
|
|
|
return variable_mapping
|
|
|
|
|
|
|
|
|
|
def _fetch_model_config(
|
|
|
|
|
self, node_data: KnowledgeRetrievalNodeData
|
|
|
|
|
) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]:
|
|
|
|
|
def _fetch_model_config(self, model: ModelConfig) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]: # type: ignore
|
|
|
|
|
"""
|
|
|
|
|
Fetch model config
|
|
|
|
|
:param node_data: node data
|
|
|
|
|
:param model: model
|
|
|
|
|
:return:
|
|
|
|
|
"""
|
|
|
|
|
if node_data.single_retrieval_config is None:
|
|
|
|
|
raise ValueError("single_retrieval_config is required")
|
|
|
|
|
model_name = node_data.single_retrieval_config.model.name
|
|
|
|
|
provider_name = node_data.single_retrieval_config.model.provider
|
|
|
|
|
if model is None:
|
|
|
|
|
raise ValueError("model is required")
|
|
|
|
|
model_name = model.name
|
|
|
|
|
provider_name = model.provider
|
|
|
|
|
|
|
|
|
|
model_manager = ModelManager()
|
|
|
|
|
model_instance = model_manager.get_model_instance(
|
|
|
|
@@ -346,14 +545,14 @@ class KnowledgeRetrievalNode(BaseNode[KnowledgeRetrievalNodeData]):
|
|
|
|
|
raise ModelQuotaExceededError(f"Model provider {provider_name} quota exceeded.")
|
|
|
|
|
|
|
|
|
|
# model config
|
|
|
|
|
completion_params = node_data.single_retrieval_config.model.completion_params
|
|
|
|
|
completion_params = model.completion_params
|
|
|
|
|
stop = []
|
|
|
|
|
if "stop" in completion_params:
|
|
|
|
|
stop = completion_params["stop"]
|
|
|
|
|
del completion_params["stop"]
|
|
|
|
|
|
|
|
|
|
# get model mode
|
|
|
|
|
model_mode = node_data.single_retrieval_config.model.mode
|
|
|
|
|
model_mode = model.mode
|
|
|
|
|
if not model_mode:
|
|
|
|
|
raise ModelNotExistError("LLM mode is required.")
|
|
|
|
|
|
|
|
|
@@ -372,3 +571,50 @@ class KnowledgeRetrievalNode(BaseNode[KnowledgeRetrievalNodeData]):
|
|
|
|
|
parameters=completion_params,
|
|
|
|
|
stop=stop,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def _get_prompt_template(self, node_data: KnowledgeRetrievalNodeData, metadata_fields: list, query: str):
|
|
|
|
|
model_mode = ModelMode.value_of(node_data.metadata_model_config.mode) # type: ignore
|
|
|
|
|
input_text = query
|
|
|
|
|
memory_str = ""
|
|
|
|
|
|
|
|
|
|
prompt_messages: list[LLMNodeChatModelMessage] = []
|
|
|
|
|
if model_mode == ModelMode.CHAT:
|
|
|
|
|
system_prompt_messages = LLMNodeChatModelMessage(
|
|
|
|
|
role=PromptMessageRole.SYSTEM, text=METADATA_FILTER_SYSTEM_PROMPT
|
|
|
|
|
)
|
|
|
|
|
prompt_messages.append(system_prompt_messages)
|
|
|
|
|
user_prompt_message_1 = LLMNodeChatModelMessage(
|
|
|
|
|
role=PromptMessageRole.USER, text=METADATA_FILTER_USER_PROMPT_1
|
|
|
|
|
)
|
|
|
|
|
prompt_messages.append(user_prompt_message_1)
|
|
|
|
|
assistant_prompt_message_1 = LLMNodeChatModelMessage(
|
|
|
|
|
role=PromptMessageRole.ASSISTANT, text=METADATA_FILTER_ASSISTANT_PROMPT_1
|
|
|
|
|
)
|
|
|
|
|
prompt_messages.append(assistant_prompt_message_1)
|
|
|
|
|
user_prompt_message_2 = LLMNodeChatModelMessage(
|
|
|
|
|
role=PromptMessageRole.USER, text=QUESTION_CLASSIFIER_USER_PROMPT_2
|
|
|
|
|
)
|
|
|
|
|
prompt_messages.append(user_prompt_message_2)
|
|
|
|
|
assistant_prompt_message_2 = LLMNodeChatModelMessage(
|
|
|
|
|
role=PromptMessageRole.ASSISTANT, text=METADATA_FILTER_ASSISTANT_PROMPT_2
|
|
|
|
|
)
|
|
|
|
|
prompt_messages.append(assistant_prompt_message_2)
|
|
|
|
|
user_prompt_message_3 = LLMNodeChatModelMessage(
|
|
|
|
|
role=PromptMessageRole.USER,
|
|
|
|
|
text=METADATA_FILTER_USER_PROMPT_3.format(
|
|
|
|
|
input_text=input_text,
|
|
|
|
|
metadata_fields=json.dumps(metadata_fields, ensure_ascii=False),
|
|
|
|
|
),
|
|
|
|
|
)
|
|
|
|
|
prompt_messages.append(user_prompt_message_3)
|
|
|
|
|
return prompt_messages
|
|
|
|
|
elif model_mode == ModelMode.COMPLETION:
|
|
|
|
|
return LLMNodeCompletionModelPromptTemplate(
|
|
|
|
|
text=METADATA_FILTER_COMPLETION_PROMPT.format(
|
|
|
|
|
input_text=input_text,
|
|
|
|
|
metadata_fields=json.dumps(metadata_fields, ensure_ascii=False),
|
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
raise InvalidModelTypeError(f"Model mode {model_mode} not support.")
|
|
|
|
|