From 04954918a5dc0075e1150a81a06b6fad91b5a592 Mon Sep 17 00:00:00 2001 From: -LAN- Date: Tue, 26 Aug 2025 13:51:23 +0800 Subject: [PATCH] Merge commit from fork MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix(oraclevector): SQL Injection Signed-off-by: -LAN- * fix(oraclevector): Remove bind variables from FETCH FIRST clause Oracle doesn't support bind variables in the FETCH FIRST clause. Fixed by using validated integers directly in the SQL string while maintaining proper input validation to prevent SQL injection. - Updated search_by_vector method to use validated top_k directly - Updated search_by_full_text method to use validated top_k directly - Adjusted parameter numbering for document_ids_filter placeholders 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude --------- Signed-off-by: -LAN- Co-authored-by: Claude --- .../rag/datasource/vdb/oracle/oraclevector.py | 42 ++++++++++++++----- 1 file changed, 32 insertions(+), 10 deletions(-) diff --git a/api/core/rag/datasource/vdb/oracle/oraclevector.py b/api/core/rag/datasource/vdb/oracle/oraclevector.py index 303c3fe31..d66829837 100644 --- a/api/core/rag/datasource/vdb/oracle/oraclevector.py +++ b/api/core/rag/datasource/vdb/oracle/oraclevector.py @@ -188,14 +188,17 @@ class OracleVector(BaseVector): def text_exists(self, id: str) -> bool: with self._get_connection() as conn: with conn.cursor() as cur: - cur.execute(f"SELECT id FROM {self.table_name} WHERE id = '%s'" % (id,)) + cur.execute(f"SELECT id FROM {self.table_name} WHERE id = :1", (id,)) return cur.fetchone() is not None conn.close() def get_by_ids(self, ids: list[str]) -> list[Document]: + if not ids: + return [] with self._get_connection() as conn: with conn.cursor() as cur: - cur.execute(f"SELECT meta, text FROM {self.table_name} WHERE id IN %s", (tuple(ids),)) + placeholders = ", ".join(f":{i + 1}" for i in range(len(ids))) + cur.execute(f"SELECT meta, text FROM {self.table_name} WHERE id IN ({placeholders})", ids) docs = [] for record in cur: docs.append(Document(page_content=record[1], metadata=record[0])) @@ -208,14 +211,15 @@ class OracleVector(BaseVector): return with self._get_connection() as conn: with conn.cursor() as cur: - cur.execute(f"DELETE FROM {self.table_name} WHERE id IN %s" % (tuple(ids),)) + placeholders = ", ".join(f":{i + 1}" for i in range(len(ids))) + cur.execute(f"DELETE FROM {self.table_name} WHERE id IN ({placeholders})", ids) conn.commit() conn.close() def delete_by_metadata_field(self, key: str, value: str) -> None: with self._get_connection() as conn: with conn.cursor() as cur: - cur.execute(f"DELETE FROM {self.table_name} WHERE meta->>%s = %s", (key, value)) + cur.execute(f"DELETE FROM {self.table_name} WHERE JSON_VALUE(meta, '$." + key + "') = :1", (value,)) conn.commit() conn.close() @@ -227,12 +231,20 @@ class OracleVector(BaseVector): :param top_k: The number of nearest neighbors to return, default is 5. :return: List of Documents that are nearest to the query vector. """ + # Validate and sanitize top_k to prevent SQL injection top_k = kwargs.get("top_k", 4) + if not isinstance(top_k, int) or top_k <= 0 or top_k > 10000: + top_k = 4 # Use default if invalid + document_ids_filter = kwargs.get("document_ids_filter") where_clause = "" + params = [numpy.array(query_vector)] + if document_ids_filter: - document_ids = ", ".join(f"'{id}'" for id in document_ids_filter) - where_clause = f"WHERE metadata->>'document_id' in ({document_ids})" + placeholders = ", ".join(f":{i + 2}" for i in range(len(document_ids_filter))) + where_clause = f"WHERE JSON_VALUE(meta, '$.document_id') IN ({placeholders})" + params.extend(document_ids_filter) + with self._get_connection() as conn: conn.inputtypehandler = self.input_type_handler conn.outputtypehandler = self.output_type_handler @@ -241,7 +253,7 @@ class OracleVector(BaseVector): f"""SELECT meta, text, vector_distance(embedding,(select to_vector(:1) from dual),cosine) AS distance FROM {self.table_name} {where_clause} ORDER BY distance fetch first {top_k} rows only""", - [numpy.array(query_vector)], + params, ) docs = [] score_threshold = float(kwargs.get("score_threshold") or 0.0) @@ -259,7 +271,10 @@ class OracleVector(BaseVector): import nltk # type: ignore from nltk.corpus import stopwords # type: ignore + # Validate and sanitize top_k to prevent SQL injection top_k = kwargs.get("top_k", 5) + if not isinstance(top_k, int) or top_k <= 0 or top_k > 10000: + top_k = 5 # Use default if invalid # just not implement fetch by score_threshold now, may be later score_threshold = float(kwargs.get("score_threshold") or 0.0) if len(query) > 0: @@ -297,14 +312,21 @@ class OracleVector(BaseVector): with conn.cursor() as cur: document_ids_filter = kwargs.get("document_ids_filter") where_clause = "" + params: dict[str, Any] = {"kk": " ACCUM ".join(entities)} + if document_ids_filter: - document_ids = ", ".join(f"'{id}'" for id in document_ids_filter) - where_clause = f" AND metadata->>'document_id' in ({document_ids}) " + placeholders = [] + for i, doc_id in enumerate(document_ids_filter): + param_name = f"doc_id_{i}" + placeholders.append(f":{param_name}") + params[param_name] = doc_id + where_clause = f" AND JSON_VALUE(meta, '$.document_id') IN ({', '.join(placeholders)}) " + cur.execute( f"""select meta, text, embedding FROM {self.table_name} WHERE CONTAINS(text, :kk, 1) > 0 {where_clause} order by score(1) desc fetch first {top_k} rows only""", - kk=" ACCUM ".join(entities), + params, ) docs = [] for record in cur: