feat: mypy for all type check (#10921)
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from typing import Optional
|
||||
from typing import Optional, cast
|
||||
|
||||
from core.llm_generator.output_parser.rule_config_generator import RuleConfigGeneratorOutputParser
|
||||
from core.llm_generator.output_parser.suggested_questions_after_answer import SuggestedQuestionsAfterAnswerOutputParser
|
||||
@@ -13,6 +13,7 @@ from core.llm_generator.prompts import (
|
||||
WORKFLOW_RULE_CONFIG_PROMPT_GENERATE_TEMPLATE,
|
||||
)
|
||||
from core.model_manager import ModelManager
|
||||
from core.model_runtime.entities.llm_entities import LLMResult
|
||||
from core.model_runtime.entities.message_entities import SystemPromptMessage, UserPromptMessage
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
|
||||
@@ -44,10 +45,13 @@ class LLMGenerator:
|
||||
prompts = [UserPromptMessage(content=prompt)]
|
||||
|
||||
with measure_time() as timer:
|
||||
response = model_instance.invoke_llm(
|
||||
prompt_messages=prompts, model_parameters={"max_tokens": 100, "temperature": 1}, stream=False
|
||||
response = cast(
|
||||
LLMResult,
|
||||
model_instance.invoke_llm(
|
||||
prompt_messages=prompts, model_parameters={"max_tokens": 100, "temperature": 1}, stream=False
|
||||
),
|
||||
)
|
||||
answer = response.message.content
|
||||
answer = cast(str, response.message.content)
|
||||
cleaned_answer = re.sub(r"^.*(\{.*\}).*$", r"\1", answer, flags=re.DOTALL)
|
||||
if cleaned_answer is None:
|
||||
return ""
|
||||
@@ -94,11 +98,16 @@ class LLMGenerator:
|
||||
prompt_messages = [UserPromptMessage(content=prompt)]
|
||||
|
||||
try:
|
||||
response = model_instance.invoke_llm(
|
||||
prompt_messages=prompt_messages, model_parameters={"max_tokens": 256, "temperature": 0}, stream=False
|
||||
response = cast(
|
||||
LLMResult,
|
||||
model_instance.invoke_llm(
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters={"max_tokens": 256, "temperature": 0},
|
||||
stream=False,
|
||||
),
|
||||
)
|
||||
|
||||
questions = output_parser.parse(response.message.content)
|
||||
questions = output_parser.parse(cast(str, response.message.content))
|
||||
except InvokeError:
|
||||
questions = []
|
||||
except Exception as e:
|
||||
@@ -138,11 +147,14 @@ class LLMGenerator:
|
||||
)
|
||||
|
||||
try:
|
||||
response = model_instance.invoke_llm(
|
||||
prompt_messages=prompt_messages, model_parameters=model_parameters, stream=False
|
||||
response = cast(
|
||||
LLMResult,
|
||||
model_instance.invoke_llm(
|
||||
prompt_messages=prompt_messages, model_parameters=model_parameters, stream=False
|
||||
),
|
||||
)
|
||||
|
||||
rule_config["prompt"] = response.message.content
|
||||
rule_config["prompt"] = cast(str, response.message.content)
|
||||
|
||||
except InvokeError as e:
|
||||
error = str(e)
|
||||
@@ -178,15 +190,18 @@ class LLMGenerator:
|
||||
model_instance = model_manager.get_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
model_type=ModelType.LLM,
|
||||
provider=model_config.get("provider") if model_config else None,
|
||||
model=model_config.get("name") if model_config else None,
|
||||
provider=model_config.get("provider", ""),
|
||||
model=model_config.get("name", ""),
|
||||
)
|
||||
|
||||
try:
|
||||
try:
|
||||
# the first step to generate the task prompt
|
||||
prompt_content = model_instance.invoke_llm(
|
||||
prompt_messages=prompt_messages, model_parameters=model_parameters, stream=False
|
||||
prompt_content = cast(
|
||||
LLMResult,
|
||||
model_instance.invoke_llm(
|
||||
prompt_messages=prompt_messages, model_parameters=model_parameters, stream=False
|
||||
),
|
||||
)
|
||||
except InvokeError as e:
|
||||
error = str(e)
|
||||
@@ -195,8 +210,10 @@ class LLMGenerator:
|
||||
|
||||
return rule_config
|
||||
|
||||
rule_config["prompt"] = prompt_content.message.content
|
||||
rule_config["prompt"] = cast(str, prompt_content.message.content)
|
||||
|
||||
if not isinstance(prompt_content.message.content, str):
|
||||
raise NotImplementedError("prompt content is not a string")
|
||||
parameter_generate_prompt = parameter_template.format(
|
||||
inputs={
|
||||
"INPUT_TEXT": prompt_content.message.content,
|
||||
@@ -216,19 +233,25 @@ class LLMGenerator:
|
||||
statement_messages = [UserPromptMessage(content=statement_generate_prompt)]
|
||||
|
||||
try:
|
||||
parameter_content = model_instance.invoke_llm(
|
||||
prompt_messages=parameter_messages, model_parameters=model_parameters, stream=False
|
||||
parameter_content = cast(
|
||||
LLMResult,
|
||||
model_instance.invoke_llm(
|
||||
prompt_messages=parameter_messages, model_parameters=model_parameters, stream=False
|
||||
),
|
||||
)
|
||||
rule_config["variables"] = re.findall(r'"\s*([^"]+)\s*"', parameter_content.message.content)
|
||||
rule_config["variables"] = re.findall(r'"\s*([^"]+)\s*"', cast(str, parameter_content.message.content))
|
||||
except InvokeError as e:
|
||||
error = str(e)
|
||||
error_step = "generate variables"
|
||||
|
||||
try:
|
||||
statement_content = model_instance.invoke_llm(
|
||||
prompt_messages=statement_messages, model_parameters=model_parameters, stream=False
|
||||
statement_content = cast(
|
||||
LLMResult,
|
||||
model_instance.invoke_llm(
|
||||
prompt_messages=statement_messages, model_parameters=model_parameters, stream=False
|
||||
),
|
||||
)
|
||||
rule_config["opening_statement"] = statement_content.message.content
|
||||
rule_config["opening_statement"] = cast(str, statement_content.message.content)
|
||||
except InvokeError as e:
|
||||
error = str(e)
|
||||
error_step = "generate conversation opener"
|
||||
@@ -267,19 +290,22 @@ class LLMGenerator:
|
||||
model_instance = model_manager.get_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
model_type=ModelType.LLM,
|
||||
provider=model_config.get("provider") if model_config else None,
|
||||
model=model_config.get("name") if model_config else None,
|
||||
provider=model_config.get("provider", ""),
|
||||
model=model_config.get("name", ""),
|
||||
)
|
||||
|
||||
prompt_messages = [UserPromptMessage(content=prompt)]
|
||||
model_parameters = {"max_tokens": max_tokens, "temperature": 0.01}
|
||||
|
||||
try:
|
||||
response = model_instance.invoke_llm(
|
||||
prompt_messages=prompt_messages, model_parameters=model_parameters, stream=False
|
||||
response = cast(
|
||||
LLMResult,
|
||||
model_instance.invoke_llm(
|
||||
prompt_messages=prompt_messages, model_parameters=model_parameters, stream=False
|
||||
),
|
||||
)
|
||||
|
||||
generated_code = response.message.content
|
||||
generated_code = cast(str, response.message.content)
|
||||
return {"code": generated_code, "language": code_language, "error": ""}
|
||||
|
||||
except InvokeError as e:
|
||||
@@ -303,9 +329,14 @@ class LLMGenerator:
|
||||
|
||||
prompt_messages = [SystemPromptMessage(content=prompt), UserPromptMessage(content=query)]
|
||||
|
||||
response = model_instance.invoke_llm(
|
||||
prompt_messages=prompt_messages, model_parameters={"temperature": 0.01, "max_tokens": 2000}, stream=False
|
||||
response = cast(
|
||||
LLMResult,
|
||||
model_instance.invoke_llm(
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters={"temperature": 0.01, "max_tokens": 2000},
|
||||
stream=False,
|
||||
),
|
||||
)
|
||||
|
||||
answer = response.message.content
|
||||
answer = cast(str, response.message.content)
|
||||
return answer.strip()
|
||||
|
Reference in New Issue
Block a user