chore(api/core): apply ruff reformatting (#7624)

This commit is contained in:
Bowen Liang
2024-09-10 17:00:20 +08:00
committed by GitHub
parent 178730266d
commit 2cf1187b32
724 changed files with 21180 additions and 21123 deletions

View File

@@ -26,15 +26,15 @@ class ElasticSearchConfig(BaseModel):
username: str
password: str
@model_validator(mode='before')
@model_validator(mode="before")
def validate_config(cls, values: dict) -> dict:
if not values['host']:
if not values["host"]:
raise ValueError("config HOST is required")
if not values['port']:
if not values["port"]:
raise ValueError("config PORT is required")
if not values['username']:
if not values["username"]:
raise ValueError("config USERNAME is required")
if not values['password']:
if not values["password"]:
raise ValueError("config PASSWORD is required")
return values
@@ -50,10 +50,10 @@ class ElasticSearchVector(BaseVector):
def _init_client(self, config: ElasticSearchConfig) -> Elasticsearch:
try:
parsed_url = urlparse(config.host)
if parsed_url.scheme in ['http', 'https']:
hosts = f'{config.host}:{config.port}'
if parsed_url.scheme in ["http", "https"]:
hosts = f"{config.host}:{config.port}"
else:
hosts = f'http://{config.host}:{config.port}'
hosts = f"http://{config.host}:{config.port}"
client = Elasticsearch(
hosts=hosts,
basic_auth=(config.username, config.password),
@@ -68,25 +68,27 @@ class ElasticSearchVector(BaseVector):
def _get_version(self) -> str:
info = self._client.info()
return info['version']['number']
return info["version"]["number"]
def _check_version(self):
if self._version < '8.0.0':
if self._version < "8.0.0":
raise ValueError("Elasticsearch vector database version must be greater than 8.0.0")
def get_type(self) -> str:
return 'elasticsearch'
return "elasticsearch"
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
uuids = self._get_uuids(documents)
for i in range(len(documents)):
self._client.index(index=self._collection_name,
id=uuids[i],
document={
Field.CONTENT_KEY.value: documents[i].page_content,
Field.VECTOR.value: embeddings[i] if embeddings[i] else None,
Field.METADATA_KEY.value: documents[i].metadata if documents[i].metadata else {}
})
self._client.index(
index=self._collection_name,
id=uuids[i],
document={
Field.CONTENT_KEY.value: documents[i].page_content,
Field.VECTOR.value: embeddings[i] if embeddings[i] else None,
Field.METADATA_KEY.value: documents[i].metadata if documents[i].metadata else {},
},
)
self._client.indices.refresh(index=self._collection_name)
return uuids
@@ -98,15 +100,9 @@ class ElasticSearchVector(BaseVector):
self._client.delete(index=self._collection_name, id=id)
def delete_by_metadata_field(self, key: str, value: str) -> None:
query_str = {
'query': {
'match': {
f'metadata.{key}': f'{value}'
}
}
}
query_str = {"query": {"match": {f"metadata.{key}": f"{value}"}}}
results = self._client.search(index=self._collection_name, body=query_str)
ids = [hit['_id'] for hit in results['hits']['hits']]
ids = [hit["_id"] for hit in results["hits"]["hits"]]
if ids:
self.delete_by_ids(ids)
@@ -115,44 +111,44 @@ class ElasticSearchVector(BaseVector):
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
top_k = kwargs.get("top_k", 10)
knn = {
"field": Field.VECTOR.value,
"query_vector": query_vector,
"k": top_k
}
knn = {"field": Field.VECTOR.value, "query_vector": query_vector, "k": top_k}
results = self._client.search(index=self._collection_name, knn=knn, size=top_k)
docs_and_scores = []
for hit in results['hits']['hits']:
for hit in results["hits"]["hits"]:
docs_and_scores.append(
(Document(page_content=hit['_source'][Field.CONTENT_KEY.value],
vector=hit['_source'][Field.VECTOR.value],
metadata=hit['_source'][Field.METADATA_KEY.value]), hit['_score']))
(
Document(
page_content=hit["_source"][Field.CONTENT_KEY.value],
vector=hit["_source"][Field.VECTOR.value],
metadata=hit["_source"][Field.METADATA_KEY.value],
),
hit["_score"],
)
)
docs = []
for doc, score in docs_and_scores:
score_threshold = kwargs.get("score_threshold", .0) if kwargs.get('score_threshold', .0) else 0.0
score_threshold = kwargs.get("score_threshold", 0.0) if kwargs.get("score_threshold", 0.0) else 0.0
if score > score_threshold:
doc.metadata['score'] = score
doc.metadata["score"] = score
docs.append(doc)
return docs
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
query_str = {
"match": {
Field.CONTENT_KEY.value: query
}
}
query_str = {"match": {Field.CONTENT_KEY.value: query}}
results = self._client.search(index=self._collection_name, query=query_str)
docs = []
for hit in results['hits']['hits']:
docs.append(Document(
page_content=hit['_source'][Field.CONTENT_KEY.value],
vector=hit['_source'][Field.VECTOR.value],
metadata=hit['_source'][Field.METADATA_KEY.value],
))
for hit in results["hits"]["hits"]:
docs.append(
Document(
page_content=hit["_source"][Field.CONTENT_KEY.value],
vector=hit["_source"][Field.VECTOR.value],
metadata=hit["_source"][Field.METADATA_KEY.value],
)
)
return docs
@@ -162,11 +158,11 @@ class ElasticSearchVector(BaseVector):
self.add_texts(texts, embeddings, **kwargs)
def create_collection(
self, embeddings: list, metadatas: Optional[list[dict]] = None, index_params: Optional[dict] = None
self, embeddings: list, metadatas: Optional[list[dict]] = None, index_params: Optional[dict] = None
):
lock_name = f'vector_indexing_lock_{self._collection_name}'
lock_name = f"vector_indexing_lock_{self._collection_name}"
with redis_client.lock(lock_name, timeout=20):
collection_exist_cache_key = f'vector_indexing_{self._collection_name}'
collection_exist_cache_key = f"vector_indexing_{self._collection_name}"
if redis_client.get(collection_exist_cache_key):
logger.info(f"Collection {self._collection_name} already exists.")
return
@@ -179,14 +175,14 @@ class ElasticSearchVector(BaseVector):
Field.VECTOR.value: { # Make sure the dimension is correct here
"type": "dense_vector",
"dims": dim,
"similarity": "cosine"
"similarity": "cosine",
},
Field.METADATA_KEY.value: {
"type": "object",
"properties": {
"doc_id": {"type": "keyword"} # Map doc_id to keyword type
}
}
},
},
}
}
self._client.indices.create(index=self._collection_name, mappings=mappings)
@@ -197,22 +193,21 @@ class ElasticSearchVector(BaseVector):
class ElasticSearchVectorFactory(AbstractVectorFactory):
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> ElasticSearchVector:
if dataset.index_struct_dict:
class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix']
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
collection_name = class_prefix
else:
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
dataset.index_struct = json.dumps(
self.gen_index_struct_dict(VectorType.ELASTICSEARCH, collection_name))
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.ELASTICSEARCH, collection_name))
config = current_app.config
return ElasticSearchVector(
index_name=collection_name,
config=ElasticSearchConfig(
host=config.get('ELASTICSEARCH_HOST'),
port=config.get('ELASTICSEARCH_PORT'),
username=config.get('ELASTICSEARCH_USERNAME'),
password=config.get('ELASTICSEARCH_PASSWORD'),
host=config.get("ELASTICSEARCH_HOST"),
port=config.get("ELASTICSEARCH_PORT"),
username=config.get("ELASTICSEARCH_USERNAME"),
password=config.get("ELASTICSEARCH_PASSWORD"),
),
attributes=[]
attributes=[],
)