Merge commit from fork

* fix(oraclevector): SQL Injection

Signed-off-by: -LAN- <laipz8200@outlook.com>

* fix(oraclevector): Remove bind variables from FETCH FIRST clause

Oracle doesn't support bind variables in the FETCH FIRST clause.
Fixed by using validated integers directly in the SQL string while
maintaining proper input validation to prevent SQL injection.

- Updated search_by_vector method to use validated top_k directly
- Updated search_by_full_text method to use validated top_k directly
- Adjusted parameter numbering for document_ids_filter placeholders

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>

---------

Signed-off-by: -LAN- <laipz8200@outlook.com>
Co-authored-by: Claude <noreply@anthropic.com>
This commit is contained in:
-LAN-
2025-08-26 13:51:23 +08:00
committed by GitHub
parent eb3a031964
commit 04954918a5

View File

@@ -188,14 +188,17 @@ class OracleVector(BaseVector):
def text_exists(self, id: str) -> bool: def text_exists(self, id: str) -> bool:
with self._get_connection() as conn: with self._get_connection() as conn:
with conn.cursor() as cur: with conn.cursor() as cur:
cur.execute(f"SELECT id FROM {self.table_name} WHERE id = '%s'" % (id,)) cur.execute(f"SELECT id FROM {self.table_name} WHERE id = :1", (id,))
return cur.fetchone() is not None return cur.fetchone() is not None
conn.close() conn.close()
def get_by_ids(self, ids: list[str]) -> list[Document]: def get_by_ids(self, ids: list[str]) -> list[Document]:
if not ids:
return []
with self._get_connection() as conn: with self._get_connection() as conn:
with conn.cursor() as cur: with conn.cursor() as cur:
cur.execute(f"SELECT meta, text FROM {self.table_name} WHERE id IN %s", (tuple(ids),)) placeholders = ", ".join(f":{i + 1}" for i in range(len(ids)))
cur.execute(f"SELECT meta, text FROM {self.table_name} WHERE id IN ({placeholders})", ids)
docs = [] docs = []
for record in cur: for record in cur:
docs.append(Document(page_content=record[1], metadata=record[0])) docs.append(Document(page_content=record[1], metadata=record[0]))
@@ -208,14 +211,15 @@ class OracleVector(BaseVector):
return return
with self._get_connection() as conn: with self._get_connection() as conn:
with conn.cursor() as cur: with conn.cursor() as cur:
cur.execute(f"DELETE FROM {self.table_name} WHERE id IN %s" % (tuple(ids),)) placeholders = ", ".join(f":{i + 1}" for i in range(len(ids)))
cur.execute(f"DELETE FROM {self.table_name} WHERE id IN ({placeholders})", ids)
conn.commit() conn.commit()
conn.close() conn.close()
def delete_by_metadata_field(self, key: str, value: str) -> None: def delete_by_metadata_field(self, key: str, value: str) -> None:
with self._get_connection() as conn: with self._get_connection() as conn:
with conn.cursor() as cur: with conn.cursor() as cur:
cur.execute(f"DELETE FROM {self.table_name} WHERE meta->>%s = %s", (key, value)) cur.execute(f"DELETE FROM {self.table_name} WHERE JSON_VALUE(meta, '$." + key + "') = :1", (value,))
conn.commit() conn.commit()
conn.close() conn.close()
@@ -227,12 +231,20 @@ class OracleVector(BaseVector):
:param top_k: The number of nearest neighbors to return, default is 5. :param top_k: The number of nearest neighbors to return, default is 5.
:return: List of Documents that are nearest to the query vector. :return: List of Documents that are nearest to the query vector.
""" """
# Validate and sanitize top_k to prevent SQL injection
top_k = kwargs.get("top_k", 4) top_k = kwargs.get("top_k", 4)
if not isinstance(top_k, int) or top_k <= 0 or top_k > 10000:
top_k = 4 # Use default if invalid
document_ids_filter = kwargs.get("document_ids_filter") document_ids_filter = kwargs.get("document_ids_filter")
where_clause = "" where_clause = ""
params = [numpy.array(query_vector)]
if document_ids_filter: if document_ids_filter:
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter) placeholders = ", ".join(f":{i + 2}" for i in range(len(document_ids_filter)))
where_clause = f"WHERE metadata->>'document_id' in ({document_ids})" where_clause = f"WHERE JSON_VALUE(meta, '$.document_id') IN ({placeholders})"
params.extend(document_ids_filter)
with self._get_connection() as conn: with self._get_connection() as conn:
conn.inputtypehandler = self.input_type_handler conn.inputtypehandler = self.input_type_handler
conn.outputtypehandler = self.output_type_handler conn.outputtypehandler = self.output_type_handler
@@ -241,7 +253,7 @@ class OracleVector(BaseVector):
f"""SELECT meta, text, vector_distance(embedding,(select to_vector(:1) from dual),cosine) f"""SELECT meta, text, vector_distance(embedding,(select to_vector(:1) from dual),cosine)
AS distance FROM {self.table_name} AS distance FROM {self.table_name}
{where_clause} ORDER BY distance fetch first {top_k} rows only""", {where_clause} ORDER BY distance fetch first {top_k} rows only""",
[numpy.array(query_vector)], params,
) )
docs = [] docs = []
score_threshold = float(kwargs.get("score_threshold") or 0.0) score_threshold = float(kwargs.get("score_threshold") or 0.0)
@@ -259,7 +271,10 @@ class OracleVector(BaseVector):
import nltk # type: ignore import nltk # type: ignore
from nltk.corpus import stopwords # type: ignore from nltk.corpus import stopwords # type: ignore
# Validate and sanitize top_k to prevent SQL injection
top_k = kwargs.get("top_k", 5) top_k = kwargs.get("top_k", 5)
if not isinstance(top_k, int) or top_k <= 0 or top_k > 10000:
top_k = 5 # Use default if invalid
# just not implement fetch by score_threshold now, may be later # just not implement fetch by score_threshold now, may be later
score_threshold = float(kwargs.get("score_threshold") or 0.0) score_threshold = float(kwargs.get("score_threshold") or 0.0)
if len(query) > 0: if len(query) > 0:
@@ -297,14 +312,21 @@ class OracleVector(BaseVector):
with conn.cursor() as cur: with conn.cursor() as cur:
document_ids_filter = kwargs.get("document_ids_filter") document_ids_filter = kwargs.get("document_ids_filter")
where_clause = "" where_clause = ""
params: dict[str, Any] = {"kk": " ACCUM ".join(entities)}
if document_ids_filter: if document_ids_filter:
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter) placeholders = []
where_clause = f" AND metadata->>'document_id' in ({document_ids}) " for i, doc_id in enumerate(document_ids_filter):
param_name = f"doc_id_{i}"
placeholders.append(f":{param_name}")
params[param_name] = doc_id
where_clause = f" AND JSON_VALUE(meta, '$.document_id') IN ({', '.join(placeholders)}) "
cur.execute( cur.execute(
f"""select meta, text, embedding FROM {self.table_name} f"""select meta, text, embedding FROM {self.table_name}
WHERE CONTAINS(text, :kk, 1) > 0 {where_clause} WHERE CONTAINS(text, :kk, 1) > 0 {where_clause}
order by score(1) desc fetch first {top_k} rows only""", order by score(1) desc fetch first {top_k} rows only""",
kk=" ACCUM ".join(entities), params,
) )
docs = [] docs = []
for record in cur: for record in cur: