feat: mypy for all type check (#10921)

This commit is contained in:
yihong
2024-12-24 18:38:51 +08:00
committed by GitHub
parent c91e8b1737
commit 56e15d09a9
584 changed files with 3975 additions and 2826 deletions

View File

@@ -1,6 +1,6 @@
import base64
import logging
from typing import Optional, cast
from typing import Any, Optional, cast
import numpy as np
from sqlalchemy.exc import IntegrityError
@@ -27,7 +27,7 @@ class CacheEmbedding(Embeddings):
def embed_documents(self, texts: list[str]) -> list[list[float]]:
"""Embed search docs in batches of 10."""
# use doc embedding cache or store if not exists
text_embeddings = [None for _ in range(len(texts))]
text_embeddings: list[Any] = [None for _ in range(len(texts))]
embedding_queue_indices = []
for i, text in enumerate(texts):
hash = helper.generate_text_hash(text)
@@ -64,7 +64,8 @@ class CacheEmbedding(Embeddings):
for vector in embedding_result.embeddings:
try:
normalized_embedding = (vector / np.linalg.norm(vector)).tolist()
# FIXME: type ignore for numpy here
normalized_embedding = (vector / np.linalg.norm(vector)).tolist() # type: ignore
# stackoverflow best way: https://stackoverflow.com/questions/20319813/how-to-check-list-containing-nan
if np.isnan(normalized_embedding).any():
# for issue #11827 float values are not json compliant
@@ -77,8 +78,8 @@ class CacheEmbedding(Embeddings):
logging.exception("Failed transform embedding")
cache_embeddings = []
try:
for i, embedding in zip(embedding_queue_indices, embedding_queue_embeddings):
text_embeddings[i] = embedding
for i, n_embedding in zip(embedding_queue_indices, embedding_queue_embeddings):
text_embeddings[i] = n_embedding
hash = helper.generate_text_hash(texts[i])
if hash not in cache_embeddings:
embedding_cache = Embedding(
@@ -86,7 +87,7 @@ class CacheEmbedding(Embeddings):
hash=hash,
provider_name=self._model_instance.provider,
)
embedding_cache.set_embedding(embedding)
embedding_cache.set_embedding(n_embedding)
db.session.add(embedding_cache)
cache_embeddings.append(hash)
db.session.commit()
@@ -115,7 +116,8 @@ class CacheEmbedding(Embeddings):
)
embedding_results = embedding_result.embeddings[0]
embedding_results = (embedding_results / np.linalg.norm(embedding_results)).tolist()
# FIXME: type ignore for numpy here
embedding_results = (embedding_results / np.linalg.norm(embedding_results)).tolist() # type: ignore
if np.isnan(embedding_results).any():
raise ValueError("Normalized embedding is nan please try again")
except Exception as ex: