improve: mordernizing validation by migrating pydantic from 1.x to 2.x (#4592)
This commit is contained in:
@@ -14,7 +14,7 @@ class CodeNodeData(BaseNodeData):
|
||||
"""
|
||||
class Output(BaseModel):
|
||||
type: Literal['string', 'number', 'object', 'array[string]', 'array[number]', 'array[object]']
|
||||
children: Optional[dict[str, 'Output']]
|
||||
children: Optional[dict[str, 'Output']] = None
|
||||
|
||||
variables: list[VariableSelector]
|
||||
code_language: Literal[CodeLanguage.PYTHON3, CodeLanguage.JAVASCRIPT]
|
||||
|
@@ -1,7 +1,7 @@
|
||||
import os
|
||||
from typing import Literal, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, validator
|
||||
from pydantic import BaseModel, field_validator
|
||||
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
|
||||
@@ -14,15 +14,18 @@ class HttpRequestNodeData(BaseNodeData):
|
||||
Code Node Data.
|
||||
"""
|
||||
class Authorization(BaseModel):
|
||||
# TODO[pydantic]: The `Config` class inherits from another class, please create the `model_config` manually.
|
||||
# Check https://docs.pydantic.dev/dev-v2/migration/#changes-to-config for more information.
|
||||
class Config(BaseModel):
|
||||
type: Literal[None, 'basic', 'bearer', 'custom']
|
||||
api_key: Union[None, str]
|
||||
header: Union[None, str]
|
||||
api_key: Union[None, str] = None
|
||||
header: Union[None, str] = None
|
||||
|
||||
type: Literal['no-auth', 'api-key']
|
||||
config: Optional[Config]
|
||||
|
||||
@validator('config', always=True, pre=True)
|
||||
@classmethod
|
||||
@field_validator('config', mode='before')
|
||||
def check_config(cls, v, values):
|
||||
"""
|
||||
Check config, if type is no-auth, config should be None, otherwise it should be a dict.
|
||||
@@ -37,7 +40,7 @@ class HttpRequestNodeData(BaseNodeData):
|
||||
|
||||
class Body(BaseModel):
|
||||
type: Literal['none', 'form-data', 'x-www-form-urlencoded', 'raw-text', 'json']
|
||||
data: Union[None, str]
|
||||
data: Union[None, str] = None
|
||||
|
||||
class Timeout(BaseModel):
|
||||
connect: Optional[int] = MAX_CONNECT_TIMEOUT
|
||||
@@ -50,5 +53,5 @@ class HttpRequestNodeData(BaseNodeData):
|
||||
headers: str
|
||||
params: str
|
||||
body: Optional[Body]
|
||||
timeout: Optional[Timeout]
|
||||
timeout: Optional[Timeout] = None
|
||||
mask_authorization_header: Optional[bool] = True
|
||||
|
@@ -39,7 +39,7 @@ class HttpRequestNode(BaseNode):
|
||||
"type": "none"
|
||||
},
|
||||
"timeout": {
|
||||
**HTTP_REQUEST_DEFAULT_TIMEOUT.dict(),
|
||||
**HTTP_REQUEST_DEFAULT_TIMEOUT.model_dump(),
|
||||
"max_connect_timeout": MAX_CONNECT_TIMEOUT,
|
||||
"max_read_timeout": MAX_READ_TIMEOUT,
|
||||
"max_write_timeout": MAX_WRITE_TIMEOUT,
|
||||
|
@@ -7,7 +7,7 @@ class IterationNodeData(BaseIterationNodeData):
|
||||
"""
|
||||
Iteration Node Data.
|
||||
"""
|
||||
parent_loop_id: Optional[str] # redundant field, not used currently
|
||||
parent_loop_id: Optional[str] = None # redundant field, not used currently
|
||||
iterator_selector: list[str] # variable selector
|
||||
output_selector: list[str] # output selector
|
||||
|
||||
|
@@ -18,7 +18,7 @@ class MultipleRetrievalConfig(BaseModel):
|
||||
Multiple Retrieval Config.
|
||||
"""
|
||||
top_k: int
|
||||
score_threshold: Optional[float]
|
||||
score_threshold: Optional[float] = None
|
||||
reranking_model: RerankingModelConfig
|
||||
|
||||
|
||||
@@ -47,5 +47,5 @@ class KnowledgeRetrievalNodeData(BaseNodeData):
|
||||
query_variable_selector: list[str]
|
||||
dataset_ids: list[str]
|
||||
retrieval_mode: Literal['single', 'multiple']
|
||||
multiple_retrieval_config: Optional[MultipleRetrievalConfig]
|
||||
single_retrieval_config: Optional[SingleRetrievalConfig]
|
||||
multiple_retrieval_config: Optional[MultipleRetrievalConfig] = None
|
||||
single_retrieval_config: Optional[SingleRetrievalConfig] = None
|
||||
|
@@ -1,6 +1,6 @@
|
||||
from typing import Any, Literal, Optional
|
||||
|
||||
from pydantic import BaseModel, validator
|
||||
from pydantic import BaseModel, field_validator
|
||||
|
||||
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
@@ -21,12 +21,13 @@ class ParameterConfig(BaseModel):
|
||||
"""
|
||||
name: str
|
||||
type: Literal['string', 'number', 'bool', 'select', 'array[string]', 'array[number]', 'array[object]']
|
||||
options: Optional[list[str]]
|
||||
options: Optional[list[str]] = None
|
||||
description: str
|
||||
required: bool
|
||||
|
||||
@validator('name', pre=True, always=True)
|
||||
def validate_name(cls, value):
|
||||
@classmethod
|
||||
@field_validator('name', mode='before')
|
||||
def validate_name(cls, value) -> str:
|
||||
if not value:
|
||||
raise ValueError('Parameter name is required')
|
||||
if value in ['__reason', '__is_success']:
|
||||
@@ -40,12 +41,13 @@ class ParameterExtractorNodeData(BaseNodeData):
|
||||
model: ModelConfig
|
||||
query: list[str]
|
||||
parameters: list[ParameterConfig]
|
||||
instruction: Optional[str]
|
||||
memory: Optional[MemoryConfig]
|
||||
instruction: Optional[str] = None
|
||||
memory: Optional[MemoryConfig] = None
|
||||
reasoning_mode: Literal['function_call', 'prompt']
|
||||
|
||||
@validator('reasoning_mode', pre=True, always=True)
|
||||
def set_reasoning_mode(cls, v):
|
||||
@classmethod
|
||||
@field_validator('reasoning_mode', mode='before')
|
||||
def set_reasoning_mode(cls, v) -> str:
|
||||
return v or 'function_call'
|
||||
|
||||
def get_parameter_json_schema(self) -> dict:
|
||||
|
@@ -32,5 +32,5 @@ class QuestionClassifierNodeData(BaseNodeData):
|
||||
type: str = 'question-classifier'
|
||||
model: ModelConfig
|
||||
classes: list[ClassConfig]
|
||||
instruction: Optional[str]
|
||||
memory: Optional[MemoryConfig]
|
||||
instruction: Optional[str] = None
|
||||
memory: Optional[MemoryConfig] = None
|
||||
|
@@ -1,6 +1,7 @@
|
||||
from typing import Any, Literal, Union
|
||||
|
||||
from pydantic import BaseModel, validator
|
||||
from pydantic import BaseModel, field_validator
|
||||
from pydantic_core.core_schema import ValidationInfo
|
||||
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
|
||||
@@ -13,13 +14,14 @@ class ToolEntity(BaseModel):
|
||||
tool_label: str # redundancy
|
||||
tool_configurations: dict[str, Any]
|
||||
|
||||
@validator('tool_configurations', pre=True, always=True)
|
||||
def validate_tool_configurations(cls, value, values):
|
||||
@classmethod
|
||||
@field_validator('tool_configurations', mode='before')
|
||||
def validate_tool_configurations(cls, value, values: ValidationInfo) -> dict[str, Any]:
|
||||
if not isinstance(value, dict):
|
||||
raise ValueError('tool_configurations must be a dictionary')
|
||||
|
||||
for key in values.get('tool_configurations', {}).keys():
|
||||
value = values.get('tool_configurations', {}).get(key)
|
||||
for key in values.data.get('tool_configurations', {}).keys():
|
||||
value = values.data.get('tool_configurations', {}).get(key)
|
||||
if not isinstance(value, str | int | float | bool):
|
||||
raise ValueError(f'{key} must be a string')
|
||||
|
||||
@@ -30,10 +32,11 @@ class ToolNodeData(BaseNodeData, ToolEntity):
|
||||
value: Union[Any, list[str]]
|
||||
type: Literal['mixed', 'variable', 'constant']
|
||||
|
||||
@validator('type', pre=True, always=True)
|
||||
def check_type(cls, value, values):
|
||||
@classmethod
|
||||
@field_validator('type', mode='before')
|
||||
def check_type(cls, value, validation_info: ValidationInfo):
|
||||
typ = value
|
||||
value = values.get('value')
|
||||
value = validation_info.data.get('value')
|
||||
if typ == 'mixed' and not isinstance(value, str):
|
||||
raise ValueError('value must be a string')
|
||||
elif typ == 'variable':
|
||||
@@ -45,7 +48,7 @@ class ToolNodeData(BaseNodeData, ToolEntity):
|
||||
elif typ == 'constant' and not isinstance(value, str | int | float | bool):
|
||||
raise ValueError('value must be a string, int, float, or bool')
|
||||
return typ
|
||||
|
||||
|
||||
"""
|
||||
Tool Node Schema
|
||||
"""
|
||||
|
@@ -30,4 +30,4 @@ class VariableAssignerNodeData(BaseNodeData):
|
||||
type: str = 'variable-assigner'
|
||||
output_type: str
|
||||
variables: list[list[str]]
|
||||
advanced_settings: Optional[AdvancedSettings]
|
||||
advanced_settings: Optional[AdvancedSettings] = None
|
@@ -592,7 +592,7 @@ class WorkflowEngineManager:
|
||||
node_data=current_iteration_node.node_data,
|
||||
inputs=workflow_run_state.current_iteration_state.inputs,
|
||||
predecessor_node_id=predecessor_node_id,
|
||||
metadata=workflow_run_state.current_iteration_state.metadata.dict()
|
||||
metadata=workflow_run_state.current_iteration_state.metadata.model_dump()
|
||||
)
|
||||
|
||||
# add steps
|
||||
|
Reference in New Issue
Block a user