chore(api/core): apply ruff reformatting (#7624)
This commit is contained in:
@@ -25,16 +25,11 @@ class TencentConfig(BaseModel):
|
||||
database: Optional[str]
|
||||
index_type: str = "HNSW"
|
||||
metric_type: str = "L2"
|
||||
shard: int = 1,
|
||||
replicas: int = 2,
|
||||
shard: int = (1,)
|
||||
replicas: int = (2,)
|
||||
|
||||
def to_tencent_params(self):
|
||||
return {
|
||||
'url': self.url,
|
||||
'username': self.username,
|
||||
'key': self.api_key,
|
||||
'timeout': self.timeout
|
||||
}
|
||||
return {"url": self.url, "username": self.username, "key": self.api_key, "timeout": self.timeout}
|
||||
|
||||
|
||||
class TencentVector(BaseVector):
|
||||
@@ -61,13 +56,10 @@ class TencentVector(BaseVector):
|
||||
return self._client.create_database(database_name=self._client_config.database)
|
||||
|
||||
def get_type(self) -> str:
|
||||
return 'tencent'
|
||||
return "tencent"
|
||||
|
||||
def to_index_struct(self) -> dict:
|
||||
return {
|
||||
"type": self.get_type(),
|
||||
"vector_store": {"class_prefix": self._collection_name}
|
||||
}
|
||||
return {"type": self.get_type(), "vector_store": {"class_prefix": self._collection_name}}
|
||||
|
||||
def _has_collection(self) -> bool:
|
||||
collections = self._db.list_collections()
|
||||
@@ -77,9 +69,9 @@ class TencentVector(BaseVector):
|
||||
return False
|
||||
|
||||
def _create_collection(self, dimension: int) -> None:
|
||||
lock_name = 'vector_indexing_lock_{}'.format(self._collection_name)
|
||||
lock_name = "vector_indexing_lock_{}".format(self._collection_name)
|
||||
with redis_client.lock(lock_name, timeout=20):
|
||||
collection_exist_cache_key = 'vector_indexing_{}'.format(self._collection_name)
|
||||
collection_exist_cache_key = "vector_indexing_{}".format(self._collection_name)
|
||||
if redis_client.get(collection_exist_cache_key):
|
||||
return
|
||||
|
||||
@@ -101,9 +93,7 @@ class TencentVector(BaseVector):
|
||||
raise ValueError("unsupported metric_type")
|
||||
params = vdb_index.HNSWParams(m=16, efconstruction=200)
|
||||
index = vdb_index.Index(
|
||||
vdb_index.FilterIndex(
|
||||
self.field_id, enum.FieldType.String, enum.IndexType.PRIMARY_KEY
|
||||
),
|
||||
vdb_index.FilterIndex(self.field_id, enum.FieldType.String, enum.IndexType.PRIMARY_KEY),
|
||||
vdb_index.VectorIndex(
|
||||
self.field_vector,
|
||||
dimension,
|
||||
@@ -111,12 +101,8 @@ class TencentVector(BaseVector):
|
||||
metric_type,
|
||||
params,
|
||||
),
|
||||
vdb_index.FilterIndex(
|
||||
self.field_text, enum.FieldType.String, enum.IndexType.FILTER
|
||||
),
|
||||
vdb_index.FilterIndex(
|
||||
self.field_metadata, enum.FieldType.String, enum.IndexType.FILTER
|
||||
),
|
||||
vdb_index.FilterIndex(self.field_text, enum.FieldType.String, enum.IndexType.FILTER),
|
||||
vdb_index.FilterIndex(self.field_metadata, enum.FieldType.String, enum.IndexType.FILTER),
|
||||
)
|
||||
|
||||
self._db.create_collection(
|
||||
@@ -163,15 +149,14 @@ class TencentVector(BaseVector):
|
||||
self._db.collection(self._collection_name).delete(filter=Filter(Filter.In(key, [value])))
|
||||
|
||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||
|
||||
res = self._db.collection(self._collection_name).search(vectors=[query_vector],
|
||||
params=document.HNSWSearchParams(
|
||||
ef=kwargs.get("ef", 10)),
|
||||
retrieve_vector=False,
|
||||
limit=kwargs.get('top_k', 4),
|
||||
timeout=self._client_config.timeout,
|
||||
)
|
||||
score_threshold = kwargs.get("score_threshold", .0) if kwargs.get('score_threshold', .0) else 0.0
|
||||
res = self._db.collection(self._collection_name).search(
|
||||
vectors=[query_vector],
|
||||
params=document.HNSWSearchParams(ef=kwargs.get("ef", 10)),
|
||||
retrieve_vector=False,
|
||||
limit=kwargs.get("top_k", 4),
|
||||
timeout=self._client_config.timeout,
|
||||
)
|
||||
score_threshold = kwargs.get("score_threshold", 0.0) if kwargs.get("score_threshold", 0.0) else 0.0
|
||||
return self._get_search_res(res, score_threshold)
|
||||
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
@@ -200,15 +185,13 @@ class TencentVector(BaseVector):
|
||||
|
||||
class TencentVectorFactory(AbstractVectorFactory):
|
||||
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> TencentVector:
|
||||
|
||||
if dataset.index_struct_dict:
|
||||
class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix']
|
||||
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
|
||||
collection_name = class_prefix.lower()
|
||||
else:
|
||||
dataset_id = dataset.id
|
||||
collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower()
|
||||
dataset.index_struct = json.dumps(
|
||||
self.gen_index_struct_dict(VectorType.TENCENT, collection_name))
|
||||
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.TENCENT, collection_name))
|
||||
|
||||
return TencentVector(
|
||||
collection_name=collection_name,
|
||||
@@ -220,5 +203,5 @@ class TencentVectorFactory(AbstractVectorFactory):
|
||||
database=dify_config.TENCENT_VECTOR_DB_DATABASE,
|
||||
shard=dify_config.TENCENT_VECTOR_DB_SHARD,
|
||||
replicas=dify_config.TENCENT_VECTOR_DB_REPLICAS,
|
||||
)
|
||||
),
|
||||
)
|
||||
|
Reference in New Issue
Block a user