feat: mypy for all type check (#10921)

This commit is contained in:
yihong
2024-12-24 18:38:51 +08:00
committed by GitHub
parent c91e8b1737
commit 56e15d09a9
584 changed files with 3975 additions and 2826 deletions

View File

@@ -1,7 +1,7 @@
import json
import logging
import re
from typing import Optional
from typing import Optional, cast
from core.llm_generator.output_parser.rule_config_generator import RuleConfigGeneratorOutputParser
from core.llm_generator.output_parser.suggested_questions_after_answer import SuggestedQuestionsAfterAnswerOutputParser
@@ -13,6 +13,7 @@ from core.llm_generator.prompts import (
WORKFLOW_RULE_CONFIG_PROMPT_GENERATE_TEMPLATE,
)
from core.model_manager import ModelManager
from core.model_runtime.entities.llm_entities import LLMResult
from core.model_runtime.entities.message_entities import SystemPromptMessage, UserPromptMessage
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
@@ -44,10 +45,13 @@ class LLMGenerator:
prompts = [UserPromptMessage(content=prompt)]
with measure_time() as timer:
response = model_instance.invoke_llm(
prompt_messages=prompts, model_parameters={"max_tokens": 100, "temperature": 1}, stream=False
response = cast(
LLMResult,
model_instance.invoke_llm(
prompt_messages=prompts, model_parameters={"max_tokens": 100, "temperature": 1}, stream=False
),
)
answer = response.message.content
answer = cast(str, response.message.content)
cleaned_answer = re.sub(r"^.*(\{.*\}).*$", r"\1", answer, flags=re.DOTALL)
if cleaned_answer is None:
return ""
@@ -94,11 +98,16 @@ class LLMGenerator:
prompt_messages = [UserPromptMessage(content=prompt)]
try:
response = model_instance.invoke_llm(
prompt_messages=prompt_messages, model_parameters={"max_tokens": 256, "temperature": 0}, stream=False
response = cast(
LLMResult,
model_instance.invoke_llm(
prompt_messages=prompt_messages,
model_parameters={"max_tokens": 256, "temperature": 0},
stream=False,
),
)
questions = output_parser.parse(response.message.content)
questions = output_parser.parse(cast(str, response.message.content))
except InvokeError:
questions = []
except Exception as e:
@@ -138,11 +147,14 @@ class LLMGenerator:
)
try:
response = model_instance.invoke_llm(
prompt_messages=prompt_messages, model_parameters=model_parameters, stream=False
response = cast(
LLMResult,
model_instance.invoke_llm(
prompt_messages=prompt_messages, model_parameters=model_parameters, stream=False
),
)
rule_config["prompt"] = response.message.content
rule_config["prompt"] = cast(str, response.message.content)
except InvokeError as e:
error = str(e)
@@ -178,15 +190,18 @@ class LLMGenerator:
model_instance = model_manager.get_model_instance(
tenant_id=tenant_id,
model_type=ModelType.LLM,
provider=model_config.get("provider") if model_config else None,
model=model_config.get("name") if model_config else None,
provider=model_config.get("provider", ""),
model=model_config.get("name", ""),
)
try:
try:
# the first step to generate the task prompt
prompt_content = model_instance.invoke_llm(
prompt_messages=prompt_messages, model_parameters=model_parameters, stream=False
prompt_content = cast(
LLMResult,
model_instance.invoke_llm(
prompt_messages=prompt_messages, model_parameters=model_parameters, stream=False
),
)
except InvokeError as e:
error = str(e)
@@ -195,8 +210,10 @@ class LLMGenerator:
return rule_config
rule_config["prompt"] = prompt_content.message.content
rule_config["prompt"] = cast(str, prompt_content.message.content)
if not isinstance(prompt_content.message.content, str):
raise NotImplementedError("prompt content is not a string")
parameter_generate_prompt = parameter_template.format(
inputs={
"INPUT_TEXT": prompt_content.message.content,
@@ -216,19 +233,25 @@ class LLMGenerator:
statement_messages = [UserPromptMessage(content=statement_generate_prompt)]
try:
parameter_content = model_instance.invoke_llm(
prompt_messages=parameter_messages, model_parameters=model_parameters, stream=False
parameter_content = cast(
LLMResult,
model_instance.invoke_llm(
prompt_messages=parameter_messages, model_parameters=model_parameters, stream=False
),
)
rule_config["variables"] = re.findall(r'"\s*([^"]+)\s*"', parameter_content.message.content)
rule_config["variables"] = re.findall(r'"\s*([^"]+)\s*"', cast(str, parameter_content.message.content))
except InvokeError as e:
error = str(e)
error_step = "generate variables"
try:
statement_content = model_instance.invoke_llm(
prompt_messages=statement_messages, model_parameters=model_parameters, stream=False
statement_content = cast(
LLMResult,
model_instance.invoke_llm(
prompt_messages=statement_messages, model_parameters=model_parameters, stream=False
),
)
rule_config["opening_statement"] = statement_content.message.content
rule_config["opening_statement"] = cast(str, statement_content.message.content)
except InvokeError as e:
error = str(e)
error_step = "generate conversation opener"
@@ -267,19 +290,22 @@ class LLMGenerator:
model_instance = model_manager.get_model_instance(
tenant_id=tenant_id,
model_type=ModelType.LLM,
provider=model_config.get("provider") if model_config else None,
model=model_config.get("name") if model_config else None,
provider=model_config.get("provider", ""),
model=model_config.get("name", ""),
)
prompt_messages = [UserPromptMessage(content=prompt)]
model_parameters = {"max_tokens": max_tokens, "temperature": 0.01}
try:
response = model_instance.invoke_llm(
prompt_messages=prompt_messages, model_parameters=model_parameters, stream=False
response = cast(
LLMResult,
model_instance.invoke_llm(
prompt_messages=prompt_messages, model_parameters=model_parameters, stream=False
),
)
generated_code = response.message.content
generated_code = cast(str, response.message.content)
return {"code": generated_code, "language": code_language, "error": ""}
except InvokeError as e:
@@ -303,9 +329,14 @@ class LLMGenerator:
prompt_messages = [SystemPromptMessage(content=prompt), UserPromptMessage(content=query)]
response = model_instance.invoke_llm(
prompt_messages=prompt_messages, model_parameters={"temperature": 0.01, "max_tokens": 2000}, stream=False
response = cast(
LLMResult,
model_instance.invoke_llm(
prompt_messages=prompt_messages,
model_parameters={"temperature": 0.01, "max_tokens": 2000},
stream=False,
),
)
answer = response.message.content
answer = cast(str, response.message.content)
return answer.strip()