Merge commit from fork
* fix(oraclevector): SQL Injection Signed-off-by: -LAN- <laipz8200@outlook.com> * fix(oraclevector): Remove bind variables from FETCH FIRST clause Oracle doesn't support bind variables in the FETCH FIRST clause. Fixed by using validated integers directly in the SQL string while maintaining proper input validation to prevent SQL injection. - Updated search_by_vector method to use validated top_k directly - Updated search_by_full_text method to use validated top_k directly - Adjusted parameter numbering for document_ids_filter placeholders 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <noreply@anthropic.com> --------- Signed-off-by: -LAN- <laipz8200@outlook.com> Co-authored-by: Claude <noreply@anthropic.com>
This commit is contained in:
@@ -188,14 +188,17 @@ class OracleVector(BaseVector):
|
|||||||
def text_exists(self, id: str) -> bool:
|
def text_exists(self, id: str) -> bool:
|
||||||
with self._get_connection() as conn:
|
with self._get_connection() as conn:
|
||||||
with conn.cursor() as cur:
|
with conn.cursor() as cur:
|
||||||
cur.execute(f"SELECT id FROM {self.table_name} WHERE id = '%s'" % (id,))
|
cur.execute(f"SELECT id FROM {self.table_name} WHERE id = :1", (id,))
|
||||||
return cur.fetchone() is not None
|
return cur.fetchone() is not None
|
||||||
conn.close()
|
conn.close()
|
||||||
|
|
||||||
def get_by_ids(self, ids: list[str]) -> list[Document]:
|
def get_by_ids(self, ids: list[str]) -> list[Document]:
|
||||||
|
if not ids:
|
||||||
|
return []
|
||||||
with self._get_connection() as conn:
|
with self._get_connection() as conn:
|
||||||
with conn.cursor() as cur:
|
with conn.cursor() as cur:
|
||||||
cur.execute(f"SELECT meta, text FROM {self.table_name} WHERE id IN %s", (tuple(ids),))
|
placeholders = ", ".join(f":{i + 1}" for i in range(len(ids)))
|
||||||
|
cur.execute(f"SELECT meta, text FROM {self.table_name} WHERE id IN ({placeholders})", ids)
|
||||||
docs = []
|
docs = []
|
||||||
for record in cur:
|
for record in cur:
|
||||||
docs.append(Document(page_content=record[1], metadata=record[0]))
|
docs.append(Document(page_content=record[1], metadata=record[0]))
|
||||||
@@ -208,14 +211,15 @@ class OracleVector(BaseVector):
|
|||||||
return
|
return
|
||||||
with self._get_connection() as conn:
|
with self._get_connection() as conn:
|
||||||
with conn.cursor() as cur:
|
with conn.cursor() as cur:
|
||||||
cur.execute(f"DELETE FROM {self.table_name} WHERE id IN %s" % (tuple(ids),))
|
placeholders = ", ".join(f":{i + 1}" for i in range(len(ids)))
|
||||||
|
cur.execute(f"DELETE FROM {self.table_name} WHERE id IN ({placeholders})", ids)
|
||||||
conn.commit()
|
conn.commit()
|
||||||
conn.close()
|
conn.close()
|
||||||
|
|
||||||
def delete_by_metadata_field(self, key: str, value: str) -> None:
|
def delete_by_metadata_field(self, key: str, value: str) -> None:
|
||||||
with self._get_connection() as conn:
|
with self._get_connection() as conn:
|
||||||
with conn.cursor() as cur:
|
with conn.cursor() as cur:
|
||||||
cur.execute(f"DELETE FROM {self.table_name} WHERE meta->>%s = %s", (key, value))
|
cur.execute(f"DELETE FROM {self.table_name} WHERE JSON_VALUE(meta, '$." + key + "') = :1", (value,))
|
||||||
conn.commit()
|
conn.commit()
|
||||||
conn.close()
|
conn.close()
|
||||||
|
|
||||||
@@ -227,12 +231,20 @@ class OracleVector(BaseVector):
|
|||||||
:param top_k: The number of nearest neighbors to return, default is 5.
|
:param top_k: The number of nearest neighbors to return, default is 5.
|
||||||
:return: List of Documents that are nearest to the query vector.
|
:return: List of Documents that are nearest to the query vector.
|
||||||
"""
|
"""
|
||||||
|
# Validate and sanitize top_k to prevent SQL injection
|
||||||
top_k = kwargs.get("top_k", 4)
|
top_k = kwargs.get("top_k", 4)
|
||||||
|
if not isinstance(top_k, int) or top_k <= 0 or top_k > 10000:
|
||||||
|
top_k = 4 # Use default if invalid
|
||||||
|
|
||||||
document_ids_filter = kwargs.get("document_ids_filter")
|
document_ids_filter = kwargs.get("document_ids_filter")
|
||||||
where_clause = ""
|
where_clause = ""
|
||||||
|
params = [numpy.array(query_vector)]
|
||||||
|
|
||||||
if document_ids_filter:
|
if document_ids_filter:
|
||||||
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
|
placeholders = ", ".join(f":{i + 2}" for i in range(len(document_ids_filter)))
|
||||||
where_clause = f"WHERE metadata->>'document_id' in ({document_ids})"
|
where_clause = f"WHERE JSON_VALUE(meta, '$.document_id') IN ({placeholders})"
|
||||||
|
params.extend(document_ids_filter)
|
||||||
|
|
||||||
with self._get_connection() as conn:
|
with self._get_connection() as conn:
|
||||||
conn.inputtypehandler = self.input_type_handler
|
conn.inputtypehandler = self.input_type_handler
|
||||||
conn.outputtypehandler = self.output_type_handler
|
conn.outputtypehandler = self.output_type_handler
|
||||||
@@ -241,7 +253,7 @@ class OracleVector(BaseVector):
|
|||||||
f"""SELECT meta, text, vector_distance(embedding,(select to_vector(:1) from dual),cosine)
|
f"""SELECT meta, text, vector_distance(embedding,(select to_vector(:1) from dual),cosine)
|
||||||
AS distance FROM {self.table_name}
|
AS distance FROM {self.table_name}
|
||||||
{where_clause} ORDER BY distance fetch first {top_k} rows only""",
|
{where_clause} ORDER BY distance fetch first {top_k} rows only""",
|
||||||
[numpy.array(query_vector)],
|
params,
|
||||||
)
|
)
|
||||||
docs = []
|
docs = []
|
||||||
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||||
@@ -259,7 +271,10 @@ class OracleVector(BaseVector):
|
|||||||
import nltk # type: ignore
|
import nltk # type: ignore
|
||||||
from nltk.corpus import stopwords # type: ignore
|
from nltk.corpus import stopwords # type: ignore
|
||||||
|
|
||||||
|
# Validate and sanitize top_k to prevent SQL injection
|
||||||
top_k = kwargs.get("top_k", 5)
|
top_k = kwargs.get("top_k", 5)
|
||||||
|
if not isinstance(top_k, int) or top_k <= 0 or top_k > 10000:
|
||||||
|
top_k = 5 # Use default if invalid
|
||||||
# just not implement fetch by score_threshold now, may be later
|
# just not implement fetch by score_threshold now, may be later
|
||||||
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||||
if len(query) > 0:
|
if len(query) > 0:
|
||||||
@@ -297,14 +312,21 @@ class OracleVector(BaseVector):
|
|||||||
with conn.cursor() as cur:
|
with conn.cursor() as cur:
|
||||||
document_ids_filter = kwargs.get("document_ids_filter")
|
document_ids_filter = kwargs.get("document_ids_filter")
|
||||||
where_clause = ""
|
where_clause = ""
|
||||||
|
params: dict[str, Any] = {"kk": " ACCUM ".join(entities)}
|
||||||
|
|
||||||
if document_ids_filter:
|
if document_ids_filter:
|
||||||
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
|
placeholders = []
|
||||||
where_clause = f" AND metadata->>'document_id' in ({document_ids}) "
|
for i, doc_id in enumerate(document_ids_filter):
|
||||||
|
param_name = f"doc_id_{i}"
|
||||||
|
placeholders.append(f":{param_name}")
|
||||||
|
params[param_name] = doc_id
|
||||||
|
where_clause = f" AND JSON_VALUE(meta, '$.document_id') IN ({', '.join(placeholders)}) "
|
||||||
|
|
||||||
cur.execute(
|
cur.execute(
|
||||||
f"""select meta, text, embedding FROM {self.table_name}
|
f"""select meta, text, embedding FROM {self.table_name}
|
||||||
WHERE CONTAINS(text, :kk, 1) > 0 {where_clause}
|
WHERE CONTAINS(text, :kk, 1) > 0 {where_clause}
|
||||||
order by score(1) desc fetch first {top_k} rows only""",
|
order by score(1) desc fetch first {top_k} rows only""",
|
||||||
kk=" ACCUM ".join(entities),
|
params,
|
||||||
)
|
)
|
||||||
docs = []
|
docs = []
|
||||||
for record in cur:
|
for record in cur:
|
||||||
|
Reference in New Issue
Block a user