Remove langchain dataset retrival agent logic (#3311)
This commit is contained in:
@@ -1,28 +1,21 @@
|
||||
import threading
|
||||
from typing import Any, cast
|
||||
|
||||
from flask import Flask, current_app
|
||||
|
||||
from core.app.app_config.entities import DatasetRetrieveConfigEntity
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
from core.entities.agent_entities import PlanningStrategy
|
||||
from core.entities.model_entities import ModelStatus
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from core.model_manager import ModelInstance, ModelManager
|
||||
from core.model_runtime.entities.message_entities import PromptMessageTool
|
||||
from core.model_runtime.entities.model_entities import ModelFeature, ModelType
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.rag.datasource.retrieval_service import RetrievalService
|
||||
from core.rerank.rerank import RerankRunner
|
||||
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
from core.workflow.entities.node_entities import NodeRunResult, NodeType
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.nodes.base_node import BaseNode
|
||||
from core.workflow.nodes.knowledge_retrieval.entities import KnowledgeRetrievalNodeData
|
||||
from core.workflow.nodes.knowledge_retrieval.multi_dataset_function_call_router import FunctionCallMultiDatasetRouter
|
||||
from core.workflow.nodes.knowledge_retrieval.multi_dataset_react_route import ReactMultiDatasetRouter
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset, DatasetQuery, Document, DocumentSegment
|
||||
from models.dataset import Dataset, Document, DocumentSegment
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
default_retrieval_model = {
|
||||
@@ -106,10 +99,45 @@ class KnowledgeRetrievalNode(BaseNode):
|
||||
|
||||
available_datasets.append(dataset)
|
||||
all_documents = []
|
||||
dataset_retrieval = DatasetRetrieval()
|
||||
if node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE.value:
|
||||
all_documents = self._single_retrieve(available_datasets, node_data, query)
|
||||
# fetch model config
|
||||
model_instance, model_config = self._fetch_model_config(node_data)
|
||||
# check model is support tool calling
|
||||
model_type_instance = model_config.provider_model_bundle.model_type_instance
|
||||
model_type_instance = cast(LargeLanguageModel, model_type_instance)
|
||||
# get model schema
|
||||
model_schema = model_type_instance.get_model_schema(
|
||||
model=model_config.model,
|
||||
credentials=model_config.credentials
|
||||
)
|
||||
|
||||
if model_schema:
|
||||
planning_strategy = PlanningStrategy.REACT_ROUTER
|
||||
features = model_schema.features
|
||||
if features:
|
||||
if ModelFeature.TOOL_CALL in features \
|
||||
or ModelFeature.MULTI_TOOL_CALL in features:
|
||||
planning_strategy = PlanningStrategy.ROUTER
|
||||
all_documents = dataset_retrieval.single_retrieve(
|
||||
available_datasets=available_datasets,
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=self.user_id,
|
||||
app_id=self.app_id,
|
||||
user_from=self.user_from.value,
|
||||
query=query,
|
||||
model_config=model_config,
|
||||
model_instance=model_instance,
|
||||
planning_strategy=planning_strategy
|
||||
)
|
||||
elif node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE.value:
|
||||
all_documents = self._multiple_retrieve(available_datasets, node_data, query)
|
||||
all_documents = dataset_retrieval.multiple_retrieve(self.app_id, self.tenant_id, self.user_id,
|
||||
self.user_from.value,
|
||||
available_datasets, query,
|
||||
node_data.multiple_retrieval_config.top_k,
|
||||
node_data.multiple_retrieval_config.score_threshold,
|
||||
node_data.multiple_retrieval_config.reranking_model.provider,
|
||||
node_data.multiple_retrieval_config.reranking_model.model)
|
||||
|
||||
context_list = []
|
||||
if all_documents:
|
||||
@@ -184,87 +212,6 @@ class KnowledgeRetrievalNode(BaseNode):
|
||||
variable_mapping['query'] = node_data.query_variable_selector
|
||||
return variable_mapping
|
||||
|
||||
def _single_retrieve(self, available_datasets, node_data, query):
|
||||
tools = []
|
||||
for dataset in available_datasets:
|
||||
description = dataset.description
|
||||
if not description:
|
||||
description = 'useful for when you want to answer queries about the ' + dataset.name
|
||||
|
||||
description = description.replace('\n', '').replace('\r', '')
|
||||
message_tool = PromptMessageTool(
|
||||
name=dataset.id,
|
||||
description=description,
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": [],
|
||||
}
|
||||
)
|
||||
tools.append(message_tool)
|
||||
# fetch model config
|
||||
model_instance, model_config = self._fetch_model_config(node_data)
|
||||
# check model is support tool calling
|
||||
model_type_instance = model_config.provider_model_bundle.model_type_instance
|
||||
model_type_instance = cast(LargeLanguageModel, model_type_instance)
|
||||
# get model schema
|
||||
model_schema = model_type_instance.get_model_schema(
|
||||
model=model_config.model,
|
||||
credentials=model_config.credentials
|
||||
)
|
||||
|
||||
if not model_schema:
|
||||
return None
|
||||
planning_strategy = PlanningStrategy.REACT_ROUTER
|
||||
features = model_schema.features
|
||||
if features:
|
||||
if ModelFeature.TOOL_CALL in features \
|
||||
or ModelFeature.MULTI_TOOL_CALL in features:
|
||||
planning_strategy = PlanningStrategy.ROUTER
|
||||
dataset_id = None
|
||||
if planning_strategy == PlanningStrategy.REACT_ROUTER:
|
||||
react_multi_dataset_router = ReactMultiDatasetRouter()
|
||||
dataset_id = react_multi_dataset_router.invoke(query, tools, node_data, model_config, model_instance,
|
||||
self.user_id, self.tenant_id)
|
||||
|
||||
elif planning_strategy == PlanningStrategy.ROUTER:
|
||||
function_call_router = FunctionCallMultiDatasetRouter()
|
||||
dataset_id = function_call_router.invoke(query, tools, model_config, model_instance)
|
||||
if dataset_id:
|
||||
# get retrieval model config
|
||||
dataset = db.session.query(Dataset).filter(
|
||||
Dataset.id == dataset_id
|
||||
).first()
|
||||
if dataset:
|
||||
retrieval_model_config = dataset.retrieval_model \
|
||||
if dataset.retrieval_model else default_retrieval_model
|
||||
|
||||
# get top k
|
||||
top_k = retrieval_model_config['top_k']
|
||||
# get retrieval method
|
||||
if dataset.indexing_technique == "economy":
|
||||
retrival_method = 'keyword_search'
|
||||
else:
|
||||
retrival_method = retrieval_model_config['search_method']
|
||||
# get reranking model
|
||||
reranking_model=retrieval_model_config['reranking_model'] \
|
||||
if retrieval_model_config['reranking_enable'] else None
|
||||
# get score threshold
|
||||
score_threshold = .0
|
||||
score_threshold_enabled = retrieval_model_config.get("score_threshold_enabled")
|
||||
if score_threshold_enabled:
|
||||
score_threshold = retrieval_model_config.get("score_threshold")
|
||||
|
||||
results = RetrievalService.retrieve(retrival_method=retrival_method, dataset_id=dataset.id,
|
||||
query=query,
|
||||
top_k=top_k, score_threshold=score_threshold,
|
||||
reranking_model=reranking_model)
|
||||
self._on_query(query, [dataset_id])
|
||||
if results:
|
||||
self._on_retrival_end(results)
|
||||
return results
|
||||
return []
|
||||
|
||||
def _fetch_model_config(self, node_data: KnowledgeRetrievalNodeData) -> tuple[
|
||||
ModelInstance, ModelConfigWithCredentialsEntity]:
|
||||
"""
|
||||
@@ -335,112 +282,3 @@ class KnowledgeRetrievalNode(BaseNode):
|
||||
parameters=completion_params,
|
||||
stop=stop,
|
||||
)
|
||||
|
||||
def _multiple_retrieve(self, available_datasets, node_data, query):
|
||||
threads = []
|
||||
all_documents = []
|
||||
dataset_ids = [dataset.id for dataset in available_datasets]
|
||||
for dataset in available_datasets:
|
||||
retrieval_thread = threading.Thread(target=self._retriever, kwargs={
|
||||
'flask_app': current_app._get_current_object(),
|
||||
'dataset_id': dataset.id,
|
||||
'query': query,
|
||||
'top_k': node_data.multiple_retrieval_config.top_k,
|
||||
'all_documents': all_documents,
|
||||
})
|
||||
threads.append(retrieval_thread)
|
||||
retrieval_thread.start()
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
# do rerank for searched documents
|
||||
model_manager = ModelManager()
|
||||
rerank_model_instance = model_manager.get_model_instance(
|
||||
tenant_id=self.tenant_id,
|
||||
provider=node_data.multiple_retrieval_config.reranking_model.provider,
|
||||
model_type=ModelType.RERANK,
|
||||
model=node_data.multiple_retrieval_config.reranking_model.model
|
||||
)
|
||||
|
||||
rerank_runner = RerankRunner(rerank_model_instance)
|
||||
all_documents = rerank_runner.run(query, all_documents,
|
||||
node_data.multiple_retrieval_config.score_threshold,
|
||||
node_data.multiple_retrieval_config.top_k)
|
||||
self._on_query(query, dataset_ids)
|
||||
if all_documents:
|
||||
self._on_retrival_end(all_documents)
|
||||
return all_documents
|
||||
|
||||
def _on_retrival_end(self, documents: list[Document]) -> None:
|
||||
"""Handle retrival end."""
|
||||
for document in documents:
|
||||
query = db.session.query(DocumentSegment).filter(
|
||||
DocumentSegment.index_node_id == document.metadata['doc_id']
|
||||
)
|
||||
|
||||
# if 'dataset_id' in document.metadata:
|
||||
if 'dataset_id' in document.metadata:
|
||||
query = query.filter(DocumentSegment.dataset_id == document.metadata['dataset_id'])
|
||||
|
||||
# add hit count to document segment
|
||||
query.update(
|
||||
{DocumentSegment.hit_count: DocumentSegment.hit_count + 1},
|
||||
synchronize_session=False
|
||||
)
|
||||
|
||||
db.session.commit()
|
||||
|
||||
def _on_query(self, query: str, dataset_ids: list[str]) -> None:
|
||||
"""
|
||||
Handle query.
|
||||
"""
|
||||
if not query:
|
||||
return
|
||||
for dataset_id in dataset_ids:
|
||||
dataset_query = DatasetQuery(
|
||||
dataset_id=dataset_id,
|
||||
content=query,
|
||||
source='app',
|
||||
source_app_id=self.app_id,
|
||||
created_by_role=self.user_from.value,
|
||||
created_by=self.user_id
|
||||
)
|
||||
db.session.add(dataset_query)
|
||||
db.session.commit()
|
||||
|
||||
def _retriever(self, flask_app: Flask, dataset_id: str, query: str, top_k: int, all_documents: list):
|
||||
with flask_app.app_context():
|
||||
dataset = db.session.query(Dataset).filter(
|
||||
Dataset.tenant_id == self.tenant_id,
|
||||
Dataset.id == dataset_id
|
||||
).first()
|
||||
|
||||
if not dataset:
|
||||
return []
|
||||
|
||||
# get retrieval model , if the model is not setting , using default
|
||||
retrieval_model = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model
|
||||
|
||||
if dataset.indexing_technique == "economy":
|
||||
# use keyword table query
|
||||
documents = RetrievalService.retrieve(retrival_method='keyword_search',
|
||||
dataset_id=dataset.id,
|
||||
query=query,
|
||||
top_k=top_k
|
||||
)
|
||||
if documents:
|
||||
all_documents.extend(documents)
|
||||
else:
|
||||
if top_k > 0:
|
||||
# retrieval source
|
||||
documents = RetrievalService.retrieve(retrival_method=retrieval_model['search_method'],
|
||||
dataset_id=dataset.id,
|
||||
query=query,
|
||||
top_k=top_k,
|
||||
score_threshold=retrieval_model['score_threshold']
|
||||
if retrieval_model['score_threshold_enabled'] else None,
|
||||
reranking_model=retrieval_model['reranking_model']
|
||||
if retrieval_model['reranking_enable'] else None
|
||||
)
|
||||
|
||||
all_documents.extend(documents)
|
||||
|
||||
|
@@ -1,47 +0,0 @@
|
||||
from typing import Union
|
||||
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities.message_entities import PromptMessageTool, SystemPromptMessage, UserPromptMessage
|
||||
|
||||
|
||||
class FunctionCallMultiDatasetRouter:
|
||||
|
||||
def invoke(
|
||||
self,
|
||||
query: str,
|
||||
dataset_tools: list[PromptMessageTool],
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
model_instance: ModelInstance,
|
||||
|
||||
) -> Union[str, None]:
|
||||
"""Given input, decided what to do.
|
||||
Returns:
|
||||
Action specifying what tool to use.
|
||||
"""
|
||||
if len(dataset_tools) == 0:
|
||||
return None
|
||||
elif len(dataset_tools) == 1:
|
||||
return dataset_tools[0].name
|
||||
|
||||
try:
|
||||
prompt_messages = [
|
||||
SystemPromptMessage(content='You are a helpful AI assistant.'),
|
||||
UserPromptMessage(content=query)
|
||||
]
|
||||
result = model_instance.invoke_llm(
|
||||
prompt_messages=prompt_messages,
|
||||
tools=dataset_tools,
|
||||
stream=False,
|
||||
model_parameters={
|
||||
'temperature': 0.2,
|
||||
'top_p': 0.3,
|
||||
'max_tokens': 1500
|
||||
}
|
||||
)
|
||||
if result.message.tool_calls:
|
||||
# get retrieval model config
|
||||
return result.message.tool_calls[0].function.name
|
||||
return None
|
||||
except Exception as e:
|
||||
return None
|
@@ -1,254 +0,0 @@
|
||||
from collections.abc import Generator, Sequence
|
||||
from typing import Optional, Union
|
||||
|
||||
from langchain import PromptTemplate
|
||||
from langchain.agents.structured_chat.base import HUMAN_MESSAGE_TEMPLATE
|
||||
from langchain.agents.structured_chat.prompt import PREFIX, SUFFIX
|
||||
from langchain.schema import AgentAction
|
||||
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole, PromptMessageTool
|
||||
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
|
||||
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage
|
||||
from core.rag.retrieval.agent.output_parser.structured_chat import StructuredChatOutputParser
|
||||
from core.workflow.nodes.knowledge_retrieval.entities import KnowledgeRetrievalNodeData
|
||||
from core.workflow.nodes.llm.llm_node import LLMNode
|
||||
|
||||
FORMAT_INSTRUCTIONS = """Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).
|
||||
The nouns in the format of "Thought", "Action", "Action Input", "Final Answer" must be expressed in English.
|
||||
Valid "action" values: "Final Answer" or {tool_names}
|
||||
|
||||
Provide only ONE action per $JSON_BLOB, as shown:
|
||||
|
||||
```
|
||||
{{
|
||||
"action": $TOOL_NAME,
|
||||
"action_input": $INPUT
|
||||
}}
|
||||
```
|
||||
|
||||
Follow this format:
|
||||
|
||||
Question: input question to answer
|
||||
Thought: consider previous and subsequent steps
|
||||
Action:
|
||||
```
|
||||
$JSON_BLOB
|
||||
```
|
||||
Observation: action result
|
||||
... (repeat Thought/Action/Observation N times)
|
||||
Thought: I know what to respond
|
||||
Action:
|
||||
```
|
||||
{{
|
||||
"action": "Final Answer",
|
||||
"action_input": "Final response to human"
|
||||
}}
|
||||
```"""
|
||||
|
||||
|
||||
class ReactMultiDatasetRouter:
|
||||
|
||||
def invoke(
|
||||
self,
|
||||
query: str,
|
||||
dataset_tools: list[PromptMessageTool],
|
||||
node_data: KnowledgeRetrievalNodeData,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
model_instance: ModelInstance,
|
||||
user_id: str,
|
||||
tenant_id: str,
|
||||
|
||||
) -> Union[str, None]:
|
||||
"""Given input, decided what to do.
|
||||
Returns:
|
||||
Action specifying what tool to use.
|
||||
"""
|
||||
if len(dataset_tools) == 0:
|
||||
return None
|
||||
elif len(dataset_tools) == 1:
|
||||
return dataset_tools[0].name
|
||||
|
||||
try:
|
||||
return self._react_invoke(query=query, node_data=node_data, model_config=model_config, model_instance=model_instance,
|
||||
tools=dataset_tools, user_id=user_id, tenant_id=tenant_id)
|
||||
except Exception as e:
|
||||
return None
|
||||
|
||||
def _react_invoke(
|
||||
self,
|
||||
query: str,
|
||||
node_data: KnowledgeRetrievalNodeData,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
model_instance: ModelInstance,
|
||||
tools: Sequence[PromptMessageTool],
|
||||
user_id: str,
|
||||
tenant_id: str,
|
||||
prefix: str = PREFIX,
|
||||
suffix: str = SUFFIX,
|
||||
human_message_template: str = HUMAN_MESSAGE_TEMPLATE,
|
||||
format_instructions: str = FORMAT_INSTRUCTIONS,
|
||||
) -> Union[str, None]:
|
||||
if model_config.mode == "chat":
|
||||
prompt = self.create_chat_prompt(
|
||||
query=query,
|
||||
tools=tools,
|
||||
prefix=prefix,
|
||||
suffix=suffix,
|
||||
human_message_template=human_message_template,
|
||||
format_instructions=format_instructions,
|
||||
)
|
||||
else:
|
||||
prompt = self.create_completion_prompt(
|
||||
tools=tools,
|
||||
prefix=prefix,
|
||||
format_instructions=format_instructions,
|
||||
input_variables=None
|
||||
)
|
||||
stop = ['Observation:']
|
||||
# handle invoke result
|
||||
prompt_transform = AdvancedPromptTransform()
|
||||
prompt_messages = prompt_transform.get_prompt(
|
||||
prompt_template=prompt,
|
||||
inputs={},
|
||||
query='',
|
||||
files=[],
|
||||
context='',
|
||||
memory_config=None,
|
||||
memory=None,
|
||||
model_config=model_config
|
||||
)
|
||||
result_text, usage = self._invoke_llm(
|
||||
node_data=node_data,
|
||||
model_instance=model_instance,
|
||||
prompt_messages=prompt_messages,
|
||||
stop=stop,
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id
|
||||
)
|
||||
output_parser = StructuredChatOutputParser()
|
||||
agent_decision = output_parser.parse(result_text)
|
||||
if isinstance(agent_decision, AgentAction):
|
||||
return agent_decision.tool
|
||||
return None
|
||||
|
||||
def _invoke_llm(self, node_data: KnowledgeRetrievalNodeData,
|
||||
model_instance: ModelInstance,
|
||||
prompt_messages: list[PromptMessage],
|
||||
stop: list[str], user_id: str, tenant_id: str) -> tuple[str, LLMUsage]:
|
||||
"""
|
||||
Invoke large language model
|
||||
:param node_data: node data
|
||||
:param model_instance: model instance
|
||||
:param prompt_messages: prompt messages
|
||||
:param stop: stop
|
||||
:return:
|
||||
"""
|
||||
invoke_result = model_instance.invoke_llm(
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=node_data.single_retrieval_config.model.completion_params,
|
||||
stop=stop,
|
||||
stream=True,
|
||||
user=user_id,
|
||||
)
|
||||
|
||||
# handle invoke result
|
||||
text, usage = self._handle_invoke_result(
|
||||
invoke_result=invoke_result
|
||||
)
|
||||
|
||||
# deduct quota
|
||||
LLMNode.deduct_llm_quota(tenant_id=tenant_id, model_instance=model_instance, usage=usage)
|
||||
|
||||
return text, usage
|
||||
|
||||
def _handle_invoke_result(self, invoke_result: Generator) -> tuple[str, LLMUsage]:
|
||||
"""
|
||||
Handle invoke result
|
||||
:param invoke_result: invoke result
|
||||
:return:
|
||||
"""
|
||||
model = None
|
||||
prompt_messages = []
|
||||
full_text = ''
|
||||
usage = None
|
||||
for result in invoke_result:
|
||||
text = result.delta.message.content
|
||||
full_text += text
|
||||
|
||||
if not model:
|
||||
model = result.model
|
||||
|
||||
if not prompt_messages:
|
||||
prompt_messages = result.prompt_messages
|
||||
|
||||
if not usage and result.delta.usage:
|
||||
usage = result.delta.usage
|
||||
|
||||
if not usage:
|
||||
usage = LLMUsage.empty_usage()
|
||||
|
||||
return full_text, usage
|
||||
|
||||
def create_chat_prompt(
|
||||
self,
|
||||
query: str,
|
||||
tools: Sequence[PromptMessageTool],
|
||||
prefix: str = PREFIX,
|
||||
suffix: str = SUFFIX,
|
||||
human_message_template: str = HUMAN_MESSAGE_TEMPLATE,
|
||||
format_instructions: str = FORMAT_INSTRUCTIONS,
|
||||
) -> list[ChatModelMessage]:
|
||||
tool_strings = []
|
||||
for tool in tools:
|
||||
tool_strings.append(f"{tool.name}: {tool.description}, args: {{'query': {{'title': 'Query', 'description': 'Query for the dataset to be used to retrieve the dataset.', 'type': 'string'}}}}")
|
||||
formatted_tools = "\n".join(tool_strings)
|
||||
unique_tool_names = set(tool.name for tool in tools)
|
||||
tool_names = ", ".join('"' + name + '"' for name in unique_tool_names)
|
||||
format_instructions = format_instructions.format(tool_names=tool_names)
|
||||
template = "\n\n".join([prefix, formatted_tools, format_instructions, suffix])
|
||||
prompt_messages = []
|
||||
system_prompt_messages = ChatModelMessage(
|
||||
role=PromptMessageRole.SYSTEM,
|
||||
text=template
|
||||
)
|
||||
prompt_messages.append(system_prompt_messages)
|
||||
user_prompt_message = ChatModelMessage(
|
||||
role=PromptMessageRole.USER,
|
||||
text=query
|
||||
)
|
||||
prompt_messages.append(user_prompt_message)
|
||||
return prompt_messages
|
||||
|
||||
def create_completion_prompt(
|
||||
self,
|
||||
tools: Sequence[PromptMessageTool],
|
||||
prefix: str = PREFIX,
|
||||
format_instructions: str = FORMAT_INSTRUCTIONS,
|
||||
input_variables: Optional[list[str]] = None,
|
||||
) -> PromptTemplate:
|
||||
"""Create prompt in the style of the zero shot agent.
|
||||
|
||||
Args:
|
||||
tools: List of tools the agent will have access to, used to format the
|
||||
prompt.
|
||||
prefix: String to put before the list of tools.
|
||||
input_variables: List of input variables the final prompt will expect.
|
||||
|
||||
Returns:
|
||||
A PromptTemplate with the template assembled from the pieces here.
|
||||
"""
|
||||
suffix = """Begin! Reminder to ALWAYS respond with a valid json blob of a single action. Use tools if necessary. Respond directly if appropriate. Format is Action:```$JSON_BLOB```then Observation:.
|
||||
Question: {input}
|
||||
Thought: {agent_scratchpad}
|
||||
"""
|
||||
|
||||
tool_strings = "\n".join([f"{tool.name}: {tool.description}" for tool in tools])
|
||||
tool_names = ", ".join([tool.name for tool in tools])
|
||||
format_instructions = format_instructions.format(tool_names=tool_names)
|
||||
template = "\n\n".join([prefix, tool_strings, format_instructions, suffix])
|
||||
if input_variables is None:
|
||||
input_variables = ["input", "agent_scratchpad"]
|
||||
return PromptTemplate(template=template, input_variables=input_variables)
|
Reference in New Issue
Block a user