test: add comprehensive OAuth authentication unit tests (#22528)
This commit is contained in:
496
api/tests/unit_tests/controllers/console/auth/test_oauth.py
Normal file
496
api/tests/unit_tests/controllers/console/auth/test_oauth.py
Normal file
@@ -0,0 +1,496 @@
|
|||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from flask import Flask
|
||||||
|
|
||||||
|
from controllers.console.auth.oauth import (
|
||||||
|
OAuthCallback,
|
||||||
|
OAuthLogin,
|
||||||
|
_generate_account,
|
||||||
|
_get_account_by_openid_or_email,
|
||||||
|
get_oauth_providers,
|
||||||
|
)
|
||||||
|
from libs.oauth import OAuthUserInfo
|
||||||
|
from models.account import AccountStatus
|
||||||
|
from services.errors.account import AccountNotFoundError
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetOAuthProviders:
|
||||||
|
@pytest.fixture
|
||||||
|
def app(self):
|
||||||
|
app = Flask(__name__)
|
||||||
|
app.config["TESTING"] = True
|
||||||
|
return app
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
("github_config", "google_config", "expected_github", "expected_google"),
|
||||||
|
[
|
||||||
|
# Both providers configured
|
||||||
|
(
|
||||||
|
{"id": "github_id", "secret": "github_secret"},
|
||||||
|
{"id": "google_id", "secret": "google_secret"},
|
||||||
|
True,
|
||||||
|
True,
|
||||||
|
),
|
||||||
|
# Only GitHub configured
|
||||||
|
({"id": "github_id", "secret": "github_secret"}, {"id": None, "secret": None}, True, False),
|
||||||
|
# Only Google configured
|
||||||
|
({"id": None, "secret": None}, {"id": "google_id", "secret": "google_secret"}, False, True),
|
||||||
|
# No providers configured
|
||||||
|
({"id": None, "secret": None}, {"id": None, "secret": None}, False, False),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
@patch("controllers.console.auth.oauth.dify_config")
|
||||||
|
def test_should_configure_oauth_providers_correctly(
|
||||||
|
self, mock_config, app, github_config, google_config, expected_github, expected_google
|
||||||
|
):
|
||||||
|
mock_config.GITHUB_CLIENT_ID = github_config["id"]
|
||||||
|
mock_config.GITHUB_CLIENT_SECRET = github_config["secret"]
|
||||||
|
mock_config.GOOGLE_CLIENT_ID = google_config["id"]
|
||||||
|
mock_config.GOOGLE_CLIENT_SECRET = google_config["secret"]
|
||||||
|
mock_config.CONSOLE_API_URL = "http://localhost"
|
||||||
|
|
||||||
|
with app.app_context():
|
||||||
|
providers = get_oauth_providers()
|
||||||
|
|
||||||
|
assert (providers["github"] is not None) == expected_github
|
||||||
|
assert (providers["google"] is not None) == expected_google
|
||||||
|
|
||||||
|
|
||||||
|
class TestOAuthLogin:
|
||||||
|
@pytest.fixture
|
||||||
|
def resource(self):
|
||||||
|
return OAuthLogin()
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def app(self):
|
||||||
|
app = Flask(__name__)
|
||||||
|
app.config["TESTING"] = True
|
||||||
|
return app
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_oauth_provider(self):
|
||||||
|
provider = MagicMock()
|
||||||
|
provider.get_authorization_url.return_value = "https://github.com/login/oauth/authorize?..."
|
||||||
|
return provider
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
("invite_token", "expected_token"),
|
||||||
|
[
|
||||||
|
(None, None),
|
||||||
|
("test_invite_token", "test_invite_token"),
|
||||||
|
("", None),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
@patch("controllers.console.auth.oauth.get_oauth_providers")
|
||||||
|
@patch("controllers.console.auth.oauth.redirect")
|
||||||
|
def test_should_handle_oauth_login_with_various_tokens(
|
||||||
|
self,
|
||||||
|
mock_redirect,
|
||||||
|
mock_get_providers,
|
||||||
|
resource,
|
||||||
|
app,
|
||||||
|
mock_oauth_provider,
|
||||||
|
invite_token,
|
||||||
|
expected_token,
|
||||||
|
):
|
||||||
|
mock_get_providers.return_value = {"github": mock_oauth_provider, "google": None}
|
||||||
|
|
||||||
|
query_string = f"invite_token={invite_token}" if invite_token else ""
|
||||||
|
with app.test_request_context(f"/auth/oauth/github?{query_string}"):
|
||||||
|
resource.get("github")
|
||||||
|
|
||||||
|
mock_oauth_provider.get_authorization_url.assert_called_once_with(invite_token=expected_token)
|
||||||
|
mock_redirect.assert_called_once_with("https://github.com/login/oauth/authorize?...")
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
("provider", "expected_error"),
|
||||||
|
[
|
||||||
|
("invalid_provider", "Invalid provider"),
|
||||||
|
("github", "Invalid provider"), # When GitHub is not configured
|
||||||
|
("google", "Invalid provider"), # When Google is not configured
|
||||||
|
],
|
||||||
|
)
|
||||||
|
@patch("controllers.console.auth.oauth.get_oauth_providers")
|
||||||
|
def test_should_return_error_for_invalid_providers(
|
||||||
|
self, mock_get_providers, resource, app, provider, expected_error
|
||||||
|
):
|
||||||
|
mock_get_providers.return_value = {"github": None, "google": None}
|
||||||
|
|
||||||
|
with app.test_request_context(f"/auth/oauth/{provider}"):
|
||||||
|
response, status_code = resource.get(provider)
|
||||||
|
|
||||||
|
assert status_code == 400
|
||||||
|
assert response["error"] == expected_error
|
||||||
|
|
||||||
|
|
||||||
|
class TestOAuthCallback:
|
||||||
|
@pytest.fixture
|
||||||
|
def resource(self):
|
||||||
|
return OAuthCallback()
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def app(self):
|
||||||
|
app = Flask(__name__)
|
||||||
|
app.config["TESTING"] = True
|
||||||
|
return app
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def oauth_setup(self):
|
||||||
|
"""Common OAuth setup for callback tests"""
|
||||||
|
oauth_provider = MagicMock()
|
||||||
|
oauth_provider.get_access_token.return_value = "access_token"
|
||||||
|
oauth_provider.get_user_info.return_value = OAuthUserInfo(id="123", name="Test User", email="test@example.com")
|
||||||
|
|
||||||
|
account = MagicMock()
|
||||||
|
account.status = AccountStatus.ACTIVE.value
|
||||||
|
|
||||||
|
token_pair = MagicMock()
|
||||||
|
token_pair.access_token = "jwt_access_token"
|
||||||
|
token_pair.refresh_token = "jwt_refresh_token"
|
||||||
|
|
||||||
|
return {"provider": oauth_provider, "account": account, "token_pair": token_pair}
|
||||||
|
|
||||||
|
@patch("controllers.console.auth.oauth.dify_config")
|
||||||
|
@patch("controllers.console.auth.oauth.get_oauth_providers")
|
||||||
|
@patch("controllers.console.auth.oauth._generate_account")
|
||||||
|
@patch("controllers.console.auth.oauth.AccountService")
|
||||||
|
@patch("controllers.console.auth.oauth.TenantService")
|
||||||
|
@patch("controllers.console.auth.oauth.redirect")
|
||||||
|
def test_should_handle_successful_oauth_callback(
|
||||||
|
self,
|
||||||
|
mock_redirect,
|
||||||
|
mock_tenant_service,
|
||||||
|
mock_account_service,
|
||||||
|
mock_generate_account,
|
||||||
|
mock_get_providers,
|
||||||
|
mock_config,
|
||||||
|
resource,
|
||||||
|
app,
|
||||||
|
oauth_setup,
|
||||||
|
):
|
||||||
|
mock_config.CONSOLE_WEB_URL = "http://localhost:3000"
|
||||||
|
mock_get_providers.return_value = {"github": oauth_setup["provider"]}
|
||||||
|
mock_generate_account.return_value = oauth_setup["account"]
|
||||||
|
mock_account_service.login.return_value = oauth_setup["token_pair"]
|
||||||
|
|
||||||
|
with app.test_request_context("/auth/oauth/github/callback?code=test_code"):
|
||||||
|
resource.get("github")
|
||||||
|
|
||||||
|
oauth_setup["provider"].get_access_token.assert_called_once_with("test_code")
|
||||||
|
oauth_setup["provider"].get_user_info.assert_called_once_with("access_token")
|
||||||
|
mock_redirect.assert_called_once_with(
|
||||||
|
"http://localhost:3000?access_token=jwt_access_token&refresh_token=jwt_refresh_token"
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
("exception", "expected_error"),
|
||||||
|
[
|
||||||
|
(Exception("OAuth error"), "OAuth process failed"),
|
||||||
|
(ValueError("Invalid token"), "OAuth process failed"),
|
||||||
|
(KeyError("Missing key"), "OAuth process failed"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
@patch("controllers.console.auth.oauth.db")
|
||||||
|
@patch("controllers.console.auth.oauth.get_oauth_providers")
|
||||||
|
def test_should_handle_oauth_exceptions(
|
||||||
|
self, mock_get_providers, mock_db, resource, app, exception, expected_error
|
||||||
|
):
|
||||||
|
# Mock database session
|
||||||
|
mock_db.session = MagicMock()
|
||||||
|
mock_db.session.rollback = MagicMock()
|
||||||
|
|
||||||
|
# Import the real requests module to create a proper exception
|
||||||
|
import requests
|
||||||
|
|
||||||
|
request_exception = requests.exceptions.RequestException("OAuth error")
|
||||||
|
request_exception.response = MagicMock()
|
||||||
|
request_exception.response.text = str(exception)
|
||||||
|
|
||||||
|
mock_oauth_provider = MagicMock()
|
||||||
|
mock_oauth_provider.get_access_token.side_effect = request_exception
|
||||||
|
mock_get_providers.return_value = {"github": mock_oauth_provider}
|
||||||
|
|
||||||
|
with app.test_request_context("/auth/oauth/github/callback?code=test_code"):
|
||||||
|
response, status_code = resource.get("github")
|
||||||
|
|
||||||
|
assert status_code == 400
|
||||||
|
assert response["error"] == expected_error
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
("account_status", "expected_redirect"),
|
||||||
|
[
|
||||||
|
(AccountStatus.BANNED.value, "http://localhost:3000/signin?message=Account is banned."),
|
||||||
|
# CLOSED status: Currently NOT handled, will proceed to login (security issue)
|
||||||
|
# This documents actual behavior. See test_defensive_check_for_closed_account_status for details
|
||||||
|
(
|
||||||
|
AccountStatus.CLOSED.value,
|
||||||
|
"http://localhost:3000?access_token=jwt_access_token&refresh_token=jwt_refresh_token",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
@patch("controllers.console.auth.oauth.AccountService")
|
||||||
|
@patch("controllers.console.auth.oauth.TenantService")
|
||||||
|
@patch("controllers.console.auth.oauth.db")
|
||||||
|
@patch("controllers.console.auth.oauth.dify_config")
|
||||||
|
@patch("controllers.console.auth.oauth.get_oauth_providers")
|
||||||
|
@patch("controllers.console.auth.oauth._generate_account")
|
||||||
|
@patch("controllers.console.auth.oauth.redirect")
|
||||||
|
def test_should_redirect_based_on_account_status(
|
||||||
|
self,
|
||||||
|
mock_redirect,
|
||||||
|
mock_generate_account,
|
||||||
|
mock_get_providers,
|
||||||
|
mock_config,
|
||||||
|
mock_db,
|
||||||
|
mock_tenant_service,
|
||||||
|
mock_account_service,
|
||||||
|
resource,
|
||||||
|
app,
|
||||||
|
oauth_setup,
|
||||||
|
account_status,
|
||||||
|
expected_redirect,
|
||||||
|
):
|
||||||
|
# Mock database session
|
||||||
|
mock_db.session = MagicMock()
|
||||||
|
mock_db.session.rollback = MagicMock()
|
||||||
|
mock_db.session.commit = MagicMock()
|
||||||
|
|
||||||
|
mock_config.CONSOLE_WEB_URL = "http://localhost:3000"
|
||||||
|
mock_get_providers.return_value = {"github": oauth_setup["provider"]}
|
||||||
|
|
||||||
|
account = MagicMock()
|
||||||
|
account.status = account_status
|
||||||
|
account.id = "123"
|
||||||
|
mock_generate_account.return_value = account
|
||||||
|
|
||||||
|
# Mock login for CLOSED status
|
||||||
|
mock_token_pair = MagicMock()
|
||||||
|
mock_token_pair.access_token = "jwt_access_token"
|
||||||
|
mock_token_pair.refresh_token = "jwt_refresh_token"
|
||||||
|
mock_account_service.login.return_value = mock_token_pair
|
||||||
|
|
||||||
|
with app.test_request_context("/auth/oauth/github/callback?code=test_code"):
|
||||||
|
resource.get("github")
|
||||||
|
|
||||||
|
mock_redirect.assert_called_once_with(expected_redirect)
|
||||||
|
|
||||||
|
@patch("controllers.console.auth.oauth.dify_config")
|
||||||
|
@patch("controllers.console.auth.oauth.get_oauth_providers")
|
||||||
|
@patch("controllers.console.auth.oauth._generate_account")
|
||||||
|
@patch("controllers.console.auth.oauth.db")
|
||||||
|
@patch("controllers.console.auth.oauth.TenantService")
|
||||||
|
@patch("controllers.console.auth.oauth.AccountService")
|
||||||
|
def test_should_activate_pending_account(
|
||||||
|
self,
|
||||||
|
mock_account_service,
|
||||||
|
mock_tenant_service,
|
||||||
|
mock_db,
|
||||||
|
mock_generate_account,
|
||||||
|
mock_get_providers,
|
||||||
|
mock_config,
|
||||||
|
resource,
|
||||||
|
app,
|
||||||
|
oauth_setup,
|
||||||
|
):
|
||||||
|
mock_get_providers.return_value = {"github": oauth_setup["provider"]}
|
||||||
|
|
||||||
|
mock_account = MagicMock()
|
||||||
|
mock_account.status = AccountStatus.PENDING.value
|
||||||
|
mock_generate_account.return_value = mock_account
|
||||||
|
|
||||||
|
with app.test_request_context("/auth/oauth/github/callback?code=test_code"):
|
||||||
|
resource.get("github")
|
||||||
|
|
||||||
|
assert mock_account.status == AccountStatus.ACTIVE.value
|
||||||
|
assert mock_account.initialized_at is not None
|
||||||
|
mock_db.session.commit.assert_called_once()
|
||||||
|
|
||||||
|
@patch("controllers.console.auth.oauth.dify_config")
|
||||||
|
@patch("controllers.console.auth.oauth.get_oauth_providers")
|
||||||
|
@patch("controllers.console.auth.oauth._generate_account")
|
||||||
|
@patch("controllers.console.auth.oauth.db")
|
||||||
|
@patch("controllers.console.auth.oauth.TenantService")
|
||||||
|
@patch("controllers.console.auth.oauth.AccountService")
|
||||||
|
@patch("controllers.console.auth.oauth.redirect")
|
||||||
|
def test_defensive_check_for_closed_account_status(
|
||||||
|
self,
|
||||||
|
mock_redirect,
|
||||||
|
mock_account_service,
|
||||||
|
mock_tenant_service,
|
||||||
|
mock_db,
|
||||||
|
mock_generate_account,
|
||||||
|
mock_get_providers,
|
||||||
|
mock_config,
|
||||||
|
resource,
|
||||||
|
app,
|
||||||
|
oauth_setup,
|
||||||
|
):
|
||||||
|
"""Defensive test for CLOSED account status handling in OAuth callback.
|
||||||
|
|
||||||
|
This is a defensive test documenting expected security behavior for CLOSED accounts.
|
||||||
|
|
||||||
|
Current behavior: CLOSED status is NOT checked, allowing closed accounts to login.
|
||||||
|
Expected behavior: CLOSED accounts should be rejected like BANNED accounts.
|
||||||
|
|
||||||
|
Context:
|
||||||
|
- AccountStatus.CLOSED is defined in the enum but never used in production
|
||||||
|
- The close_account() method exists but is never called
|
||||||
|
- Account deletion uses external service instead of status change
|
||||||
|
- All authentication services (OAuth, password, email) don't check CLOSED status
|
||||||
|
|
||||||
|
TODO: If CLOSED status is implemented in the future:
|
||||||
|
1. Update OAuth callback to check for CLOSED status
|
||||||
|
2. Add similar checks to all authentication services for consistency
|
||||||
|
3. Update this test to verify the rejection behavior
|
||||||
|
|
||||||
|
Security consideration: Until properly implemented, CLOSED status provides no protection.
|
||||||
|
"""
|
||||||
|
# Setup
|
||||||
|
mock_config.CONSOLE_WEB_URL = "http://localhost:3000"
|
||||||
|
mock_get_providers.return_value = {"github": oauth_setup["provider"]}
|
||||||
|
|
||||||
|
# Create account with CLOSED status
|
||||||
|
closed_account = MagicMock()
|
||||||
|
closed_account.status = AccountStatus.CLOSED.value
|
||||||
|
closed_account.id = "123"
|
||||||
|
closed_account.name = "Closed Account"
|
||||||
|
mock_generate_account.return_value = closed_account
|
||||||
|
|
||||||
|
# Mock successful login (current behavior)
|
||||||
|
mock_token_pair = MagicMock()
|
||||||
|
mock_token_pair.access_token = "jwt_access_token"
|
||||||
|
mock_token_pair.refresh_token = "jwt_refresh_token"
|
||||||
|
mock_account_service.login.return_value = mock_token_pair
|
||||||
|
|
||||||
|
# Execute OAuth callback
|
||||||
|
with app.test_request_context("/auth/oauth/github/callback?code=test_code"):
|
||||||
|
resource.get("github")
|
||||||
|
|
||||||
|
# Verify current behavior: login succeeds (this is NOT ideal)
|
||||||
|
mock_redirect.assert_called_once_with(
|
||||||
|
"http://localhost:3000?access_token=jwt_access_token&refresh_token=jwt_refresh_token"
|
||||||
|
)
|
||||||
|
mock_account_service.login.assert_called_once()
|
||||||
|
|
||||||
|
# Document expected behavior in comments:
|
||||||
|
# Expected: mock_redirect.assert_called_once_with(
|
||||||
|
# "http://localhost:3000/signin?message=Account is closed."
|
||||||
|
# )
|
||||||
|
# Expected: mock_account_service.login.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
class TestAccountGeneration:
|
||||||
|
@pytest.fixture
|
||||||
|
def user_info(self):
|
||||||
|
return OAuthUserInfo(id="123", name="Test User", email="test@example.com")
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_account(self):
|
||||||
|
account = MagicMock()
|
||||||
|
account.name = "Test User"
|
||||||
|
return account
|
||||||
|
|
||||||
|
@patch("controllers.console.auth.oauth.db")
|
||||||
|
@patch("controllers.console.auth.oauth.Account")
|
||||||
|
@patch("controllers.console.auth.oauth.Session")
|
||||||
|
@patch("controllers.console.auth.oauth.select")
|
||||||
|
def test_should_get_account_by_openid_or_email(
|
||||||
|
self, mock_select, mock_session, mock_account_model, mock_db, user_info, mock_account
|
||||||
|
):
|
||||||
|
# Mock db.engine for Session creation
|
||||||
|
mock_db.engine = MagicMock()
|
||||||
|
|
||||||
|
# Test OpenID found
|
||||||
|
mock_account_model.get_by_openid.return_value = mock_account
|
||||||
|
result = _get_account_by_openid_or_email("github", user_info)
|
||||||
|
assert result == mock_account
|
||||||
|
mock_account_model.get_by_openid.assert_called_once_with("github", "123")
|
||||||
|
|
||||||
|
# Test fallback to email
|
||||||
|
mock_account_model.get_by_openid.return_value = None
|
||||||
|
mock_session_instance = MagicMock()
|
||||||
|
mock_session_instance.execute.return_value.scalar_one_or_none.return_value = mock_account
|
||||||
|
mock_session.return_value.__enter__.return_value = mock_session_instance
|
||||||
|
|
||||||
|
result = _get_account_by_openid_or_email("github", user_info)
|
||||||
|
assert result == mock_account
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
("allow_register", "existing_account", "should_create"),
|
||||||
|
[
|
||||||
|
(True, None, True), # New account creation allowed
|
||||||
|
(True, "existing", False), # Existing account
|
||||||
|
(False, None, False), # Registration not allowed
|
||||||
|
],
|
||||||
|
)
|
||||||
|
@patch("controllers.console.auth.oauth._get_account_by_openid_or_email")
|
||||||
|
@patch("controllers.console.auth.oauth.FeatureService")
|
||||||
|
@patch("controllers.console.auth.oauth.RegisterService")
|
||||||
|
@patch("controllers.console.auth.oauth.AccountService")
|
||||||
|
@patch("controllers.console.auth.oauth.TenantService")
|
||||||
|
@patch("controllers.console.auth.oauth.db")
|
||||||
|
def test_should_handle_account_generation_scenarios(
|
||||||
|
self,
|
||||||
|
mock_db,
|
||||||
|
mock_tenant_service,
|
||||||
|
mock_account_service,
|
||||||
|
mock_register_service,
|
||||||
|
mock_feature_service,
|
||||||
|
mock_get_account,
|
||||||
|
app,
|
||||||
|
user_info,
|
||||||
|
mock_account,
|
||||||
|
allow_register,
|
||||||
|
existing_account,
|
||||||
|
should_create,
|
||||||
|
):
|
||||||
|
mock_get_account.return_value = mock_account if existing_account else None
|
||||||
|
mock_feature_service.get_system_features.return_value.is_allow_register = allow_register
|
||||||
|
mock_register_service.register.return_value = mock_account
|
||||||
|
|
||||||
|
with app.test_request_context(headers={"Accept-Language": "en-US,en;q=0.9"}):
|
||||||
|
if not allow_register and not existing_account:
|
||||||
|
with pytest.raises(AccountNotFoundError):
|
||||||
|
_generate_account("github", user_info)
|
||||||
|
else:
|
||||||
|
result = _generate_account("github", user_info)
|
||||||
|
assert result == mock_account
|
||||||
|
|
||||||
|
if should_create:
|
||||||
|
mock_register_service.register.assert_called_once_with(
|
||||||
|
email="test@example.com", name="Test User", password=None, open_id="123", provider="github"
|
||||||
|
)
|
||||||
|
|
||||||
|
@patch("controllers.console.auth.oauth._get_account_by_openid_or_email")
|
||||||
|
@patch("controllers.console.auth.oauth.TenantService")
|
||||||
|
@patch("controllers.console.auth.oauth.FeatureService")
|
||||||
|
@patch("controllers.console.auth.oauth.AccountService")
|
||||||
|
@patch("controllers.console.auth.oauth.tenant_was_created")
|
||||||
|
def test_should_create_workspace_for_account_without_tenant(
|
||||||
|
self,
|
||||||
|
mock_event,
|
||||||
|
mock_account_service,
|
||||||
|
mock_feature_service,
|
||||||
|
mock_tenant_service,
|
||||||
|
mock_get_account,
|
||||||
|
app,
|
||||||
|
user_info,
|
||||||
|
mock_account,
|
||||||
|
):
|
||||||
|
mock_get_account.return_value = mock_account
|
||||||
|
mock_tenant_service.get_join_tenants.return_value = []
|
||||||
|
mock_feature_service.get_system_features.return_value.is_allow_create_workspace = True
|
||||||
|
|
||||||
|
mock_new_tenant = MagicMock()
|
||||||
|
mock_tenant_service.create_tenant.return_value = mock_new_tenant
|
||||||
|
|
||||||
|
with app.test_request_context(headers={"Accept-Language": "en-US,en;q=0.9"}):
|
||||||
|
result = _generate_account("github", user_info)
|
||||||
|
|
||||||
|
assert result == mock_account
|
||||||
|
mock_tenant_service.create_tenant.assert_called_once_with("Test User's Workspace")
|
||||||
|
mock_tenant_service.create_tenant_member.assert_called_once_with(
|
||||||
|
mock_new_tenant, mock_account, role="owner"
|
||||||
|
)
|
||||||
|
mock_event.send.assert_called_once_with(mock_new_tenant)
|
249
api/tests/unit_tests/libs/test_oauth_clients.py
Normal file
249
api/tests/unit_tests/libs/test_oauth_clients.py
Normal file
@@ -0,0 +1,249 @@
|
|||||||
|
import urllib.parse
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import requests
|
||||||
|
|
||||||
|
from libs.oauth import GitHubOAuth, GoogleOAuth, OAuthUserInfo
|
||||||
|
|
||||||
|
|
||||||
|
class BaseOAuthTest:
|
||||||
|
"""Base class for OAuth provider tests with common fixtures"""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def oauth_config(self):
|
||||||
|
return {
|
||||||
|
"client_id": "test_client_id",
|
||||||
|
"client_secret": "test_client_secret",
|
||||||
|
"redirect_uri": "http://localhost/callback",
|
||||||
|
}
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_response(self):
|
||||||
|
response = MagicMock()
|
||||||
|
response.json.return_value = {}
|
||||||
|
return response
|
||||||
|
|
||||||
|
def parse_auth_url(self, url):
|
||||||
|
"""Helper to parse authorization URL"""
|
||||||
|
parsed = urllib.parse.urlparse(url)
|
||||||
|
params = urllib.parse.parse_qs(parsed.query)
|
||||||
|
return parsed, params
|
||||||
|
|
||||||
|
|
||||||
|
class TestGitHubOAuth(BaseOAuthTest):
|
||||||
|
@pytest.fixture
|
||||||
|
def oauth(self, oauth_config):
|
||||||
|
return GitHubOAuth(oauth_config["client_id"], oauth_config["client_secret"], oauth_config["redirect_uri"])
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
("invite_token", "expected_state"),
|
||||||
|
[
|
||||||
|
(None, None),
|
||||||
|
("test_invite_token", "test_invite_token"),
|
||||||
|
("", None),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_should_generate_authorization_url_correctly(self, oauth, oauth_config, invite_token, expected_state):
|
||||||
|
url = oauth.get_authorization_url(invite_token)
|
||||||
|
parsed, params = self.parse_auth_url(url)
|
||||||
|
|
||||||
|
assert parsed.scheme == "https"
|
||||||
|
assert parsed.netloc == "github.com"
|
||||||
|
assert parsed.path == "/login/oauth/authorize"
|
||||||
|
assert params["client_id"][0] == oauth_config["client_id"]
|
||||||
|
assert params["redirect_uri"][0] == oauth_config["redirect_uri"]
|
||||||
|
assert params["scope"][0] == "user:email"
|
||||||
|
|
||||||
|
if expected_state:
|
||||||
|
assert params["state"][0] == expected_state
|
||||||
|
else:
|
||||||
|
assert "state" not in params
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
("response_data", "expected_token", "should_raise"),
|
||||||
|
[
|
||||||
|
({"access_token": "test_token"}, "test_token", False),
|
||||||
|
({"error": "invalid_grant"}, None, True),
|
||||||
|
({}, None, True),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
@patch("requests.post")
|
||||||
|
def test_should_retrieve_access_token(
|
||||||
|
self, mock_post, oauth, mock_response, response_data, expected_token, should_raise
|
||||||
|
):
|
||||||
|
mock_response.json.return_value = response_data
|
||||||
|
mock_post.return_value = mock_response
|
||||||
|
|
||||||
|
if should_raise:
|
||||||
|
with pytest.raises(ValueError) as exc_info:
|
||||||
|
oauth.get_access_token("test_code")
|
||||||
|
assert "Error in GitHub OAuth" in str(exc_info.value)
|
||||||
|
else:
|
||||||
|
token = oauth.get_access_token("test_code")
|
||||||
|
assert token == expected_token
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
("user_data", "email_data", "expected_email"),
|
||||||
|
[
|
||||||
|
# User with primary email
|
||||||
|
(
|
||||||
|
{"id": 12345, "login": "testuser", "name": "Test User"},
|
||||||
|
[
|
||||||
|
{"email": "secondary@example.com", "primary": False},
|
||||||
|
{"email": "primary@example.com", "primary": True},
|
||||||
|
],
|
||||||
|
"primary@example.com",
|
||||||
|
),
|
||||||
|
# User with no emails - fallback to noreply
|
||||||
|
({"id": 12345, "login": "testuser", "name": "Test User"}, [], "12345+testuser@users.noreply.github.com"),
|
||||||
|
# User with only secondary email - fallback to noreply
|
||||||
|
(
|
||||||
|
{"id": 12345, "login": "testuser", "name": "Test User"},
|
||||||
|
[{"email": "secondary@example.com", "primary": False}],
|
||||||
|
"12345+testuser@users.noreply.github.com",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
@patch("requests.get")
|
||||||
|
def test_should_retrieve_user_info_correctly(self, mock_get, oauth, user_data, email_data, expected_email):
|
||||||
|
user_response = MagicMock()
|
||||||
|
user_response.json.return_value = user_data
|
||||||
|
|
||||||
|
email_response = MagicMock()
|
||||||
|
email_response.json.return_value = email_data
|
||||||
|
|
||||||
|
mock_get.side_effect = [user_response, email_response]
|
||||||
|
|
||||||
|
user_info = oauth.get_user_info("test_token")
|
||||||
|
|
||||||
|
assert user_info.id == str(user_data["id"])
|
||||||
|
assert user_info.name == user_data["name"]
|
||||||
|
assert user_info.email == expected_email
|
||||||
|
|
||||||
|
@patch("requests.get")
|
||||||
|
def test_should_handle_network_errors(self, mock_get, oauth):
|
||||||
|
mock_get.side_effect = requests.exceptions.RequestException("Network error")
|
||||||
|
|
||||||
|
with pytest.raises(requests.exceptions.RequestException):
|
||||||
|
oauth.get_raw_user_info("test_token")
|
||||||
|
|
||||||
|
|
||||||
|
class TestGoogleOAuth(BaseOAuthTest):
|
||||||
|
@pytest.fixture
|
||||||
|
def oauth(self, oauth_config):
|
||||||
|
return GoogleOAuth(oauth_config["client_id"], oauth_config["client_secret"], oauth_config["redirect_uri"])
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
("invite_token", "expected_state"),
|
||||||
|
[
|
||||||
|
(None, None),
|
||||||
|
("test_invite_token", "test_invite_token"),
|
||||||
|
("", None),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_should_generate_authorization_url_correctly(self, oauth, oauth_config, invite_token, expected_state):
|
||||||
|
url = oauth.get_authorization_url(invite_token)
|
||||||
|
parsed, params = self.parse_auth_url(url)
|
||||||
|
|
||||||
|
assert parsed.scheme == "https"
|
||||||
|
assert parsed.netloc == "accounts.google.com"
|
||||||
|
assert parsed.path == "/o/oauth2/v2/auth"
|
||||||
|
assert params["client_id"][0] == oauth_config["client_id"]
|
||||||
|
assert params["redirect_uri"][0] == oauth_config["redirect_uri"]
|
||||||
|
assert params["response_type"][0] == "code"
|
||||||
|
assert params["scope"][0] == "openid email"
|
||||||
|
|
||||||
|
if expected_state:
|
||||||
|
assert params["state"][0] == expected_state
|
||||||
|
else:
|
||||||
|
assert "state" not in params
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
("response_data", "expected_token", "should_raise"),
|
||||||
|
[
|
||||||
|
({"access_token": "test_token"}, "test_token", False),
|
||||||
|
({"error": "invalid_grant"}, None, True),
|
||||||
|
({}, None, True),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
@patch("requests.post")
|
||||||
|
def test_should_retrieve_access_token(
|
||||||
|
self, mock_post, oauth, oauth_config, mock_response, response_data, expected_token, should_raise
|
||||||
|
):
|
||||||
|
mock_response.json.return_value = response_data
|
||||||
|
mock_post.return_value = mock_response
|
||||||
|
|
||||||
|
if should_raise:
|
||||||
|
with pytest.raises(ValueError) as exc_info:
|
||||||
|
oauth.get_access_token("test_code")
|
||||||
|
assert "Error in Google OAuth" in str(exc_info.value)
|
||||||
|
else:
|
||||||
|
token = oauth.get_access_token("test_code")
|
||||||
|
assert token == expected_token
|
||||||
|
|
||||||
|
mock_post.assert_called_once_with(
|
||||||
|
oauth._TOKEN_URL,
|
||||||
|
data={
|
||||||
|
"client_id": oauth_config["client_id"],
|
||||||
|
"client_secret": oauth_config["client_secret"],
|
||||||
|
"code": "test_code",
|
||||||
|
"grant_type": "authorization_code",
|
||||||
|
"redirect_uri": oauth_config["redirect_uri"],
|
||||||
|
},
|
||||||
|
headers={"Accept": "application/json"},
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
("user_data", "expected_name"),
|
||||||
|
[
|
||||||
|
({"sub": "123", "email": "test@example.com", "email_verified": True}, ""),
|
||||||
|
({"sub": "123", "email": "test@example.com", "name": "Test User"}, ""), # Always returns empty string
|
||||||
|
],
|
||||||
|
)
|
||||||
|
@patch("requests.get")
|
||||||
|
def test_should_retrieve_user_info_correctly(self, mock_get, oauth, mock_response, user_data, expected_name):
|
||||||
|
mock_response.json.return_value = user_data
|
||||||
|
mock_get.return_value = mock_response
|
||||||
|
|
||||||
|
user_info = oauth.get_user_info("test_token")
|
||||||
|
|
||||||
|
assert user_info.id == user_data["sub"]
|
||||||
|
assert user_info.name == expected_name
|
||||||
|
assert user_info.email == user_data["email"]
|
||||||
|
|
||||||
|
mock_get.assert_called_once_with(oauth._USER_INFO_URL, headers={"Authorization": "Bearer test_token"})
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"exception_type",
|
||||||
|
[
|
||||||
|
requests.exceptions.HTTPError,
|
||||||
|
requests.exceptions.ConnectionError,
|
||||||
|
requests.exceptions.Timeout,
|
||||||
|
],
|
||||||
|
)
|
||||||
|
@patch("requests.get")
|
||||||
|
def test_should_handle_http_errors(self, mock_get, oauth, exception_type):
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.raise_for_status.side_effect = exception_type("Error")
|
||||||
|
mock_get.return_value = mock_response
|
||||||
|
|
||||||
|
with pytest.raises(exception_type):
|
||||||
|
oauth.get_raw_user_info("invalid_token")
|
||||||
|
|
||||||
|
|
||||||
|
class TestOAuthUserInfo:
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"user_data",
|
||||||
|
[
|
||||||
|
{"id": "123", "name": "Test User", "email": "test@example.com"},
|
||||||
|
{"id": "456", "name": "", "email": "user@domain.com"},
|
||||||
|
{"id": "789", "name": "Another User", "email": "another@test.org"},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_should_create_user_info_dataclass(self, user_data):
|
||||||
|
user_info = OAuthUserInfo(**user_data)
|
||||||
|
|
||||||
|
assert user_info.id == user_data["id"]
|
||||||
|
assert user_info.name == user_data["name"]
|
||||||
|
assert user_info.email == user_data["email"]
|
Reference in New Issue
Block a user