Modify params for bedrock retrieve generate (#13182)
This commit is contained in:
@@ -1,114 +0,0 @@
|
||||
"""
|
||||
Configuration classes for AWS Bedrock retrieve and generate API
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Literal, Optional
|
||||
|
||||
|
||||
@dataclass
|
||||
class TextInferenceConfig:
|
||||
"""Text inference configuration"""
|
||||
|
||||
maxTokens: Optional[int] = None
|
||||
stopSequences: Optional[list[str]] = None
|
||||
temperature: Optional[float] = None
|
||||
topP: Optional[float] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class PerformanceConfig:
|
||||
"""Performance configuration"""
|
||||
|
||||
latency: Literal["standard", "optimized"]
|
||||
|
||||
|
||||
@dataclass
|
||||
class PromptTemplate:
|
||||
"""Prompt template configuration"""
|
||||
|
||||
textPromptTemplate: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class GuardrailConfig:
|
||||
"""Guardrail configuration"""
|
||||
|
||||
guardrailId: str
|
||||
guardrailVersion: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class GenerationConfig:
|
||||
"""Generation configuration"""
|
||||
|
||||
additionalModelRequestFields: Optional[dict[str, Any]] = None
|
||||
guardrailConfiguration: Optional[GuardrailConfig] = None
|
||||
inferenceConfig: Optional[dict[str, TextInferenceConfig]] = None
|
||||
performanceConfig: Optional[PerformanceConfig] = None
|
||||
promptTemplate: Optional[PromptTemplate] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class VectorSearchConfig:
|
||||
"""Vector search configuration"""
|
||||
|
||||
filter: Optional[dict[str, Any]] = None
|
||||
numberOfResults: Optional[int] = None
|
||||
overrideSearchType: Optional[Literal["HYBRID", "SEMANTIC"]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class RetrievalConfig:
|
||||
"""Retrieval configuration"""
|
||||
|
||||
vectorSearchConfiguration: VectorSearchConfig
|
||||
|
||||
|
||||
@dataclass
|
||||
class OrchestrationConfig:
|
||||
"""Orchestration configuration"""
|
||||
|
||||
additionalModelRequestFields: Optional[dict[str, Any]] = None
|
||||
inferenceConfig: Optional[dict[str, TextInferenceConfig]] = None
|
||||
performanceConfig: Optional[PerformanceConfig] = None
|
||||
promptTemplate: Optional[PromptTemplate] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class KnowledgeBaseConfig:
|
||||
"""Knowledge base configuration"""
|
||||
|
||||
generationConfiguration: GenerationConfig
|
||||
knowledgeBaseId: str
|
||||
modelArn: str
|
||||
orchestrationConfiguration: Optional[OrchestrationConfig] = None
|
||||
retrievalConfiguration: Optional[RetrievalConfig] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class SessionConfig:
|
||||
"""Session configuration"""
|
||||
|
||||
kmsKeyArn: Optional[str] = None
|
||||
sessionId: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class RetrieveAndGenerateConfiguration:
|
||||
"""Retrieve and generate configuration
|
||||
The use of knowledgeBaseConfiguration or externalSourcesConfiguration depends on the type value
|
||||
"""
|
||||
|
||||
type: str = "KNOWLEDGE_BASE"
|
||||
knowledgeBaseConfiguration: Optional[KnowledgeBaseConfig] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class RetrieveAndGenerateConfig:
|
||||
"""Retrieve and generate main configuration"""
|
||||
|
||||
input: dict[str, str]
|
||||
retrieveAndGenerateConfiguration: RetrieveAndGenerateConfiguration
|
||||
sessionConfiguration: Optional[SessionConfig] = None
|
||||
sessionId: Optional[str] = None
|
@@ -77,15 +77,27 @@ class BedrockRetrieveTool(BuiltinTool):
|
||||
"""
|
||||
invoke tools
|
||||
"""
|
||||
line = 0
|
||||
try:
|
||||
line = 0
|
||||
# Initialize Bedrock client if not already initialized
|
||||
if not self.bedrock_client:
|
||||
aws_region = tool_parameters.get("aws_region")
|
||||
if aws_region:
|
||||
self.bedrock_client = boto3.client("bedrock-agent-runtime", region_name=aws_region)
|
||||
else:
|
||||
self.bedrock_client = boto3.client("bedrock-agent-runtime")
|
||||
aws_access_key_id = tool_parameters.get("aws_access_key_id")
|
||||
aws_secret_access_key = tool_parameters.get("aws_secret_access_key")
|
||||
|
||||
client_kwargs = {"service_name": "bedrock-agent-runtime", "region_name": aws_region or None}
|
||||
|
||||
# Only add credentials if both access key and secret key are provided
|
||||
if aws_access_key_id and aws_secret_access_key:
|
||||
client_kwargs.update(
|
||||
{"aws_access_key_id": aws_access_key_id, "aws_secret_access_key": aws_secret_access_key}
|
||||
)
|
||||
|
||||
self.bedrock_client = boto3.client(**client_kwargs)
|
||||
except Exception as e:
|
||||
return self.create_text_message(f"Failed to initialize Bedrock client: {str(e)}")
|
||||
|
||||
try:
|
||||
line = 1
|
||||
if not self.knowledge_base_id:
|
||||
self.knowledge_base_id = tool_parameters.get("knowledge_base_id")
|
||||
@@ -123,7 +135,14 @@ class BedrockRetrieveTool(BuiltinTool):
|
||||
sorted_docs = sorted(retrieved_docs, key=operator.itemgetter("score"), reverse=True)
|
||||
|
||||
line = 6
|
||||
return [self.create_json_message(res) for res in sorted_docs]
|
||||
result_type = tool_parameters.get("result_type")
|
||||
if result_type == "json":
|
||||
return [self.create_json_message(res) for res in sorted_docs]
|
||||
else:
|
||||
text = ""
|
||||
for i, res in enumerate(sorted_docs):
|
||||
text += f"{i + 1}: {res['content']}\n"
|
||||
return self.create_text_message(text)
|
||||
|
||||
except Exception as e:
|
||||
return self.create_text_message(f"Exception {str(e)}, line : {line}")
|
||||
@@ -138,7 +157,6 @@ class BedrockRetrieveTool(BuiltinTool):
|
||||
if not parameters.get("query"):
|
||||
raise ValueError("query is required")
|
||||
|
||||
# 可选:可以验证元数据过滤条件是否为有效的 JSON 字符串(如果提供)
|
||||
metadata_filter_str = parameters.get("metadata_filter")
|
||||
if metadata_filter_str and not isinstance(json.loads(metadata_filter_str), dict):
|
||||
raise ValueError("metadata_filter must be a valid JSON object")
|
||||
|
@@ -15,6 +15,60 @@ description:
|
||||
llm: A tool for retrieving relevant information from Amazon Bedrock Knowledge Base. You can find deploy instructions on Github Repo - https://github.com/aws-samples/dify-aws-tool
|
||||
|
||||
parameters:
|
||||
- name: aws_region
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: AWS Region
|
||||
zh_Hans: AWS区域
|
||||
human_description:
|
||||
en_US: AWS region for the Bedrock service
|
||||
zh_Hans: Bedrock服务的AWS区域
|
||||
form: form
|
||||
|
||||
- name: aws_access_key_id
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: AWS Access Key ID
|
||||
zh_Hans: AWS访问密钥ID
|
||||
human_description:
|
||||
en_US: AWS access key ID for authentication (optional)
|
||||
zh_Hans: 用于身份验证的AWS访问密钥ID(可选)
|
||||
form: form
|
||||
|
||||
- name: aws_secret_access_key
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: AWS Secret Access Key
|
||||
zh_Hans: AWS秘密访问密钥
|
||||
human_description:
|
||||
en_US: AWS secret access key for authentication (optional)
|
||||
zh_Hans: 用于身份验证的AWS秘密访问密钥(可选)
|
||||
form: form
|
||||
|
||||
- name: result_type
|
||||
type: select
|
||||
required: true
|
||||
label:
|
||||
en_US: result type
|
||||
zh_Hans: 结果类型
|
||||
human_description:
|
||||
en_US: return a list of json or texts
|
||||
zh_Hans: 返回一个列表,内容是json还是纯文本
|
||||
default: text
|
||||
options:
|
||||
- value: json
|
||||
label:
|
||||
en_US: JSON
|
||||
zh_Hans: JSON
|
||||
- value: text
|
||||
label:
|
||||
en_US: Text
|
||||
zh_Hans: 文本
|
||||
form: form
|
||||
|
||||
- name: knowledge_base_id
|
||||
type: string
|
||||
required: true
|
||||
@@ -95,6 +149,7 @@ parameters:
|
||||
zh_Hans: 重拍模型ID
|
||||
pt_BR: rerank model id
|
||||
llm_description: rerank model id
|
||||
default: default
|
||||
options:
|
||||
- value: default
|
||||
label:
|
||||
@@ -110,20 +165,6 @@ parameters:
|
||||
zh_Hans: amazon.rerank-v1:0
|
||||
form: form
|
||||
|
||||
- name: aws_region
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: AWS Region
|
||||
zh_Hans: AWS 区域
|
||||
pt_BR: AWS Region
|
||||
human_description:
|
||||
en_US: AWS region where the Bedrock Knowledge Base is located
|
||||
zh_Hans: Bedrock知识库所在的AWS区域
|
||||
pt_BR: AWS region where the Bedrock Knowledge Base is located
|
||||
llm_description: AWS region where the Bedrock Knowledge Base is located
|
||||
form: form
|
||||
|
||||
- name: metadata_filter # Additional parameter for metadata filtering
|
||||
type: string # String type, expects JSON-formatted filter conditions
|
||||
required: false # Optional field - can be omitted
|
||||
|
@@ -1,5 +1,5 @@
|
||||
import json
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
import boto3
|
||||
|
||||
@@ -10,193 +10,63 @@ from core.tools.tool.builtin_tool import BuiltinTool
|
||||
class BedrockRetrieveAndGenerateTool(BuiltinTool):
|
||||
bedrock_client: Any = None
|
||||
|
||||
def _create_text_inference_config(
|
||||
def _invoke(
|
||||
self,
|
||||
max_tokens: Optional[int] = None,
|
||||
stop_sequences: Optional[str] = None,
|
||||
temperature: Optional[float] = None,
|
||||
top_p: Optional[float] = None,
|
||||
) -> Optional[dict]:
|
||||
"""Create text inference configuration"""
|
||||
if any([max_tokens, stop_sequences, temperature, top_p]):
|
||||
config = {}
|
||||
if max_tokens is not None:
|
||||
config["maxTokens"] = max_tokens
|
||||
if stop_sequences:
|
||||
try:
|
||||
config["stopSequences"] = json.loads(stop_sequences)
|
||||
except json.JSONDecodeError:
|
||||
config["stopSequences"] = []
|
||||
if temperature is not None:
|
||||
config["temperature"] = temperature
|
||||
if top_p is not None:
|
||||
config["topP"] = top_p
|
||||
return config
|
||||
return None
|
||||
|
||||
def _create_guardrail_config(
|
||||
self,
|
||||
guardrail_id: Optional[str] = None,
|
||||
guardrail_version: Optional[str] = None,
|
||||
) -> Optional[dict]:
|
||||
"""Create guardrail configuration"""
|
||||
if guardrail_id and guardrail_version:
|
||||
return {"guardrailId": guardrail_id, "guardrailVersion": guardrail_version}
|
||||
return None
|
||||
|
||||
def _create_generation_config(
|
||||
self,
|
||||
additional_model_fields: Optional[str] = None,
|
||||
guardrail_config: Optional[dict] = None,
|
||||
text_inference_config: Optional[dict] = None,
|
||||
performance_mode: Optional[str] = None,
|
||||
prompt_template: Optional[str] = None,
|
||||
) -> dict:
|
||||
"""Create generation configuration"""
|
||||
config = {}
|
||||
|
||||
if additional_model_fields:
|
||||
try:
|
||||
config["additionalModelRequestFields"] = json.loads(additional_model_fields)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
if guardrail_config:
|
||||
config["guardrailConfiguration"] = guardrail_config
|
||||
|
||||
if text_inference_config:
|
||||
config["inferenceConfig"] = {"textInferenceConfig": text_inference_config}
|
||||
|
||||
if performance_mode:
|
||||
config["performanceConfig"] = {"latency": performance_mode}
|
||||
|
||||
if prompt_template:
|
||||
config["promptTemplate"] = {"textPromptTemplate": prompt_template}
|
||||
|
||||
return config
|
||||
|
||||
def _create_orchestration_config(
|
||||
self,
|
||||
orchestration_additional_model_fields: Optional[str] = None,
|
||||
orchestration_text_inference_config: Optional[dict] = None,
|
||||
orchestration_performance_mode: Optional[str] = None,
|
||||
orchestration_prompt_template: Optional[str] = None,
|
||||
) -> dict:
|
||||
"""Create orchestration configuration"""
|
||||
config = {}
|
||||
|
||||
if orchestration_additional_model_fields:
|
||||
try:
|
||||
config["additionalModelRequestFields"] = json.loads(orchestration_additional_model_fields)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
if orchestration_text_inference_config:
|
||||
config["inferenceConfig"] = {"textInferenceConfig": orchestration_text_inference_config}
|
||||
|
||||
if orchestration_performance_mode:
|
||||
config["performanceConfig"] = {"latency": orchestration_performance_mode}
|
||||
|
||||
if orchestration_prompt_template:
|
||||
config["promptTemplate"] = {"textPromptTemplate": orchestration_prompt_template}
|
||||
|
||||
return config
|
||||
|
||||
def _create_vector_search_config(
|
||||
self,
|
||||
number_of_results: int = 5,
|
||||
search_type: str = "SEMANTIC",
|
||||
metadata_filter: Optional[dict] = None,
|
||||
) -> dict:
|
||||
"""Create vector search configuration"""
|
||||
config = {
|
||||
"numberOfResults": number_of_results,
|
||||
"overrideSearchType": search_type,
|
||||
}
|
||||
|
||||
# Only add filter if metadata_filter is not empty
|
||||
if metadata_filter:
|
||||
config["filter"] = metadata_filter
|
||||
|
||||
return config
|
||||
|
||||
def _bedrock_retrieve_and_generate(
|
||||
self,
|
||||
query: str,
|
||||
knowledge_base_id: str,
|
||||
model_arn: str,
|
||||
# Generation Configuration
|
||||
additional_model_fields: Optional[str] = None,
|
||||
guardrail_id: Optional[str] = None,
|
||||
guardrail_version: Optional[str] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
stop_sequences: Optional[str] = None,
|
||||
temperature: Optional[float] = None,
|
||||
top_p: Optional[float] = None,
|
||||
performance_mode: str = "standard",
|
||||
prompt_template: Optional[str] = None,
|
||||
# Orchestration Configuration
|
||||
orchestration_additional_model_fields: Optional[str] = None,
|
||||
orchestration_max_tokens: Optional[int] = None,
|
||||
orchestration_stop_sequences: Optional[str] = None,
|
||||
orchestration_temperature: Optional[float] = None,
|
||||
orchestration_top_p: Optional[float] = None,
|
||||
orchestration_performance_mode: Optional[str] = None,
|
||||
orchestration_prompt_template: Optional[str] = None,
|
||||
# Retrieval Configuration
|
||||
number_of_results: int = 5,
|
||||
search_type: str = "SEMANTIC",
|
||||
metadata_filter: Optional[dict] = None,
|
||||
# Additional Configuration
|
||||
session_id: Optional[str] = None,
|
||||
) -> dict[str, Any]:
|
||||
user_id: str,
|
||||
tool_parameters: dict[str, Any],
|
||||
) -> ToolInvokeMessage:
|
||||
try:
|
||||
# Create text inference configurations
|
||||
text_inference_config = self._create_text_inference_config(max_tokens, stop_sequences, temperature, top_p)
|
||||
orchestration_text_inference_config = self._create_text_inference_config(
|
||||
orchestration_max_tokens, orchestration_stop_sequences, orchestration_temperature, orchestration_top_p
|
||||
)
|
||||
# Initialize Bedrock client if not already initialized
|
||||
if not self.bedrock_client:
|
||||
aws_region = tool_parameters.get("aws_region")
|
||||
aws_access_key_id = tool_parameters.get("aws_access_key_id")
|
||||
aws_secret_access_key = tool_parameters.get("aws_secret_access_key")
|
||||
|
||||
# Create guardrail configuration
|
||||
guardrail_config = self._create_guardrail_config(guardrail_id, guardrail_version)
|
||||
client_kwargs = {"service_name": "bedrock-agent-runtime", "region_name": aws_region or None}
|
||||
|
||||
# Create vector search configuration
|
||||
vector_search_config = self._create_vector_search_config(number_of_results, search_type, metadata_filter)
|
||||
# Only add credentials if both access key and secret key are provided
|
||||
if aws_access_key_id and aws_secret_access_key:
|
||||
client_kwargs.update(
|
||||
{"aws_access_key_id": aws_access_key_id, "aws_secret_access_key": aws_secret_access_key}
|
||||
)
|
||||
|
||||
# Create generation configuration
|
||||
generation_config = self._create_generation_config(
|
||||
additional_model_fields, guardrail_config, text_inference_config, performance_mode, prompt_template
|
||||
)
|
||||
self.bedrock_client = boto3.client(**client_kwargs)
|
||||
except Exception as e:
|
||||
return self.create_text_message(f"Failed to initialize Bedrock client: {str(e)}")
|
||||
|
||||
# Create orchestration configuration
|
||||
orchestration_config = self._create_orchestration_config(
|
||||
orchestration_additional_model_fields,
|
||||
orchestration_text_inference_config,
|
||||
orchestration_performance_mode,
|
||||
orchestration_prompt_template,
|
||||
)
|
||||
try:
|
||||
request_config = {}
|
||||
|
||||
# Create knowledge base configuration
|
||||
knowledge_base_config = {
|
||||
"knowledgeBaseId": knowledge_base_id,
|
||||
"modelArn": model_arn,
|
||||
"generationConfiguration": generation_config,
|
||||
"orchestrationConfiguration": orchestration_config,
|
||||
"retrievalConfiguration": {"vectorSearchConfiguration": vector_search_config},
|
||||
}
|
||||
# Set input configuration
|
||||
input_text = tool_parameters.get("input")
|
||||
if input_text:
|
||||
request_config["input"] = {"text": input_text}
|
||||
|
||||
# Create request configuration
|
||||
request_config = {
|
||||
"input": {"text": query},
|
||||
"retrieveAndGenerateConfiguration": {
|
||||
"type": "KNOWLEDGE_BASE",
|
||||
"knowledgeBaseConfiguration": knowledge_base_config,
|
||||
},
|
||||
}
|
||||
# Build retrieve and generate configuration
|
||||
config_type = tool_parameters.get("type")
|
||||
retrieve_generate_config = {"type": config_type}
|
||||
|
||||
# Add session configuration if provided
|
||||
if session_id and len(session_id) >= 2:
|
||||
request_config["sessionConfiguration"] = {"sessionId": session_id}
|
||||
# Add configuration based on type
|
||||
if config_type == "KNOWLEDGE_BASE":
|
||||
kb_config_str = tool_parameters.get("knowledge_base_configuration")
|
||||
kb_config = json.loads(kb_config_str) if kb_config_str else None
|
||||
retrieve_generate_config["knowledgeBaseConfiguration"] = kb_config
|
||||
else: # EXTERNAL_SOURCES
|
||||
es_config_str = tool_parameters.get("external_sources_configuration")
|
||||
es_config = json.loads(kb_config_str) if es_config_str else None
|
||||
retrieve_generate_config["externalSourcesConfiguration"] = es_config
|
||||
|
||||
request_config["retrieveAndGenerateConfiguration"] = retrieve_generate_config
|
||||
|
||||
# Parse session configuration
|
||||
session_config_str = tool_parameters.get("session_configuration")
|
||||
session_config = json.loads(session_config_str) if session_config_str else None
|
||||
if session_config:
|
||||
request_config["sessionConfiguration"] = session_config
|
||||
|
||||
# Add session ID if provided
|
||||
session_id = tool_parameters.get("session_id")
|
||||
if session_id:
|
||||
request_config["sessionId"] = session_id
|
||||
|
||||
# Send request
|
||||
@@ -226,99 +96,42 @@ class BedrockRetrieveAndGenerateTool(BuiltinTool):
|
||||
citation_info["references"].append(reference)
|
||||
|
||||
result["citations"].append(citation_info)
|
||||
|
||||
return result
|
||||
|
||||
result_type = tool_parameters.get("result_type")
|
||||
if result_type == "json":
|
||||
return self.create_json_message(result)
|
||||
elif result_type == "text-with-citations":
|
||||
return self.create_text_message(result)
|
||||
else:
|
||||
return self.create_text_message(result.get("output"))
|
||||
except json.JSONDecodeError as e:
|
||||
return self.create_text_message(f"Invalid JSON format: {str(e)}")
|
||||
except Exception as e:
|
||||
raise Exception(f"Error calling Bedrock service: {str(e)}")
|
||||
|
||||
def _invoke(
|
||||
self,
|
||||
user_id: str,
|
||||
tool_parameters: dict[str, Any],
|
||||
) -> ToolInvokeMessage:
|
||||
try:
|
||||
# Initialize Bedrock client if not already initialized
|
||||
if not self.bedrock_client:
|
||||
aws_region = tool_parameters.get("aws_region")
|
||||
aws_access_key_id = tool_parameters.get("aws_access_key_id")
|
||||
aws_secret_access_key = tool_parameters.get("aws_secret_access_key")
|
||||
|
||||
client_kwargs = {
|
||||
"service_name": "bedrock-agent-runtime",
|
||||
}
|
||||
if aws_region:
|
||||
client_kwargs["region_name"] = aws_region
|
||||
# Only add credentials if both access key and secret key are provided
|
||||
if aws_access_key_id and aws_secret_access_key:
|
||||
client_kwargs.update(
|
||||
{"aws_access_key_id": aws_access_key_id, "aws_secret_access_key": aws_secret_access_key}
|
||||
)
|
||||
|
||||
try:
|
||||
self.bedrock_client = boto3.client(**client_kwargs)
|
||||
except Exception as e:
|
||||
return self.create_text_message(f"Failed to initialize Bedrock client: {str(e)}")
|
||||
|
||||
# Parse metadata filter if provided
|
||||
metadata_filter = None
|
||||
if metadata_filter_str := tool_parameters.get("metadata_filter"):
|
||||
try:
|
||||
parsed_filter = json.loads(metadata_filter_str)
|
||||
if parsed_filter: # Only set if not empty
|
||||
metadata_filter = parsed_filter
|
||||
except json.JSONDecodeError:
|
||||
return self.create_text_message("metadata_filter must be a valid JSON string")
|
||||
|
||||
try:
|
||||
response = self._bedrock_retrieve_and_generate(
|
||||
query=tool_parameters["query"],
|
||||
knowledge_base_id=tool_parameters["knowledge_base_id"],
|
||||
model_arn=tool_parameters["model_arn"],
|
||||
# Generation Configuration
|
||||
additional_model_fields=tool_parameters.get("additional_model_fields"),
|
||||
guardrail_id=tool_parameters.get("guardrail_id"),
|
||||
guardrail_version=tool_parameters.get("guardrail_version"),
|
||||
max_tokens=tool_parameters.get("max_tokens"),
|
||||
stop_sequences=tool_parameters.get("stop_sequences"),
|
||||
temperature=tool_parameters.get("temperature"),
|
||||
top_p=tool_parameters.get("top_p"),
|
||||
performance_mode=tool_parameters.get("performance_mode", "standard"),
|
||||
prompt_template=tool_parameters.get("prompt_template"),
|
||||
# Orchestration Configuration
|
||||
orchestration_additional_model_fields=tool_parameters.get("orchestration_additional_model_fields"),
|
||||
orchestration_max_tokens=tool_parameters.get("orchestration_max_tokens"),
|
||||
orchestration_stop_sequences=tool_parameters.get("orchestration_stop_sequences"),
|
||||
orchestration_temperature=tool_parameters.get("orchestration_temperature"),
|
||||
orchestration_top_p=tool_parameters.get("orchestration_top_p"),
|
||||
orchestration_performance_mode=tool_parameters.get("orchestration_performance_mode"),
|
||||
orchestration_prompt_template=tool_parameters.get("orchestration_prompt_template"),
|
||||
# Retrieval Configuration
|
||||
number_of_results=tool_parameters.get("number_of_results", 5),
|
||||
search_type=tool_parameters.get("search_type", "SEMANTIC"),
|
||||
metadata_filter=metadata_filter,
|
||||
# Additional Configuration
|
||||
session_id=tool_parameters.get("session_id"),
|
||||
)
|
||||
return self.create_json_message(response)
|
||||
|
||||
except Exception as e:
|
||||
return self.create_text_message(f"Tool invocation error: {str(e)}")
|
||||
|
||||
except Exception as e:
|
||||
return self.create_text_message(f"Tool execution error: {str(e)}")
|
||||
return self.create_text_message(f"Tool invocation error: {str(e)}")
|
||||
|
||||
def validate_parameters(self, parameters: dict[str, Any]) -> None:
|
||||
"""Validate the parameters"""
|
||||
required_params = ["query", "model_arn", "knowledge_base_id"]
|
||||
for param in required_params:
|
||||
if not parameters.get(param):
|
||||
raise ValueError(f"{param} is required")
|
||||
# Validate required parameters
|
||||
if not parameters.get("input"):
|
||||
raise ValueError("input is required")
|
||||
if not parameters.get("type"):
|
||||
raise ValueError("type is required")
|
||||
|
||||
# Validate metadata filter if provided
|
||||
if metadata_filter_str := parameters.get("metadata_filter"):
|
||||
try:
|
||||
if not isinstance(json.loads(metadata_filter_str), dict):
|
||||
raise ValueError("metadata_filter must be a valid JSON object")
|
||||
except json.JSONDecodeError:
|
||||
raise ValueError("metadata_filter must be a valid JSON string")
|
||||
# Validate JSON configurations
|
||||
json_configs = ["knowledge_base_configuration", "external_sources_configuration", "session_configuration"]
|
||||
for config in json_configs:
|
||||
if config_value := parameters.get(config):
|
||||
try:
|
||||
json.loads(config_value)
|
||||
except json.JSONDecodeError:
|
||||
raise ValueError(f"{config} must be a valid JSON string")
|
||||
|
||||
# Validate configuration type
|
||||
config_type = parameters.get("type")
|
||||
if config_type not in ["KNOWLEDGE_BASE", "EXTERNAL_SOURCES"]:
|
||||
raise ValueError("type must be either KNOWLEDGE_BASE or EXTERNAL_SOURCES")
|
||||
|
||||
# Validate type-specific configuration
|
||||
if config_type == "KNOWLEDGE_BASE" and not parameters.get("knowledge_base_configuration"):
|
||||
raise ValueError("knowledge_base_configuration is required when type is KNOWLEDGE_BASE")
|
||||
elif config_type == "EXTERNAL_SOURCES" and not parameters.get("external_sources_configuration"):
|
||||
raise ValueError("external_sources_configuration is required when type is EXTERNAL_SOURCES")
|
||||
|
@@ -8,24 +8,11 @@ identity:
|
||||
|
||||
description:
|
||||
human:
|
||||
en_US: A tool for retrieving and generating information using Amazon Bedrock Knowledge Base
|
||||
zh_Hans: 使用Amazon Bedrock知识库进行信息检索和生成的工具
|
||||
en_US: "This is an advanced usage of Bedrock Retrieve. Please refer to the API documentation for detailed parameters and paste them into the corresponding Knowledge Base Configuration or External Sources Configuration"
|
||||
zh_Hans: "这个工具为Bedrock Retrieve的高级用法,请参考API设置详细的参数,并粘贴到对应的知识库配置或者外部源配置"
|
||||
llm: A tool for retrieving and generating information using Amazon Bedrock Knowledge Base
|
||||
|
||||
parameters:
|
||||
# Additional Configuration
|
||||
- name: session_id
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: Session ID
|
||||
zh_Hans: 会话ID
|
||||
human_description:
|
||||
en_US: Optional session ID for continuous conversations
|
||||
zh_Hans: 用于连续对话的可选会话ID
|
||||
form: form
|
||||
|
||||
# AWS Configuration
|
||||
- name: aws_region
|
||||
type: string
|
||||
required: false
|
||||
@@ -59,300 +46,103 @@ parameters:
|
||||
zh_Hans: 用于身份验证的AWS秘密访问密钥(可选)
|
||||
form: form
|
||||
|
||||
# Knowledge Base Configuration
|
||||
- name: knowledge_base_id
|
||||
type: string
|
||||
- name: result_type
|
||||
type: select
|
||||
required: true
|
||||
label:
|
||||
en_US: Knowledge Base ID
|
||||
zh_Hans: 知识库ID
|
||||
en_US: result type
|
||||
zh_Hans: 结果类型
|
||||
human_description:
|
||||
en_US: ID of the Bedrock Knowledge Base
|
||||
zh_Hans: Bedrock知识库的ID
|
||||
en_US: return a list of json or texts
|
||||
zh_Hans: 返回一个列表,内容是json还是纯文本
|
||||
default: text
|
||||
options:
|
||||
- value: json
|
||||
label:
|
||||
en_US: JSON
|
||||
zh_Hans: JSON
|
||||
- value: text
|
||||
label:
|
||||
en_US: Text
|
||||
zh_Hans: 文本
|
||||
- value: text-with-citations
|
||||
label:
|
||||
en_US: Text With Citations
|
||||
zh_Hans: 文本(包含引用)
|
||||
form: form
|
||||
|
||||
- name: model_arn
|
||||
- name: input
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: Model ARN
|
||||
zh_Hans: 模型ARN
|
||||
en_US: Input Text
|
||||
zh_Hans: 输入文本
|
||||
human_description:
|
||||
en_US: The ARN of the model to use
|
||||
zh_Hans: 要使用的模型ARN
|
||||
form: form
|
||||
|
||||
# Retrieval Configuration
|
||||
- name: query
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: Query
|
||||
zh_Hans: 查询
|
||||
human_description:
|
||||
en_US: The search query to retrieve information
|
||||
zh_Hans: 用于检索信息的查询语句
|
||||
en_US: The text query to retrieve information
|
||||
zh_Hans: 用于检索信息的文本查询
|
||||
form: llm
|
||||
|
||||
- name: number_of_results
|
||||
type: number
|
||||
required: false
|
||||
label:
|
||||
en_US: Number of Results
|
||||
zh_Hans: 结果数量
|
||||
human_description:
|
||||
en_US: Number of results to retrieve (1-10)
|
||||
zh_Hans: 要检索的结果数量(1-10)
|
||||
default: 5
|
||||
min: 1
|
||||
max: 10
|
||||
form: form
|
||||
|
||||
- name: search_type
|
||||
- name: type
|
||||
type: select
|
||||
required: false
|
||||
required: true
|
||||
label:
|
||||
en_US: Search Type
|
||||
zh_Hans: 搜索类型
|
||||
en_US: Configuration Type
|
||||
zh_Hans: 配置类型
|
||||
human_description:
|
||||
en_US: Type of search to perform
|
||||
zh_Hans: 要执行的搜索类型
|
||||
default: SEMANTIC
|
||||
en_US: Type of retrieve and generate configuration
|
||||
zh_Hans: 检索和生成配置的类型
|
||||
options:
|
||||
- value: SEMANTIC
|
||||
- value: KNOWLEDGE_BASE
|
||||
label:
|
||||
en_US: Semantic Search
|
||||
zh_Hans: 语义搜索
|
||||
- value: HYBRID
|
||||
en_US: Knowledge Base
|
||||
zh_Hans: 知识库
|
||||
- value: EXTERNAL_SOURCES
|
||||
label:
|
||||
en_US: Hybrid Search
|
||||
zh_Hans: 混合搜索
|
||||
en_US: External Sources
|
||||
zh_Hans: 外部源
|
||||
form: form
|
||||
|
||||
- name: metadata_filter
|
||||
- name: knowledge_base_configuration
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: Metadata Filter
|
||||
zh_Hans: 元数据过滤器
|
||||
en_US: Knowledge Base Configuration
|
||||
zh_Hans: 知识库配置
|
||||
human_description:
|
||||
en_US: JSON formatted filter conditions for metadata, supporting operations like equals, greaterThan, lessThan, etc.
|
||||
zh_Hans: 元数据的JSON格式过滤条件,支持等于、大于、小于等操作
|
||||
default: "{}"
|
||||
en_US: Please refer to @https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/bedrock-agent-runtime/client/retrieve_and_generate.html#retrieve-and-generate for complete parameters and paste them here
|
||||
zh_Hans: 请参考 https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/bedrock-agent-runtime/client/retrieve_and_generate.html#retrieve-and-generate 配置完整的参数并粘贴到这里
|
||||
form: form
|
||||
|
||||
# Generation Configuration
|
||||
- name: guardrail_id
|
||||
- name: external_sources_configuration
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: Guardrail ID
|
||||
zh_Hans: 防护栏ID
|
||||
en_US: External Sources Configuration
|
||||
zh_Hans: 外部源配置
|
||||
human_description:
|
||||
en_US: ID of the guardrail to apply
|
||||
zh_Hans: 要应用的防护栏ID
|
||||
en_US: Please refer to https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/bedrock-agent-runtime/client/retrieve_and_generate.html#retrieve-and-generate for complete parameters and paste them here
|
||||
zh_Hans: 请参考 https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/bedrock-agent-runtime/client/retrieve_and_generate.html#retrieve-and-generate 配置完整的参数并粘贴到这里
|
||||
form: form
|
||||
|
||||
- name: guardrail_version
|
||||
- name: session_configuration
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: Guardrail Version
|
||||
zh_Hans: 防护栏版本
|
||||
en_US: Session Configuration
|
||||
zh_Hans: 会话配置
|
||||
human_description:
|
||||
en_US: Version of the guardrail to apply
|
||||
zh_Hans: 要应用的防护栏版本
|
||||
en_US: JSON formatted session configuration
|
||||
zh_Hans: JSON格式的会话配置
|
||||
default: ""
|
||||
form: form
|
||||
|
||||
- name: max_tokens
|
||||
type: number
|
||||
required: false
|
||||
label:
|
||||
en_US: Maximum Tokens
|
||||
zh_Hans: 最大令牌数
|
||||
human_description:
|
||||
en_US: Maximum number of tokens to generate
|
||||
zh_Hans: 生成的最大令牌数
|
||||
default: 2048
|
||||
form: form
|
||||
|
||||
- name: stop_sequences
|
||||
- name: session_id
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: Stop Sequences
|
||||
zh_Hans: 停止序列
|
||||
en_US: Session ID
|
||||
zh_Hans: 会话ID
|
||||
human_description:
|
||||
en_US: JSON array of strings that will stop generation when encountered
|
||||
zh_Hans: JSON数组格式的字符串,遇到这些序列时将停止生成
|
||||
default: "[]"
|
||||
form: form
|
||||
|
||||
- name: temperature
|
||||
type: number
|
||||
required: false
|
||||
label:
|
||||
en_US: Temperature
|
||||
zh_Hans: 温度
|
||||
human_description:
|
||||
en_US: Controls randomness in the output (0-1)
|
||||
zh_Hans: 控制输出的随机性(0-1)
|
||||
default: 0.7
|
||||
min: 0
|
||||
max: 1
|
||||
form: form
|
||||
|
||||
- name: top_p
|
||||
type: number
|
||||
required: false
|
||||
label:
|
||||
en_US: Top P
|
||||
zh_Hans: Top P值
|
||||
human_description:
|
||||
en_US: Controls diversity via nucleus sampling (0-1)
|
||||
zh_Hans: 通过核采样控制多样性(0-1)
|
||||
default: 0.95
|
||||
min: 0
|
||||
max: 1
|
||||
form: form
|
||||
|
||||
- name: performance_mode
|
||||
type: select
|
||||
required: false
|
||||
label:
|
||||
en_US: Performance Mode
|
||||
zh_Hans: 性能模式
|
||||
human_description:
|
||||
en_US: Select performance optimization mode(performanceConfig.latency)
|
||||
zh_Hans: 选择性能优化模式(performanceConfig.latency)
|
||||
default: standard
|
||||
options:
|
||||
- value: standard
|
||||
label:
|
||||
en_US: Standard
|
||||
zh_Hans: 标准
|
||||
- value: optimized
|
||||
label:
|
||||
en_US: Optimized
|
||||
zh_Hans: 优化
|
||||
form: form
|
||||
|
||||
- name: prompt_template
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: Prompt Template
|
||||
zh_Hans: 提示模板
|
||||
human_description:
|
||||
en_US: Custom prompt template for generation
|
||||
zh_Hans: 用于生成的自定义提示模板
|
||||
form: form
|
||||
|
||||
- name: additional_model_fields
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: Additional Model Fields
|
||||
zh_Hans: 额外模型字段
|
||||
human_description:
|
||||
en_US: JSON formatted additional fields for model configuration
|
||||
zh_Hans: JSON格式的额外模型配置字段
|
||||
default: "{}"
|
||||
form: form
|
||||
|
||||
# Orchestration Configuration
|
||||
- name: orchestration_max_tokens
|
||||
type: number
|
||||
required: false
|
||||
label:
|
||||
en_US: Orchestration Maximum Tokens
|
||||
zh_Hans: 编排最大令牌数
|
||||
human_description:
|
||||
en_US: Maximum number of tokens for orchestration
|
||||
zh_Hans: 编排过程的最大令牌数
|
||||
default: 2048
|
||||
form: form
|
||||
|
||||
- name: orchestration_stop_sequences
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: Orchestration Stop Sequences
|
||||
zh_Hans: 编排停止序列
|
||||
human_description:
|
||||
en_US: JSON array of strings that will stop orchestration when encountered
|
||||
zh_Hans: JSON数组格式的字符串,遇到这些序列时将停止编排
|
||||
default: "[]"
|
||||
form: form
|
||||
|
||||
- name: orchestration_temperature
|
||||
type: number
|
||||
required: false
|
||||
label:
|
||||
en_US: Orchestration Temperature
|
||||
zh_Hans: 编排温度
|
||||
human_description:
|
||||
en_US: Controls randomness in the orchestration output (0-1)
|
||||
zh_Hans: 控制编排输出的随机性(0-1)
|
||||
default: 0.7
|
||||
min: 0
|
||||
max: 1
|
||||
form: form
|
||||
|
||||
- name: orchestration_top_p
|
||||
type: number
|
||||
required: false
|
||||
label:
|
||||
en_US: Orchestration Top P
|
||||
zh_Hans: 编排Top P值
|
||||
human_description:
|
||||
en_US: Controls diversity via nucleus sampling in orchestration (0-1)
|
||||
zh_Hans: 通过核采样控制编排的多样性(0-1)
|
||||
default: 0.95
|
||||
min: 0
|
||||
max: 1
|
||||
form: form
|
||||
|
||||
- name: orchestration_performance_mode
|
||||
type: select
|
||||
required: false
|
||||
label:
|
||||
en_US: Orchestration Performance Mode
|
||||
zh_Hans: 编排性能模式
|
||||
human_description:
|
||||
en_US: Select performance optimization mode for orchestration
|
||||
zh_Hans: 选择编排的性能优化模式
|
||||
default: standard
|
||||
options:
|
||||
- value: standard
|
||||
label:
|
||||
en_US: Standard
|
||||
zh_Hans: 标准
|
||||
- value: optimized
|
||||
label:
|
||||
en_US: Optimized
|
||||
zh_Hans: 优化
|
||||
form: form
|
||||
|
||||
- name: orchestration_prompt_template
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: Orchestration Prompt Template
|
||||
zh_Hans: 编排提示模板
|
||||
human_description:
|
||||
en_US: Custom prompt template for orchestration
|
||||
zh_Hans: 用于编排的自定义提示模板
|
||||
form: form
|
||||
|
||||
- name: orchestration_additional_model_fields
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: Orchestration Additional Model Fields
|
||||
zh_Hans: 编排额外模型字段
|
||||
human_description:
|
||||
en_US: JSON formatted additional fields for orchestration model configuration
|
||||
zh_Hans: JSON格式的编排模型额外配置字段
|
||||
default: "{}"
|
||||
en_US: Session ID for continuous conversations
|
||||
zh_Hans: 用于连续对话的会话ID
|
||||
form: form
|
||||
|
Reference in New Issue
Block a user