normalize embedding (#974)

Co-authored-by: jyong <jyong@dify.ai>
This commit is contained in:
Jyong
2023-08-23 19:10:11 +08:00
committed by GitHub
parent 916d8be0ae
commit 1fc57d7358

View File

@@ -1,6 +1,7 @@
import logging import logging
from typing import List from typing import List
import numpy as np
from langchain.embeddings.base import Embeddings from langchain.embeddings.base import Embeddings
from sqlalchemy.exc import IntegrityError from sqlalchemy.exc import IntegrityError
@@ -32,14 +33,17 @@ class CacheEmbedding(Embeddings):
embedding_results = self._embeddings.client.embed_documents(embedding_queue_texts) embedding_results = self._embeddings.client.embed_documents(embedding_queue_texts)
except Exception as ex: except Exception as ex:
raise self._embeddings.handle_exceptions(ex) raise self._embeddings.handle_exceptions(ex)
i = 0 i = 0
normalized_embedding_results = []
for text in embedding_queue_texts: for text in embedding_queue_texts:
hash = helper.generate_text_hash(text) hash = helper.generate_text_hash(text)
try: try:
embedding = Embedding(model_name=self._embeddings.name, hash=hash) embedding = Embedding(model_name=self._embeddings.name, hash=hash)
embedding.set_embedding(embedding_results[i]) vector = embedding_results[i]
normalized_embedding = (vector / np.linalg.norm(vector)).tolist()
normalized_embedding_results.append(normalized_embedding)
embedding.set_embedding(normalized_embedding)
db.session.add(embedding) db.session.add(embedding)
db.session.commit() db.session.commit()
except IntegrityError: except IntegrityError:
@@ -51,7 +55,7 @@ class CacheEmbedding(Embeddings):
finally: finally:
i += 1 i += 1
text_embeddings.extend(embedding_results) text_embeddings.extend(normalized_embedding_results)
return text_embeddings return text_embeddings
def embed_query(self, text: str) -> List[float]: def embed_query(self, text: str) -> List[float]:
@@ -64,6 +68,7 @@ class CacheEmbedding(Embeddings):
try: try:
embedding_results = self._embeddings.client.embed_query(text) embedding_results = self._embeddings.client.embed_query(text)
embedding_results = (embedding_results / np.linalg.norm(embedding_results)).tolist()
except Exception as ex: except Exception as ex:
raise self._embeddings.handle_exceptions(ex) raise self._embeddings.handle_exceptions(ex)
@@ -79,4 +84,3 @@ class CacheEmbedding(Embeddings):
return embedding_results return embedding_results