diff --git a/api/core/embedding/openai_embedding.py b/api/core/embedding/openai_embedding.py index 0f7cb252e..d1179180f 100644 --- a/api/core/embedding/openai_embedding.py +++ b/api/core/embedding/openai_embedding.py @@ -173,6 +173,13 @@ class OpenAIEmbedding(BaseEmbedding): Can be overriden for batch queries. """ + if self.openai_api_type and self.openai_api_type == 'azure': + embeddings = [] + for text in texts: + embeddings.append(self._get_text_embedding(text)) + + return embeddings + if self.deployment_name is not None: engine = self.deployment_name else: @@ -187,6 +194,13 @@ class OpenAIEmbedding(BaseEmbedding): async def _aget_text_embeddings(self, texts: List[str]) -> List[List[float]]: """Asynchronously get text embeddings.""" + if self.openai_api_type and self.openai_api_type == 'azure': + embeddings = [] + for text in texts: + embeddings.append(await self._aget_text_embedding(text)) + + return embeddings + if self.deployment_name is not None: engine = self.deployment_name else: