Migrate SQLAlchemy from 1.x to 2.0 with automated and manual adjustments (#23224)

Co-authored-by: Yongtao Huang <99629139+hyongtao-db@users.noreply.github.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
Yongtao Huang
2025-09-02 10:30:19 +08:00
committed by GitHub
parent 2e89d29c87
commit be3af1e234
33 changed files with 226 additions and 260 deletions

View File

@@ -3,6 +3,7 @@ from typing import Any
from flask import Flask, current_app
from pydantic import BaseModel, Field
from sqlalchemy import select
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
from core.model_manager import ModelManager
@@ -85,17 +86,14 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
document_context_list = []
index_node_ids = [document.metadata["doc_id"] for document in all_documents if document.metadata]
segments = (
db.session.query(DocumentSegment)
.where(
DocumentSegment.dataset_id.in_(self.dataset_ids),
DocumentSegment.completed_at.isnot(None),
DocumentSegment.status == "completed",
DocumentSegment.enabled == True,
DocumentSegment.index_node_id.in_(index_node_ids),
)
.all()
document_segment_stmt = select(DocumentSegment).where(
DocumentSegment.dataset_id.in_(self.dataset_ids),
DocumentSegment.completed_at.isnot(None),
DocumentSegment.status == "completed",
DocumentSegment.enabled == True,
DocumentSegment.index_node_id.in_(index_node_ids),
)
segments = db.session.scalars(document_segment_stmt).all()
if segments:
index_node_id_to_position = {id: position for position, id in enumerate(index_node_ids)}
@@ -112,15 +110,12 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
resource_number = 1
for segment in sorted_segments:
dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first()
document = (
db.session.query(Document)
.where(
Document.id == segment.document_id,
Document.enabled == True,
Document.archived == False,
)
.first()
document_stmt = select(Document).where(
Document.id == segment.document_id,
Document.enabled == True,
Document.archived == False,
)
document = db.session.scalar(document_stmt)
if dataset and document:
source = RetrievalSourceMetadata(
position=resource_number,
@@ -162,9 +157,8 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
hit_callbacks: list[DatasetIndexToolCallbackHandler],
):
with flask_app.app_context():
dataset = (
db.session.query(Dataset).where(Dataset.tenant_id == self.tenant_id, Dataset.id == dataset_id).first()
)
stmt = select(Dataset).where(Dataset.tenant_id == self.tenant_id, Dataset.id == dataset_id)
dataset = db.session.scalar(stmt)
if not dataset:
return []

View File

@@ -1,6 +1,7 @@
from typing import Any, Optional, cast
from pydantic import BaseModel, Field
from sqlalchemy import select
from core.app.app_config.entities import DatasetRetrieveConfigEntity, ModelConfig
from core.rag.datasource.retrieval_service import RetrievalService
@@ -56,9 +57,8 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
)
def _run(self, query: str) -> str:
dataset = (
db.session.query(Dataset).where(Dataset.tenant_id == self.tenant_id, Dataset.id == self.dataset_id).first()
)
dataset_stmt = select(Dataset).where(Dataset.tenant_id == self.tenant_id, Dataset.id == self.dataset_id)
dataset = db.session.scalar(dataset_stmt)
if not dataset:
return ""
@@ -188,15 +188,12 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
for record in records:
segment = record.segment
dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first()
document = (
db.session.query(DatasetDocument) # type: ignore
.where(
DatasetDocument.id == segment.document_id,
DatasetDocument.enabled == True,
DatasetDocument.archived == False,
)
.first()
dataset_document_stmt = select(DatasetDocument).where(
DatasetDocument.id == segment.document_id,
DatasetDocument.enabled == True,
DatasetDocument.archived == False,
)
document = db.session.scalar(dataset_document_stmt) # type: ignore
if dataset and document:
source = RetrievalSourceMetadata(
dataset_id=dataset.id,