Update Oracle db connection library and change connection pool to single connection (#18466)
This commit is contained in:
@@ -2,12 +2,12 @@ import array
|
||||
import json
|
||||
import re
|
||||
import uuid
|
||||
from contextlib import contextmanager
|
||||
from typing import Any
|
||||
|
||||
import jieba.posseg as pseg # type: ignore
|
||||
import numpy
|
||||
import oracledb
|
||||
from oracledb.connection import Connection
|
||||
from pydantic import BaseModel, model_validator
|
||||
|
||||
from configs import dify_config
|
||||
@@ -70,6 +70,7 @@ class OracleVector(BaseVector):
|
||||
super().__init__(collection_name)
|
||||
self.pool = self._create_connection_pool(config)
|
||||
self.table_name = f"embedding_{collection_name}"
|
||||
self.config = config
|
||||
|
||||
def get_type(self) -> str:
|
||||
return VectorType.ORACLE
|
||||
@@ -107,16 +108,19 @@ class OracleVector(BaseVector):
|
||||
outconverter=self.numpy_converter_out,
|
||||
)
|
||||
|
||||
def _get_connection(self) -> Connection:
|
||||
connection = oracledb.connect(user=self.config.user, password=self.config.password, dsn=self.config.dsn)
|
||||
return connection
|
||||
|
||||
def _create_connection_pool(self, config: OracleVectorConfig):
|
||||
pool_params = {
|
||||
"user": config.user,
|
||||
"password": config.password,
|
||||
"dsn": config.dsn,
|
||||
"min": 1,
|
||||
"max": 50,
|
||||
"max": 5,
|
||||
"increment": 1,
|
||||
}
|
||||
|
||||
if config.is_autonomous:
|
||||
pool_params.update(
|
||||
{
|
||||
@@ -125,22 +129,8 @@ class OracleVector(BaseVector):
|
||||
"wallet_password": config.wallet_password,
|
||||
}
|
||||
)
|
||||
|
||||
return oracledb.create_pool(**pool_params)
|
||||
|
||||
@contextmanager
|
||||
def _get_cursor(self):
|
||||
conn = self.pool.acquire()
|
||||
conn.inputtypehandler = self.input_type_handler
|
||||
conn.outputtypehandler = self.output_type_handler
|
||||
cur = conn.cursor()
|
||||
try:
|
||||
yield cur
|
||||
finally:
|
||||
cur.close()
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
dimension = len(embeddings[0])
|
||||
self._create_collection(dimension)
|
||||
@@ -162,41 +152,68 @@ class OracleVector(BaseVector):
|
||||
numpy.array(embeddings[i]),
|
||||
)
|
||||
)
|
||||
# print(f"INSERT INTO {self.table_name} (id, text, meta, embedding) VALUES (:1, :2, :3, :4)")
|
||||
with self._get_cursor() as cur:
|
||||
cur.executemany(
|
||||
f"INSERT INTO {self.table_name} (id, text, meta, embedding) VALUES (:1, :2, :3, :4)", values
|
||||
)
|
||||
with self._get_connection() as conn:
|
||||
conn.inputtypehandler = self.input_type_handler
|
||||
conn.outputtypehandler = self.output_type_handler
|
||||
# with conn.cursor() as cur:
|
||||
# cur.executemany(
|
||||
# f"INSERT INTO {self.table_name} (id, text, meta, embedding) VALUES (:1, :2, :3, :4)", values
|
||||
# )
|
||||
# conn.commit()
|
||||
for value in values:
|
||||
with conn.cursor() as cur:
|
||||
try:
|
||||
cur.execute(
|
||||
f"""INSERT INTO {self.table_name} (id, text, meta, embedding)
|
||||
VALUES (:1, :2, :3, :4)""",
|
||||
value,
|
||||
)
|
||||
conn.commit()
|
||||
except Exception as e:
|
||||
print(e)
|
||||
conn.close()
|
||||
return pks
|
||||
|
||||
def text_exists(self, id: str) -> bool:
|
||||
with self._get_cursor() as cur:
|
||||
cur.execute(f"SELECT id FROM {self.table_name} WHERE id = '%s'" % (id,))
|
||||
return cur.fetchone() is not None
|
||||
with self._get_connection() as conn:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(f"SELECT id FROM {self.table_name} WHERE id = '%s'" % (id,))
|
||||
return cur.fetchone() is not None
|
||||
conn.close()
|
||||
|
||||
def get_by_ids(self, ids: list[str]) -> list[Document]:
|
||||
with self._get_cursor() as cur:
|
||||
cur.execute(f"SELECT meta, text FROM {self.table_name} WHERE id IN %s", (tuple(ids),))
|
||||
docs = []
|
||||
for record in cur:
|
||||
docs.append(Document(page_content=record[1], metadata=record[0]))
|
||||
with self._get_connection() as conn:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(f"SELECT meta, text FROM {self.table_name} WHERE id IN %s", (tuple(ids),))
|
||||
docs = []
|
||||
for record in cur:
|
||||
docs.append(Document(page_content=record[1], metadata=record[0]))
|
||||
self.pool.release(connection=conn)
|
||||
conn.close()
|
||||
return docs
|
||||
|
||||
def delete_by_ids(self, ids: list[str]) -> None:
|
||||
if not ids:
|
||||
return
|
||||
with self._get_cursor() as cur:
|
||||
cur.execute(f"DELETE FROM {self.table_name} WHERE id IN %s" % (tuple(ids),))
|
||||
with self._get_connection() as conn:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(f"DELETE FROM {self.table_name} WHERE id IN %s" % (tuple(ids),))
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
def delete_by_metadata_field(self, key: str, value: str) -> None:
|
||||
with self._get_cursor() as cur:
|
||||
cur.execute(f"DELETE FROM {self.table_name} WHERE meta->>%s = %s", (key, value))
|
||||
with self._get_connection() as conn:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(f"DELETE FROM {self.table_name} WHERE meta->>%s = %s", (key, value))
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||
"""
|
||||
Search the nearest neighbors to a vector.
|
||||
|
||||
:param query_vector: The input vector to search for similar items.
|
||||
:param top_k: The number of nearest neighbors to return, default is 5.
|
||||
:return: List of Documents that are nearest to the query vector.
|
||||
"""
|
||||
top_k = kwargs.get("top_k", 4)
|
||||
@@ -205,20 +222,25 @@ class OracleVector(BaseVector):
|
||||
if document_ids_filter:
|
||||
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
|
||||
where_clause = f"WHERE metadata->>'document_id' in ({document_ids})"
|
||||
with self._get_cursor() as cur:
|
||||
cur.execute(
|
||||
f"SELECT meta, text, vector_distance(embedding,:1) AS distance FROM {self.table_name}"
|
||||
f" {where_clause} ORDER BY distance fetch first {top_k} rows only",
|
||||
[numpy.array(query_vector)],
|
||||
)
|
||||
docs = []
|
||||
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||
for record in cur:
|
||||
metadata, text, distance = record
|
||||
score = 1 - distance
|
||||
metadata["score"] = score
|
||||
if score > score_threshold:
|
||||
docs.append(Document(page_content=text, metadata=metadata))
|
||||
with self._get_connection() as conn:
|
||||
conn.inputtypehandler = self.input_type_handler
|
||||
conn.outputtypehandler = self.output_type_handler
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
f"""SELECT meta, text, vector_distance(embedding,(select to_vector(:1) from dual),cosine)
|
||||
AS distance FROM {self.table_name}
|
||||
{where_clause} ORDER BY distance fetch first {top_k} rows only""",
|
||||
[numpy.array(query_vector)],
|
||||
)
|
||||
docs = []
|
||||
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||
for record in cur:
|
||||
metadata, text, distance = record
|
||||
score = 1 - distance
|
||||
metadata["score"] = score
|
||||
if score > score_threshold:
|
||||
docs.append(Document(page_content=text, metadata=metadata))
|
||||
conn.close()
|
||||
return docs
|
||||
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
@@ -228,7 +250,7 @@ class OracleVector(BaseVector):
|
||||
|
||||
top_k = kwargs.get("top_k", 5)
|
||||
# just not implement fetch by score_threshold now, may be later
|
||||
# score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||
if len(query) > 0:
|
||||
# Check which language the query is in
|
||||
zh_pattern = re.compile("[\u4e00-\u9fa5]+")
|
||||
@@ -239,7 +261,7 @@ class OracleVector(BaseVector):
|
||||
words = pseg.cut(query)
|
||||
current_entity = ""
|
||||
for word, pos in words:
|
||||
if pos in {"nr", "Ng", "eng", "nz", "n", "ORG", "v"}: # nr: 人名,ns: 地名,nt: 机构名
|
||||
if pos in {"nr", "Ng", "eng", "nz", "n", "ORG", "v"}: # nr: 人名, ns: 地名, nt: 机构名
|
||||
current_entity += word
|
||||
else:
|
||||
if current_entity:
|
||||
@@ -260,30 +282,35 @@ class OracleVector(BaseVector):
|
||||
for token in all_tokens:
|
||||
if token not in stop_words:
|
||||
entities.append(token)
|
||||
with self._get_cursor() as cur:
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
where_clause = ""
|
||||
if document_ids_filter:
|
||||
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
|
||||
where_clause = f" AND metadata->>'document_id' in ({document_ids}) "
|
||||
cur.execute(
|
||||
f"select meta, text, embedding FROM {self.table_name}"
|
||||
f"WHERE CONTAINS(text, :1, 1) > 0 {where_clause} "
|
||||
f"order by score(1) desc fetch first {top_k} rows only",
|
||||
[" ACCUM ".join(entities)],
|
||||
)
|
||||
docs = []
|
||||
for record in cur:
|
||||
metadata, text, embedding = record
|
||||
docs.append(Document(page_content=text, vector=embedding, metadata=metadata))
|
||||
with self._get_connection() as conn:
|
||||
with conn.cursor() as cur:
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
where_clause = ""
|
||||
if document_ids_filter:
|
||||
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
|
||||
where_clause = f" AND metadata->>'document_id' in ({document_ids}) "
|
||||
cur.execute(
|
||||
f"""select meta, text, embedding FROM {self.table_name}
|
||||
WHERE CONTAINS(text, :kk, 1) > 0 {where_clause}
|
||||
order by score(1) desc fetch first {top_k} rows only""",
|
||||
kk=" ACCUM ".join(entities),
|
||||
)
|
||||
docs = []
|
||||
for record in cur:
|
||||
metadata, text, embedding = record
|
||||
docs.append(Document(page_content=text, vector=embedding, metadata=metadata))
|
||||
conn.close()
|
||||
return docs
|
||||
else:
|
||||
return [Document(page_content="", metadata={})]
|
||||
return []
|
||||
|
||||
def delete(self) -> None:
|
||||
with self._get_cursor() as cur:
|
||||
cur.execute(f"DROP TABLE IF EXISTS {self.table_name} cascade constraints")
|
||||
with self._get_connection() as conn:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(f"DROP TABLE IF EXISTS {self.table_name} cascade constraints")
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
def _create_collection(self, dimension: int):
|
||||
cache_key = f"vector_indexing_{self._collection_name}"
|
||||
@@ -293,11 +320,14 @@ class OracleVector(BaseVector):
|
||||
if redis_client.get(collection_exist_cache_key):
|
||||
return
|
||||
|
||||
with self._get_cursor() as cur:
|
||||
cur.execute(SQL_CREATE_TABLE.format(table_name=self.table_name))
|
||||
redis_client.set(collection_exist_cache_key, 1, ex=3600)
|
||||
with self._get_cursor() as cur:
|
||||
cur.execute(SQL_CREATE_INDEX.format(table_name=self.table_name))
|
||||
with self._get_connection() as conn:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(SQL_CREATE_TABLE.format(table_name=self.table_name))
|
||||
redis_client.set(collection_exist_cache_key, 1, ex=3600)
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(SQL_CREATE_INDEX.format(table_name=self.table_name))
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
|
||||
class OracleVectorFactory(AbstractVectorFactory):
|
||||
|
Reference in New Issue
Block a user