diff --git a/api/core/rag/datasource/vdb/clickzetta/README.md b/api/core/rag/datasource/vdb/clickzetta/README.md index 40229f8d4..2ee3e657d 100644 --- a/api/core/rag/datasource/vdb/clickzetta/README.md +++ b/api/core/rag/datasource/vdb/clickzetta/README.md @@ -185,6 +185,6 @@ Clickzetta supports advanced full-text search with multiple analyzers: ## References -- [Clickzetta Vector Search Documentation](../../../../../../../yunqidoc/cn_markdown_20250526/vector-search.md) -- [Clickzetta Inverted Index Documentation](../../../../../../../yunqidoc/cn_markdown_20250526/inverted-index.md) -- [Clickzetta SQL Functions](../../../../../../../yunqidoc/cn_markdown_20250526/sql_functions/) +- [Clickzetta Vector Search Documentation](https://yunqi.tech/documents/vector-search) +- [Clickzetta Inverted Index Documentation](https://yunqi.tech/documents/inverted-index) +- [Clickzetta SQL Functions](https://yunqi.tech/documents/sql-reference) diff --git a/api/core/rag/datasource/vdb/clickzetta/clickzetta_vector.py b/api/core/rag/datasource/vdb/clickzetta/clickzetta_vector.py index 50a395a37..1059b855a 100644 --- a/api/core/rag/datasource/vdb/clickzetta/clickzetta_vector.py +++ b/api/core/rag/datasource/vdb/clickzetta/clickzetta_vector.py @@ -1,7 +1,9 @@ import json import logging import queue +import re import threading +import time import uuid from typing import TYPE_CHECKING, Any, Optional @@ -67,6 +69,243 @@ class ClickzettaConfig(BaseModel): return values +class ClickzettaConnectionPool: + """ + Global connection pool for ClickZetta connections. + Manages connection reuse across ClickzettaVector instances. + """ + + _instance: Optional["ClickzettaConnectionPool"] = None + _lock = threading.Lock() + + def __init__(self): + self._pools: dict[str, list[tuple[Connection, float]]] = {} # config_key -> [(connection, last_used_time)] + self._pool_locks: dict[str, threading.Lock] = {} + self._max_pool_size = 5 # Maximum connections per configuration + self._connection_timeout = 300 # 5 minutes timeout + self._cleanup_thread: Optional[threading.Thread] = None + self._shutdown = False + self._start_cleanup_thread() + + @classmethod + def get_instance(cls) -> "ClickzettaConnectionPool": + """Get singleton instance of connection pool.""" + if cls._instance is None: + with cls._lock: + if cls._instance is None: + cls._instance = cls() + return cls._instance + + def _get_config_key(self, config: ClickzettaConfig) -> str: + """Generate unique key for connection configuration.""" + return ( + f"{config.username}:{config.instance}:{config.service}:" + f"{config.workspace}:{config.vcluster}:{config.schema_name}" + ) + + def _create_connection(self, config: ClickzettaConfig) -> "Connection": + """Create a new ClickZetta connection.""" + max_retries = 3 + retry_delay = 1.0 + + for attempt in range(max_retries): + try: + connection = clickzetta.connect( + username=config.username, + password=config.password, + instance=config.instance, + service=config.service, + workspace=config.workspace, + vcluster=config.vcluster, + schema=config.schema_name, + ) + + # Configure connection session settings + self._configure_connection(connection) + logger.debug("Created new ClickZetta connection (attempt %d/%d)", attempt + 1, max_retries) + return connection + except Exception: + logger.exception("ClickZetta connection attempt %d/%d failed", attempt + 1, max_retries) + if attempt < max_retries - 1: + time.sleep(retry_delay * (2**attempt)) + else: + raise + + raise RuntimeError(f"Failed to create ClickZetta connection after {max_retries} attempts") + + def _configure_connection(self, connection: "Connection") -> None: + """Configure connection session settings.""" + try: + with connection.cursor() as cursor: + # Temporarily suppress ClickZetta client logging to reduce noise + clickzetta_logger = logging.getLogger("clickzetta") + original_level = clickzetta_logger.level + clickzetta_logger.setLevel(logging.WARNING) + + try: + # Use quote mode for string literal escaping + cursor.execute("SET cz.sql.string.literal.escape.mode = 'quote'") + + # Apply performance optimization hints + performance_hints = [ + # Vector index optimization + "SET cz.storage.parquet.vector.index.read.memory.cache = true", + "SET cz.storage.parquet.vector.index.read.local.cache = false", + # Query optimization + "SET cz.sql.table.scan.push.down.filter = true", + "SET cz.sql.table.scan.enable.ensure.filter = true", + "SET cz.storage.always.prefetch.internal = true", + "SET cz.optimizer.generate.columns.always.valid = true", + "SET cz.sql.index.prewhere.enabled = true", + # Storage optimization + "SET cz.storage.parquet.enable.io.prefetch = false", + "SET cz.optimizer.enable.mv.rewrite = false", + "SET cz.sql.dump.as.lz4 = true", + "SET cz.optimizer.limited.optimization.naive.query = true", + "SET cz.sql.table.scan.enable.push.down.log = false", + "SET cz.storage.use.file.format.local.stats = false", + "SET cz.storage.local.file.object.cache.level = all", + # Job execution optimization + "SET cz.sql.job.fast.mode = true", + "SET cz.storage.parquet.non.contiguous.read = true", + "SET cz.sql.compaction.after.commit = true", + ] + + for hint in performance_hints: + cursor.execute(hint) + finally: + # Restore original logging level + clickzetta_logger.setLevel(original_level) + + except Exception: + logger.exception("Failed to configure connection, continuing with defaults") + + def _is_connection_valid(self, connection: "Connection") -> bool: + """Check if connection is still valid.""" + try: + with connection.cursor() as cursor: + cursor.execute("SELECT 1") + return True + except Exception: + return False + + def get_connection(self, config: ClickzettaConfig) -> "Connection": + """Get a connection from the pool or create a new one.""" + config_key = self._get_config_key(config) + + # Ensure pool lock exists + if config_key not in self._pool_locks: + with self._lock: + if config_key not in self._pool_locks: + self._pool_locks[config_key] = threading.Lock() + self._pools[config_key] = [] + + with self._pool_locks[config_key]: + pool = self._pools[config_key] + current_time = time.time() + + # Try to reuse existing connection + while pool: + connection, last_used = pool.pop(0) + + # Check if connection is not expired and still valid + if current_time - last_used < self._connection_timeout and self._is_connection_valid(connection): + logger.debug("Reusing ClickZetta connection from pool") + return connection + else: + # Connection expired or invalid, close it + try: + connection.close() + except Exception: + pass + + # No valid connection found, create new one + return self._create_connection(config) + + def return_connection(self, config: ClickzettaConfig, connection: "Connection") -> None: + """Return a connection to the pool.""" + config_key = self._get_config_key(config) + + if config_key not in self._pool_locks: + # Pool was cleaned up, just close the connection + try: + connection.close() + except Exception: + pass + return + + with self._pool_locks[config_key]: + pool = self._pools[config_key] + + # Only return to pool if not at capacity and connection is valid + if len(pool) < self._max_pool_size and self._is_connection_valid(connection): + pool.append((connection, time.time())) + logger.debug("Returned ClickZetta connection to pool") + else: + # Pool full or connection invalid, close it + try: + connection.close() + except Exception: + pass + + def _cleanup_expired_connections(self) -> None: + """Clean up expired connections from all pools.""" + current_time = time.time() + + with self._lock: + for config_key in list(self._pools.keys()): + if config_key not in self._pool_locks: + continue + + with self._pool_locks[config_key]: + pool = self._pools[config_key] + valid_connections = [] + + for connection, last_used in pool: + if current_time - last_used < self._connection_timeout: + valid_connections.append((connection, last_used)) + else: + try: + connection.close() + except Exception: + pass + + self._pools[config_key] = valid_connections + + def _start_cleanup_thread(self) -> None: + """Start background thread for connection cleanup.""" + + def cleanup_worker(): + while not self._shutdown: + try: + time.sleep(60) # Cleanup every minute + if not self._shutdown: + self._cleanup_expired_connections() + except Exception: + logger.exception("Error in connection pool cleanup") + + self._cleanup_thread = threading.Thread(target=cleanup_worker, daemon=True) + self._cleanup_thread.start() + + def shutdown(self) -> None: + """Shutdown connection pool and close all connections.""" + self._shutdown = True + + with self._lock: + for config_key in list(self._pools.keys()): + if config_key not in self._pool_locks: + continue + + with self._pool_locks[config_key]: + pool = self._pools[config_key] + for connection, _ in pool: + try: + connection.close() + except Exception: + pass + pool.clear() + + class ClickzettaVector(BaseVector): """ Clickzetta vector storage implementation. @@ -82,70 +321,74 @@ class ClickzettaVector(BaseVector): super().__init__(collection_name) self._config = config self._table_name = collection_name.replace("-", "_").lower() # Ensure valid table name - self._connection: Optional[Connection] = None - self._init_connection() + self._connection_pool = ClickzettaConnectionPool.get_instance() self._init_write_queue() - def _init_connection(self): - """Initialize Clickzetta connection.""" - self._connection = clickzetta.connect( - username=self._config.username, - password=self._config.password, - instance=self._config.instance, - service=self._config.service, - workspace=self._config.workspace, - vcluster=self._config.vcluster, - schema=self._config.schema_name, - ) + def _get_connection(self) -> "Connection": + """Get a connection from the pool.""" + return self._connection_pool.get_connection(self._config) - # Set session parameters for better string handling and performance optimization - if self._connection is not None: - with self._connection.cursor() as cursor: - # Use quote mode for string literal escaping to handle quotes better - cursor.execute("SET cz.sql.string.literal.escape.mode = 'quote'") - logger.info("Set string literal escape mode to 'quote' for better quote handling") + def _return_connection(self, connection: "Connection") -> None: + """Return a connection to the pool.""" + self._connection_pool.return_connection(self._config, connection) - # Performance optimization hints for vector operations - self._set_performance_hints(cursor) + class ConnectionContext: + """Context manager for borrowing and returning connections.""" - def _set_performance_hints(self, cursor): - """Set ClickZetta performance optimization hints for vector operations.""" + def __init__(self, vector_instance: "ClickzettaVector"): + self.vector = vector_instance + self.connection: Optional[Connection] = None + + def __enter__(self) -> "Connection": + self.connection = self.vector._get_connection() + return self.connection + + def __exit__(self, exc_type, exc_val, exc_tb): + if self.connection: + self.vector._return_connection(self.connection) + + def get_connection_context(self) -> "ClickzettaVector.ConnectionContext": + """Get a connection context manager.""" + return self.ConnectionContext(self) + + def _parse_metadata(self, raw_metadata: str, row_id: str) -> dict: + """ + Parse metadata from JSON string with proper error handling and fallback. + + Args: + raw_metadata: Raw JSON string from database + row_id: Row ID for fallback document_id + + Returns: + Parsed metadata dict with guaranteed required fields + """ try: - # Performance optimization hints for vector operations and query processing - performance_hints = [ - # Vector index optimization - "SET cz.storage.parquet.vector.index.read.memory.cache = true", - "SET cz.storage.parquet.vector.index.read.local.cache = false", - # Query optimization - "SET cz.sql.table.scan.push.down.filter = true", - "SET cz.sql.table.scan.enable.ensure.filter = true", - "SET cz.storage.always.prefetch.internal = true", - "SET cz.optimizer.generate.columns.always.valid = true", - "SET cz.sql.index.prewhere.enabled = true", - # Storage optimization - "SET cz.storage.parquet.enable.io.prefetch = false", - "SET cz.optimizer.enable.mv.rewrite = false", - "SET cz.sql.dump.as.lz4 = true", - "SET cz.optimizer.limited.optimization.naive.query = true", - "SET cz.sql.table.scan.enable.push.down.log = false", - "SET cz.storage.use.file.format.local.stats = false", - "SET cz.storage.local.file.object.cache.level = all", - # Job execution optimization - "SET cz.sql.job.fast.mode = true", - "SET cz.storage.parquet.non.contiguous.read = true", - "SET cz.sql.compaction.after.commit = true", - ] + if raw_metadata: + metadata = json.loads(raw_metadata) - for hint in performance_hints: - cursor.execute(hint) + # Handle double-encoded JSON + if isinstance(metadata, str): + metadata = json.loads(metadata) - logger.info( - "Applied %d performance optimization hints for ClickZetta vector operations", len(performance_hints) - ) + # Ensure we have a dict + if not isinstance(metadata, dict): + metadata = {} + else: + metadata = {} + except (json.JSONDecodeError, TypeError): + logger.exception("JSON parsing failed for metadata") + # Fallback: extract document_id with regex + doc_id_match = re.search(r'"document_id":\s*"([^"]+)"', raw_metadata or "") + metadata = {"document_id": doc_id_match.group(1)} if doc_id_match else {} - except Exception: - # Catch any errors setting performance hints but continue with defaults - logger.exception("Failed to set some performance hints, continuing with default settings") + # Ensure required fields are set + metadata["doc_id"] = row_id # segment id + + # Ensure document_id exists (critical for Dify's format_retrieval_documents) + if "document_id" not in metadata: + metadata["document_id"] = row_id # fallback to segment id + + return metadata @classmethod def _init_write_queue(cls): @@ -204,24 +447,33 @@ class ClickzettaVector(BaseVector): return "clickzetta" def _ensure_connection(self) -> "Connection": - """Ensure connection is available and return it.""" - if self._connection is None: - raise RuntimeError("Database connection not initialized") - return self._connection + """Get a connection from the pool.""" + return self._get_connection() def _table_exists(self) -> bool: """Check if the table exists.""" try: - connection = self._ensure_connection() - with connection.cursor() as cursor: - cursor.execute(f"DESC {self._config.schema_name}.{self._table_name}") - return True - except (RuntimeError, ValueError) as e: - if "table or view not found" in str(e).lower(): + with self.get_connection_context() as connection: + with connection.cursor() as cursor: + cursor.execute(f"DESC {self._config.schema_name}.{self._table_name}") + return True + except Exception as e: + error_message = str(e).lower() + # Handle ClickZetta specific "table or view not found" errors + if any( + phrase in error_message + for phrase in ["table or view not found", "czlh-42000", "semantic analysis exception"] + ): + logger.debug("Table %s.%s does not exist", self._config.schema_name, self._table_name) return False else: - # Re-raise if it's a different error - raise + # For other connection/permission errors, log warning but return False to avoid blocking cleanup + logger.exception( + "Table existence check failed for %s.%s, assuming it doesn't exist", + self._config.schema_name, + self._table_name, + ) + return False def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): """Create the collection and add initial documents.""" @@ -253,17 +505,17 @@ class ClickzettaVector(BaseVector): ) COMMENT 'Dify RAG knowledge base vector storage table for document embeddings and content' """ - connection = self._ensure_connection() - with connection.cursor() as cursor: - cursor.execute(create_table_sql) - logger.info("Created table %s.%s", self._config.schema_name, self._table_name) + with self.get_connection_context() as connection: + with connection.cursor() as cursor: + cursor.execute(create_table_sql) + logger.info("Created table %s.%s", self._config.schema_name, self._table_name) - # Create vector index - self._create_vector_index(cursor) + # Create vector index + self._create_vector_index(cursor) - # Create inverted index for full-text search if enabled - if self._config.enable_inverted_index: - self._create_inverted_index(cursor) + # Create inverted index for full-text search if enabled + if self._config.enable_inverted_index: + self._create_inverted_index(cursor) def _create_vector_index(self, cursor): """Create HNSW vector index for similarity search.""" @@ -432,39 +684,53 @@ class ClickzettaVector(BaseVector): f"VALUES (?, ?, CAST(? AS JSON), CAST(? AS VECTOR({vector_dimension})))" ) - connection = self._ensure_connection() - with connection.cursor() as cursor: - try: - # Set session-level hints for batch insert operations - # Note: executemany doesn't support hints parameter, so we set them as session variables - cursor.execute("SET cz.sql.job.fast.mode = true") - cursor.execute("SET cz.sql.compaction.after.commit = true") - cursor.execute("SET cz.storage.always.prefetch.internal = true") + with self.get_connection_context() as connection: + with connection.cursor() as cursor: + try: + # Set session-level hints for batch insert operations + # Note: executemany doesn't support hints parameter, so we set them as session variables + # Temporarily suppress ClickZetta client logging to reduce noise + clickzetta_logger = logging.getLogger("clickzetta") + original_level = clickzetta_logger.level + clickzetta_logger.setLevel(logging.WARNING) - cursor.executemany(insert_sql, data_rows) - logger.info( - "Inserted batch %d/%d (%d valid docs using parameterized query with VECTOR(%d) cast)", - batch_index // batch_size + 1, - total_batches, - len(data_rows), - vector_dimension, - ) - except (RuntimeError, ValueError, TypeError, ConnectionError) as e: - logger.exception("Parameterized SQL execution failed for %d documents", len(data_rows)) - logger.exception("SQL template: %s", insert_sql) - logger.exception("Sample data row: %s", data_rows[0] if data_rows else "None") - raise + try: + cursor.execute("SET cz.sql.job.fast.mode = true") + cursor.execute("SET cz.sql.compaction.after.commit = true") + cursor.execute("SET cz.storage.always.prefetch.internal = true") + finally: + # Restore original logging level + clickzetta_logger.setLevel(original_level) + + cursor.executemany(insert_sql, data_rows) + logger.info( + "Inserted batch %d/%d (%d valid docs using parameterized query with VECTOR(%d) cast)", + batch_index // batch_size + 1, + total_batches, + len(data_rows), + vector_dimension, + ) + except (RuntimeError, ValueError, TypeError, ConnectionError) as e: + logger.exception("Parameterized SQL execution failed for %d documents", len(data_rows)) + logger.exception("SQL template: %s", insert_sql) + logger.exception("Sample data row: %s", data_rows[0] if data_rows else "None") + raise def text_exists(self, id: str) -> bool: """Check if a document exists by ID.""" + # Check if table exists first + if not self._table_exists(): + return False + safe_id = self._safe_doc_id(id) - connection = self._ensure_connection() - with connection.cursor() as cursor: - cursor.execute( - f"SELECT COUNT(*) FROM {self._config.schema_name}.{self._table_name} WHERE id = ?", [safe_id] - ) - result = cursor.fetchone() - return result[0] > 0 if result else False + with self.get_connection_context() as connection: + with connection.cursor() as cursor: + cursor.execute( + f"SELECT COUNT(*) FROM {self._config.schema_name}.{self._table_name} WHERE id = ?", + binding_params=[safe_id], + ) + result = cursor.fetchone() + return result[0] > 0 if result else False def delete_by_ids(self, ids: list[str]) -> None: """Delete documents by IDs.""" @@ -482,13 +748,14 @@ class ClickzettaVector(BaseVector): def _delete_by_ids_impl(self, ids: list[str]) -> None: """Implementation of delete by IDs (executed in write worker thread).""" safe_ids = [self._safe_doc_id(id) for id in ids] - # Create properly escaped string literals for SQL - id_list = ",".join(f"'{id}'" for id in safe_ids) - sql = f"DELETE FROM {self._config.schema_name}.{self._table_name} WHERE id IN ({id_list})" - connection = self._ensure_connection() - with connection.cursor() as cursor: - cursor.execute(sql) + # Use parameterized query to prevent SQL injection + placeholders = ",".join("?" for _ in safe_ids) + sql = f"DELETE FROM {self._config.schema_name}.{self._table_name} WHERE id IN ({placeholders})" + + with self.get_connection_context() as connection: + with connection.cursor() as cursor: + cursor.execute(sql, binding_params=safe_ids) def delete_by_metadata_field(self, key: str, value: str) -> None: """Delete documents by metadata field.""" @@ -502,19 +769,28 @@ class ClickzettaVector(BaseVector): def _delete_by_metadata_field_impl(self, key: str, value: str) -> None: """Implementation of delete by metadata field (executed in write worker thread).""" - connection = self._ensure_connection() - with connection.cursor() as cursor: - # Using JSON path to filter with parameterized query - # Note: JSON path requires literal key name, cannot be parameterized - # Use json_extract_string function for ClickZetta compatibility - sql = ( - f"DELETE FROM {self._config.schema_name}.{self._table_name} " - f"WHERE json_extract_string({Field.METADATA_KEY.value}, '$.{key}') = ?" - ) - cursor.execute(sql, [value]) + with self.get_connection_context() as connection: + with connection.cursor() as cursor: + # Using JSON path to filter with parameterized query + # Note: JSON path requires literal key name, cannot be parameterized + # Use json_extract_string function for ClickZetta compatibility + sql = ( + f"DELETE FROM {self._config.schema_name}.{self._table_name} " + f"WHERE json_extract_string({Field.METADATA_KEY.value}, '$.{key}') = ?" + ) + cursor.execute(sql, binding_params=[value]) def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: """Search for documents by vector similarity.""" + # Check if table exists first + if not self._table_exists(): + logger.warning( + "Table %s.%s does not exist, returning empty results", + self._config.schema_name, + self._table_name, + ) + return [] + top_k = kwargs.get("top_k", 10) score_threshold = kwargs.get("score_threshold", 0.0) document_ids_filter = kwargs.get("document_ids_filter") @@ -565,56 +841,31 @@ class ClickzettaVector(BaseVector): """ documents = [] - connection = self._ensure_connection() - with connection.cursor() as cursor: - # Use hints parameter for vector search optimization - search_hints = { - "hints": { - "sdk.job.timeout": 60, # Increase timeout for vector search - "cz.sql.job.fast.mode": True, - "cz.storage.parquet.vector.index.read.memory.cache": True, + with self.get_connection_context() as connection: + with connection.cursor() as cursor: + # Use hints parameter for vector search optimization + search_hints = { + "hints": { + "sdk.job.timeout": 60, # Increase timeout for vector search + "cz.sql.job.fast.mode": True, + "cz.storage.parquet.vector.index.read.memory.cache": True, + } } - } - cursor.execute(search_sql, parameters=search_hints) - results = cursor.fetchall() + cursor.execute(search_sql, search_hints) + results = cursor.fetchall() - for row in results: - # Parse metadata from JSON string (may be double-encoded) - try: - if row[2]: - metadata = json.loads(row[2]) + for row in results: + # Parse metadata using centralized method + metadata = self._parse_metadata(row[2], row[0]) - # If result is a string, it's double-encoded JSON - parse again - if isinstance(metadata, str): - metadata = json.loads(metadata) - - if not isinstance(metadata, dict): - metadata = {} + # Add score based on distance + if self._config.vector_distance_function == "cosine_distance": + metadata["score"] = 1 - (row[3] / 2) else: - metadata = {} - except (json.JSONDecodeError, TypeError) as e: - logger.exception("JSON parsing failed") - # Fallback: extract document_id with regex - import re + metadata["score"] = 1 / (1 + row[3]) - doc_id_match = re.search(r'"document_id":\s*"([^"]+)"', str(row[2] or "")) - metadata = {"document_id": doc_id_match.group(1)} if doc_id_match else {} - - # Ensure required fields are set - metadata["doc_id"] = row[0] # segment id - - # Ensure document_id exists (critical for Dify's format_retrieval_documents) - if "document_id" not in metadata: - metadata["document_id"] = row[0] # fallback to segment id - - # Add score based on distance - if self._config.vector_distance_function == "cosine_distance": - metadata["score"] = 1 - (row[3] / 2) - else: - metadata["score"] = 1 / (1 + row[3]) - - doc = Document(page_content=row[1], metadata=metadata) - documents.append(doc) + doc = Document(page_content=row[1], metadata=metadata) + documents.append(doc) return documents @@ -624,6 +875,15 @@ class ClickzettaVector(BaseVector): logger.warning("Full-text search is not enabled. Enable inverted index in config.") return [] + # Check if table exists first + if not self._table_exists(): + logger.warning( + "Table %s.%s does not exist, returning empty results", + self._config.schema_name, + self._table_name, + ) + return [] + top_k = kwargs.get("top_k", 10) document_ids_filter = kwargs.get("document_ids_filter") @@ -659,62 +919,70 @@ class ClickzettaVector(BaseVector): """ documents = [] - connection = self._ensure_connection() - with connection.cursor() as cursor: - try: - # Use hints parameter for full-text search optimization - fulltext_hints = { - "hints": { - "sdk.job.timeout": 30, # Timeout for full-text search - "cz.sql.job.fast.mode": True, - "cz.sql.index.prewhere.enabled": True, + with self.get_connection_context() as connection: + with connection.cursor() as cursor: + try: + # Use hints parameter for full-text search optimization + fulltext_hints = { + "hints": { + "sdk.job.timeout": 30, # Timeout for full-text search + "cz.sql.job.fast.mode": True, + "cz.sql.index.prewhere.enabled": True, + } } - } - cursor.execute(search_sql, parameters=fulltext_hints) - results = cursor.fetchall() + cursor.execute(search_sql, fulltext_hints) + results = cursor.fetchall() - for row in results: - # Parse metadata from JSON string (may be double-encoded) - try: - if row[2]: - metadata = json.loads(row[2]) + for row in results: + # Parse metadata from JSON string (may be double-encoded) + try: + if row[2]: + metadata = json.loads(row[2]) - # If result is a string, it's double-encoded JSON - parse again - if isinstance(metadata, str): - metadata = json.loads(metadata) + # If result is a string, it's double-encoded JSON - parse again + if isinstance(metadata, str): + metadata = json.loads(metadata) - if not isinstance(metadata, dict): + if not isinstance(metadata, dict): + metadata = {} + else: metadata = {} - else: - metadata = {} - except (json.JSONDecodeError, TypeError) as e: - logger.exception("JSON parsing failed") - # Fallback: extract document_id with regex - import re + except (json.JSONDecodeError, TypeError) as e: + logger.exception("JSON parsing failed") + # Fallback: extract document_id with regex - doc_id_match = re.search(r'"document_id":\s*"([^"]+)"', str(row[2] or "")) - metadata = {"document_id": doc_id_match.group(1)} if doc_id_match else {} + doc_id_match = re.search(r'"document_id":\s*"([^"]+)"', str(row[2] or "")) + metadata = {"document_id": doc_id_match.group(1)} if doc_id_match else {} - # Ensure required fields are set - metadata["doc_id"] = row[0] # segment id + # Ensure required fields are set + metadata["doc_id"] = row[0] # segment id - # Ensure document_id exists (critical for Dify's format_retrieval_documents) - if "document_id" not in metadata: - metadata["document_id"] = row[0] # fallback to segment id + # Ensure document_id exists (critical for Dify's format_retrieval_documents) + if "document_id" not in metadata: + metadata["document_id"] = row[0] # fallback to segment id - # Add a relevance score for full-text search - metadata["score"] = 1.0 # Clickzetta doesn't provide relevance scores - doc = Document(page_content=row[1], metadata=metadata) - documents.append(doc) - except (RuntimeError, ValueError, TypeError, ConnectionError) as e: - logger.exception("Full-text search failed") - # Fallback to LIKE search if full-text search fails - return self._search_by_like(query, **kwargs) + # Add a relevance score for full-text search + metadata["score"] = 1.0 # Clickzetta doesn't provide relevance scores + doc = Document(page_content=row[1], metadata=metadata) + documents.append(doc) + except (RuntimeError, ValueError, TypeError, ConnectionError) as e: + logger.exception("Full-text search failed") + # Fallback to LIKE search if full-text search fails + return self._search_by_like(query, **kwargs) return documents def _search_by_like(self, query: str, **kwargs: Any) -> list[Document]: """Fallback search using LIKE operator.""" + # Check if table exists first + if not self._table_exists(): + logger.warning( + "Table %s.%s does not exist, returning empty results", + self._config.schema_name, + self._table_name, + ) + return [] + top_k = kwargs.get("top_k", 10) document_ids_filter = kwargs.get("document_ids_filter") @@ -746,58 +1014,33 @@ class ClickzettaVector(BaseVector): """ documents = [] - connection = self._ensure_connection() - with connection.cursor() as cursor: - # Use hints parameter for LIKE search optimization - like_hints = { - "hints": { - "sdk.job.timeout": 20, # Timeout for LIKE search - "cz.sql.job.fast.mode": True, + with self.get_connection_context() as connection: + with connection.cursor() as cursor: + # Use hints parameter for LIKE search optimization + like_hints = { + "hints": { + "sdk.job.timeout": 20, # Timeout for LIKE search + "cz.sql.job.fast.mode": True, + } } - } - cursor.execute(search_sql, parameters=like_hints) - results = cursor.fetchall() + cursor.execute(search_sql, like_hints) + results = cursor.fetchall() - for row in results: - # Parse metadata from JSON string (may be double-encoded) - try: - if row[2]: - metadata = json.loads(row[2]) + for row in results: + # Parse metadata using centralized method + metadata = self._parse_metadata(row[2], row[0]) - # If result is a string, it's double-encoded JSON - parse again - if isinstance(metadata, str): - metadata = json.loads(metadata) - - if not isinstance(metadata, dict): - metadata = {} - else: - metadata = {} - except (json.JSONDecodeError, TypeError) as e: - logger.exception("JSON parsing failed") - # Fallback: extract document_id with regex - import re - - doc_id_match = re.search(r'"document_id":\s*"([^"]+)"', str(row[2] or "")) - metadata = {"document_id": doc_id_match.group(1)} if doc_id_match else {} - - # Ensure required fields are set - metadata["doc_id"] = row[0] # segment id - - # Ensure document_id exists (critical for Dify's format_retrieval_documents) - if "document_id" not in metadata: - metadata["document_id"] = row[0] # fallback to segment id - - metadata["score"] = 0.5 # Lower score for LIKE search - doc = Document(page_content=row[1], metadata=metadata) - documents.append(doc) + metadata["score"] = 0.5 # Lower score for LIKE search + doc = Document(page_content=row[1], metadata=metadata) + documents.append(doc) return documents def delete(self) -> None: """Delete the entire collection.""" - connection = self._ensure_connection() - with connection.cursor() as cursor: - cursor.execute(f"DROP TABLE IF EXISTS {self._config.schema_name}.{self._table_name}") + with self.get_connection_context() as connection: + with connection.cursor() as cursor: + cursor.execute(f"DROP TABLE IF EXISTS {self._config.schema_name}.{self._table_name}") def _format_vector_simple(self, vector: list[float]) -> str: """Simple vector formatting for SQL queries.""" diff --git a/api/tasks/clean_dataset_task.py b/api/tasks/clean_dataset_task.py index c769446ed..69e5df025 100644 --- a/api/tasks/clean_dataset_task.py +++ b/api/tasks/clean_dataset_task.py @@ -59,7 +59,14 @@ def clean_dataset_task( # Fix: Always clean vector database resources regardless of document existence # This ensures all 33 vector databases properly drop tables/collections/indices if doc_form is None: - raise ValueError("Index type must be specified.") + # Use default paragraph index type for empty datasets to enable vector database cleanup + from core.rag.index_processor.constant.index_type import IndexType + + doc_form = IndexType.PARAGRAPH_INDEX + logging.info( + click.style(f"No documents found, using default index type for cleanup: {doc_form}", fg="yellow") + ) + index_processor = IndexProcessorFactory(doc_form).init_index_processor() index_processor.clean(dataset, None, with_keywords=True, delete_child_chunks=True)