feat: mypy for all type check (#10921)
This commit is contained in:
@@ -124,17 +124,20 @@ class ModelInstance:
|
||||
raise Exception("Model type instance is not LargeLanguageModel")
|
||||
|
||||
self.model_type_instance = cast(LargeLanguageModel, self.model_type_instance)
|
||||
return self._round_robin_invoke(
|
||||
function=self.model_type_instance.invoke,
|
||||
model=self.model,
|
||||
credentials=self.credentials,
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=model_parameters,
|
||||
tools=tools,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
user=user,
|
||||
callbacks=callbacks,
|
||||
return cast(
|
||||
Union[LLMResult, Generator],
|
||||
self._round_robin_invoke(
|
||||
function=self.model_type_instance.invoke,
|
||||
model=self.model,
|
||||
credentials=self.credentials,
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=model_parameters,
|
||||
tools=tools,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
user=user,
|
||||
callbacks=callbacks,
|
||||
),
|
||||
)
|
||||
|
||||
def get_llm_num_tokens(
|
||||
@@ -151,12 +154,15 @@ class ModelInstance:
|
||||
raise Exception("Model type instance is not LargeLanguageModel")
|
||||
|
||||
self.model_type_instance = cast(LargeLanguageModel, self.model_type_instance)
|
||||
return self._round_robin_invoke(
|
||||
function=self.model_type_instance.get_num_tokens,
|
||||
model=self.model,
|
||||
credentials=self.credentials,
|
||||
prompt_messages=prompt_messages,
|
||||
tools=tools,
|
||||
return cast(
|
||||
int,
|
||||
self._round_robin_invoke(
|
||||
function=self.model_type_instance.get_num_tokens,
|
||||
model=self.model,
|
||||
credentials=self.credentials,
|
||||
prompt_messages=prompt_messages,
|
||||
tools=tools,
|
||||
),
|
||||
)
|
||||
|
||||
def invoke_text_embedding(
|
||||
@@ -174,13 +180,16 @@ class ModelInstance:
|
||||
raise Exception("Model type instance is not TextEmbeddingModel")
|
||||
|
||||
self.model_type_instance = cast(TextEmbeddingModel, self.model_type_instance)
|
||||
return self._round_robin_invoke(
|
||||
function=self.model_type_instance.invoke,
|
||||
model=self.model,
|
||||
credentials=self.credentials,
|
||||
texts=texts,
|
||||
user=user,
|
||||
input_type=input_type,
|
||||
return cast(
|
||||
TextEmbeddingResult,
|
||||
self._round_robin_invoke(
|
||||
function=self.model_type_instance.invoke,
|
||||
model=self.model,
|
||||
credentials=self.credentials,
|
||||
texts=texts,
|
||||
user=user,
|
||||
input_type=input_type,
|
||||
),
|
||||
)
|
||||
|
||||
def get_text_embedding_num_tokens(self, texts: list[str]) -> int:
|
||||
@@ -194,11 +203,14 @@ class ModelInstance:
|
||||
raise Exception("Model type instance is not TextEmbeddingModel")
|
||||
|
||||
self.model_type_instance = cast(TextEmbeddingModel, self.model_type_instance)
|
||||
return self._round_robin_invoke(
|
||||
function=self.model_type_instance.get_num_tokens,
|
||||
model=self.model,
|
||||
credentials=self.credentials,
|
||||
texts=texts,
|
||||
return cast(
|
||||
int,
|
||||
self._round_robin_invoke(
|
||||
function=self.model_type_instance.get_num_tokens,
|
||||
model=self.model,
|
||||
credentials=self.credentials,
|
||||
texts=texts,
|
||||
),
|
||||
)
|
||||
|
||||
def invoke_rerank(
|
||||
@@ -223,15 +235,18 @@ class ModelInstance:
|
||||
raise Exception("Model type instance is not RerankModel")
|
||||
|
||||
self.model_type_instance = cast(RerankModel, self.model_type_instance)
|
||||
return self._round_robin_invoke(
|
||||
function=self.model_type_instance.invoke,
|
||||
model=self.model,
|
||||
credentials=self.credentials,
|
||||
query=query,
|
||||
docs=docs,
|
||||
score_threshold=score_threshold,
|
||||
top_n=top_n,
|
||||
user=user,
|
||||
return cast(
|
||||
RerankResult,
|
||||
self._round_robin_invoke(
|
||||
function=self.model_type_instance.invoke,
|
||||
model=self.model,
|
||||
credentials=self.credentials,
|
||||
query=query,
|
||||
docs=docs,
|
||||
score_threshold=score_threshold,
|
||||
top_n=top_n,
|
||||
user=user,
|
||||
),
|
||||
)
|
||||
|
||||
def invoke_moderation(self, text: str, user: Optional[str] = None) -> bool:
|
||||
@@ -246,12 +261,15 @@ class ModelInstance:
|
||||
raise Exception("Model type instance is not ModerationModel")
|
||||
|
||||
self.model_type_instance = cast(ModerationModel, self.model_type_instance)
|
||||
return self._round_robin_invoke(
|
||||
function=self.model_type_instance.invoke,
|
||||
model=self.model,
|
||||
credentials=self.credentials,
|
||||
text=text,
|
||||
user=user,
|
||||
return cast(
|
||||
bool,
|
||||
self._round_robin_invoke(
|
||||
function=self.model_type_instance.invoke,
|
||||
model=self.model,
|
||||
credentials=self.credentials,
|
||||
text=text,
|
||||
user=user,
|
||||
),
|
||||
)
|
||||
|
||||
def invoke_speech2text(self, file: IO[bytes], user: Optional[str] = None) -> str:
|
||||
@@ -266,12 +284,15 @@ class ModelInstance:
|
||||
raise Exception("Model type instance is not Speech2TextModel")
|
||||
|
||||
self.model_type_instance = cast(Speech2TextModel, self.model_type_instance)
|
||||
return self._round_robin_invoke(
|
||||
function=self.model_type_instance.invoke,
|
||||
model=self.model,
|
||||
credentials=self.credentials,
|
||||
file=file,
|
||||
user=user,
|
||||
return cast(
|
||||
str,
|
||||
self._round_robin_invoke(
|
||||
function=self.model_type_instance.invoke,
|
||||
model=self.model,
|
||||
credentials=self.credentials,
|
||||
file=file,
|
||||
user=user,
|
||||
),
|
||||
)
|
||||
|
||||
def invoke_tts(self, content_text: str, tenant_id: str, voice: str, user: Optional[str] = None) -> Iterable[bytes]:
|
||||
@@ -288,17 +309,20 @@ class ModelInstance:
|
||||
raise Exception("Model type instance is not TTSModel")
|
||||
|
||||
self.model_type_instance = cast(TTSModel, self.model_type_instance)
|
||||
return self._round_robin_invoke(
|
||||
function=self.model_type_instance.invoke,
|
||||
model=self.model,
|
||||
credentials=self.credentials,
|
||||
content_text=content_text,
|
||||
user=user,
|
||||
tenant_id=tenant_id,
|
||||
voice=voice,
|
||||
return cast(
|
||||
Iterable[bytes],
|
||||
self._round_robin_invoke(
|
||||
function=self.model_type_instance.invoke,
|
||||
model=self.model,
|
||||
credentials=self.credentials,
|
||||
content_text=content_text,
|
||||
user=user,
|
||||
tenant_id=tenant_id,
|
||||
voice=voice,
|
||||
),
|
||||
)
|
||||
|
||||
def _round_robin_invoke(self, function: Callable[..., Any], *args, **kwargs):
|
||||
def _round_robin_invoke(self, function: Callable[..., Any], *args, **kwargs) -> Any:
|
||||
"""
|
||||
Round-robin invoke
|
||||
:param function: function to invoke
|
||||
@@ -309,7 +333,7 @@ class ModelInstance:
|
||||
if not self.load_balancing_manager:
|
||||
return function(*args, **kwargs)
|
||||
|
||||
last_exception = None
|
||||
last_exception: Union[InvokeRateLimitError, InvokeAuthorizationError, InvokeConnectionError, None] = None
|
||||
while True:
|
||||
lb_config = self.load_balancing_manager.fetch_next()
|
||||
if not lb_config:
|
||||
@@ -463,7 +487,7 @@ class LBModelManager:
|
||||
if real_index > max_index:
|
||||
real_index = 0
|
||||
|
||||
config = self._load_balancing_configs[real_index]
|
||||
config: ModelLoadBalancingConfiguration = self._load_balancing_configs[real_index]
|
||||
|
||||
if self.in_cooldown(config):
|
||||
cooldown_load_balancing_configs.append(config)
|
||||
@@ -507,8 +531,7 @@ class LBModelManager:
|
||||
self._tenant_id, self._provider, self._model_type.value, self._model, config.id
|
||||
)
|
||||
|
||||
res = redis_client.exists(cooldown_cache_key)
|
||||
res = cast(bool, res)
|
||||
res: bool = redis_client.exists(cooldown_cache_key)
|
||||
return res
|
||||
|
||||
@staticmethod
|
||||
|
Reference in New Issue
Block a user