chore(api/core): apply ruff reformatting (#7624)
This commit is contained in:
@@ -11,6 +11,7 @@ class ModelConfig(BaseModel):
|
||||
"""
|
||||
Model Config.
|
||||
"""
|
||||
|
||||
provider: str
|
||||
name: str
|
||||
mode: str
|
||||
@@ -21,6 +22,7 @@ class ContextConfig(BaseModel):
|
||||
"""
|
||||
Context Config.
|
||||
"""
|
||||
|
||||
enabled: bool
|
||||
variable_selector: Optional[list[str]] = None
|
||||
|
||||
@@ -29,37 +31,47 @@ class VisionConfig(BaseModel):
|
||||
"""
|
||||
Vision Config.
|
||||
"""
|
||||
|
||||
class Configs(BaseModel):
|
||||
"""
|
||||
Configs.
|
||||
"""
|
||||
detail: Literal['low', 'high']
|
||||
|
||||
detail: Literal["low", "high"]
|
||||
|
||||
enabled: bool
|
||||
configs: Optional[Configs] = None
|
||||
|
||||
|
||||
class PromptConfig(BaseModel):
|
||||
"""
|
||||
Prompt Config.
|
||||
"""
|
||||
|
||||
jinja2_variables: Optional[list[VariableSelector]] = None
|
||||
|
||||
|
||||
class LLMNodeChatModelMessage(ChatModelMessage):
|
||||
"""
|
||||
LLM Node Chat Model Message.
|
||||
"""
|
||||
|
||||
jinja2_text: Optional[str] = None
|
||||
|
||||
|
||||
class LLMNodeCompletionModelPromptTemplate(CompletionModelPromptTemplate):
|
||||
"""
|
||||
LLM Node Chat Model Prompt Template.
|
||||
"""
|
||||
|
||||
jinja2_text: Optional[str] = None
|
||||
|
||||
|
||||
class LLMNodeData(BaseNodeData):
|
||||
"""
|
||||
LLM Node Data.
|
||||
"""
|
||||
|
||||
model: ModelConfig
|
||||
prompt_template: Union[list[LLMNodeChatModelMessage], LLMNodeCompletionModelPromptTemplate]
|
||||
prompt_config: Optional[PromptConfig] = None
|
||||
|
@@ -45,11 +45,11 @@ if TYPE_CHECKING:
|
||||
from core.file.file_obj import FileVar
|
||||
|
||||
|
||||
|
||||
class ModelInvokeCompleted(BaseModel):
|
||||
"""
|
||||
Model invoke completed
|
||||
"""
|
||||
|
||||
text: str
|
||||
usage: LLMUsage
|
||||
finish_reason: Optional[str] = None
|
||||
@@ -89,7 +89,7 @@ class LLMNode(BaseNode):
|
||||
files = self._fetch_files(node_data, variable_pool)
|
||||
|
||||
if files:
|
||||
node_inputs['#files#'] = [file.to_dict() for file in files]
|
||||
node_inputs["#files#"] = [file.to_dict() for file in files]
|
||||
|
||||
# fetch context value
|
||||
generator = self._fetch_context(node_data, variable_pool)
|
||||
@@ -100,7 +100,7 @@ class LLMNode(BaseNode):
|
||||
yield event
|
||||
|
||||
if context:
|
||||
node_inputs['#context#'] = context # type: ignore
|
||||
node_inputs["#context#"] = context # type: ignore
|
||||
|
||||
# fetch model config
|
||||
model_instance, model_config = self._fetch_model_config(node_data.model)
|
||||
@@ -111,24 +111,22 @@ class LLMNode(BaseNode):
|
||||
# fetch prompt messages
|
||||
prompt_messages, stop = self._fetch_prompt_messages(
|
||||
node_data=node_data,
|
||||
query=variable_pool.get_any(['sys', SystemVariableKey.QUERY.value])
|
||||
if node_data.memory else None,
|
||||
query=variable_pool.get_any(["sys", SystemVariableKey.QUERY.value]) if node_data.memory else None,
|
||||
query_prompt_template=node_data.memory.query_prompt_template if node_data.memory else None,
|
||||
inputs=inputs,
|
||||
files=files,
|
||||
context=context,
|
||||
memory=memory,
|
||||
model_config=model_config
|
||||
model_config=model_config,
|
||||
)
|
||||
|
||||
process_data = {
|
||||
'model_mode': model_config.mode,
|
||||
'prompts': PromptMessageUtil.prompt_messages_to_prompt_for_saving(
|
||||
model_mode=model_config.mode,
|
||||
prompt_messages=prompt_messages
|
||||
"model_mode": model_config.mode,
|
||||
"prompts": PromptMessageUtil.prompt_messages_to_prompt_for_saving(
|
||||
model_mode=model_config.mode, prompt_messages=prompt_messages
|
||||
),
|
||||
'model_provider': model_config.provider,
|
||||
'model_name': model_config.model,
|
||||
"model_provider": model_config.provider,
|
||||
"model_name": model_config.model,
|
||||
}
|
||||
|
||||
# handle invoke result
|
||||
@@ -136,10 +134,10 @@ class LLMNode(BaseNode):
|
||||
node_data_model=node_data.model,
|
||||
model_instance=model_instance,
|
||||
prompt_messages=prompt_messages,
|
||||
stop=stop
|
||||
stop=stop,
|
||||
)
|
||||
|
||||
result_text = ''
|
||||
result_text = ""
|
||||
usage = LLMUsage.empty_usage()
|
||||
finish_reason = None
|
||||
for event in generator:
|
||||
@@ -156,16 +154,12 @@ class LLMNode(BaseNode):
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
error=str(e),
|
||||
inputs=node_inputs,
|
||||
process_data=process_data
|
||||
process_data=process_data,
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
outputs = {
|
||||
'text': result_text,
|
||||
'usage': jsonable_encoder(usage),
|
||||
'finish_reason': finish_reason
|
||||
}
|
||||
outputs = {"text": result_text, "usage": jsonable_encoder(usage), "finish_reason": finish_reason}
|
||||
|
||||
yield RunCompletedEvent(
|
||||
run_result=NodeRunResult(
|
||||
@@ -176,17 +170,19 @@ class LLMNode(BaseNode):
|
||||
metadata={
|
||||
NodeRunMetadataKey.TOTAL_TOKENS: usage.total_tokens,
|
||||
NodeRunMetadataKey.TOTAL_PRICE: usage.total_price,
|
||||
NodeRunMetadataKey.CURRENCY: usage.currency
|
||||
NodeRunMetadataKey.CURRENCY: usage.currency,
|
||||
},
|
||||
llm_usage=usage
|
||||
llm_usage=usage,
|
||||
)
|
||||
)
|
||||
|
||||
def _invoke_llm(self, node_data_model: ModelConfig,
|
||||
model_instance: ModelInstance,
|
||||
prompt_messages: list[PromptMessage],
|
||||
stop: Optional[list[str]] = None) \
|
||||
-> Generator[RunEvent | ModelInvokeCompleted, None, None]:
|
||||
def _invoke_llm(
|
||||
self,
|
||||
node_data_model: ModelConfig,
|
||||
model_instance: ModelInstance,
|
||||
prompt_messages: list[PromptMessage],
|
||||
stop: Optional[list[str]] = None,
|
||||
) -> Generator[RunEvent | ModelInvokeCompleted, None, None]:
|
||||
"""
|
||||
Invoke large language model
|
||||
:param node_data_model: node data model
|
||||
@@ -206,9 +202,7 @@ class LLMNode(BaseNode):
|
||||
)
|
||||
|
||||
# handle invoke result
|
||||
generator = self._handle_invoke_result(
|
||||
invoke_result=invoke_result
|
||||
)
|
||||
generator = self._handle_invoke_result(invoke_result=invoke_result)
|
||||
|
||||
usage = LLMUsage.empty_usage()
|
||||
for event in generator:
|
||||
@@ -219,8 +213,9 @@ class LLMNode(BaseNode):
|
||||
# deduct quota
|
||||
self.deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage)
|
||||
|
||||
def _handle_invoke_result(self, invoke_result: LLMResult | Generator) \
|
||||
-> Generator[RunEvent | ModelInvokeCompleted, None, None]:
|
||||
def _handle_invoke_result(
|
||||
self, invoke_result: LLMResult | Generator
|
||||
) -> Generator[RunEvent | ModelInvokeCompleted, None, None]:
|
||||
"""
|
||||
Handle invoke result
|
||||
:param invoke_result: invoke result
|
||||
@@ -231,17 +226,14 @@ class LLMNode(BaseNode):
|
||||
|
||||
model = None
|
||||
prompt_messages: list[PromptMessage] = []
|
||||
full_text = ''
|
||||
full_text = ""
|
||||
usage = None
|
||||
finish_reason = None
|
||||
for result in invoke_result:
|
||||
text = result.delta.message.content
|
||||
full_text += text
|
||||
|
||||
yield RunStreamChunkEvent(
|
||||
chunk_content=text,
|
||||
from_variable_selector=[self.node_id, 'text']
|
||||
)
|
||||
yield RunStreamChunkEvent(chunk_content=text, from_variable_selector=[self.node_id, "text"])
|
||||
|
||||
if not model:
|
||||
model = result.model
|
||||
@@ -258,15 +250,11 @@ class LLMNode(BaseNode):
|
||||
if not usage:
|
||||
usage = LLMUsage.empty_usage()
|
||||
|
||||
yield ModelInvokeCompleted(
|
||||
text=full_text,
|
||||
usage=usage,
|
||||
finish_reason=finish_reason
|
||||
)
|
||||
yield ModelInvokeCompleted(text=full_text, usage=usage, finish_reason=finish_reason)
|
||||
|
||||
def _transform_chat_messages(self,
|
||||
messages: list[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate
|
||||
) -> list[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate:
|
||||
def _transform_chat_messages(
|
||||
self, messages: list[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate
|
||||
) -> list[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate:
|
||||
"""
|
||||
Transform chat messages
|
||||
|
||||
@@ -275,13 +263,13 @@ class LLMNode(BaseNode):
|
||||
"""
|
||||
|
||||
if isinstance(messages, LLMNodeCompletionModelPromptTemplate):
|
||||
if messages.edition_type == 'jinja2' and messages.jinja2_text:
|
||||
if messages.edition_type == "jinja2" and messages.jinja2_text:
|
||||
messages.text = messages.jinja2_text
|
||||
|
||||
return messages
|
||||
|
||||
for message in messages:
|
||||
if message.edition_type == 'jinja2' and message.jinja2_text:
|
||||
if message.edition_type == "jinja2" and message.jinja2_text:
|
||||
message.text = message.jinja2_text
|
||||
|
||||
return messages
|
||||
@@ -300,17 +288,15 @@ class LLMNode(BaseNode):
|
||||
|
||||
for variable_selector in node_data.prompt_config.jinja2_variables or []:
|
||||
variable = variable_selector.variable
|
||||
value = variable_pool.get_any(
|
||||
variable_selector.value_selector
|
||||
)
|
||||
value = variable_pool.get_any(variable_selector.value_selector)
|
||||
|
||||
def parse_dict(d: dict) -> str:
|
||||
"""
|
||||
Parse dict into string
|
||||
"""
|
||||
# check if it's a context structure
|
||||
if 'metadata' in d and '_source' in d['metadata'] and 'content' in d:
|
||||
return d['content']
|
||||
if "metadata" in d and "_source" in d["metadata"] and "content" in d:
|
||||
return d["content"]
|
||||
|
||||
# else, parse the dict
|
||||
try:
|
||||
@@ -321,7 +307,7 @@ class LLMNode(BaseNode):
|
||||
if isinstance(value, str):
|
||||
value = value
|
||||
elif isinstance(value, list):
|
||||
result = ''
|
||||
result = ""
|
||||
for item in value:
|
||||
if isinstance(item, dict):
|
||||
result += parse_dict(item)
|
||||
@@ -331,7 +317,7 @@ class LLMNode(BaseNode):
|
||||
result += str(item)
|
||||
else:
|
||||
result += str(item)
|
||||
result += '\n'
|
||||
result += "\n"
|
||||
value = result.strip()
|
||||
elif isinstance(value, dict):
|
||||
value = parse_dict(value)
|
||||
@@ -366,18 +352,19 @@ class LLMNode(BaseNode):
|
||||
for variable_selector in variable_selectors:
|
||||
variable_value = variable_pool.get_any(variable_selector.value_selector)
|
||||
if variable_value is None:
|
||||
raise ValueError(f'Variable {variable_selector.variable} not found')
|
||||
raise ValueError(f"Variable {variable_selector.variable} not found")
|
||||
|
||||
inputs[variable_selector.variable] = variable_value
|
||||
|
||||
memory = node_data.memory
|
||||
if memory and memory.query_prompt_template:
|
||||
query_variable_selectors = (VariableTemplateParser(template=memory.query_prompt_template)
|
||||
.extract_variable_selectors())
|
||||
query_variable_selectors = VariableTemplateParser(
|
||||
template=memory.query_prompt_template
|
||||
).extract_variable_selectors()
|
||||
for variable_selector in query_variable_selectors:
|
||||
variable_value = variable_pool.get_any(variable_selector.value_selector)
|
||||
if variable_value is None:
|
||||
raise ValueError(f'Variable {variable_selector.variable} not found')
|
||||
raise ValueError(f"Variable {variable_selector.variable} not found")
|
||||
|
||||
inputs[variable_selector.variable] = variable_value
|
||||
|
||||
@@ -393,7 +380,7 @@ class LLMNode(BaseNode):
|
||||
if not node_data.vision.enabled:
|
||||
return []
|
||||
|
||||
files = variable_pool.get_any(['sys', SystemVariableKey.FILES.value])
|
||||
files = variable_pool.get_any(["sys", SystemVariableKey.FILES.value])
|
||||
if not files:
|
||||
return []
|
||||
|
||||
@@ -415,29 +402,25 @@ class LLMNode(BaseNode):
|
||||
context_value = variable_pool.get_any(node_data.context.variable_selector)
|
||||
if context_value:
|
||||
if isinstance(context_value, str):
|
||||
yield RunRetrieverResourceEvent(
|
||||
retriever_resources=[],
|
||||
context=context_value
|
||||
)
|
||||
yield RunRetrieverResourceEvent(retriever_resources=[], context=context_value)
|
||||
elif isinstance(context_value, list):
|
||||
context_str = ''
|
||||
context_str = ""
|
||||
original_retriever_resource = []
|
||||
for item in context_value:
|
||||
if isinstance(item, str):
|
||||
context_str += item + '\n'
|
||||
context_str += item + "\n"
|
||||
else:
|
||||
if 'content' not in item:
|
||||
raise ValueError(f'Invalid context structure: {item}')
|
||||
if "content" not in item:
|
||||
raise ValueError(f"Invalid context structure: {item}")
|
||||
|
||||
context_str += item['content'] + '\n'
|
||||
context_str += item["content"] + "\n"
|
||||
|
||||
retriever_resource = self._convert_to_original_retriever_resource(item)
|
||||
if retriever_resource:
|
||||
original_retriever_resource.append(retriever_resource)
|
||||
|
||||
yield RunRetrieverResourceEvent(
|
||||
retriever_resources=original_retriever_resource,
|
||||
context=context_str.strip()
|
||||
retriever_resources=original_retriever_resource, context=context_str.strip()
|
||||
)
|
||||
|
||||
def _convert_to_original_retriever_resource(self, context_dict: dict) -> Optional[dict]:
|
||||
@@ -446,34 +429,38 @@ class LLMNode(BaseNode):
|
||||
:param context_dict: context dict
|
||||
:return:
|
||||
"""
|
||||
if ('metadata' in context_dict and '_source' in context_dict['metadata']
|
||||
and context_dict['metadata']['_source'] == 'knowledge'):
|
||||
metadata = context_dict.get('metadata', {})
|
||||
if (
|
||||
"metadata" in context_dict
|
||||
and "_source" in context_dict["metadata"]
|
||||
and context_dict["metadata"]["_source"] == "knowledge"
|
||||
):
|
||||
metadata = context_dict.get("metadata", {})
|
||||
|
||||
source = {
|
||||
'position': metadata.get('position'),
|
||||
'dataset_id': metadata.get('dataset_id'),
|
||||
'dataset_name': metadata.get('dataset_name'),
|
||||
'document_id': metadata.get('document_id'),
|
||||
'document_name': metadata.get('document_name'),
|
||||
'data_source_type': metadata.get('document_data_source_type'),
|
||||
'segment_id': metadata.get('segment_id'),
|
||||
'retriever_from': metadata.get('retriever_from'),
|
||||
'score': metadata.get('score'),
|
||||
'hit_count': metadata.get('segment_hit_count'),
|
||||
'word_count': metadata.get('segment_word_count'),
|
||||
'segment_position': metadata.get('segment_position'),
|
||||
'index_node_hash': metadata.get('segment_index_node_hash'),
|
||||
'content': context_dict.get('content'),
|
||||
'page': metadata.get('page'),
|
||||
"position": metadata.get("position"),
|
||||
"dataset_id": metadata.get("dataset_id"),
|
||||
"dataset_name": metadata.get("dataset_name"),
|
||||
"document_id": metadata.get("document_id"),
|
||||
"document_name": metadata.get("document_name"),
|
||||
"data_source_type": metadata.get("document_data_source_type"),
|
||||
"segment_id": metadata.get("segment_id"),
|
||||
"retriever_from": metadata.get("retriever_from"),
|
||||
"score": metadata.get("score"),
|
||||
"hit_count": metadata.get("segment_hit_count"),
|
||||
"word_count": metadata.get("segment_word_count"),
|
||||
"segment_position": metadata.get("segment_position"),
|
||||
"index_node_hash": metadata.get("segment_index_node_hash"),
|
||||
"content": context_dict.get("content"),
|
||||
"page": metadata.get("page"),
|
||||
}
|
||||
|
||||
return source
|
||||
|
||||
return None
|
||||
|
||||
def _fetch_model_config(self, node_data_model: ModelConfig) -> tuple[
|
||||
ModelInstance, ModelConfigWithCredentialsEntity]:
|
||||
def _fetch_model_config(
|
||||
self, node_data_model: ModelConfig
|
||||
) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]:
|
||||
"""
|
||||
Fetch model config
|
||||
:param node_data_model: node data model
|
||||
@@ -484,10 +471,7 @@ class LLMNode(BaseNode):
|
||||
|
||||
model_manager = ModelManager()
|
||||
model_instance = model_manager.get_model_instance(
|
||||
tenant_id=self.tenant_id,
|
||||
model_type=ModelType.LLM,
|
||||
provider=provider_name,
|
||||
model=model_name
|
||||
tenant_id=self.tenant_id, model_type=ModelType.LLM, provider=provider_name, model=model_name
|
||||
)
|
||||
|
||||
provider_model_bundle = model_instance.provider_model_bundle
|
||||
@@ -498,8 +482,7 @@ class LLMNode(BaseNode):
|
||||
|
||||
# check model
|
||||
provider_model = provider_model_bundle.configuration.get_provider_model(
|
||||
model=model_name,
|
||||
model_type=ModelType.LLM
|
||||
model=model_name, model_type=ModelType.LLM
|
||||
)
|
||||
|
||||
if provider_model is None:
|
||||
@@ -515,19 +498,16 @@ class LLMNode(BaseNode):
|
||||
# model config
|
||||
completion_params = node_data_model.completion_params
|
||||
stop = []
|
||||
if 'stop' in completion_params:
|
||||
stop = completion_params['stop']
|
||||
del completion_params['stop']
|
||||
if "stop" in completion_params:
|
||||
stop = completion_params["stop"]
|
||||
del completion_params["stop"]
|
||||
|
||||
# get model mode
|
||||
model_mode = node_data_model.mode
|
||||
if not model_mode:
|
||||
raise ValueError("LLM mode is required.")
|
||||
|
||||
model_schema = model_type_instance.get_model_schema(
|
||||
model_name,
|
||||
model_credentials
|
||||
)
|
||||
model_schema = model_type_instance.get_model_schema(model_name, model_credentials)
|
||||
|
||||
if not model_schema:
|
||||
raise ValueError(f"Model {model_name} not exist.")
|
||||
@@ -543,9 +523,9 @@ class LLMNode(BaseNode):
|
||||
stop=stop,
|
||||
)
|
||||
|
||||
def _fetch_memory(self, node_data_memory: Optional[MemoryConfig],
|
||||
variable_pool: VariablePool,
|
||||
model_instance: ModelInstance) -> Optional[TokenBufferMemory]:
|
||||
def _fetch_memory(
|
||||
self, node_data_memory: Optional[MemoryConfig], variable_pool: VariablePool, model_instance: ModelInstance
|
||||
) -> Optional[TokenBufferMemory]:
|
||||
"""
|
||||
Fetch memory
|
||||
:param node_data_memory: node data memory
|
||||
@@ -556,35 +536,35 @@ class LLMNode(BaseNode):
|
||||
return None
|
||||
|
||||
# get conversation id
|
||||
conversation_id = variable_pool.get_any(['sys', SystemVariableKey.CONVERSATION_ID.value])
|
||||
conversation_id = variable_pool.get_any(["sys", SystemVariableKey.CONVERSATION_ID.value])
|
||||
if conversation_id is None:
|
||||
return None
|
||||
|
||||
# get conversation
|
||||
conversation = db.session.query(Conversation).filter(
|
||||
Conversation.app_id == self.app_id,
|
||||
Conversation.id == conversation_id
|
||||
).first()
|
||||
conversation = (
|
||||
db.session.query(Conversation)
|
||||
.filter(Conversation.app_id == self.app_id, Conversation.id == conversation_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not conversation:
|
||||
return None
|
||||
|
||||
memory = TokenBufferMemory(
|
||||
conversation=conversation,
|
||||
model_instance=model_instance
|
||||
)
|
||||
memory = TokenBufferMemory(conversation=conversation, model_instance=model_instance)
|
||||
|
||||
return memory
|
||||
|
||||
def _fetch_prompt_messages(self, node_data: LLMNodeData,
|
||||
query: Optional[str],
|
||||
query_prompt_template: Optional[str],
|
||||
inputs: dict[str, str],
|
||||
files: list["FileVar"],
|
||||
context: Optional[str],
|
||||
memory: Optional[TokenBufferMemory],
|
||||
model_config: ModelConfigWithCredentialsEntity) \
|
||||
-> tuple[list[PromptMessage], Optional[list[str]]]:
|
||||
def _fetch_prompt_messages(
|
||||
self,
|
||||
node_data: LLMNodeData,
|
||||
query: Optional[str],
|
||||
query_prompt_template: Optional[str],
|
||||
inputs: dict[str, str],
|
||||
files: list["FileVar"],
|
||||
context: Optional[str],
|
||||
memory: Optional[TokenBufferMemory],
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
) -> tuple[list[PromptMessage], Optional[list[str]]]:
|
||||
"""
|
||||
Fetch prompt messages
|
||||
:param node_data: node data
|
||||
@@ -601,7 +581,7 @@ class LLMNode(BaseNode):
|
||||
prompt_messages = prompt_transform.get_prompt(
|
||||
prompt_template=node_data.prompt_template,
|
||||
inputs=inputs,
|
||||
query=query if query else '',
|
||||
query=query if query else "",
|
||||
files=files,
|
||||
context=context,
|
||||
memory_config=node_data.memory,
|
||||
@@ -621,8 +601,11 @@ class LLMNode(BaseNode):
|
||||
if not isinstance(prompt_message.content, str):
|
||||
prompt_message_content = []
|
||||
for content_item in prompt_message.content:
|
||||
if vision_enabled and content_item.type == PromptMessageContentType.IMAGE and isinstance(
|
||||
content_item, ImagePromptMessageContent):
|
||||
if (
|
||||
vision_enabled
|
||||
and content_item.type == PromptMessageContentType.IMAGE
|
||||
and isinstance(content_item, ImagePromptMessageContent)
|
||||
):
|
||||
# Override vision config if LLM node has vision config
|
||||
if vision_detail:
|
||||
content_item.detail = ImagePromptMessageContent.DETAIL(vision_detail)
|
||||
@@ -632,15 +615,18 @@ class LLMNode(BaseNode):
|
||||
|
||||
if len(prompt_message_content) > 1:
|
||||
prompt_message.content = prompt_message_content
|
||||
elif (len(prompt_message_content) == 1
|
||||
and prompt_message_content[0].type == PromptMessageContentType.TEXT):
|
||||
elif (
|
||||
len(prompt_message_content) == 1 and prompt_message_content[0].type == PromptMessageContentType.TEXT
|
||||
):
|
||||
prompt_message.content = prompt_message_content[0].data
|
||||
|
||||
filtered_prompt_messages.append(prompt_message)
|
||||
|
||||
if not filtered_prompt_messages:
|
||||
raise ValueError("No prompt found in the LLM configuration. "
|
||||
"Please ensure a prompt is properly configured before proceeding.")
|
||||
raise ValueError(
|
||||
"No prompt found in the LLM configuration. "
|
||||
"Please ensure a prompt is properly configured before proceeding."
|
||||
)
|
||||
|
||||
return filtered_prompt_messages, stop
|
||||
|
||||
@@ -678,7 +664,7 @@ class LLMNode(BaseNode):
|
||||
elif quota_unit == QuotaUnit.CREDITS:
|
||||
used_quota = 1
|
||||
|
||||
if 'gpt-4' in model_instance.model:
|
||||
if "gpt-4" in model_instance.model:
|
||||
used_quota = 20
|
||||
else:
|
||||
used_quota = 1
|
||||
@@ -689,16 +675,13 @@ class LLMNode(BaseNode):
|
||||
Provider.provider_name == model_instance.provider,
|
||||
Provider.provider_type == ProviderType.SYSTEM.value,
|
||||
Provider.quota_type == system_configuration.current_quota_type.value,
|
||||
Provider.quota_limit > Provider.quota_used
|
||||
).update({'quota_used': Provider.quota_used + used_quota})
|
||||
Provider.quota_limit > Provider.quota_used,
|
||||
).update({"quota_used": Provider.quota_used + used_quota})
|
||||
db.session.commit()
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
cls,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: LLMNodeData
|
||||
cls, graph_config: Mapping[str, Any], node_id: str, node_data: LLMNodeData
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
"""
|
||||
Extract variable selector to variable mapping
|
||||
@@ -712,11 +695,11 @@ class LLMNode(BaseNode):
|
||||
variable_selectors = []
|
||||
if isinstance(prompt_template, list):
|
||||
for prompt in prompt_template:
|
||||
if prompt.edition_type != 'jinja2':
|
||||
if prompt.edition_type != "jinja2":
|
||||
variable_template_parser = VariableTemplateParser(template=prompt.text)
|
||||
variable_selectors.extend(variable_template_parser.extract_variable_selectors())
|
||||
else:
|
||||
if prompt_template.edition_type != 'jinja2':
|
||||
if prompt_template.edition_type != "jinja2":
|
||||
variable_template_parser = VariableTemplateParser(template=prompt_template.text)
|
||||
variable_selectors = variable_template_parser.extract_variable_selectors()
|
||||
|
||||
@@ -726,39 +709,38 @@ class LLMNode(BaseNode):
|
||||
|
||||
memory = node_data.memory
|
||||
if memory and memory.query_prompt_template:
|
||||
query_variable_selectors = (VariableTemplateParser(template=memory.query_prompt_template)
|
||||
.extract_variable_selectors())
|
||||
query_variable_selectors = VariableTemplateParser(
|
||||
template=memory.query_prompt_template
|
||||
).extract_variable_selectors()
|
||||
for variable_selector in query_variable_selectors:
|
||||
variable_mapping[variable_selector.variable] = variable_selector.value_selector
|
||||
|
||||
if node_data.context.enabled:
|
||||
variable_mapping['#context#'] = node_data.context.variable_selector
|
||||
variable_mapping["#context#"] = node_data.context.variable_selector
|
||||
|
||||
if node_data.vision.enabled:
|
||||
variable_mapping['#files#'] = ['sys', SystemVariableKey.FILES.value]
|
||||
variable_mapping["#files#"] = ["sys", SystemVariableKey.FILES.value]
|
||||
|
||||
if node_data.memory:
|
||||
variable_mapping['#sys.query#'] = ['sys', SystemVariableKey.QUERY.value]
|
||||
variable_mapping["#sys.query#"] = ["sys", SystemVariableKey.QUERY.value]
|
||||
|
||||
if node_data.prompt_config:
|
||||
enable_jinja = False
|
||||
|
||||
if isinstance(prompt_template, list):
|
||||
for prompt in prompt_template:
|
||||
if prompt.edition_type == 'jinja2':
|
||||
if prompt.edition_type == "jinja2":
|
||||
enable_jinja = True
|
||||
break
|
||||
else:
|
||||
if prompt_template.edition_type == 'jinja2':
|
||||
if prompt_template.edition_type == "jinja2":
|
||||
enable_jinja = True
|
||||
|
||||
if enable_jinja:
|
||||
for variable_selector in node_data.prompt_config.jinja2_variables or []:
|
||||
variable_mapping[variable_selector.variable] = variable_selector.value_selector
|
||||
|
||||
variable_mapping = {
|
||||
node_id + '.' + key: value for key, value in variable_mapping.items()
|
||||
}
|
||||
variable_mapping = {node_id + "." + key: value for key, value in variable_mapping.items()}
|
||||
|
||||
return variable_mapping
|
||||
|
||||
@@ -775,26 +757,19 @@ class LLMNode(BaseNode):
|
||||
"prompt_templates": {
|
||||
"chat_model": {
|
||||
"prompts": [
|
||||
{
|
||||
"role": "system",
|
||||
"text": "You are a helpful AI assistant.",
|
||||
"edition_type": "basic"
|
||||
}
|
||||
{"role": "system", "text": "You are a helpful AI assistant.", "edition_type": "basic"}
|
||||
]
|
||||
},
|
||||
"completion_model": {
|
||||
"conversation_histories_role": {
|
||||
"user_prefix": "Human",
|
||||
"assistant_prefix": "Assistant"
|
||||
},
|
||||
"conversation_histories_role": {"user_prefix": "Human", "assistant_prefix": "Assistant"},
|
||||
"prompt": {
|
||||
"text": "Here is the chat histories between human and assistant, inside "
|
||||
"<histories></histories> XML tags.\n\n<histories>\n{{"
|
||||
"#histories#}}\n</histories>\n\n\nHuman: {{#sys.query#}}\n\nAssistant:",
|
||||
"edition_type": "basic"
|
||||
"<histories></histories> XML tags.\n\n<histories>\n{{"
|
||||
"#histories#}}\n</histories>\n\n\nHuman: {{#sys.query#}}\n\nAssistant:",
|
||||
"edition_type": "basic",
|
||||
},
|
||||
"stop": ["Human:"]
|
||||
}
|
||||
"stop": ["Human:"],
|
||||
},
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
|
Reference in New Issue
Block a user