feat: mypy for all type check (#10921)

This commit is contained in:
yihong
2024-12-24 18:38:51 +08:00
committed by GitHub
parent c91e8b1737
commit 56e15d09a9
584 changed files with 3975 additions and 2826 deletions

View File

@@ -124,17 +124,20 @@ class ModelInstance:
raise Exception("Model type instance is not LargeLanguageModel")
self.model_type_instance = cast(LargeLanguageModel, self.model_type_instance)
return self._round_robin_invoke(
function=self.model_type_instance.invoke,
model=self.model,
credentials=self.credentials,
prompt_messages=prompt_messages,
model_parameters=model_parameters,
tools=tools,
stop=stop,
stream=stream,
user=user,
callbacks=callbacks,
return cast(
Union[LLMResult, Generator],
self._round_robin_invoke(
function=self.model_type_instance.invoke,
model=self.model,
credentials=self.credentials,
prompt_messages=prompt_messages,
model_parameters=model_parameters,
tools=tools,
stop=stop,
stream=stream,
user=user,
callbacks=callbacks,
),
)
def get_llm_num_tokens(
@@ -151,12 +154,15 @@ class ModelInstance:
raise Exception("Model type instance is not LargeLanguageModel")
self.model_type_instance = cast(LargeLanguageModel, self.model_type_instance)
return self._round_robin_invoke(
function=self.model_type_instance.get_num_tokens,
model=self.model,
credentials=self.credentials,
prompt_messages=prompt_messages,
tools=tools,
return cast(
int,
self._round_robin_invoke(
function=self.model_type_instance.get_num_tokens,
model=self.model,
credentials=self.credentials,
prompt_messages=prompt_messages,
tools=tools,
),
)
def invoke_text_embedding(
@@ -174,13 +180,16 @@ class ModelInstance:
raise Exception("Model type instance is not TextEmbeddingModel")
self.model_type_instance = cast(TextEmbeddingModel, self.model_type_instance)
return self._round_robin_invoke(
function=self.model_type_instance.invoke,
model=self.model,
credentials=self.credentials,
texts=texts,
user=user,
input_type=input_type,
return cast(
TextEmbeddingResult,
self._round_robin_invoke(
function=self.model_type_instance.invoke,
model=self.model,
credentials=self.credentials,
texts=texts,
user=user,
input_type=input_type,
),
)
def get_text_embedding_num_tokens(self, texts: list[str]) -> int:
@@ -194,11 +203,14 @@ class ModelInstance:
raise Exception("Model type instance is not TextEmbeddingModel")
self.model_type_instance = cast(TextEmbeddingModel, self.model_type_instance)
return self._round_robin_invoke(
function=self.model_type_instance.get_num_tokens,
model=self.model,
credentials=self.credentials,
texts=texts,
return cast(
int,
self._round_robin_invoke(
function=self.model_type_instance.get_num_tokens,
model=self.model,
credentials=self.credentials,
texts=texts,
),
)
def invoke_rerank(
@@ -223,15 +235,18 @@ class ModelInstance:
raise Exception("Model type instance is not RerankModel")
self.model_type_instance = cast(RerankModel, self.model_type_instance)
return self._round_robin_invoke(
function=self.model_type_instance.invoke,
model=self.model,
credentials=self.credentials,
query=query,
docs=docs,
score_threshold=score_threshold,
top_n=top_n,
user=user,
return cast(
RerankResult,
self._round_robin_invoke(
function=self.model_type_instance.invoke,
model=self.model,
credentials=self.credentials,
query=query,
docs=docs,
score_threshold=score_threshold,
top_n=top_n,
user=user,
),
)
def invoke_moderation(self, text: str, user: Optional[str] = None) -> bool:
@@ -246,12 +261,15 @@ class ModelInstance:
raise Exception("Model type instance is not ModerationModel")
self.model_type_instance = cast(ModerationModel, self.model_type_instance)
return self._round_robin_invoke(
function=self.model_type_instance.invoke,
model=self.model,
credentials=self.credentials,
text=text,
user=user,
return cast(
bool,
self._round_robin_invoke(
function=self.model_type_instance.invoke,
model=self.model,
credentials=self.credentials,
text=text,
user=user,
),
)
def invoke_speech2text(self, file: IO[bytes], user: Optional[str] = None) -> str:
@@ -266,12 +284,15 @@ class ModelInstance:
raise Exception("Model type instance is not Speech2TextModel")
self.model_type_instance = cast(Speech2TextModel, self.model_type_instance)
return self._round_robin_invoke(
function=self.model_type_instance.invoke,
model=self.model,
credentials=self.credentials,
file=file,
user=user,
return cast(
str,
self._round_robin_invoke(
function=self.model_type_instance.invoke,
model=self.model,
credentials=self.credentials,
file=file,
user=user,
),
)
def invoke_tts(self, content_text: str, tenant_id: str, voice: str, user: Optional[str] = None) -> Iterable[bytes]:
@@ -288,17 +309,20 @@ class ModelInstance:
raise Exception("Model type instance is not TTSModel")
self.model_type_instance = cast(TTSModel, self.model_type_instance)
return self._round_robin_invoke(
function=self.model_type_instance.invoke,
model=self.model,
credentials=self.credentials,
content_text=content_text,
user=user,
tenant_id=tenant_id,
voice=voice,
return cast(
Iterable[bytes],
self._round_robin_invoke(
function=self.model_type_instance.invoke,
model=self.model,
credentials=self.credentials,
content_text=content_text,
user=user,
tenant_id=tenant_id,
voice=voice,
),
)
def _round_robin_invoke(self, function: Callable[..., Any], *args, **kwargs):
def _round_robin_invoke(self, function: Callable[..., Any], *args, **kwargs) -> Any:
"""
Round-robin invoke
:param function: function to invoke
@@ -309,7 +333,7 @@ class ModelInstance:
if not self.load_balancing_manager:
return function(*args, **kwargs)
last_exception = None
last_exception: Union[InvokeRateLimitError, InvokeAuthorizationError, InvokeConnectionError, None] = None
while True:
lb_config = self.load_balancing_manager.fetch_next()
if not lb_config:
@@ -463,7 +487,7 @@ class LBModelManager:
if real_index > max_index:
real_index = 0
config = self._load_balancing_configs[real_index]
config: ModelLoadBalancingConfiguration = self._load_balancing_configs[real_index]
if self.in_cooldown(config):
cooldown_load_balancing_configs.append(config)
@@ -507,8 +531,7 @@ class LBModelManager:
self._tenant_id, self._provider, self._model_type.value, self._model, config.id
)
res = redis_client.exists(cooldown_cache_key)
res = cast(bool, res)
res: bool = redis_client.exists(cooldown_cache_key)
return res
@staticmethod