feat: add ops trace (#5483)

Co-authored-by: takatost <takatost@gmail.com>
This commit is contained in:
Joe
2024-06-26 17:33:29 +08:00
committed by GitHub
parent 31a061ebaa
commit 4e2de638af
58 changed files with 3553 additions and 622 deletions

View File

@@ -66,44 +66,43 @@ class ParameterExtractorNode(LLMNode):
}
}
def _run(self, variable_pool: VariablePool) -> NodeRunResult:
"""
Run the node.
"""
node_data = cast(ParameterExtractorNodeData, self.node_data)
query = variable_pool.get_variable_value(node_data.query)
if not query:
raise ValueError("Query not found")
inputs={
raise ValueError("Input variable content not found or is empty")
inputs = {
'query': query,
'parameters': jsonable_encoder(node_data.parameters),
'instruction': jsonable_encoder(node_data.instruction),
}
model_instance, model_config = self._fetch_model_config(node_data.model)
if not isinstance(model_instance.model_type_instance, LargeLanguageModel):
raise ValueError("Model is not a Large Language Model")
llm_model = model_instance.model_type_instance
model_schema = llm_model.get_model_schema(model_config.model, model_config.credentials)
if not model_schema:
raise ValueError("Model schema not found")
# fetch memory
memory = self._fetch_memory(node_data.memory, variable_pool, model_instance)
if set(model_schema.features or []) & set([ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL]) \
and node_data.reasoning_mode == 'function_call':
and node_data.reasoning_mode == 'function_call':
# use function call
prompt_messages, prompt_message_tools = self._generate_function_call_prompt(
node_data, query, variable_pool, model_config, memory
)
else:
# use prompt engineering
prompt_messages = self._generate_prompt_engineering_prompt(node_data, query, variable_pool, model_config, memory)
prompt_messages = self._generate_prompt_engineering_prompt(node_data, query, variable_pool, model_config,
memory)
prompt_message_tools = []
process_data = {
@@ -202,7 +201,7 @@ class ParameterExtractorNode(LLMNode):
# handle invoke result
if not isinstance(invoke_result, LLMResult):
raise ValueError(f"Invalid invoke result: {invoke_result}")
text = invoke_result.message.content
usage = invoke_result.usage
tool_call = invoke_result.message.tool_calls[0] if invoke_result.message.tool_calls else None
@@ -212,21 +211,23 @@ class ParameterExtractorNode(LLMNode):
return text, usage, tool_call
def _generate_function_call_prompt(self,
node_data: ParameterExtractorNodeData,
query: str,
variable_pool: VariablePool,
model_config: ModelConfigWithCredentialsEntity,
memory: Optional[TokenBufferMemory],
) -> tuple[list[PromptMessage], list[PromptMessageTool]]:
def _generate_function_call_prompt(self,
node_data: ParameterExtractorNodeData,
query: str,
variable_pool: VariablePool,
model_config: ModelConfigWithCredentialsEntity,
memory: Optional[TokenBufferMemory],
) -> tuple[list[PromptMessage], list[PromptMessageTool]]:
"""
Generate function call prompt.
"""
query = FUNCTION_CALLING_EXTRACTOR_USER_TEMPLATE.format(content=query, structure=json.dumps(node_data.get_parameter_json_schema()))
query = FUNCTION_CALLING_EXTRACTOR_USER_TEMPLATE.format(content=query, structure=json.dumps(
node_data.get_parameter_json_schema()))
prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True)
rest_token = self._calculate_rest_token(node_data, query, variable_pool, model_config, '')
prompt_template = self._get_function_calling_prompt_template(node_data, query, variable_pool, memory, rest_token)
prompt_template = self._get_function_calling_prompt_template(node_data, query, variable_pool, memory,
rest_token)
prompt_messages = prompt_transform.get_prompt(
prompt_template=prompt_template,
inputs={},
@@ -259,8 +260,8 @@ class ParameterExtractorNode(LLMNode):
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
name=example['assistant']['function_call']['name'],
arguments=json.dumps(example['assistant']['function_call']['parameters']
)
))
)
))
]
),
ToolPromptMessage(
@@ -273,8 +274,8 @@ class ParameterExtractorNode(LLMNode):
])
prompt_messages = prompt_messages[:last_user_message_idx] + \
example_messages + prompt_messages[last_user_message_idx:]
example_messages + prompt_messages[last_user_message_idx:]
# generate tool
tool = PromptMessageTool(
name=FUNCTION_CALLING_EXTRACTOR_NAME,
@@ -284,13 +285,13 @@ class ParameterExtractorNode(LLMNode):
return prompt_messages, [tool]
def _generate_prompt_engineering_prompt(self,
data: ParameterExtractorNodeData,
query: str,
variable_pool: VariablePool,
model_config: ModelConfigWithCredentialsEntity,
memory: Optional[TokenBufferMemory],
) -> list[PromptMessage]:
def _generate_prompt_engineering_prompt(self,
data: ParameterExtractorNodeData,
query: str,
variable_pool: VariablePool,
model_config: ModelConfigWithCredentialsEntity,
memory: Optional[TokenBufferMemory],
) -> list[PromptMessage]:
"""
Generate prompt engineering prompt.
"""
@@ -308,18 +309,19 @@ class ParameterExtractorNode(LLMNode):
raise ValueError(f"Invalid model mode: {model_mode}")
def _generate_prompt_engineering_completion_prompt(self,
node_data: ParameterExtractorNodeData,
query: str,
variable_pool: VariablePool,
model_config: ModelConfigWithCredentialsEntity,
memory: Optional[TokenBufferMemory],
) -> list[PromptMessage]:
node_data: ParameterExtractorNodeData,
query: str,
variable_pool: VariablePool,
model_config: ModelConfigWithCredentialsEntity,
memory: Optional[TokenBufferMemory],
) -> list[PromptMessage]:
"""
Generate completion prompt.
"""
prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True)
rest_token = self._calculate_rest_token(node_data, query, variable_pool, model_config, '')
prompt_template = self._get_prompt_engineering_prompt_template(node_data, query, variable_pool, memory, rest_token)
prompt_template = self._get_prompt_engineering_prompt_template(node_data, query, variable_pool, memory,
rest_token)
prompt_messages = prompt_transform.get_prompt(
prompt_template=prompt_template,
inputs={
@@ -336,23 +338,23 @@ class ParameterExtractorNode(LLMNode):
return prompt_messages
def _generate_prompt_engineering_chat_prompt(self,
node_data: ParameterExtractorNodeData,
query: str,
variable_pool: VariablePool,
model_config: ModelConfigWithCredentialsEntity,
memory: Optional[TokenBufferMemory],
) -> list[PromptMessage]:
node_data: ParameterExtractorNodeData,
query: str,
variable_pool: VariablePool,
model_config: ModelConfigWithCredentialsEntity,
memory: Optional[TokenBufferMemory],
) -> list[PromptMessage]:
"""
Generate chat prompt.
"""
prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True)
rest_token = self._calculate_rest_token(node_data, query, variable_pool, model_config, '')
prompt_template = self._get_prompt_engineering_prompt_template(
node_data,
node_data,
CHAT_GENERATE_JSON_USER_MESSAGE_TEMPLATE.format(
structure=json.dumps(node_data.get_parameter_json_schema()),
text=query
),
),
variable_pool, memory, rest_token
)
@@ -387,7 +389,7 @@ class ParameterExtractorNode(LLMNode):
])
prompt_messages = prompt_messages[:last_user_message_idx] + \
example_messages + prompt_messages[last_user_message_idx:]
example_messages + prompt_messages[last_user_message_idx:]
return prompt_messages
@@ -397,23 +399,23 @@ class ParameterExtractorNode(LLMNode):
"""
if len(data.parameters) != len(result):
raise ValueError("Invalid number of parameters")
for parameter in data.parameters:
if parameter.required and parameter.name not in result:
raise ValueError(f"Parameter {parameter.name} is required")
if parameter.type == 'select' and parameter.options and result.get(parameter.name) not in parameter.options:
raise ValueError(f"Invalid `select` value for parameter {parameter.name}")
if parameter.type == 'number' and not isinstance(result.get(parameter.name), int | float):
raise ValueError(f"Invalid `number` value for parameter {parameter.name}")
if parameter.type == 'bool' and not isinstance(result.get(parameter.name), bool):
raise ValueError(f"Invalid `bool` value for parameter {parameter.name}")
if parameter.type == 'string' and not isinstance(result.get(parameter.name), str):
raise ValueError(f"Invalid `string` value for parameter {parameter.name}")
if parameter.type.startswith('array'):
if not isinstance(result.get(parameter.name), list):
raise ValueError(f"Invalid `array` value for parameter {parameter.name}")
@@ -499,6 +501,7 @@ class ParameterExtractorNode(LLMNode):
"""
Extract complete json response.
"""
def extract_json(text):
"""
From a given JSON started from '{' or '[' extract the complete JSON object.
@@ -515,11 +518,11 @@ class ParameterExtractorNode(LLMNode):
if (c == '}' and stack[-1] == '{') or (c == ']' and stack[-1] == '['):
stack.pop()
if not stack:
return text[:i+1]
return text[:i + 1]
else:
return text[:i]
return None
# extract json from the text
for idx in range(len(result)):
if result[idx] == '{' or result[idx] == '[':
@@ -536,9 +539,9 @@ class ParameterExtractorNode(LLMNode):
"""
if not tool_call or not tool_call.function.arguments:
return None
return json.loads(tool_call.function.arguments)
def _generate_default_result(self, data: ParameterExtractorNodeData) -> dict:
"""
Generate default result.
@@ -551,7 +554,7 @@ class ParameterExtractorNode(LLMNode):
result[parameter.name] = False
elif parameter.type in ['string', 'select']:
result[parameter.name] = ''
return result
def _render_instruction(self, instruction: str, variable_pool: VariablePool) -> str:
@@ -562,13 +565,13 @@ class ParameterExtractorNode(LLMNode):
inputs = {}
for selector in variable_template_parser.extract_variable_selectors():
inputs[selector.variable] = variable_pool.get_variable_value(selector.value_selector)
return variable_template_parser.format(inputs)
def _get_function_calling_prompt_template(self, node_data: ParameterExtractorNodeData, query: str,
variable_pool: VariablePool,
memory: Optional[TokenBufferMemory],
max_token_limit: int = 2000) \
variable_pool: VariablePool,
memory: Optional[TokenBufferMemory],
max_token_limit: int = 2000) \
-> list[ChatModelMessage]:
model_mode = ModelMode.value_of(node_data.model.mode)
input_text = query
@@ -590,12 +593,12 @@ class ParameterExtractorNode(LLMNode):
return [system_prompt_messages, user_prompt_message]
else:
raise ValueError(f"Model mode {model_mode} not support.")
def _get_prompt_engineering_prompt_template(self, node_data: ParameterExtractorNodeData, query: str,
variable_pool: VariablePool,
memory: Optional[TokenBufferMemory],
max_token_limit: int = 2000) \
-> list[ChatModelMessage]:
-> list[ChatModelMessage]:
model_mode = ModelMode.value_of(node_data.model.mode)
input_text = query
@@ -620,8 +623,8 @@ class ParameterExtractorNode(LLMNode):
text=COMPLETION_GENERATE_JSON_PROMPT.format(histories=memory_str,
text=input_text,
instruction=instruction)
.replace('{γγγ', '')
.replace('}γγγ', '')
.replace('{γγγ', '')
.replace('}γγγ', '')
)
else:
raise ValueError(f"Model mode {model_mode} not support.")
@@ -635,7 +638,7 @@ class ParameterExtractorNode(LLMNode):
model_instance, model_config = self._fetch_model_config(node_data.model)
if not isinstance(model_instance.model_type_instance, LargeLanguageModel):
raise ValueError("Model is not a Large Language Model")
llm_model = model_instance.model_type_instance
model_schema = llm_model.get_model_schema(model_config.model, model_config.credentials)
if not model_schema:
@@ -667,7 +670,7 @@ class ParameterExtractorNode(LLMNode):
model_config.model,
model_config.credentials,
prompt_messages
) + 1000 # add 1000 to ensure tool call messages
) + 1000 # add 1000 to ensure tool call messages
max_tokens = 0
for parameter_rule in model_config.model_schema.parameter_rules:
@@ -680,8 +683,9 @@ class ParameterExtractorNode(LLMNode):
rest_tokens = max(rest_tokens, 0)
return rest_tokens
def _fetch_model_config(self, node_data_model: ModelConfig) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]:
def _fetch_model_config(self, node_data_model: ModelConfig) -> tuple[
ModelInstance, ModelConfigWithCredentialsEntity]:
"""
Fetch model config.
"""
@@ -689,9 +693,10 @@ class ParameterExtractorNode(LLMNode):
self._model_instance, self._model_config = super()._fetch_model_config(node_data_model)
return self._model_instance, self._model_config
@classmethod
def _extract_variable_selector_to_variable_mapping(cls, node_data: ParameterExtractorNodeData) -> dict[str, list[str]]:
def _extract_variable_selector_to_variable_mapping(cls, node_data: ParameterExtractorNodeData) -> dict[
str, list[str]]:
"""
Extract variable selector to variable mapping
:param node_data: node data
@@ -708,4 +713,4 @@ class ParameterExtractorNode(LLMNode):
for selector in variable_template_parser.extract_variable_selectors():
variable_mapping[selector.variable] = selector.value_selector
return variable_mapping
return variable_mapping