improve: extract method for safe loading yaml file and avoid using PyYaml's FullLoader (#4031)
This commit is contained in:
@@ -3,8 +3,6 @@ import os
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
|
||||
import yaml
|
||||
|
||||
from core.model_runtime.entities.common_entities import I18nObject
|
||||
from core.model_runtime.entities.defaults import PARAMETER_RULE_TEMPLATE
|
||||
from core.model_runtime.entities.model_entities import (
|
||||
@@ -18,6 +16,7 @@ from core.model_runtime.entities.model_entities import (
|
||||
)
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
|
||||
from core.model_runtime.model_providers.__base.tokenizers.gpt2_tokenzier import GPT2Tokenizer
|
||||
from core.tools.utils.yaml_utils import load_yaml_file
|
||||
from core.utils.position_helper import get_position_map, sort_by_position_map
|
||||
|
||||
|
||||
@@ -154,8 +153,7 @@ class AIModel(ABC):
|
||||
# traverse all model_schema_yaml_paths
|
||||
for model_schema_yaml_path in model_schema_yaml_paths:
|
||||
# read yaml data from yaml file
|
||||
with open(model_schema_yaml_path, encoding='utf-8') as f:
|
||||
yaml_data = yaml.safe_load(f)
|
||||
yaml_data = load_yaml_file(model_schema_yaml_path, ignore_error=True)
|
||||
|
||||
new_parameter_rules = []
|
||||
for parameter_rule in yaml_data.get('parameter_rules', []):
|
||||
|
@@ -1,11 +1,10 @@
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
import yaml
|
||||
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity, ModelType
|
||||
from core.model_runtime.entities.provider_entities import ProviderEntity
|
||||
from core.model_runtime.model_providers.__base.ai_model import AIModel
|
||||
from core.tools.utils.yaml_utils import load_yaml_file
|
||||
from core.utils.module_import_helper import get_subclasses_from_module, import_module_from_source
|
||||
|
||||
|
||||
@@ -44,10 +43,7 @@ class ModelProvider(ABC):
|
||||
|
||||
# read provider schema from yaml file
|
||||
yaml_path = os.path.join(current_path, f'{provider_name}.yaml')
|
||||
yaml_data = {}
|
||||
if os.path.exists(yaml_path):
|
||||
with open(yaml_path, encoding='utf-8') as f:
|
||||
yaml_data = yaml.safe_load(f)
|
||||
yaml_data = load_yaml_file(yaml_path, ignore_error=True)
|
||||
|
||||
try:
|
||||
# yaml_data to entity
|
||||
|
Reference in New Issue
Block a user