improve: mordernizing validation by migrating pydantic from 1.x to 2.x (#4592)

This commit is contained in:
Bowen Liang
2024-06-14 01:05:37 +08:00
committed by GitHub
parent e8afc416dd
commit f976740b57
87 changed files with 697 additions and 300 deletions

View File

@@ -14,7 +14,7 @@ class CodeNodeData(BaseNodeData):
"""
class Output(BaseModel):
type: Literal['string', 'number', 'object', 'array[string]', 'array[number]', 'array[object]']
children: Optional[dict[str, 'Output']]
children: Optional[dict[str, 'Output']] = None
variables: list[VariableSelector]
code_language: Literal[CodeLanguage.PYTHON3, CodeLanguage.JAVASCRIPT]

View File

@@ -1,7 +1,7 @@
import os
from typing import Literal, Optional, Union
from pydantic import BaseModel, validator
from pydantic import BaseModel, field_validator
from core.workflow.entities.base_node_data_entities import BaseNodeData
@@ -14,15 +14,18 @@ class HttpRequestNodeData(BaseNodeData):
Code Node Data.
"""
class Authorization(BaseModel):
# TODO[pydantic]: The `Config` class inherits from another class, please create the `model_config` manually.
# Check https://docs.pydantic.dev/dev-v2/migration/#changes-to-config for more information.
class Config(BaseModel):
type: Literal[None, 'basic', 'bearer', 'custom']
api_key: Union[None, str]
header: Union[None, str]
api_key: Union[None, str] = None
header: Union[None, str] = None
type: Literal['no-auth', 'api-key']
config: Optional[Config]
@validator('config', always=True, pre=True)
@classmethod
@field_validator('config', mode='before')
def check_config(cls, v, values):
"""
Check config, if type is no-auth, config should be None, otherwise it should be a dict.
@@ -37,7 +40,7 @@ class HttpRequestNodeData(BaseNodeData):
class Body(BaseModel):
type: Literal['none', 'form-data', 'x-www-form-urlencoded', 'raw-text', 'json']
data: Union[None, str]
data: Union[None, str] = None
class Timeout(BaseModel):
connect: Optional[int] = MAX_CONNECT_TIMEOUT
@@ -50,5 +53,5 @@ class HttpRequestNodeData(BaseNodeData):
headers: str
params: str
body: Optional[Body]
timeout: Optional[Timeout]
timeout: Optional[Timeout] = None
mask_authorization_header: Optional[bool] = True

View File

@@ -39,7 +39,7 @@ class HttpRequestNode(BaseNode):
"type": "none"
},
"timeout": {
**HTTP_REQUEST_DEFAULT_TIMEOUT.dict(),
**HTTP_REQUEST_DEFAULT_TIMEOUT.model_dump(),
"max_connect_timeout": MAX_CONNECT_TIMEOUT,
"max_read_timeout": MAX_READ_TIMEOUT,
"max_write_timeout": MAX_WRITE_TIMEOUT,

View File

@@ -7,7 +7,7 @@ class IterationNodeData(BaseIterationNodeData):
"""
Iteration Node Data.
"""
parent_loop_id: Optional[str] # redundant field, not used currently
parent_loop_id: Optional[str] = None # redundant field, not used currently
iterator_selector: list[str] # variable selector
output_selector: list[str] # output selector

View File

@@ -18,7 +18,7 @@ class MultipleRetrievalConfig(BaseModel):
Multiple Retrieval Config.
"""
top_k: int
score_threshold: Optional[float]
score_threshold: Optional[float] = None
reranking_model: RerankingModelConfig
@@ -47,5 +47,5 @@ class KnowledgeRetrievalNodeData(BaseNodeData):
query_variable_selector: list[str]
dataset_ids: list[str]
retrieval_mode: Literal['single', 'multiple']
multiple_retrieval_config: Optional[MultipleRetrievalConfig]
single_retrieval_config: Optional[SingleRetrievalConfig]
multiple_retrieval_config: Optional[MultipleRetrievalConfig] = None
single_retrieval_config: Optional[SingleRetrievalConfig] = None

View File

@@ -1,6 +1,6 @@
from typing import Any, Literal, Optional
from pydantic import BaseModel, validator
from pydantic import BaseModel, field_validator
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
from core.workflow.entities.base_node_data_entities import BaseNodeData
@@ -21,12 +21,13 @@ class ParameterConfig(BaseModel):
"""
name: str
type: Literal['string', 'number', 'bool', 'select', 'array[string]', 'array[number]', 'array[object]']
options: Optional[list[str]]
options: Optional[list[str]] = None
description: str
required: bool
@validator('name', pre=True, always=True)
def validate_name(cls, value):
@classmethod
@field_validator('name', mode='before')
def validate_name(cls, value) -> str:
if not value:
raise ValueError('Parameter name is required')
if value in ['__reason', '__is_success']:
@@ -40,12 +41,13 @@ class ParameterExtractorNodeData(BaseNodeData):
model: ModelConfig
query: list[str]
parameters: list[ParameterConfig]
instruction: Optional[str]
memory: Optional[MemoryConfig]
instruction: Optional[str] = None
memory: Optional[MemoryConfig] = None
reasoning_mode: Literal['function_call', 'prompt']
@validator('reasoning_mode', pre=True, always=True)
def set_reasoning_mode(cls, v):
@classmethod
@field_validator('reasoning_mode', mode='before')
def set_reasoning_mode(cls, v) -> str:
return v or 'function_call'
def get_parameter_json_schema(self) -> dict:

View File

@@ -32,5 +32,5 @@ class QuestionClassifierNodeData(BaseNodeData):
type: str = 'question-classifier'
model: ModelConfig
classes: list[ClassConfig]
instruction: Optional[str]
memory: Optional[MemoryConfig]
instruction: Optional[str] = None
memory: Optional[MemoryConfig] = None

View File

@@ -1,6 +1,7 @@
from typing import Any, Literal, Union
from pydantic import BaseModel, validator
from pydantic import BaseModel, field_validator
from pydantic_core.core_schema import ValidationInfo
from core.workflow.entities.base_node_data_entities import BaseNodeData
@@ -13,13 +14,14 @@ class ToolEntity(BaseModel):
tool_label: str # redundancy
tool_configurations: dict[str, Any]
@validator('tool_configurations', pre=True, always=True)
def validate_tool_configurations(cls, value, values):
@classmethod
@field_validator('tool_configurations', mode='before')
def validate_tool_configurations(cls, value, values: ValidationInfo) -> dict[str, Any]:
if not isinstance(value, dict):
raise ValueError('tool_configurations must be a dictionary')
for key in values.get('tool_configurations', {}).keys():
value = values.get('tool_configurations', {}).get(key)
for key in values.data.get('tool_configurations', {}).keys():
value = values.data.get('tool_configurations', {}).get(key)
if not isinstance(value, str | int | float | bool):
raise ValueError(f'{key} must be a string')
@@ -30,10 +32,11 @@ class ToolNodeData(BaseNodeData, ToolEntity):
value: Union[Any, list[str]]
type: Literal['mixed', 'variable', 'constant']
@validator('type', pre=True, always=True)
def check_type(cls, value, values):
@classmethod
@field_validator('type', mode='before')
def check_type(cls, value, validation_info: ValidationInfo):
typ = value
value = values.get('value')
value = validation_info.data.get('value')
if typ == 'mixed' and not isinstance(value, str):
raise ValueError('value must be a string')
elif typ == 'variable':
@@ -45,7 +48,7 @@ class ToolNodeData(BaseNodeData, ToolEntity):
elif typ == 'constant' and not isinstance(value, str | int | float | bool):
raise ValueError('value must be a string, int, float, or bool')
return typ
"""
Tool Node Schema
"""

View File

@@ -30,4 +30,4 @@ class VariableAssignerNodeData(BaseNodeData):
type: str = 'variable-assigner'
output_type: str
variables: list[list[str]]
advanced_settings: Optional[AdvancedSettings]
advanced_settings: Optional[AdvancedSettings] = None

View File

@@ -592,7 +592,7 @@ class WorkflowEngineManager:
node_data=current_iteration_node.node_data,
inputs=workflow_run_state.current_iteration_state.inputs,
predecessor_node_id=predecessor_node_id,
metadata=workflow_run_state.current_iteration_state.metadata.dict()
metadata=workflow_run_state.current_iteration_state.metadata.model_dump()
)
# add steps