Update Oracle db connection library and change connection pool to single connection (#18466)

This commit is contained in:
tmuife
2025-04-21 17:56:57 +08:00
committed by GitHub
parent 30c051d485
commit 7b6523e54d
3 changed files with 117 additions and 87 deletions

View File

@@ -2,12 +2,12 @@ import array
import json
import re
import uuid
from contextlib import contextmanager
from typing import Any
import jieba.posseg as pseg # type: ignore
import numpy
import oracledb
from oracledb.connection import Connection
from pydantic import BaseModel, model_validator
from configs import dify_config
@@ -70,6 +70,7 @@ class OracleVector(BaseVector):
super().__init__(collection_name)
self.pool = self._create_connection_pool(config)
self.table_name = f"embedding_{collection_name}"
self.config = config
def get_type(self) -> str:
return VectorType.ORACLE
@@ -107,16 +108,19 @@ class OracleVector(BaseVector):
outconverter=self.numpy_converter_out,
)
def _get_connection(self) -> Connection:
connection = oracledb.connect(user=self.config.user, password=self.config.password, dsn=self.config.dsn)
return connection
def _create_connection_pool(self, config: OracleVectorConfig):
pool_params = {
"user": config.user,
"password": config.password,
"dsn": config.dsn,
"min": 1,
"max": 50,
"max": 5,
"increment": 1,
}
if config.is_autonomous:
pool_params.update(
{
@@ -125,22 +129,8 @@ class OracleVector(BaseVector):
"wallet_password": config.wallet_password,
}
)
return oracledb.create_pool(**pool_params)
@contextmanager
def _get_cursor(self):
conn = self.pool.acquire()
conn.inputtypehandler = self.input_type_handler
conn.outputtypehandler = self.output_type_handler
cur = conn.cursor()
try:
yield cur
finally:
cur.close()
conn.commit()
conn.close()
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
dimension = len(embeddings[0])
self._create_collection(dimension)
@@ -162,41 +152,68 @@ class OracleVector(BaseVector):
numpy.array(embeddings[i]),
)
)
# print(f"INSERT INTO {self.table_name} (id, text, meta, embedding) VALUES (:1, :2, :3, :4)")
with self._get_cursor() as cur:
cur.executemany(
f"INSERT INTO {self.table_name} (id, text, meta, embedding) VALUES (:1, :2, :3, :4)", values
)
with self._get_connection() as conn:
conn.inputtypehandler = self.input_type_handler
conn.outputtypehandler = self.output_type_handler
# with conn.cursor() as cur:
# cur.executemany(
# f"INSERT INTO {self.table_name} (id, text, meta, embedding) VALUES (:1, :2, :3, :4)", values
# )
# conn.commit()
for value in values:
with conn.cursor() as cur:
try:
cur.execute(
f"""INSERT INTO {self.table_name} (id, text, meta, embedding)
VALUES (:1, :2, :3, :4)""",
value,
)
conn.commit()
except Exception as e:
print(e)
conn.close()
return pks
def text_exists(self, id: str) -> bool:
with self._get_cursor() as cur:
cur.execute(f"SELECT id FROM {self.table_name} WHERE id = '%s'" % (id,))
return cur.fetchone() is not None
with self._get_connection() as conn:
with conn.cursor() as cur:
cur.execute(f"SELECT id FROM {self.table_name} WHERE id = '%s'" % (id,))
return cur.fetchone() is not None
conn.close()
def get_by_ids(self, ids: list[str]) -> list[Document]:
with self._get_cursor() as cur:
cur.execute(f"SELECT meta, text FROM {self.table_name} WHERE id IN %s", (tuple(ids),))
docs = []
for record in cur:
docs.append(Document(page_content=record[1], metadata=record[0]))
with self._get_connection() as conn:
with conn.cursor() as cur:
cur.execute(f"SELECT meta, text FROM {self.table_name} WHERE id IN %s", (tuple(ids),))
docs = []
for record in cur:
docs.append(Document(page_content=record[1], metadata=record[0]))
self.pool.release(connection=conn)
conn.close()
return docs
def delete_by_ids(self, ids: list[str]) -> None:
if not ids:
return
with self._get_cursor() as cur:
cur.execute(f"DELETE FROM {self.table_name} WHERE id IN %s" % (tuple(ids),))
with self._get_connection() as conn:
with conn.cursor() as cur:
cur.execute(f"DELETE FROM {self.table_name} WHERE id IN %s" % (tuple(ids),))
conn.commit()
conn.close()
def delete_by_metadata_field(self, key: str, value: str) -> None:
with self._get_cursor() as cur:
cur.execute(f"DELETE FROM {self.table_name} WHERE meta->>%s = %s", (key, value))
with self._get_connection() as conn:
with conn.cursor() as cur:
cur.execute(f"DELETE FROM {self.table_name} WHERE meta->>%s = %s", (key, value))
conn.commit()
conn.close()
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
"""
Search the nearest neighbors to a vector.
:param query_vector: The input vector to search for similar items.
:param top_k: The number of nearest neighbors to return, default is 5.
:return: List of Documents that are nearest to the query vector.
"""
top_k = kwargs.get("top_k", 4)
@@ -205,20 +222,25 @@ class OracleVector(BaseVector):
if document_ids_filter:
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
where_clause = f"WHERE metadata->>'document_id' in ({document_ids})"
with self._get_cursor() as cur:
cur.execute(
f"SELECT meta, text, vector_distance(embedding,:1) AS distance FROM {self.table_name}"
f" {where_clause} ORDER BY distance fetch first {top_k} rows only",
[numpy.array(query_vector)],
)
docs = []
score_threshold = float(kwargs.get("score_threshold") or 0.0)
for record in cur:
metadata, text, distance = record
score = 1 - distance
metadata["score"] = score
if score > score_threshold:
docs.append(Document(page_content=text, metadata=metadata))
with self._get_connection() as conn:
conn.inputtypehandler = self.input_type_handler
conn.outputtypehandler = self.output_type_handler
with conn.cursor() as cur:
cur.execute(
f"""SELECT meta, text, vector_distance(embedding,(select to_vector(:1) from dual),cosine)
AS distance FROM {self.table_name}
{where_clause} ORDER BY distance fetch first {top_k} rows only""",
[numpy.array(query_vector)],
)
docs = []
score_threshold = float(kwargs.get("score_threshold") or 0.0)
for record in cur:
metadata, text, distance = record
score = 1 - distance
metadata["score"] = score
if score > score_threshold:
docs.append(Document(page_content=text, metadata=metadata))
conn.close()
return docs
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
@@ -228,7 +250,7 @@ class OracleVector(BaseVector):
top_k = kwargs.get("top_k", 5)
# just not implement fetch by score_threshold now, may be later
# score_threshold = float(kwargs.get("score_threshold") or 0.0)
score_threshold = float(kwargs.get("score_threshold") or 0.0)
if len(query) > 0:
# Check which language the query is in
zh_pattern = re.compile("[\u4e00-\u9fa5]+")
@@ -239,7 +261,7 @@ class OracleVector(BaseVector):
words = pseg.cut(query)
current_entity = ""
for word, pos in words:
if pos in {"nr", "Ng", "eng", "nz", "n", "ORG", "v"}: # nr: 人名ns: 地名nt: 机构名
if pos in {"nr", "Ng", "eng", "nz", "n", "ORG", "v"}: # nr: 人名, ns: 地名, nt: 机构名
current_entity += word
else:
if current_entity:
@@ -260,30 +282,35 @@ class OracleVector(BaseVector):
for token in all_tokens:
if token not in stop_words:
entities.append(token)
with self._get_cursor() as cur:
document_ids_filter = kwargs.get("document_ids_filter")
where_clause = ""
if document_ids_filter:
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
where_clause = f" AND metadata->>'document_id' in ({document_ids}) "
cur.execute(
f"select meta, text, embedding FROM {self.table_name}"
f"WHERE CONTAINS(text, :1, 1) > 0 {where_clause} "
f"order by score(1) desc fetch first {top_k} rows only",
[" ACCUM ".join(entities)],
)
docs = []
for record in cur:
metadata, text, embedding = record
docs.append(Document(page_content=text, vector=embedding, metadata=metadata))
with self._get_connection() as conn:
with conn.cursor() as cur:
document_ids_filter = kwargs.get("document_ids_filter")
where_clause = ""
if document_ids_filter:
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
where_clause = f" AND metadata->>'document_id' in ({document_ids}) "
cur.execute(
f"""select meta, text, embedding FROM {self.table_name}
WHERE CONTAINS(text, :kk, 1) > 0 {where_clause}
order by score(1) desc fetch first {top_k} rows only""",
kk=" ACCUM ".join(entities),
)
docs = []
for record in cur:
metadata, text, embedding = record
docs.append(Document(page_content=text, vector=embedding, metadata=metadata))
conn.close()
return docs
else:
return [Document(page_content="", metadata={})]
return []
def delete(self) -> None:
with self._get_cursor() as cur:
cur.execute(f"DROP TABLE IF EXISTS {self.table_name} cascade constraints")
with self._get_connection() as conn:
with conn.cursor() as cur:
cur.execute(f"DROP TABLE IF EXISTS {self.table_name} cascade constraints")
conn.commit()
conn.close()
def _create_collection(self, dimension: int):
cache_key = f"vector_indexing_{self._collection_name}"
@@ -293,11 +320,14 @@ class OracleVector(BaseVector):
if redis_client.get(collection_exist_cache_key):
return
with self._get_cursor() as cur:
cur.execute(SQL_CREATE_TABLE.format(table_name=self.table_name))
redis_client.set(collection_exist_cache_key, 1, ex=3600)
with self._get_cursor() as cur:
cur.execute(SQL_CREATE_INDEX.format(table_name=self.table_name))
with self._get_connection() as conn:
with conn.cursor() as cur:
cur.execute(SQL_CREATE_TABLE.format(table_name=self.table_name))
redis_client.set(collection_exist_cache_key, 1, ex=3600)
with conn.cursor() as cur:
cur.execute(SQL_CREATE_INDEX.format(table_name=self.table_name))
conn.commit()
conn.close()
class OracleVectorFactory(AbstractVectorFactory):