chore(api/core): apply ruff reformatting (#7624)
This commit is contained in:
@@ -26,15 +26,15 @@ class ElasticSearchConfig(BaseModel):
|
||||
username: str
|
||||
password: str
|
||||
|
||||
@model_validator(mode='before')
|
||||
@model_validator(mode="before")
|
||||
def validate_config(cls, values: dict) -> dict:
|
||||
if not values['host']:
|
||||
if not values["host"]:
|
||||
raise ValueError("config HOST is required")
|
||||
if not values['port']:
|
||||
if not values["port"]:
|
||||
raise ValueError("config PORT is required")
|
||||
if not values['username']:
|
||||
if not values["username"]:
|
||||
raise ValueError("config USERNAME is required")
|
||||
if not values['password']:
|
||||
if not values["password"]:
|
||||
raise ValueError("config PASSWORD is required")
|
||||
return values
|
||||
|
||||
@@ -50,10 +50,10 @@ class ElasticSearchVector(BaseVector):
|
||||
def _init_client(self, config: ElasticSearchConfig) -> Elasticsearch:
|
||||
try:
|
||||
parsed_url = urlparse(config.host)
|
||||
if parsed_url.scheme in ['http', 'https']:
|
||||
hosts = f'{config.host}:{config.port}'
|
||||
if parsed_url.scheme in ["http", "https"]:
|
||||
hosts = f"{config.host}:{config.port}"
|
||||
else:
|
||||
hosts = f'http://{config.host}:{config.port}'
|
||||
hosts = f"http://{config.host}:{config.port}"
|
||||
client = Elasticsearch(
|
||||
hosts=hosts,
|
||||
basic_auth=(config.username, config.password),
|
||||
@@ -68,25 +68,27 @@ class ElasticSearchVector(BaseVector):
|
||||
|
||||
def _get_version(self) -> str:
|
||||
info = self._client.info()
|
||||
return info['version']['number']
|
||||
return info["version"]["number"]
|
||||
|
||||
def _check_version(self):
|
||||
if self._version < '8.0.0':
|
||||
if self._version < "8.0.0":
|
||||
raise ValueError("Elasticsearch vector database version must be greater than 8.0.0")
|
||||
|
||||
def get_type(self) -> str:
|
||||
return 'elasticsearch'
|
||||
return "elasticsearch"
|
||||
|
||||
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
uuids = self._get_uuids(documents)
|
||||
for i in range(len(documents)):
|
||||
self._client.index(index=self._collection_name,
|
||||
id=uuids[i],
|
||||
document={
|
||||
Field.CONTENT_KEY.value: documents[i].page_content,
|
||||
Field.VECTOR.value: embeddings[i] if embeddings[i] else None,
|
||||
Field.METADATA_KEY.value: documents[i].metadata if documents[i].metadata else {}
|
||||
})
|
||||
self._client.index(
|
||||
index=self._collection_name,
|
||||
id=uuids[i],
|
||||
document={
|
||||
Field.CONTENT_KEY.value: documents[i].page_content,
|
||||
Field.VECTOR.value: embeddings[i] if embeddings[i] else None,
|
||||
Field.METADATA_KEY.value: documents[i].metadata if documents[i].metadata else {},
|
||||
},
|
||||
)
|
||||
self._client.indices.refresh(index=self._collection_name)
|
||||
return uuids
|
||||
|
||||
@@ -98,15 +100,9 @@ class ElasticSearchVector(BaseVector):
|
||||
self._client.delete(index=self._collection_name, id=id)
|
||||
|
||||
def delete_by_metadata_field(self, key: str, value: str) -> None:
|
||||
query_str = {
|
||||
'query': {
|
||||
'match': {
|
||||
f'metadata.{key}': f'{value}'
|
||||
}
|
||||
}
|
||||
}
|
||||
query_str = {"query": {"match": {f"metadata.{key}": f"{value}"}}}
|
||||
results = self._client.search(index=self._collection_name, body=query_str)
|
||||
ids = [hit['_id'] for hit in results['hits']['hits']]
|
||||
ids = [hit["_id"] for hit in results["hits"]["hits"]]
|
||||
if ids:
|
||||
self.delete_by_ids(ids)
|
||||
|
||||
@@ -115,44 +111,44 @@ class ElasticSearchVector(BaseVector):
|
||||
|
||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||
top_k = kwargs.get("top_k", 10)
|
||||
knn = {
|
||||
"field": Field.VECTOR.value,
|
||||
"query_vector": query_vector,
|
||||
"k": top_k
|
||||
}
|
||||
knn = {"field": Field.VECTOR.value, "query_vector": query_vector, "k": top_k}
|
||||
|
||||
results = self._client.search(index=self._collection_name, knn=knn, size=top_k)
|
||||
|
||||
docs_and_scores = []
|
||||
for hit in results['hits']['hits']:
|
||||
for hit in results["hits"]["hits"]:
|
||||
docs_and_scores.append(
|
||||
(Document(page_content=hit['_source'][Field.CONTENT_KEY.value],
|
||||
vector=hit['_source'][Field.VECTOR.value],
|
||||
metadata=hit['_source'][Field.METADATA_KEY.value]), hit['_score']))
|
||||
(
|
||||
Document(
|
||||
page_content=hit["_source"][Field.CONTENT_KEY.value],
|
||||
vector=hit["_source"][Field.VECTOR.value],
|
||||
metadata=hit["_source"][Field.METADATA_KEY.value],
|
||||
),
|
||||
hit["_score"],
|
||||
)
|
||||
)
|
||||
|
||||
docs = []
|
||||
for doc, score in docs_and_scores:
|
||||
score_threshold = kwargs.get("score_threshold", .0) if kwargs.get('score_threshold', .0) else 0.0
|
||||
score_threshold = kwargs.get("score_threshold", 0.0) if kwargs.get("score_threshold", 0.0) else 0.0
|
||||
if score > score_threshold:
|
||||
doc.metadata['score'] = score
|
||||
doc.metadata["score"] = score
|
||||
docs.append(doc)
|
||||
|
||||
return docs
|
||||
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
query_str = {
|
||||
"match": {
|
||||
Field.CONTENT_KEY.value: query
|
||||
}
|
||||
}
|
||||
query_str = {"match": {Field.CONTENT_KEY.value: query}}
|
||||
results = self._client.search(index=self._collection_name, query=query_str)
|
||||
docs = []
|
||||
for hit in results['hits']['hits']:
|
||||
docs.append(Document(
|
||||
page_content=hit['_source'][Field.CONTENT_KEY.value],
|
||||
vector=hit['_source'][Field.VECTOR.value],
|
||||
metadata=hit['_source'][Field.METADATA_KEY.value],
|
||||
))
|
||||
for hit in results["hits"]["hits"]:
|
||||
docs.append(
|
||||
Document(
|
||||
page_content=hit["_source"][Field.CONTENT_KEY.value],
|
||||
vector=hit["_source"][Field.VECTOR.value],
|
||||
metadata=hit["_source"][Field.METADATA_KEY.value],
|
||||
)
|
||||
)
|
||||
|
||||
return docs
|
||||
|
||||
@@ -162,11 +158,11 @@ class ElasticSearchVector(BaseVector):
|
||||
self.add_texts(texts, embeddings, **kwargs)
|
||||
|
||||
def create_collection(
|
||||
self, embeddings: list, metadatas: Optional[list[dict]] = None, index_params: Optional[dict] = None
|
||||
self, embeddings: list, metadatas: Optional[list[dict]] = None, index_params: Optional[dict] = None
|
||||
):
|
||||
lock_name = f'vector_indexing_lock_{self._collection_name}'
|
||||
lock_name = f"vector_indexing_lock_{self._collection_name}"
|
||||
with redis_client.lock(lock_name, timeout=20):
|
||||
collection_exist_cache_key = f'vector_indexing_{self._collection_name}'
|
||||
collection_exist_cache_key = f"vector_indexing_{self._collection_name}"
|
||||
if redis_client.get(collection_exist_cache_key):
|
||||
logger.info(f"Collection {self._collection_name} already exists.")
|
||||
return
|
||||
@@ -179,14 +175,14 @@ class ElasticSearchVector(BaseVector):
|
||||
Field.VECTOR.value: { # Make sure the dimension is correct here
|
||||
"type": "dense_vector",
|
||||
"dims": dim,
|
||||
"similarity": "cosine"
|
||||
"similarity": "cosine",
|
||||
},
|
||||
Field.METADATA_KEY.value: {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"doc_id": {"type": "keyword"} # Map doc_id to keyword type
|
||||
}
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
self._client.indices.create(index=self._collection_name, mappings=mappings)
|
||||
@@ -197,22 +193,21 @@ class ElasticSearchVector(BaseVector):
|
||||
class ElasticSearchVectorFactory(AbstractVectorFactory):
|
||||
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> ElasticSearchVector:
|
||||
if dataset.index_struct_dict:
|
||||
class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix']
|
||||
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
|
||||
collection_name = class_prefix
|
||||
else:
|
||||
dataset_id = dataset.id
|
||||
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
|
||||
dataset.index_struct = json.dumps(
|
||||
self.gen_index_struct_dict(VectorType.ELASTICSEARCH, collection_name))
|
||||
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.ELASTICSEARCH, collection_name))
|
||||
|
||||
config = current_app.config
|
||||
return ElasticSearchVector(
|
||||
index_name=collection_name,
|
||||
config=ElasticSearchConfig(
|
||||
host=config.get('ELASTICSEARCH_HOST'),
|
||||
port=config.get('ELASTICSEARCH_PORT'),
|
||||
username=config.get('ELASTICSEARCH_USERNAME'),
|
||||
password=config.get('ELASTICSEARCH_PASSWORD'),
|
||||
host=config.get("ELASTICSEARCH_HOST"),
|
||||
port=config.get("ELASTICSEARCH_PORT"),
|
||||
username=config.get("ELASTICSEARCH_USERNAME"),
|
||||
password=config.get("ELASTICSEARCH_PASSWORD"),
|
||||
),
|
||||
attributes=[]
|
||||
attributes=[],
|
||||
)
|
||||
|
Reference in New Issue
Block a user