diff --git a/api/controllers/console/auth/forgot_password.py b/api/controllers/console/auth/forgot_password.py index 8db19095d..3bbe3177f 100644 --- a/api/controllers/console/auth/forgot_password.py +++ b/api/controllers/console/auth/forgot_password.py @@ -119,9 +119,6 @@ class ForgotPasswordResetApi(Resource): if not reset_data: raise InvalidTokenError() # Must use token in reset phase - if reset_data.get("phase", "") != "reset": - raise InvalidTokenError() - # Must use token in reset phase if reset_data.get("phase", "") != "reset": raise InvalidTokenError() diff --git a/api/controllers/console/explore/installed_app.py b/api/controllers/console/explore/installed_app.py index f73a226c8..9d0c08564 100644 --- a/api/controllers/console/explore/installed_app.py +++ b/api/controllers/console/explore/installed_app.py @@ -59,7 +59,14 @@ class InstalledAppsListApi(Resource): if FeatureService.get_system_features().webapp_auth.enabled: user_id = current_user.id res = [] + app_ids = [installed_app["app"].id for installed_app in installed_app_list] + webapp_settings = EnterpriseService.WebAppAuth.batch_get_app_access_mode_by_id(app_ids) for installed_app in installed_app_list: + webapp_setting = webapp_settings.get(installed_app["app"].id) + if not webapp_setting: + continue + if webapp_setting.access_mode == "sso_verified": + continue app_code = AppService.get_app_code_by_id(str(installed_app["app"].id)) if EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp( user_id=user_id, diff --git a/api/controllers/console/wraps.py b/api/controllers/console/wraps.py index 360cbd924..ca122772d 100644 --- a/api/controllers/console/wraps.py +++ b/api/controllers/console/wraps.py @@ -44,6 +44,17 @@ def only_edition_cloud(view): return decorated +def only_edition_enterprise(view): + @wraps(view) + def decorated(*args, **kwargs): + if not dify_config.ENTERPRISE_ENABLED: + abort(404) + + return view(*args, **kwargs) + + return decorated + + def only_edition_self_hosted(view): @wraps(view) def decorated(*args, **kwargs): diff --git a/api/controllers/web/__init__.py b/api/controllers/web/__init__.py index 50a04a625..56749a0e2 100644 --- a/api/controllers/web/__init__.py +++ b/api/controllers/web/__init__.py @@ -15,4 +15,17 @@ api.add_resource(FileApi, "/files/upload") api.add_resource(RemoteFileInfoApi, "/remote-files/") api.add_resource(RemoteFileUploadApi, "/remote-files/upload") -from . import app, audio, completion, conversation, feature, message, passport, saved_message, site, workflow +from . import ( + app, + audio, + completion, + conversation, + feature, + forgot_password, + login, + message, + passport, + saved_message, + site, + workflow, +) diff --git a/api/controllers/web/app.py b/api/controllers/web/app.py index bb4486bd9..94a525a75 100644 --- a/api/controllers/web/app.py +++ b/api/controllers/web/app.py @@ -10,6 +10,8 @@ from libs.passport import PassportService from models.model import App, AppMode from services.app_service import AppService from services.enterprise.enterprise_service import EnterpriseService +from services.feature_service import FeatureService +from services.webapp_auth_service import WebAppAuthService class AppParameterApi(WebApiResource): @@ -46,10 +48,22 @@ class AppMeta(WebApiResource): class AppAccessMode(Resource): def get(self): parser = reqparse.RequestParser() - parser.add_argument("appId", type=str, required=True, location="args") + parser.add_argument("appId", type=str, required=False, location="args") + parser.add_argument("appCode", type=str, required=False, location="args") args = parser.parse_args() - app_id = args["appId"] + features = FeatureService.get_system_features() + if not features.webapp_auth.enabled: + return {"accessMode": "public"} + + app_id = args.get("appId") + if args.get("appCode"): + app_code = args["appCode"] + app_id = AppService.get_app_id_by_code(app_code) + + if not app_id: + raise ValueError("appId or appCode must be provided") + res = EnterpriseService.WebAppAuth.get_app_access_mode_by_id(app_id) return {"accessMode": res.access_mode} @@ -75,6 +89,10 @@ class AppWebAuthPermission(Resource): except Exception as e: pass + features = FeatureService.get_system_features() + if not features.webapp_auth.enabled: + return {"result": True} + parser = reqparse.RequestParser() parser.add_argument("appId", type=str, required=True, location="args") args = parser.parse_args() @@ -82,7 +100,9 @@ class AppWebAuthPermission(Resource): app_id = args["appId"] app_code = AppService.get_app_code_by_id(app_id) - res = EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp(str(user_id), app_code) + res = True + if WebAppAuthService.is_app_require_permission_check(app_id=app_id): + res = EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp(str(user_id), app_code) return {"result": res} diff --git a/api/controllers/web/forgot_password.py b/api/controllers/web/forgot_password.py new file mode 100644 index 000000000..0da8d65ef --- /dev/null +++ b/api/controllers/web/forgot_password.py @@ -0,0 +1,147 @@ +import base64 +import secrets + +from flask import request +from flask_restful import Resource, reqparse +from sqlalchemy import select +from sqlalchemy.orm import Session + +from controllers.console.auth.error import ( + EmailCodeError, + EmailPasswordResetLimitError, + InvalidEmailError, + InvalidTokenError, + PasswordMismatchError, +) +from controllers.console.error import AccountNotFound, EmailSendIpLimitError +from controllers.console.wraps import email_password_login_enabled, only_edition_enterprise, setup_required +from controllers.web import api +from extensions.ext_database import db +from libs.helper import email, extract_remote_ip +from libs.password import hash_password, valid_password +from models.account import Account +from services.account_service import AccountService + + +class ForgotPasswordSendEmailApi(Resource): + @only_edition_enterprise + @setup_required + @email_password_login_enabled + def post(self): + parser = reqparse.RequestParser() + parser.add_argument("email", type=email, required=True, location="json") + parser.add_argument("language", type=str, required=False, location="json") + args = parser.parse_args() + + ip_address = extract_remote_ip(request) + if AccountService.is_email_send_ip_limit(ip_address): + raise EmailSendIpLimitError() + + if args["language"] is not None and args["language"] == "zh-Hans": + language = "zh-Hans" + else: + language = "en-US" + + with Session(db.engine) as session: + account = session.execute(select(Account).filter_by(email=args["email"])).scalar_one_or_none() + token = None + if account is None: + raise AccountNotFound() + else: + token = AccountService.send_reset_password_email(account=account, email=args["email"], language=language) + + return {"result": "success", "data": token} + + +class ForgotPasswordCheckApi(Resource): + @only_edition_enterprise + @setup_required + @email_password_login_enabled + def post(self): + parser = reqparse.RequestParser() + parser.add_argument("email", type=str, required=True, location="json") + parser.add_argument("code", type=str, required=True, location="json") + parser.add_argument("token", type=str, required=True, nullable=False, location="json") + args = parser.parse_args() + + user_email = args["email"] + + is_forgot_password_error_rate_limit = AccountService.is_forgot_password_error_rate_limit(args["email"]) + if is_forgot_password_error_rate_limit: + raise EmailPasswordResetLimitError() + + token_data = AccountService.get_reset_password_data(args["token"]) + if token_data is None: + raise InvalidTokenError() + + if user_email != token_data.get("email"): + raise InvalidEmailError() + + if args["code"] != token_data.get("code"): + AccountService.add_forgot_password_error_rate_limit(args["email"]) + raise EmailCodeError() + + # Verified, revoke the first token + AccountService.revoke_reset_password_token(args["token"]) + + # Refresh token data by generating a new token + _, new_token = AccountService.generate_reset_password_token( + user_email, code=args["code"], additional_data={"phase": "reset"} + ) + + AccountService.reset_forgot_password_error_rate_limit(args["email"]) + return {"is_valid": True, "email": token_data.get("email"), "token": new_token} + + +class ForgotPasswordResetApi(Resource): + @only_edition_enterprise + @setup_required + @email_password_login_enabled + def post(self): + parser = reqparse.RequestParser() + parser.add_argument("token", type=str, required=True, nullable=False, location="json") + parser.add_argument("new_password", type=valid_password, required=True, nullable=False, location="json") + parser.add_argument("password_confirm", type=valid_password, required=True, nullable=False, location="json") + args = parser.parse_args() + + # Validate passwords match + if args["new_password"] != args["password_confirm"]: + raise PasswordMismatchError() + + # Validate token and get reset data + reset_data = AccountService.get_reset_password_data(args["token"]) + if not reset_data: + raise InvalidTokenError() + # Must use token in reset phase + if reset_data.get("phase", "") != "reset": + raise InvalidTokenError() + + # Revoke token to prevent reuse + AccountService.revoke_reset_password_token(args["token"]) + + # Generate secure salt and hash password + salt = secrets.token_bytes(16) + password_hashed = hash_password(args["new_password"], salt) + + email = reset_data.get("email", "") + + with Session(db.engine) as session: + account = session.execute(select(Account).filter_by(email=email)).scalar_one_or_none() + + if account: + self._update_existing_account(account, password_hashed, salt, session) + else: + raise AccountNotFound() + + return {"result": "success"} + + def _update_existing_account(self, account, password_hashed, salt, session): + # Update existing account credentials + account.password = base64.b64encode(password_hashed).decode() + account.password_salt = base64.b64encode(salt).decode() + session.commit() + + +api.add_resource(ForgotPasswordSendEmailApi, "/forgot-password") +api.add_resource(ForgotPasswordCheckApi, "/forgot-password/validity") +api.add_resource(ForgotPasswordResetApi, "/forgot-password/resets") diff --git a/api/controllers/web/login.py b/api/controllers/web/login.py index 06c227444..01c4f4a26 100644 --- a/api/controllers/web/login.py +++ b/api/controllers/web/login.py @@ -1,12 +1,11 @@ -from flask import request from flask_restful import Resource, reqparse from jwt import InvalidTokenError # type: ignore -from werkzeug.exceptions import BadRequest import services from controllers.console.auth.error import EmailCodeError, EmailOrPasswordMismatchError, InvalidEmailError from controllers.console.error import AccountBannedError, AccountNotFound -from controllers.console.wraps import setup_required +from controllers.console.wraps import only_edition_enterprise, setup_required +from controllers.web import api from libs.helper import email from libs.password import valid_password from services.account_service import AccountService @@ -16,6 +15,8 @@ from services.webapp_auth_service import WebAppAuthService class LoginApi(Resource): """Resource for web app email/password login.""" + @setup_required + @only_edition_enterprise def post(self): """Authenticate user and login.""" parser = reqparse.RequestParser() @@ -23,10 +24,6 @@ class LoginApi(Resource): parser.add_argument("password", type=valid_password, required=True, location="json") args = parser.parse_args() - app_code = request.headers.get("X-App-Code") - if app_code is None: - raise BadRequest("X-App-Code header is missing.") - try: account = WebAppAuthService.authenticate(args["email"], args["password"]) except services.errors.account.AccountLoginError: @@ -36,12 +33,8 @@ class LoginApi(Resource): except services.errors.account.AccountNotFoundError: raise AccountNotFound() - WebAppAuthService._validate_user_accessibility(account=account, app_code=app_code) - - end_user = WebAppAuthService.create_end_user(email=args["email"], app_code=app_code) - - token = WebAppAuthService.login(account=account, app_code=app_code, end_user_id=end_user.id) - return {"result": "success", "token": token} + token = WebAppAuthService.login(account=account) + return {"result": "success", "data": {"access_token": token}} # class LogoutApi(Resource): @@ -56,6 +49,7 @@ class LoginApi(Resource): class EmailCodeLoginSendEmailApi(Resource): @setup_required + @only_edition_enterprise def post(self): parser = reqparse.RequestParser() parser.add_argument("email", type=email, required=True, location="json") @@ -78,6 +72,7 @@ class EmailCodeLoginSendEmailApi(Resource): class EmailCodeLoginApi(Resource): @setup_required + @only_edition_enterprise def post(self): parser = reqparse.RequestParser() parser.add_argument("email", type=str, required=True, location="json") @@ -86,9 +81,6 @@ class EmailCodeLoginApi(Resource): args = parser.parse_args() user_email = args["email"] - app_code = request.headers.get("X-App-Code") - if app_code is None: - raise BadRequest("X-App-Code header is missing.") token_data = WebAppAuthService.get_email_code_login_data(args["token"]) if token_data is None: @@ -105,16 +97,12 @@ class EmailCodeLoginApi(Resource): if not account: raise AccountNotFound() - WebAppAuthService._validate_user_accessibility(account=account, app_code=app_code) - - end_user = WebAppAuthService.create_end_user(email=user_email, app_code=app_code) - - token = WebAppAuthService.login(account=account, app_code=app_code, end_user_id=end_user.id) + token = WebAppAuthService.login(account=account) AccountService.reset_login_error_rate_limit(args["email"]) - return {"result": "success", "token": token} + return {"result": "success", "data": {"access_token": token}} -# api.add_resource(LoginApi, "/login") +api.add_resource(LoginApi, "/login") # api.add_resource(LogoutApi, "/logout") -# api.add_resource(EmailCodeLoginSendEmailApi, "/email-code-login") -# api.add_resource(EmailCodeLoginApi, "/email-code-login/validity") +api.add_resource(EmailCodeLoginSendEmailApi, "/email-code-login") +api.add_resource(EmailCodeLoginApi, "/email-code-login/validity") diff --git a/api/controllers/web/passport.py b/api/controllers/web/passport.py index 154c6772b..9d229185f 100644 --- a/api/controllers/web/passport.py +++ b/api/controllers/web/passport.py @@ -1,9 +1,11 @@ import uuid +from datetime import UTC, datetime, timedelta from flask import request from flask_restful import Resource from werkzeug.exceptions import NotFound, Unauthorized +from configs import dify_config from controllers.web import api from controllers.web.error import WebAppAuthRequiredError from extensions.ext_database import db @@ -11,6 +13,7 @@ from libs.passport import PassportService from models.model import App, EndUser, Site from services.enterprise.enterprise_service import EnterpriseService from services.feature_service import FeatureService +from services.webapp_auth_service import WebAppAuthService, WebAppAuthType class PassportResource(Resource): @@ -20,10 +23,19 @@ class PassportResource(Resource): system_features = FeatureService.get_system_features() app_code = request.headers.get("X-App-Code") user_id = request.args.get("user_id") + web_app_access_token = request.args.get("web_app_access_token") if app_code is None: raise Unauthorized("X-App-Code header is missing.") + # exchange token for enterprise logined web user + enterprise_user_decoded = decode_enterprise_webapp_user_id(web_app_access_token) + if enterprise_user_decoded: + # a web user has already logged in, exchange a token for this app without redirecting to the login page + return exchange_token_for_existing_web_user( + app_code=app_code, enterprise_user_decoded=enterprise_user_decoded + ) + if system_features.webapp_auth.enabled: app_settings = EnterpriseService.WebAppAuth.get_app_access_mode_by_code(app_code=app_code) if not app_settings or not app_settings.access_mode == "public": @@ -84,6 +96,128 @@ class PassportResource(Resource): api.add_resource(PassportResource, "/passport") +def decode_enterprise_webapp_user_id(jwt_token: str | None): + """ + Decode the enterprise user session from the Authorization header. + """ + if not jwt_token: + return None + + decoded = PassportService().verify(jwt_token) + source = decoded.get("token_source") + if not source or source != "webapp_login_token": + raise Unauthorized("Invalid token source. Expected 'webapp_login_token'.") + return decoded + + +def exchange_token_for_existing_web_user(app_code: str, enterprise_user_decoded: dict): + """ + Exchange a token for an existing web user session. + """ + user_id = enterprise_user_decoded.get("user_id") + end_user_id = enterprise_user_decoded.get("end_user_id") + session_id = enterprise_user_decoded.get("session_id") + user_auth_type = enterprise_user_decoded.get("auth_type") + if not user_auth_type: + raise Unauthorized("Missing auth_type in the token.") + + site = db.session.query(Site).filter(Site.code == app_code, Site.status == "normal").first() + if not site: + raise NotFound() + + app_model = db.session.query(App).filter(App.id == site.app_id).first() + if not app_model or app_model.status != "normal" or not app_model.enable_site: + raise NotFound() + + app_auth_type = WebAppAuthService.get_app_auth_type(app_code=app_code) + + if app_auth_type == WebAppAuthType.PUBLIC: + return _exchange_for_public_app_token(app_model, site, enterprise_user_decoded) + elif app_auth_type == WebAppAuthType.EXTERNAL and user_auth_type != "external": + raise WebAppAuthRequiredError("Please login as external user.") + elif app_auth_type == WebAppAuthType.INTERNAL and user_auth_type != "internal": + raise WebAppAuthRequiredError("Please login as internal user.") + + end_user = None + if end_user_id: + end_user = db.session.query(EndUser).filter(EndUser.id == end_user_id).first() + if session_id: + end_user = ( + db.session.query(EndUser) + .filter( + EndUser.session_id == session_id, + EndUser.tenant_id == app_model.tenant_id, + EndUser.app_id == app_model.id, + ) + .first() + ) + if not end_user: + if not session_id: + raise NotFound("Missing session_id for existing web user.") + end_user = EndUser( + tenant_id=app_model.tenant_id, + app_id=app_model.id, + type="browser", + is_anonymous=True, + session_id=session_id, + ) + db.session.add(end_user) + db.session.commit() + exp_dt = datetime.now(UTC) + timedelta(hours=dify_config.ACCESS_TOKEN_EXPIRE_MINUTES * 24) + exp = int(exp_dt.timestamp()) + payload = { + "iss": site.id, + "sub": "Web API Passport", + "app_id": site.app_id, + "app_code": site.code, + "user_id": user_id, + "end_user_id": end_user.id, + "auth_type": user_auth_type, + "granted_at": int(datetime.now(UTC).timestamp()), + "token_source": "webapp", + "exp": exp, + } + token: str = PassportService().issue(payload) + return { + "access_token": token, + } + + +def _exchange_for_public_app_token(app_model, site, token_decoded): + user_id = token_decoded.get("user_id") + end_user = None + if user_id: + end_user = ( + db.session.query(EndUser).filter(EndUser.app_id == app_model.id, EndUser.session_id == user_id).first() + ) + + if not end_user: + end_user = EndUser( + tenant_id=app_model.tenant_id, + app_id=app_model.id, + type="browser", + is_anonymous=True, + session_id=generate_session_id(), + ) + + db.session.add(end_user) + db.session.commit() + + payload = { + "iss": site.app_id, + "sub": "Web API Passport", + "app_id": site.app_id, + "app_code": site.code, + "end_user_id": end_user.id, + } + + tk = PassportService().issue(payload) + + return { + "access_token": tk, + } + + def generate_session_id(): """ Generate a unique session ID. diff --git a/api/controllers/web/wraps.py b/api/controllers/web/wraps.py index 3bb029d6e..154bddfc5 100644 --- a/api/controllers/web/wraps.py +++ b/api/controllers/web/wraps.py @@ -1,3 +1,4 @@ +from datetime import UTC, datetime from functools import wraps from flask import request @@ -8,8 +9,9 @@ from controllers.web.error import WebAppAuthAccessDeniedError, WebAppAuthRequire from extensions.ext_database import db from libs.passport import PassportService from models.model import App, EndUser, Site -from services.enterprise.enterprise_service import EnterpriseService +from services.enterprise.enterprise_service import EnterpriseService, WebAppSettings from services.feature_service import FeatureService +from services.webapp_auth_service import WebAppAuthService def validate_jwt_token(view=None): @@ -45,7 +47,8 @@ def decode_jwt_token(): raise Unauthorized("Invalid Authorization header format. Expected 'Bearer ' format.") decoded = PassportService().verify(tk) app_code = decoded.get("app_code") - app_model = db.session.query(App).filter(App.id == decoded["app_id"]).first() + app_id = decoded.get("app_id") + app_model = db.session.query(App).filter(App.id == app_id).first() site = db.session.query(Site).filter(Site.code == app_code).first() if not app_model: raise NotFound() @@ -53,23 +56,30 @@ def decode_jwt_token(): raise BadRequest("Site URL is no longer valid.") if app_model.enable_site is False: raise BadRequest("Site is disabled.") - end_user = db.session.query(EndUser).filter(EndUser.id == decoded["end_user_id"]).first() + end_user_id = decoded.get("end_user_id") + end_user = db.session.query(EndUser).filter(EndUser.id == end_user_id).first() if not end_user: raise NotFound() # for enterprise webapp auth app_web_auth_enabled = False + webapp_settings = None if system_features.webapp_auth.enabled: - app_web_auth_enabled = ( - EnterpriseService.WebAppAuth.get_app_access_mode_by_code(app_code=app_code).access_mode != "public" - ) + webapp_settings = EnterpriseService.WebAppAuth.get_app_access_mode_by_code(app_code=app_code) + if not webapp_settings: + raise NotFound("Web app settings not found.") + app_web_auth_enabled = webapp_settings.access_mode != "public" _validate_webapp_token(decoded, app_web_auth_enabled, system_features.webapp_auth.enabled) - _validate_user_accessibility(decoded, app_code, app_web_auth_enabled, system_features.webapp_auth.enabled) + _validate_user_accessibility( + decoded, app_code, app_web_auth_enabled, system_features.webapp_auth.enabled, webapp_settings + ) return app_model, end_user except Unauthorized as e: if system_features.webapp_auth.enabled: + if not app_code: + raise Unauthorized("Please re-login to access the web app.") app_web_auth_enabled = ( EnterpriseService.WebAppAuth.get_app_access_mode_by_code(app_code=str(app_code)).access_mode != "public" ) @@ -95,15 +105,41 @@ def _validate_webapp_token(decoded, app_web_auth_enabled: bool, system_webapp_au raise Unauthorized("webapp token expired.") -def _validate_user_accessibility(decoded, app_code, app_web_auth_enabled: bool, system_webapp_auth_enabled: bool): +def _validate_user_accessibility( + decoded, + app_code, + app_web_auth_enabled: bool, + system_webapp_auth_enabled: bool, + webapp_settings: WebAppSettings | None, +): if system_webapp_auth_enabled and app_web_auth_enabled: # Check if the user is allowed to access the web app user_id = decoded.get("user_id") if not user_id: raise WebAppAuthRequiredError() - if not EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp(user_id, app_code=app_code): - raise WebAppAuthAccessDeniedError() + if not webapp_settings: + raise WebAppAuthRequiredError("Web app settings not found.") + + if WebAppAuthService.is_app_require_permission_check(access_mode=webapp_settings.access_mode): + if not EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp(user_id, app_code=app_code): + raise WebAppAuthAccessDeniedError() + + auth_type = decoded.get("auth_type") + granted_at = decoded.get("granted_at") + if not auth_type: + raise WebAppAuthAccessDeniedError("Missing auth_type in the token.") + if not granted_at: + raise WebAppAuthAccessDeniedError("Missing granted_at in the token.") + # check if sso has been updated + if auth_type == "external": + last_update_time = EnterpriseService.get_app_sso_settings_last_update_time() + if granted_at and datetime.fromtimestamp(granted_at, tz=UTC) < last_update_time: + raise WebAppAuthAccessDeniedError("SSO settings have been updated. Please re-login.") + elif auth_type == "internal": + last_update_time = EnterpriseService.get_workspace_sso_settings_last_update_time() + if granted_at and datetime.fromtimestamp(granted_at, tz=UTC) < last_update_time: + raise WebAppAuthAccessDeniedError("SSO settings have been updated. Please re-login.") class WebApiResource(Resource): diff --git a/api/extensions/ext_login.py b/api/extensions/ext_login.py index 06f42494e..3b4d787d0 100644 --- a/api/extensions/ext_login.py +++ b/api/extensions/ext_login.py @@ -57,6 +57,9 @@ def load_user_from_request(request_from_flask_login): raise Unauthorized("Invalid Authorization token.") decoded = PassportService().verify(auth_token) user_id = decoded.get("user_id") + source = decoded.get("token_source") + if source: + raise Unauthorized("Invalid Authorization token.") if not user_id: raise Unauthorized("Invalid Authorization token.") diff --git a/api/services/app_service.py b/api/services/app_service.py index ebebf8fa5..d08462d00 100644 --- a/api/services/app_service.py +++ b/api/services/app_service.py @@ -395,3 +395,15 @@ class AppService: if not site: raise ValueError(f"App with id {app_id} not found") return str(site.code) + + @staticmethod + def get_app_id_by_code(app_code: str) -> str: + """ + Get app id by app code + :param app_code: app code + :return: app id + """ + site = db.session.query(Site).filter(Site.code == app_code).first() + if not site: + raise ValueError(f"App with code {app_code} not found") + return str(site.app_id) diff --git a/api/services/enterprise/enterprise_service.py b/api/services/enterprise/enterprise_service.py index 1be78d2e6..8c06ee938 100644 --- a/api/services/enterprise/enterprise_service.py +++ b/api/services/enterprise/enterprise_service.py @@ -1,3 +1,5 @@ +from datetime import datetime + from pydantic import BaseModel, Field from services.enterprise.base import EnterpriseRequest @@ -5,7 +7,7 @@ from services.enterprise.base import EnterpriseRequest class WebAppSettings(BaseModel): access_mode: str = Field( - description="Access mode for the web app. Can be 'public' or 'private'", + description="Access mode for the web app. Can be 'public', 'private', 'private_all', 'sso_verified'", default="private", alias="accessMode", ) @@ -20,6 +22,28 @@ class EnterpriseService: def get_workspace_info(cls, tenant_id: str): return EnterpriseRequest.send_request("GET", f"/workspace/{tenant_id}/info") + @classmethod + def get_app_sso_settings_last_update_time(cls) -> datetime: + data = EnterpriseRequest.send_request("GET", "/sso/app/last-update-time") + if not data: + raise ValueError("No data found.") + try: + # parse the UTC timestamp from the response + return datetime.fromisoformat(data.replace("Z", "+00:00")) + except ValueError as e: + raise ValueError(f"Invalid date format: {data}") from e + + @classmethod + def get_workspace_sso_settings_last_update_time(cls) -> datetime: + data = EnterpriseRequest.send_request("GET", "/sso/workspace/last-update-time") + if not data: + raise ValueError("No data found.") + try: + # parse the UTC timestamp from the response + return datetime.fromisoformat(data.replace("Z", "+00:00")) + except ValueError as e: + raise ValueError(f"Invalid date format: {data}") from e + class WebAppAuth: @classmethod def is_user_allowed_to_access_webapp(cls, user_id: str, app_code: str): diff --git a/api/services/webapp_auth_service.py b/api/services/webapp_auth_service.py index d83303056..8f92b3f07 100644 --- a/api/services/webapp_auth_service.py +++ b/api/services/webapp_auth_service.py @@ -1,3 +1,4 @@ +import enum import secrets from datetime import UTC, datetime, timedelta from typing import Any, Optional, cast @@ -5,27 +6,33 @@ from typing import Any, Optional, cast from werkzeug.exceptions import NotFound, Unauthorized from configs import dify_config -from controllers.web.error import WebAppAuthAccessDeniedError from extensions.ext_database import db from libs.helper import TokenManager from libs.passport import PassportService from libs.password import compare_password from models.account import Account, AccountStatus from models.model import App, EndUser, Site +from services.app_service import AppService from services.enterprise.enterprise_service import EnterpriseService from services.errors.account import AccountLoginError, AccountNotFoundError, AccountPasswordError -from services.feature_service import FeatureService from tasks.mail_email_code_login import send_email_code_login_mail_task +class WebAppAuthType(enum.StrEnum): + """Enum for web app authentication types.""" + + PUBLIC = "public" + INTERNAL = "internal" + EXTERNAL = "external" + + class WebAppAuthService: """Service for web app authentication.""" @staticmethod def authenticate(email: str, password: str) -> Account: """authenticate account with email and password""" - - account = Account.query.filter_by(email=email).first() + account = db.session.query(Account).filter_by(email=email).first() if not account: raise AccountNotFoundError() @@ -38,12 +45,8 @@ class WebAppAuthService: return cast(Account, account) @classmethod - def login(cls, account: Account, app_code: str, end_user_id: str) -> str: - site = db.session.query(Site).filter(Site.code == app_code).first() - if not site: - raise NotFound("Site not found.") - - access_token = cls._get_account_jwt_token(account=account, site=site, end_user_id=end_user_id) + def login(cls, account: Account) -> str: + access_token = cls._get_account_jwt_token(account=account) return access_token @@ -68,7 +71,7 @@ class WebAppAuthService: code = "".join([str(secrets.randbelow(exclusive_upper_bound=10)) for _ in range(6)]) token = TokenManager.generate_token( - account=account, email=email, token_type="webapp_email_code_login", additional_data={"code": code} + account=account, email=email, token_type="email_code_login", additional_data={"code": code} ) send_email_code_login_mail_task.delay( language=language, @@ -80,11 +83,11 @@ class WebAppAuthService: @classmethod def get_email_code_login_data(cls, token: str) -> Optional[dict[str, Any]]: - return TokenManager.get_token_data(token, "webapp_email_code_login") + return TokenManager.get_token_data(token, "email_code_login") @classmethod def revoke_email_code_login_token(cls, token: str): - TokenManager.revoke_token(token, "webapp_email_code_login") + TokenManager.revoke_token(token, "email_code_login") @classmethod def create_end_user(cls, app_code, email) -> EndUser: @@ -109,33 +112,67 @@ class WebAppAuthService: return end_user @classmethod - def _validate_user_accessibility(cls, account: Account, app_code: str): - """Check if the user is allowed to access the app.""" - system_features = FeatureService.get_system_features() - if system_features.webapp_auth.enabled: - app_settings = EnterpriseService.WebAppAuth.get_app_access_mode_by_code(app_code=app_code) - - if ( - app_settings.access_mode != "public" - and not EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp(account.id, app_code=app_code) - ): - raise WebAppAuthAccessDeniedError() - - @classmethod - def _get_account_jwt_token(cls, account: Account, site: Site, end_user_id: str) -> str: + def _get_account_jwt_token(cls, account: Account) -> str: exp_dt = datetime.now(UTC) + timedelta(hours=dify_config.ACCESS_TOKEN_EXPIRE_MINUTES * 24) exp = int(exp_dt.timestamp()) payload = { - "iss": site.id, "sub": "Web API Passport", - "app_id": site.app_id, - "app_code": site.code, "user_id": account.id, - "end_user_id": end_user_id, - "token_source": "webapp", + "session_id": account.email, + "token_source": "webapp_login_token", + "auth_type": "internal", "exp": exp, } token: str = PassportService().issue(payload) return token + + @classmethod + def is_app_require_permission_check( + cls, app_code: Optional[str] = None, app_id: Optional[str] = None, access_mode: Optional[str] = None + ) -> bool: + """ + Check if the app requires permission check based on its access mode. + """ + modes_requiring_permission_check = [ + "private", + "private_all", + ] + if access_mode: + return access_mode in modes_requiring_permission_check + + if not app_code and not app_id: + raise ValueError("Either app_code or app_id must be provided.") + + if app_code: + app_id = AppService.get_app_id_by_code(app_code) + if not app_id: + raise ValueError("App ID could not be determined from the provided app_code.") + + webapp_settings = EnterpriseService.WebAppAuth.get_app_access_mode_by_id(app_id) + if webapp_settings and webapp_settings.access_mode in modes_requiring_permission_check: + return True + return False + + @classmethod + def get_app_auth_type(cls, app_code: str | None = None, access_mode: str | None = None) -> WebAppAuthType: + """ + Get the authentication type for the app based on its access mode. + """ + if not app_code and not access_mode: + raise ValueError("Either app_code or access_mode must be provided.") + + if access_mode: + if access_mode == "public": + return WebAppAuthType.PUBLIC + elif access_mode in ["private", "private_all"]: + return WebAppAuthType.INTERNAL + elif access_mode == "sso_verified": + return WebAppAuthType.EXTERNAL + + if app_code: + webapp_settings = EnterpriseService.WebAppAuth.get_app_access_mode_by_code(app_code) + return cls.get_app_auth_type(access_mode=webapp_settings.access_mode) + + raise ValueError("Could not determine app authentication type.")