Merge commit from fork

* fix(oraclevector): SQL Injection

Signed-off-by: -LAN- <laipz8200@outlook.com>

* fix(oraclevector): Remove bind variables from FETCH FIRST clause

Oracle doesn't support bind variables in the FETCH FIRST clause.
Fixed by using validated integers directly in the SQL string while
maintaining proper input validation to prevent SQL injection.

- Updated search_by_vector method to use validated top_k directly
- Updated search_by_full_text method to use validated top_k directly
- Adjusted parameter numbering for document_ids_filter placeholders

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>

---------

Signed-off-by: -LAN- <laipz8200@outlook.com>
Co-authored-by: Claude <noreply@anthropic.com>
This commit is contained in:
-LAN-
2025-08-26 13:51:23 +08:00
committed by GitHub
parent eb3a031964
commit 04954918a5

View File

@@ -188,14 +188,17 @@ class OracleVector(BaseVector):
def text_exists(self, id: str) -> bool:
with self._get_connection() as conn:
with conn.cursor() as cur:
cur.execute(f"SELECT id FROM {self.table_name} WHERE id = '%s'" % (id,))
cur.execute(f"SELECT id FROM {self.table_name} WHERE id = :1", (id,))
return cur.fetchone() is not None
conn.close()
def get_by_ids(self, ids: list[str]) -> list[Document]:
if not ids:
return []
with self._get_connection() as conn:
with conn.cursor() as cur:
cur.execute(f"SELECT meta, text FROM {self.table_name} WHERE id IN %s", (tuple(ids),))
placeholders = ", ".join(f":{i + 1}" for i in range(len(ids)))
cur.execute(f"SELECT meta, text FROM {self.table_name} WHERE id IN ({placeholders})", ids)
docs = []
for record in cur:
docs.append(Document(page_content=record[1], metadata=record[0]))
@@ -208,14 +211,15 @@ class OracleVector(BaseVector):
return
with self._get_connection() as conn:
with conn.cursor() as cur:
cur.execute(f"DELETE FROM {self.table_name} WHERE id IN %s" % (tuple(ids),))
placeholders = ", ".join(f":{i + 1}" for i in range(len(ids)))
cur.execute(f"DELETE FROM {self.table_name} WHERE id IN ({placeholders})", ids)
conn.commit()
conn.close()
def delete_by_metadata_field(self, key: str, value: str) -> None:
with self._get_connection() as conn:
with conn.cursor() as cur:
cur.execute(f"DELETE FROM {self.table_name} WHERE meta->>%s = %s", (key, value))
cur.execute(f"DELETE FROM {self.table_name} WHERE JSON_VALUE(meta, '$." + key + "') = :1", (value,))
conn.commit()
conn.close()
@@ -227,12 +231,20 @@ class OracleVector(BaseVector):
:param top_k: The number of nearest neighbors to return, default is 5.
:return: List of Documents that are nearest to the query vector.
"""
# Validate and sanitize top_k to prevent SQL injection
top_k = kwargs.get("top_k", 4)
if not isinstance(top_k, int) or top_k <= 0 or top_k > 10000:
top_k = 4 # Use default if invalid
document_ids_filter = kwargs.get("document_ids_filter")
where_clause = ""
params = [numpy.array(query_vector)]
if document_ids_filter:
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
where_clause = f"WHERE metadata->>'document_id' in ({document_ids})"
placeholders = ", ".join(f":{i + 2}" for i in range(len(document_ids_filter)))
where_clause = f"WHERE JSON_VALUE(meta, '$.document_id') IN ({placeholders})"
params.extend(document_ids_filter)
with self._get_connection() as conn:
conn.inputtypehandler = self.input_type_handler
conn.outputtypehandler = self.output_type_handler
@@ -241,7 +253,7 @@ class OracleVector(BaseVector):
f"""SELECT meta, text, vector_distance(embedding,(select to_vector(:1) from dual),cosine)
AS distance FROM {self.table_name}
{where_clause} ORDER BY distance fetch first {top_k} rows only""",
[numpy.array(query_vector)],
params,
)
docs = []
score_threshold = float(kwargs.get("score_threshold") or 0.0)
@@ -259,7 +271,10 @@ class OracleVector(BaseVector):
import nltk # type: ignore
from nltk.corpus import stopwords # type: ignore
# Validate and sanitize top_k to prevent SQL injection
top_k = kwargs.get("top_k", 5)
if not isinstance(top_k, int) or top_k <= 0 or top_k > 10000:
top_k = 5 # Use default if invalid
# just not implement fetch by score_threshold now, may be later
score_threshold = float(kwargs.get("score_threshold") or 0.0)
if len(query) > 0:
@@ -297,14 +312,21 @@ class OracleVector(BaseVector):
with conn.cursor() as cur:
document_ids_filter = kwargs.get("document_ids_filter")
where_clause = ""
params: dict[str, Any] = {"kk": " ACCUM ".join(entities)}
if document_ids_filter:
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
where_clause = f" AND metadata->>'document_id' in ({document_ids}) "
placeholders = []
for i, doc_id in enumerate(document_ids_filter):
param_name = f"doc_id_{i}"
placeholders.append(f":{param_name}")
params[param_name] = doc_id
where_clause = f" AND JSON_VALUE(meta, '$.document_id') IN ({', '.join(placeholders)}) "
cur.execute(
f"""select meta, text, embedding FROM {self.table_name}
WHERE CONTAINS(text, :kk, 1) > 0 {where_clause}
order by score(1) desc fetch first {top_k} rows only""",
kk=" ACCUM ".join(entities),
params,
)
docs = []
for record in cur: