feat: mypy for all type check (#10921)
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
import base64
|
||||
import logging
|
||||
from typing import Optional, cast
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
import numpy as np
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
@@ -27,7 +27,7 @@ class CacheEmbedding(Embeddings):
|
||||
def embed_documents(self, texts: list[str]) -> list[list[float]]:
|
||||
"""Embed search docs in batches of 10."""
|
||||
# use doc embedding cache or store if not exists
|
||||
text_embeddings = [None for _ in range(len(texts))]
|
||||
text_embeddings: list[Any] = [None for _ in range(len(texts))]
|
||||
embedding_queue_indices = []
|
||||
for i, text in enumerate(texts):
|
||||
hash = helper.generate_text_hash(text)
|
||||
@@ -64,7 +64,8 @@ class CacheEmbedding(Embeddings):
|
||||
|
||||
for vector in embedding_result.embeddings:
|
||||
try:
|
||||
normalized_embedding = (vector / np.linalg.norm(vector)).tolist()
|
||||
# FIXME: type ignore for numpy here
|
||||
normalized_embedding = (vector / np.linalg.norm(vector)).tolist() # type: ignore
|
||||
# stackoverflow best way: https://stackoverflow.com/questions/20319813/how-to-check-list-containing-nan
|
||||
if np.isnan(normalized_embedding).any():
|
||||
# for issue #11827 float values are not json compliant
|
||||
@@ -77,8 +78,8 @@ class CacheEmbedding(Embeddings):
|
||||
logging.exception("Failed transform embedding")
|
||||
cache_embeddings = []
|
||||
try:
|
||||
for i, embedding in zip(embedding_queue_indices, embedding_queue_embeddings):
|
||||
text_embeddings[i] = embedding
|
||||
for i, n_embedding in zip(embedding_queue_indices, embedding_queue_embeddings):
|
||||
text_embeddings[i] = n_embedding
|
||||
hash = helper.generate_text_hash(texts[i])
|
||||
if hash not in cache_embeddings:
|
||||
embedding_cache = Embedding(
|
||||
@@ -86,7 +87,7 @@ class CacheEmbedding(Embeddings):
|
||||
hash=hash,
|
||||
provider_name=self._model_instance.provider,
|
||||
)
|
||||
embedding_cache.set_embedding(embedding)
|
||||
embedding_cache.set_embedding(n_embedding)
|
||||
db.session.add(embedding_cache)
|
||||
cache_embeddings.append(hash)
|
||||
db.session.commit()
|
||||
@@ -115,7 +116,8 @@ class CacheEmbedding(Embeddings):
|
||||
)
|
||||
|
||||
embedding_results = embedding_result.embeddings[0]
|
||||
embedding_results = (embedding_results / np.linalg.norm(embedding_results)).tolist()
|
||||
# FIXME: type ignore for numpy here
|
||||
embedding_results = (embedding_results / np.linalg.norm(embedding_results)).tolist() # type: ignore
|
||||
if np.isnan(embedding_results).any():
|
||||
raise ValueError("Normalized embedding is nan please try again")
|
||||
except Exception as ex:
|
||||
|
Reference in New Issue
Block a user