improve: extract method for safe loading yaml file and avoid using PyYaml's FullLoader (#4031)

This commit is contained in:
Bowen Liang
2024-05-24 12:08:12 +08:00
committed by GitHub
parent 296887754f
commit 3fda2245a4
11 changed files with 190 additions and 62 deletions

View File

@@ -3,8 +3,6 @@ import os
from abc import ABC, abstractmethod
from typing import Optional
import yaml
from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.defaults import PARAMETER_RULE_TEMPLATE
from core.model_runtime.entities.model_entities import (
@@ -18,6 +16,7 @@ from core.model_runtime.entities.model_entities import (
)
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
from core.model_runtime.model_providers.__base.tokenizers.gpt2_tokenzier import GPT2Tokenizer
from core.tools.utils.yaml_utils import load_yaml_file
from core.utils.position_helper import get_position_map, sort_by_position_map
@@ -154,8 +153,7 @@ class AIModel(ABC):
# traverse all model_schema_yaml_paths
for model_schema_yaml_path in model_schema_yaml_paths:
# read yaml data from yaml file
with open(model_schema_yaml_path, encoding='utf-8') as f:
yaml_data = yaml.safe_load(f)
yaml_data = load_yaml_file(model_schema_yaml_path, ignore_error=True)
new_parameter_rules = []
for parameter_rule in yaml_data.get('parameter_rules', []):

View File

@@ -1,11 +1,10 @@
import os
from abc import ABC, abstractmethod
import yaml
from core.model_runtime.entities.model_entities import AIModelEntity, ModelType
from core.model_runtime.entities.provider_entities import ProviderEntity
from core.model_runtime.model_providers.__base.ai_model import AIModel
from core.tools.utils.yaml_utils import load_yaml_file
from core.utils.module_import_helper import get_subclasses_from_module, import_module_from_source
@@ -44,10 +43,7 @@ class ModelProvider(ABC):
# read provider schema from yaml file
yaml_path = os.path.join(current_path, f'{provider_name}.yaml')
yaml_data = {}
if os.path.exists(yaml_path):
with open(yaml_path, encoding='utf-8') as f:
yaml_data = yaml.safe_load(f)
yaml_data = load_yaml_file(yaml_path, ignore_error=True)
try:
# yaml_data to entity