more assert (#24996)

Signed-off-by: -LAN- <laipz8200@outlook.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: -LAN- <laipz8200@outlook.com>
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
This commit is contained in:
Asuka Minato
2025-09-08 10:59:43 +09:00
committed by GitHub
parent 98204d78fb
commit 16a3e21410
17 changed files with 235 additions and 90 deletions

View File

@@ -1,9 +1,9 @@
from flask_login import current_user
from flask_restx import Resource, reqparse
from controllers.console import api
from controllers.console.wraps import account_initialization_required, only_edition_cloud, setup_required
from libs.login import login_required
from libs.login import current_user, login_required
from models.model import Account
from services.billing_service import BillingService
@@ -17,9 +17,10 @@ class Subscription(Resource):
parser.add_argument("plan", type=str, required=True, location="args", choices=["professional", "team"])
parser.add_argument("interval", type=str, required=True, location="args", choices=["month", "year"])
args = parser.parse_args()
assert isinstance(current_user, Account)
BillingService.is_tenant_owner_or_admin(current_user)
assert current_user.current_tenant_id is not None
return BillingService.get_subscription(
args["plan"], args["interval"], current_user.email, current_user.current_tenant_id
)
@@ -31,7 +32,9 @@ class Invoices(Resource):
@account_initialization_required
@only_edition_cloud
def get(self):
assert isinstance(current_user, Account)
BillingService.is_tenant_owner_or_admin(current_user)
assert current_user.current_tenant_id is not None
return BillingService.get_invoices(current_user.email, current_user.current_tenant_id)

View File

@@ -2,7 +2,6 @@ import threading
from typing import Any, Optional
import pytz
from flask_login import current_user
import contexts
from core.app.app_config.easy_ui_based_app.agent.manager import AgentConfigManager
@@ -10,6 +9,7 @@ from core.plugin.impl.agent import PluginAgentClient
from core.plugin.impl.exc import PluginDaemonClientSideError
from core.tools.tool_manager import ToolManager
from extensions.ext_database import db
from libs.login import current_user
from models.account import Account
from models.model import App, Conversation, EndUser, Message, MessageAgentThought
@@ -61,7 +61,8 @@ class AgentService:
executor = executor.name
else:
executor = "Unknown"
assert isinstance(current_user, Account)
assert current_user.timezone is not None
timezone = pytz.timezone(current_user.timezone)
app_model_config = app_model.app_model_config

View File

@@ -2,7 +2,6 @@ import uuid
from typing import Optional
import pandas as pd
from flask_login import current_user
from sqlalchemy import or_, select
from werkzeug.datastructures import FileStorage
from werkzeug.exceptions import NotFound
@@ -10,6 +9,8 @@ from werkzeug.exceptions import NotFound
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs.datetime_utils import naive_utc_now
from libs.login import current_user
from models.account import Account
from models.model import App, AppAnnotationHitHistory, AppAnnotationSetting, Message, MessageAnnotation
from services.feature_service import FeatureService
from tasks.annotation.add_annotation_to_index_task import add_annotation_to_index_task
@@ -24,6 +25,7 @@ class AppAnnotationService:
@classmethod
def up_insert_app_annotation_from_message(cls, args: dict, app_id: str) -> MessageAnnotation:
# get app info
assert isinstance(current_user, Account)
app = (
db.session.query(App)
.where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
@@ -62,6 +64,7 @@ class AppAnnotationService:
db.session.commit()
# if annotation reply is enabled , add annotation to index
annotation_setting = db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first()
assert current_user.current_tenant_id is not None
if annotation_setting:
add_annotation_to_index_task.delay(
annotation.id,
@@ -84,6 +87,8 @@ class AppAnnotationService:
enable_app_annotation_job_key = f"enable_app_annotation_job_{str(job_id)}"
# send batch add segments task
redis_client.setnx(enable_app_annotation_job_key, "waiting")
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
enable_annotation_reply_task.delay(
str(job_id),
app_id,
@@ -97,6 +102,8 @@ class AppAnnotationService:
@classmethod
def disable_app_annotation(cls, app_id: str):
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
disable_app_annotation_key = f"disable_app_annotation_{str(app_id)}"
cache_result = redis_client.get(disable_app_annotation_key)
if cache_result is not None:
@@ -113,6 +120,8 @@ class AppAnnotationService:
@classmethod
def get_annotation_list_by_app_id(cls, app_id: str, page: int, limit: int, keyword: str):
# get app info
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
app = (
db.session.query(App)
.where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
@@ -145,6 +154,8 @@ class AppAnnotationService:
@classmethod
def export_annotation_list_by_app_id(cls, app_id: str):
# get app info
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
app = (
db.session.query(App)
.where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
@@ -164,6 +175,8 @@ class AppAnnotationService:
@classmethod
def insert_app_annotation_directly(cls, args: dict, app_id: str) -> MessageAnnotation:
# get app info
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
app = (
db.session.query(App)
.where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
@@ -193,6 +206,8 @@ class AppAnnotationService:
@classmethod
def update_app_annotation_directly(cls, args: dict, app_id: str, annotation_id: str):
# get app info
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
app = (
db.session.query(App)
.where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
@@ -230,6 +245,8 @@ class AppAnnotationService:
@classmethod
def delete_app_annotation(cls, app_id: str, annotation_id: str):
# get app info
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
app = (
db.session.query(App)
.where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
@@ -269,6 +286,8 @@ class AppAnnotationService:
@classmethod
def delete_app_annotations_in_batch(cls, app_id: str, annotation_ids: list[str]):
# get app info
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
app = (
db.session.query(App)
.where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
@@ -317,6 +336,8 @@ class AppAnnotationService:
@classmethod
def batch_import_app_annotations(cls, app_id, file: FileStorage):
# get app info
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
app = (
db.session.query(App)
.where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
@@ -355,6 +376,8 @@ class AppAnnotationService:
@classmethod
def get_annotation_hit_histories(cls, app_id: str, annotation_id: str, page, limit):
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
# get app info
app = (
db.session.query(App)
@@ -425,6 +448,8 @@ class AppAnnotationService:
@classmethod
def get_app_annotation_setting_by_app_id(cls, app_id: str):
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
# get app info
app = (
db.session.query(App)
@@ -451,6 +476,8 @@ class AppAnnotationService:
@classmethod
def update_app_annotation_setting(cls, app_id: str, annotation_setting_id: str, args: dict):
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
# get app info
app = (
db.session.query(App)
@@ -491,6 +518,8 @@ class AppAnnotationService:
@classmethod
def clear_all_annotations(cls, app_id: str):
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
app = (
db.session.query(App)
.where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")

View File

@@ -2,7 +2,6 @@ import json
import logging
from typing import Optional, TypedDict, cast
from flask_login import current_user
from flask_sqlalchemy.pagination import Pagination
from configs import dify_config
@@ -17,6 +16,7 @@ from core.tools.utils.configuration import ToolParameterConfigurationManager
from events.app_event import app_was_created
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from libs.login import current_user
from models.account import Account
from models.model import App, AppMode, AppModelConfig, Site
from models.tools import ApiToolProvider
@@ -168,6 +168,8 @@ class AppService:
"""
Get App
"""
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
# get original app model config
if app.mode == AppMode.AGENT_CHAT.value or app.is_agent:
model_config = app.app_model_config
@@ -242,6 +244,7 @@ class AppService:
:param args: request args
:return: App instance
"""
assert current_user is not None
app.name = args["name"]
app.description = args["description"]
app.icon_type = args["icon_type"]
@@ -262,6 +265,7 @@ class AppService:
:param name: new name
:return: App instance
"""
assert current_user is not None
app.name = name
app.updated_by = current_user.id
app.updated_at = naive_utc_now()
@@ -277,6 +281,7 @@ class AppService:
:param icon_background: new icon_background
:return: App instance
"""
assert current_user is not None
app.icon = icon
app.icon_background = icon_background
app.updated_by = current_user.id
@@ -294,7 +299,7 @@ class AppService:
"""
if enable_site == app.enable_site:
return app
assert current_user is not None
app.enable_site = enable_site
app.updated_by = current_user.id
app.updated_at = naive_utc_now()
@@ -311,6 +316,7 @@ class AppService:
"""
if enable_api == app.enable_api:
return app
assert current_user is not None
app.enable_api = enable_api
app.updated_by = current_user.id

View File

@@ -70,7 +70,7 @@ class BillingService:
return response.json()
@staticmethod
def is_tenant_owner_or_admin(current_user):
def is_tenant_owner_or_admin(current_user: Account):
tenant_id = current_user.current_tenant_id
join: Optional[TenantAccountJoin] = (

View File

@@ -8,7 +8,7 @@ import uuid
from collections import Counter
from typing import Any, Literal, Optional
from flask_login import current_user
import sqlalchemy as sa
from sqlalchemy import exists, func, select
from sqlalchemy.orm import Session
from werkzeug.exceptions import NotFound
@@ -27,6 +27,7 @@ from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs import helper
from libs.datetime_utils import naive_utc_now
from libs.login import current_user
from models.account import Account, TenantAccountRole
from models.dataset import (
AppDatasetJoin,
@@ -498,8 +499,11 @@ class DatasetService:
data: Update data dictionary
filtered_data: Filtered update data to modify
"""
# assert isinstance(current_user, Account) and current_user.current_tenant_id is not None
try:
model_manager = ModelManager()
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
embedding_model = model_manager.get_model_instance(
tenant_id=current_user.current_tenant_id,
provider=data["embedding_model_provider"],
@@ -611,8 +615,12 @@ class DatasetService:
data: Update data dictionary
filtered_data: Filtered update data to modify
"""
# assert isinstance(current_user, Account) and current_user.current_tenant_id is not None
model_manager = ModelManager()
try:
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
embedding_model = model_manager.get_model_instance(
tenant_id=current_user.current_tenant_id,
provider=data["embedding_model_provider"],
@@ -720,6 +728,8 @@ class DatasetService:
@staticmethod
def get_dataset_auto_disable_logs(dataset_id: str):
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
features = FeatureService.get_features(current_user.current_tenant_id)
if not features.billing.enabled or features.billing.subscription.plan == "sandbox":
return {
@@ -924,6 +934,8 @@ class DocumentService:
@staticmethod
def get_batch_documents(dataset_id: str, batch: str) -> list[Document]:
assert isinstance(current_user, Account)
documents = (
db.session.query(Document)
.where(
@@ -983,6 +995,8 @@ class DocumentService:
@staticmethod
def rename_document(dataset_id: str, document_id: str, name: str) -> Document:
assert isinstance(current_user, Account)
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
raise ValueError("Dataset not found.")
@@ -1012,6 +1026,7 @@ class DocumentService:
if document.indexing_status not in {"waiting", "parsing", "cleaning", "splitting", "indexing"}:
raise DocumentIndexingError()
# update document to be paused
assert current_user is not None
document.is_paused = True
document.paused_by = current_user.id
document.paused_at = naive_utc_now()
@@ -1098,6 +1113,9 @@ class DocumentService:
# check doc_form
DatasetService.check_doc_form(dataset, knowledge_config.doc_form)
# check document limit
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
features = FeatureService.get_features(current_user.current_tenant_id)
if features.billing.enabled:
@@ -1434,6 +1452,8 @@ class DocumentService:
@staticmethod
def get_tenant_documents_count():
assert isinstance(current_user, Account)
documents_count = (
db.session.query(Document)
.where(
@@ -1454,6 +1474,8 @@ class DocumentService:
dataset_process_rule: Optional[DatasetProcessRule] = None,
created_from: str = "web",
):
assert isinstance(current_user, Account)
DatasetService.check_dataset_model_setting(dataset)
document = DocumentService.get_document(dataset.id, document_data.original_document_id)
if document is None:
@@ -1513,7 +1535,7 @@ class DocumentService:
data_source_binding = (
db.session.query(DataSourceOauthBinding)
.where(
db.and_(
sa.and_(
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
DataSourceOauthBinding.provider == "notion",
DataSourceOauthBinding.disabled == False,
@@ -1574,6 +1596,9 @@ class DocumentService:
@staticmethod
def save_document_without_dataset_id(tenant_id: str, knowledge_config: KnowledgeConfig, account: Account):
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
features = FeatureService.get_features(current_user.current_tenant_id)
if features.billing.enabled:
@@ -2013,6 +2038,9 @@ class SegmentService:
@classmethod
def create_segment(cls, args: dict, document: Document, dataset: Dataset):
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
content = args["content"]
doc_id = str(uuid.uuid4())
segment_hash = helper.generate_text_hash(content)
@@ -2075,6 +2103,9 @@ class SegmentService:
@classmethod
def multi_create_segment(cls, segments: list, document: Document, dataset: Dataset):
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
lock_name = f"multi_add_segment_lock_document_id_{document.id}"
increment_word_count = 0
with redis_client.lock(lock_name, timeout=600):
@@ -2158,6 +2189,9 @@ class SegmentService:
@classmethod
def update_segment(cls, args: SegmentUpdateArgs, segment: DocumentSegment, document: Document, dataset: Dataset):
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
indexing_cache_key = f"segment_{segment.id}_indexing"
cache_result = redis_client.get(indexing_cache_key)
if cache_result is not None:
@@ -2349,6 +2383,7 @@ class SegmentService:
@classmethod
def delete_segments(cls, segment_ids: list, document: Document, dataset: Dataset):
assert isinstance(current_user, Account)
segments = (
db.session.query(DocumentSegment.index_node_id, DocumentSegment.word_count)
.where(
@@ -2379,6 +2414,8 @@ class SegmentService:
def update_segments_status(
cls, segment_ids: list, action: Literal["enable", "disable"], dataset: Dataset, document: Document
):
assert current_user is not None
# Check if segment_ids is not empty to avoid WHERE false condition
if not segment_ids or len(segment_ids) == 0:
return
@@ -2441,6 +2478,8 @@ class SegmentService:
def create_child_chunk(
cls, content: str, segment: DocumentSegment, document: Document, dataset: Dataset
) -> ChildChunk:
assert isinstance(current_user, Account)
lock_name = f"add_child_lock_{segment.id}"
with redis_client.lock(lock_name, timeout=20):
index_node_id = str(uuid.uuid4())
@@ -2488,6 +2527,8 @@ class SegmentService:
document: Document,
dataset: Dataset,
) -> list[ChildChunk]:
assert isinstance(current_user, Account)
child_chunks = (
db.session.query(ChildChunk)
.where(
@@ -2562,6 +2603,8 @@ class SegmentService:
document: Document,
dataset: Dataset,
) -> ChildChunk:
assert current_user is not None
try:
child_chunk.content = content
child_chunk.word_count = len(content)
@@ -2592,6 +2635,8 @@ class SegmentService:
def get_child_chunks(
cls, segment_id: str, document_id: str, dataset_id: str, page: int, limit: int, keyword: Optional[str] = None
):
assert isinstance(current_user, Account)
query = (
select(ChildChunk)
.filter_by(

View File

@@ -3,7 +3,6 @@ import os
import uuid
from typing import Any, Literal, Union
from flask_login import current_user
from werkzeug.exceptions import NotFound
from configs import dify_config
@@ -19,6 +18,7 @@ from extensions.ext_database import db
from extensions.ext_storage import storage
from libs.datetime_utils import naive_utc_now
from libs.helper import extract_tenant_id
from libs.login import current_user
from models.account import Account
from models.enums import CreatorUserRole
from models.model import EndUser, UploadFile
@@ -111,6 +111,9 @@ class FileService:
@staticmethod
def upload_text(text: str, text_name: str) -> UploadFile:
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
if len(text_name) > 200:
text_name = text_name[:200]
# user uuid as file name

View File

@@ -1,10 +1,11 @@
import json
from unittest.mock import MagicMock, patch
from unittest.mock import MagicMock, create_autospec, patch
import pytest
from faker import Faker
from core.plugin.impl.exc import PluginDaemonClientSideError
from models.account import Account
from models.model import AppModelConfig, Conversation, EndUser, Message, MessageAgentThought
from services.account_service import AccountService, TenantService
from services.agent_service import AgentService
@@ -21,7 +22,7 @@ class TestAgentService:
patch("services.agent_service.PluginAgentClient") as mock_plugin_agent_client,
patch("services.agent_service.ToolManager") as mock_tool_manager,
patch("services.agent_service.AgentConfigManager") as mock_agent_config_manager,
patch("services.agent_service.current_user") as mock_current_user,
patch("services.agent_service.current_user", create_autospec(Account, instance=True)) as mock_current_user,
patch("services.app_service.FeatureService") as mock_feature_service,
patch("services.app_service.EnterpriseService") as mock_enterprise_service,
patch("services.app_service.ModelManager") as mock_model_manager,

View File

@@ -1,9 +1,10 @@
from unittest.mock import patch
from unittest.mock import create_autospec, patch
import pytest
from faker import Faker
from werkzeug.exceptions import NotFound
from models.account import Account
from models.model import MessageAnnotation
from services.annotation_service import AppAnnotationService
from services.app_service import AppService
@@ -24,7 +25,9 @@ class TestAnnotationService:
patch("services.annotation_service.enable_annotation_reply_task") as mock_enable_task,
patch("services.annotation_service.disable_annotation_reply_task") as mock_disable_task,
patch("services.annotation_service.batch_import_annotations_task") as mock_batch_import_task,
patch("services.annotation_service.current_user") as mock_current_user,
patch(
"services.annotation_service.current_user", create_autospec(Account, instance=True)
) as mock_current_user,
):
# Setup default mock returns
mock_account_feature_service.get_features.return_value.billing.enabled = False

View File

@@ -1,9 +1,10 @@
from unittest.mock import patch
from unittest.mock import create_autospec, patch
import pytest
from faker import Faker
from constants.model_template import default_app_templates
from models.account import Account
from models.model import App, Site
from services.account_service import AccountService, TenantService
from services.app_service import AppService
@@ -161,8 +162,13 @@ class TestAppService:
app_service = AppService()
created_app = app_service.create_app(tenant.id, app_args, account)
# Get app using the service
retrieved_app = app_service.get_app(created_app)
# Get app using the service - needs current_user mock
mock_current_user = create_autospec(Account, instance=True)
mock_current_user.id = account.id
mock_current_user.current_tenant_id = account.current_tenant_id
with patch("services.app_service.current_user", mock_current_user):
retrieved_app = app_service.get_app(created_app)
# Verify retrieved app matches created app
assert retrieved_app.id == created_app.id
@@ -406,7 +412,11 @@ class TestAppService:
"use_icon_as_answer_icon": True,
}
with patch("flask_login.utils._get_user", return_value=account):
mock_current_user = create_autospec(Account, instance=True)
mock_current_user.id = account.id
mock_current_user.current_tenant_id = account.current_tenant_id
with patch("services.app_service.current_user", mock_current_user):
updated_app = app_service.update_app(app, update_args)
# Verify updated fields
@@ -456,7 +466,11 @@ class TestAppService:
# Update app name
new_name = "New App Name"
with patch("flask_login.utils._get_user", return_value=account):
mock_current_user = create_autospec(Account, instance=True)
mock_current_user.id = account.id
mock_current_user.current_tenant_id = account.current_tenant_id
with patch("services.app_service.current_user", mock_current_user):
updated_app = app_service.update_app_name(app, new_name)
assert updated_app.name == new_name
@@ -504,7 +518,11 @@ class TestAppService:
# Update app icon
new_icon = "🌟"
new_icon_background = "#FFD93D"
with patch("flask_login.utils._get_user", return_value=account):
mock_current_user = create_autospec(Account, instance=True)
mock_current_user.id = account.id
mock_current_user.current_tenant_id = account.current_tenant_id
with patch("services.app_service.current_user", mock_current_user):
updated_app = app_service.update_app_icon(app, new_icon, new_icon_background)
assert updated_app.icon == new_icon
@@ -551,13 +569,17 @@ class TestAppService:
original_site_status = app.enable_site
# Update site status to disabled
with patch("flask_login.utils._get_user", return_value=account):
mock_current_user = create_autospec(Account, instance=True)
mock_current_user.id = account.id
mock_current_user.current_tenant_id = account.current_tenant_id
with patch("services.app_service.current_user", mock_current_user):
updated_app = app_service.update_app_site_status(app, False)
assert updated_app.enable_site is False
assert updated_app.updated_by == account.id
# Update site status back to enabled
with patch("flask_login.utils._get_user", return_value=account):
with patch("services.app_service.current_user", mock_current_user):
updated_app = app_service.update_app_site_status(updated_app, True)
assert updated_app.enable_site is True
assert updated_app.updated_by == account.id
@@ -602,13 +624,17 @@ class TestAppService:
original_api_status = app.enable_api
# Update API status to disabled
with patch("flask_login.utils._get_user", return_value=account):
mock_current_user = create_autospec(Account, instance=True)
mock_current_user.id = account.id
mock_current_user.current_tenant_id = account.current_tenant_id
with patch("services.app_service.current_user", mock_current_user):
updated_app = app_service.update_app_api_status(app, False)
assert updated_app.enable_api is False
assert updated_app.updated_by == account.id
# Update API status back to enabled
with patch("flask_login.utils._get_user", return_value=account):
with patch("services.app_service.current_user", mock_current_user):
updated_app = app_service.update_app_api_status(updated_app, True)
assert updated_app.enable_api is True
assert updated_app.updated_by == account.id

View File

@@ -1,6 +1,6 @@
import hashlib
from io import BytesIO
from unittest.mock import patch
from unittest.mock import create_autospec, patch
import pytest
from faker import Faker
@@ -417,11 +417,12 @@ class TestFileService:
text = "This is a test text content"
text_name = "test_text.txt"
# Mock current_user
with patch("services.file_service.current_user") as mock_current_user:
mock_current_user.current_tenant_id = str(fake.uuid4())
mock_current_user.id = str(fake.uuid4())
# Mock current_user using create_autospec
mock_current_user = create_autospec(Account, instance=True)
mock_current_user.current_tenant_id = str(fake.uuid4())
mock_current_user.id = str(fake.uuid4())
with patch("services.file_service.current_user", mock_current_user):
upload_file = FileService.upload_text(text=text, text_name=text_name)
assert upload_file is not None
@@ -443,11 +444,12 @@ class TestFileService:
text = "test content"
long_name = "a" * 250 # Longer than 200 characters
# Mock current_user
with patch("services.file_service.current_user") as mock_current_user:
mock_current_user.current_tenant_id = str(fake.uuid4())
mock_current_user.id = str(fake.uuid4())
# Mock current_user using create_autospec
mock_current_user = create_autospec(Account, instance=True)
mock_current_user.current_tenant_id = str(fake.uuid4())
mock_current_user.id = str(fake.uuid4())
with patch("services.file_service.current_user", mock_current_user):
upload_file = FileService.upload_text(text=text, text_name=long_name)
# Verify name was truncated
@@ -846,11 +848,12 @@ class TestFileService:
text = ""
text_name = "empty.txt"
# Mock current_user
with patch("services.file_service.current_user") as mock_current_user:
mock_current_user.current_tenant_id = str(fake.uuid4())
mock_current_user.id = str(fake.uuid4())
# Mock current_user using create_autospec
mock_current_user = create_autospec(Account, instance=True)
mock_current_user.current_tenant_id = str(fake.uuid4())
mock_current_user.id = str(fake.uuid4())
with patch("services.file_service.current_user", mock_current_user):
upload_file = FileService.upload_text(text=text, text_name=text_name)
assert upload_file is not None

View File

@@ -1,4 +1,4 @@
from unittest.mock import patch
from unittest.mock import create_autospec, patch
import pytest
from faker import Faker
@@ -17,7 +17,9 @@ class TestMetadataService:
def mock_external_service_dependencies(self):
"""Mock setup for external service dependencies."""
with (
patch("services.metadata_service.current_user") as mock_current_user,
patch(
"services.metadata_service.current_user", create_autospec(Account, instance=True)
) as mock_current_user,
patch("services.metadata_service.redis_client") as mock_redis_client,
patch("services.dataset_service.DocumentService") as mock_document_service,
):

View File

@@ -1,4 +1,4 @@
from unittest.mock import patch
from unittest.mock import create_autospec, patch
import pytest
from faker import Faker
@@ -17,7 +17,7 @@ class TestTagService:
def mock_external_service_dependencies(self):
"""Mock setup for external service dependencies."""
with (
patch("services.tag_service.current_user") as mock_current_user,
patch("services.tag_service.current_user", create_autospec(Account, instance=True)) as mock_current_user,
):
# Setup default mock returns
mock_current_user.current_tenant_id = "test-tenant-id"

View File

@@ -1,5 +1,5 @@
from datetime import datetime
from unittest.mock import MagicMock, patch
from unittest.mock import MagicMock, create_autospec, patch
import pytest
from faker import Faker
@@ -231,9 +231,10 @@ class TestWebsiteService:
fake = Faker()
# Mock current_user for the test
with patch("services.website_service.current_user") as mock_current_user:
mock_current_user.current_tenant_id = account.current_tenant.id
mock_current_user = create_autospec(Account, instance=True)
mock_current_user.current_tenant_id = account.current_tenant.id
with patch("services.website_service.current_user", mock_current_user):
# Create API request
api_request = WebsiteCrawlApiRequest(
provider="firecrawl",
@@ -285,9 +286,10 @@ class TestWebsiteService:
account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies)
# Mock current_user for the test
with patch("services.website_service.current_user") as mock_current_user:
mock_current_user.current_tenant_id = account.current_tenant.id
mock_current_user = create_autospec(Account, instance=True)
mock_current_user.current_tenant_id = account.current_tenant.id
with patch("services.website_service.current_user", mock_current_user):
# Create API request
api_request = WebsiteCrawlApiRequest(
provider="watercrawl",
@@ -336,9 +338,10 @@ class TestWebsiteService:
account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies)
# Mock current_user for the test
with patch("services.website_service.current_user") as mock_current_user:
mock_current_user.current_tenant_id = account.current_tenant.id
mock_current_user = create_autospec(Account, instance=True)
mock_current_user.current_tenant_id = account.current_tenant.id
with patch("services.website_service.current_user", mock_current_user):
# Create API request for single page crawling
api_request = WebsiteCrawlApiRequest(
provider="jinareader",
@@ -389,9 +392,10 @@ class TestWebsiteService:
account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies)
# Mock current_user for the test
with patch("services.website_service.current_user") as mock_current_user:
mock_current_user.current_tenant_id = account.current_tenant.id
mock_current_user = create_autospec(Account, instance=True)
mock_current_user.current_tenant_id = account.current_tenant.id
with patch("services.website_service.current_user", mock_current_user):
# Create API request with invalid provider
api_request = WebsiteCrawlApiRequest(
provider="invalid_provider",
@@ -419,9 +423,10 @@ class TestWebsiteService:
account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies)
# Mock current_user for the test
with patch("services.website_service.current_user") as mock_current_user:
mock_current_user.current_tenant_id = account.current_tenant.id
mock_current_user = create_autospec(Account, instance=True)
mock_current_user.current_tenant_id = account.current_tenant.id
with patch("services.website_service.current_user", mock_current_user):
# Create API request
api_request = WebsiteCrawlStatusApiRequest(provider="firecrawl", job_id="test_job_id_123")
@@ -463,9 +468,10 @@ class TestWebsiteService:
account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies)
# Mock current_user for the test
with patch("services.website_service.current_user") as mock_current_user:
mock_current_user.current_tenant_id = account.current_tenant.id
mock_current_user = create_autospec(Account, instance=True)
mock_current_user.current_tenant_id = account.current_tenant.id
with patch("services.website_service.current_user", mock_current_user):
# Create API request
api_request = WebsiteCrawlStatusApiRequest(provider="watercrawl", job_id="watercrawl_job_123")
@@ -502,9 +508,10 @@ class TestWebsiteService:
account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies)
# Mock current_user for the test
with patch("services.website_service.current_user") as mock_current_user:
mock_current_user.current_tenant_id = account.current_tenant.id
mock_current_user = create_autospec(Account, instance=True)
mock_current_user.current_tenant_id = account.current_tenant.id
with patch("services.website_service.current_user", mock_current_user):
# Create API request
api_request = WebsiteCrawlStatusApiRequest(provider="jinareader", job_id="jina_job_123")
@@ -544,9 +551,10 @@ class TestWebsiteService:
account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies)
# Mock current_user for the test
with patch("services.website_service.current_user") as mock_current_user:
mock_current_user.current_tenant_id = account.current_tenant.id
mock_current_user = create_autospec(Account, instance=True)
mock_current_user.current_tenant_id = account.current_tenant.id
with patch("services.website_service.current_user", mock_current_user):
# Create API request with invalid provider
api_request = WebsiteCrawlStatusApiRequest(provider="invalid_provider", job_id="test_job_id_123")
@@ -569,9 +577,10 @@ class TestWebsiteService:
account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies)
# Mock current_user for the test
with patch("services.website_service.current_user") as mock_current_user:
mock_current_user.current_tenant_id = account.current_tenant.id
mock_current_user = create_autospec(Account, instance=True)
mock_current_user.current_tenant_id = account.current_tenant.id
with patch("services.website_service.current_user", mock_current_user):
# Mock missing credentials
mock_external_service_dependencies["api_key_auth_service"].get_auth_credentials.return_value = None
@@ -597,9 +606,10 @@ class TestWebsiteService:
account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies)
# Mock current_user for the test
with patch("services.website_service.current_user") as mock_current_user:
mock_current_user.current_tenant_id = account.current_tenant.id
mock_current_user = create_autospec(Account, instance=True)
mock_current_user.current_tenant_id = account.current_tenant.id
with patch("services.website_service.current_user", mock_current_user):
# Mock missing API key in config
mock_external_service_dependencies["api_key_auth_service"].get_auth_credentials.return_value = {
"config": {"base_url": "https://api.example.com"}
@@ -995,9 +1005,10 @@ class TestWebsiteService:
account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies)
# Mock current_user for the test
with patch("services.website_service.current_user") as mock_current_user:
mock_current_user.current_tenant_id = account.current_tenant.id
mock_current_user = create_autospec(Account, instance=True)
mock_current_user.current_tenant_id = account.current_tenant.id
with patch("services.website_service.current_user", mock_current_user):
# Create API request for sub-page crawling
api_request = WebsiteCrawlApiRequest(
provider="jinareader",
@@ -1054,9 +1065,10 @@ class TestWebsiteService:
mock_external_service_dependencies["requests"].get.return_value = mock_failed_response
# Mock current_user for the test
with patch("services.website_service.current_user") as mock_current_user:
mock_current_user.current_tenant_id = account.current_tenant.id
mock_current_user = create_autospec(Account, instance=True)
mock_current_user.current_tenant_id = account.current_tenant.id
with patch("services.website_service.current_user", mock_current_user):
# Create API request
api_request = WebsiteCrawlApiRequest(
provider="jinareader",
@@ -1096,9 +1108,10 @@ class TestWebsiteService:
mock_external_service_dependencies["firecrawl_app"].return_value = mock_firecrawl_instance
# Mock current_user for the test
with patch("services.website_service.current_user") as mock_current_user:
mock_current_user.current_tenant_id = account.current_tenant.id
mock_current_user = create_autospec(Account, instance=True)
mock_current_user.current_tenant_id = account.current_tenant.id
with patch("services.website_service.current_user", mock_current_user):
# Create API request
api_request = WebsiteCrawlStatusApiRequest(provider="firecrawl", job_id="active_job_123")

View File

@@ -2,11 +2,12 @@ import datetime
from typing import Any, Optional
# Mock redis_client before importing dataset_service
from unittest.mock import Mock, patch
from unittest.mock import Mock, create_autospec, patch
import pytest
from core.model_runtime.entities.model_entities import ModelType
from models.account import Account
from models.dataset import Dataset, ExternalKnowledgeBindings
from services.dataset_service import DatasetService
from services.errors.account import NoPermissionError
@@ -78,7 +79,7 @@ class DatasetUpdateTestDataFactory:
@staticmethod
def create_current_user_mock(tenant_id: str = "tenant-123") -> Mock:
"""Create a mock current user."""
current_user = Mock()
current_user = create_autospec(Account, instance=True)
current_user.current_tenant_id = tenant_id
return current_user
@@ -135,7 +136,9 @@ class TestDatasetServiceUpdateDataset:
"services.dataset_service.DatasetCollectionBindingService.get_dataset_collection_binding"
) as mock_get_binding,
patch("services.dataset_service.deal_dataset_vector_index_task") as mock_task,
patch("services.dataset_service.current_user") as mock_current_user,
patch(
"services.dataset_service.current_user", create_autospec(Account, instance=True)
) as mock_current_user,
):
mock_current_user.current_tenant_id = "tenant-123"
yield {

View File

@@ -1,9 +1,10 @@
from unittest.mock import Mock, patch
from unittest.mock import Mock, create_autospec, patch
import pytest
from flask_restx import reqparse
from werkzeug.exceptions import BadRequest
from models.account import Account
from services.entities.knowledge_entities.knowledge_entities import MetadataArgs
from services.metadata_service import MetadataService
@@ -35,19 +36,21 @@ class TestMetadataBugCompleteValidation:
mock_metadata_args.name = None
mock_metadata_args.type = "string"
with patch("services.metadata_service.current_user") as mock_user:
mock_user.current_tenant_id = "tenant-123"
mock_user.id = "user-456"
mock_user = create_autospec(Account, instance=True)
mock_user.current_tenant_id = "tenant-123"
mock_user.id = "user-456"
with patch("services.metadata_service.current_user", mock_user):
# Should crash with TypeError
with pytest.raises(TypeError, match="object of type 'NoneType' has no len"):
MetadataService.create_metadata("dataset-123", mock_metadata_args)
# Test update method as well
with patch("services.metadata_service.current_user") as mock_user:
mock_user.current_tenant_id = "tenant-123"
mock_user.id = "user-456"
mock_user = create_autospec(Account, instance=True)
mock_user.current_tenant_id = "tenant-123"
mock_user.id = "user-456"
with patch("services.metadata_service.current_user", mock_user):
with pytest.raises(TypeError, match="object of type 'NoneType' has no len"):
MetadataService.update_metadata_name("dataset-123", "metadata-456", None)

View File

@@ -1,8 +1,9 @@
from unittest.mock import Mock, patch
from unittest.mock import Mock, create_autospec, patch
import pytest
from flask_restx import reqparse
from models.account import Account
from services.entities.knowledge_entities.knowledge_entities import MetadataArgs
from services.metadata_service import MetadataService
@@ -24,20 +25,22 @@ class TestMetadataNullableBug:
mock_metadata_args.name = None # This will cause len() to crash
mock_metadata_args.type = "string"
with patch("services.metadata_service.current_user") as mock_user:
mock_user.current_tenant_id = "tenant-123"
mock_user.id = "user-456"
mock_user = create_autospec(Account, instance=True)
mock_user.current_tenant_id = "tenant-123"
mock_user.id = "user-456"
with patch("services.metadata_service.current_user", mock_user):
# This should crash with TypeError when calling len(None)
with pytest.raises(TypeError, match="object of type 'NoneType' has no len"):
MetadataService.create_metadata("dataset-123", mock_metadata_args)
def test_metadata_service_update_with_none_name_crashes(self):
"""Test that MetadataService.update_metadata_name crashes when name is None."""
with patch("services.metadata_service.current_user") as mock_user:
mock_user.current_tenant_id = "tenant-123"
mock_user.id = "user-456"
mock_user = create_autospec(Account, instance=True)
mock_user.current_tenant_id = "tenant-123"
mock_user.id = "user-456"
with patch("services.metadata_service.current_user", mock_user):
# This should crash with TypeError when calling len(None)
with pytest.raises(TypeError, match="object of type 'NoneType' has no len"):
MetadataService.update_metadata_name("dataset-123", "metadata-456", None)
@@ -81,10 +84,11 @@ class TestMetadataNullableBug:
mock_metadata_args.name = None # From args["name"]
mock_metadata_args.type = None # From args["type"]
with patch("services.metadata_service.current_user") as mock_user:
mock_user.current_tenant_id = "tenant-123"
mock_user.id = "user-456"
mock_user = create_autospec(Account, instance=True)
mock_user.current_tenant_id = "tenant-123"
mock_user.id = "user-456"
with patch("services.metadata_service.current_user", mock_user):
# Step 4: Service layer crashes on len(None)
with pytest.raises(TypeError, match="object of type 'NoneType' has no len"):
MetadataService.create_metadata("dataset-123", mock_metadata_args)