Feat/webapp verified sso main (#20494)

This commit is contained in:
Xiyuan Chen
2025-06-09 17:19:53 +09:00
committed by GitHub
parent ab62a9662c
commit 0720bc7408
13 changed files with 504 additions and 75 deletions

View File

@@ -119,9 +119,6 @@ class ForgotPasswordResetApi(Resource):
if not reset_data: if not reset_data:
raise InvalidTokenError() raise InvalidTokenError()
# Must use token in reset phase # Must use token in reset phase
if reset_data.get("phase", "") != "reset":
raise InvalidTokenError()
# Must use token in reset phase
if reset_data.get("phase", "") != "reset": if reset_data.get("phase", "") != "reset":
raise InvalidTokenError() raise InvalidTokenError()

View File

@@ -59,7 +59,14 @@ class InstalledAppsListApi(Resource):
if FeatureService.get_system_features().webapp_auth.enabled: if FeatureService.get_system_features().webapp_auth.enabled:
user_id = current_user.id user_id = current_user.id
res = [] res = []
app_ids = [installed_app["app"].id for installed_app in installed_app_list]
webapp_settings = EnterpriseService.WebAppAuth.batch_get_app_access_mode_by_id(app_ids)
for installed_app in installed_app_list: for installed_app in installed_app_list:
webapp_setting = webapp_settings.get(installed_app["app"].id)
if not webapp_setting:
continue
if webapp_setting.access_mode == "sso_verified":
continue
app_code = AppService.get_app_code_by_id(str(installed_app["app"].id)) app_code = AppService.get_app_code_by_id(str(installed_app["app"].id))
if EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp( if EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp(
user_id=user_id, user_id=user_id,

View File

@@ -44,6 +44,17 @@ def only_edition_cloud(view):
return decorated return decorated
def only_edition_enterprise(view):
@wraps(view)
def decorated(*args, **kwargs):
if not dify_config.ENTERPRISE_ENABLED:
abort(404)
return view(*args, **kwargs)
return decorated
def only_edition_self_hosted(view): def only_edition_self_hosted(view):
@wraps(view) @wraps(view)
def decorated(*args, **kwargs): def decorated(*args, **kwargs):

View File

@@ -15,4 +15,17 @@ api.add_resource(FileApi, "/files/upload")
api.add_resource(RemoteFileInfoApi, "/remote-files/<path:url>") api.add_resource(RemoteFileInfoApi, "/remote-files/<path:url>")
api.add_resource(RemoteFileUploadApi, "/remote-files/upload") api.add_resource(RemoteFileUploadApi, "/remote-files/upload")
from . import app, audio, completion, conversation, feature, message, passport, saved_message, site, workflow from . import (
app,
audio,
completion,
conversation,
feature,
forgot_password,
login,
message,
passport,
saved_message,
site,
workflow,
)

View File

@@ -10,6 +10,8 @@ from libs.passport import PassportService
from models.model import App, AppMode from models.model import App, AppMode
from services.app_service import AppService from services.app_service import AppService
from services.enterprise.enterprise_service import EnterpriseService from services.enterprise.enterprise_service import EnterpriseService
from services.feature_service import FeatureService
from services.webapp_auth_service import WebAppAuthService
class AppParameterApi(WebApiResource): class AppParameterApi(WebApiResource):
@@ -46,10 +48,22 @@ class AppMeta(WebApiResource):
class AppAccessMode(Resource): class AppAccessMode(Resource):
def get(self): def get(self):
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("appId", type=str, required=True, location="args") parser.add_argument("appId", type=str, required=False, location="args")
parser.add_argument("appCode", type=str, required=False, location="args")
args = parser.parse_args() args = parser.parse_args()
app_id = args["appId"] features = FeatureService.get_system_features()
if not features.webapp_auth.enabled:
return {"accessMode": "public"}
app_id = args.get("appId")
if args.get("appCode"):
app_code = args["appCode"]
app_id = AppService.get_app_id_by_code(app_code)
if not app_id:
raise ValueError("appId or appCode must be provided")
res = EnterpriseService.WebAppAuth.get_app_access_mode_by_id(app_id) res = EnterpriseService.WebAppAuth.get_app_access_mode_by_id(app_id)
return {"accessMode": res.access_mode} return {"accessMode": res.access_mode}
@@ -75,6 +89,10 @@ class AppWebAuthPermission(Resource):
except Exception as e: except Exception as e:
pass pass
features = FeatureService.get_system_features()
if not features.webapp_auth.enabled:
return {"result": True}
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("appId", type=str, required=True, location="args") parser.add_argument("appId", type=str, required=True, location="args")
args = parser.parse_args() args = parser.parse_args()
@@ -82,7 +100,9 @@ class AppWebAuthPermission(Resource):
app_id = args["appId"] app_id = args["appId"]
app_code = AppService.get_app_code_by_id(app_id) app_code = AppService.get_app_code_by_id(app_id)
res = EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp(str(user_id), app_code) res = True
if WebAppAuthService.is_app_require_permission_check(app_id=app_id):
res = EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp(str(user_id), app_code)
return {"result": res} return {"result": res}

View File

@@ -0,0 +1,147 @@
import base64
import secrets
from flask import request
from flask_restful import Resource, reqparse
from sqlalchemy import select
from sqlalchemy.orm import Session
from controllers.console.auth.error import (
EmailCodeError,
EmailPasswordResetLimitError,
InvalidEmailError,
InvalidTokenError,
PasswordMismatchError,
)
from controllers.console.error import AccountNotFound, EmailSendIpLimitError
from controllers.console.wraps import email_password_login_enabled, only_edition_enterprise, setup_required
from controllers.web import api
from extensions.ext_database import db
from libs.helper import email, extract_remote_ip
from libs.password import hash_password, valid_password
from models.account import Account
from services.account_service import AccountService
class ForgotPasswordSendEmailApi(Resource):
@only_edition_enterprise
@setup_required
@email_password_login_enabled
def post(self):
parser = reqparse.RequestParser()
parser.add_argument("email", type=email, required=True, location="json")
parser.add_argument("language", type=str, required=False, location="json")
args = parser.parse_args()
ip_address = extract_remote_ip(request)
if AccountService.is_email_send_ip_limit(ip_address):
raise EmailSendIpLimitError()
if args["language"] is not None and args["language"] == "zh-Hans":
language = "zh-Hans"
else:
language = "en-US"
with Session(db.engine) as session:
account = session.execute(select(Account).filter_by(email=args["email"])).scalar_one_or_none()
token = None
if account is None:
raise AccountNotFound()
else:
token = AccountService.send_reset_password_email(account=account, email=args["email"], language=language)
return {"result": "success", "data": token}
class ForgotPasswordCheckApi(Resource):
@only_edition_enterprise
@setup_required
@email_password_login_enabled
def post(self):
parser = reqparse.RequestParser()
parser.add_argument("email", type=str, required=True, location="json")
parser.add_argument("code", type=str, required=True, location="json")
parser.add_argument("token", type=str, required=True, nullable=False, location="json")
args = parser.parse_args()
user_email = args["email"]
is_forgot_password_error_rate_limit = AccountService.is_forgot_password_error_rate_limit(args["email"])
if is_forgot_password_error_rate_limit:
raise EmailPasswordResetLimitError()
token_data = AccountService.get_reset_password_data(args["token"])
if token_data is None:
raise InvalidTokenError()
if user_email != token_data.get("email"):
raise InvalidEmailError()
if args["code"] != token_data.get("code"):
AccountService.add_forgot_password_error_rate_limit(args["email"])
raise EmailCodeError()
# Verified, revoke the first token
AccountService.revoke_reset_password_token(args["token"])
# Refresh token data by generating a new token
_, new_token = AccountService.generate_reset_password_token(
user_email, code=args["code"], additional_data={"phase": "reset"}
)
AccountService.reset_forgot_password_error_rate_limit(args["email"])
return {"is_valid": True, "email": token_data.get("email"), "token": new_token}
class ForgotPasswordResetApi(Resource):
@only_edition_enterprise
@setup_required
@email_password_login_enabled
def post(self):
parser = reqparse.RequestParser()
parser.add_argument("token", type=str, required=True, nullable=False, location="json")
parser.add_argument("new_password", type=valid_password, required=True, nullable=False, location="json")
parser.add_argument("password_confirm", type=valid_password, required=True, nullable=False, location="json")
args = parser.parse_args()
# Validate passwords match
if args["new_password"] != args["password_confirm"]:
raise PasswordMismatchError()
# Validate token and get reset data
reset_data = AccountService.get_reset_password_data(args["token"])
if not reset_data:
raise InvalidTokenError()
# Must use token in reset phase
if reset_data.get("phase", "") != "reset":
raise InvalidTokenError()
# Revoke token to prevent reuse
AccountService.revoke_reset_password_token(args["token"])
# Generate secure salt and hash password
salt = secrets.token_bytes(16)
password_hashed = hash_password(args["new_password"], salt)
email = reset_data.get("email", "")
with Session(db.engine) as session:
account = session.execute(select(Account).filter_by(email=email)).scalar_one_or_none()
if account:
self._update_existing_account(account, password_hashed, salt, session)
else:
raise AccountNotFound()
return {"result": "success"}
def _update_existing_account(self, account, password_hashed, salt, session):
# Update existing account credentials
account.password = base64.b64encode(password_hashed).decode()
account.password_salt = base64.b64encode(salt).decode()
session.commit()
api.add_resource(ForgotPasswordSendEmailApi, "/forgot-password")
api.add_resource(ForgotPasswordCheckApi, "/forgot-password/validity")
api.add_resource(ForgotPasswordResetApi, "/forgot-password/resets")

View File

@@ -1,12 +1,11 @@
from flask import request
from flask_restful import Resource, reqparse from flask_restful import Resource, reqparse
from jwt import InvalidTokenError # type: ignore from jwt import InvalidTokenError # type: ignore
from werkzeug.exceptions import BadRequest
import services import services
from controllers.console.auth.error import EmailCodeError, EmailOrPasswordMismatchError, InvalidEmailError from controllers.console.auth.error import EmailCodeError, EmailOrPasswordMismatchError, InvalidEmailError
from controllers.console.error import AccountBannedError, AccountNotFound from controllers.console.error import AccountBannedError, AccountNotFound
from controllers.console.wraps import setup_required from controllers.console.wraps import only_edition_enterprise, setup_required
from controllers.web import api
from libs.helper import email from libs.helper import email
from libs.password import valid_password from libs.password import valid_password
from services.account_service import AccountService from services.account_service import AccountService
@@ -16,6 +15,8 @@ from services.webapp_auth_service import WebAppAuthService
class LoginApi(Resource): class LoginApi(Resource):
"""Resource for web app email/password login.""" """Resource for web app email/password login."""
@setup_required
@only_edition_enterprise
def post(self): def post(self):
"""Authenticate user and login.""" """Authenticate user and login."""
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
@@ -23,10 +24,6 @@ class LoginApi(Resource):
parser.add_argument("password", type=valid_password, required=True, location="json") parser.add_argument("password", type=valid_password, required=True, location="json")
args = parser.parse_args() args = parser.parse_args()
app_code = request.headers.get("X-App-Code")
if app_code is None:
raise BadRequest("X-App-Code header is missing.")
try: try:
account = WebAppAuthService.authenticate(args["email"], args["password"]) account = WebAppAuthService.authenticate(args["email"], args["password"])
except services.errors.account.AccountLoginError: except services.errors.account.AccountLoginError:
@@ -36,12 +33,8 @@ class LoginApi(Resource):
except services.errors.account.AccountNotFoundError: except services.errors.account.AccountNotFoundError:
raise AccountNotFound() raise AccountNotFound()
WebAppAuthService._validate_user_accessibility(account=account, app_code=app_code) token = WebAppAuthService.login(account=account)
return {"result": "success", "data": {"access_token": token}}
end_user = WebAppAuthService.create_end_user(email=args["email"], app_code=app_code)
token = WebAppAuthService.login(account=account, app_code=app_code, end_user_id=end_user.id)
return {"result": "success", "token": token}
# class LogoutApi(Resource): # class LogoutApi(Resource):
@@ -56,6 +49,7 @@ class LoginApi(Resource):
class EmailCodeLoginSendEmailApi(Resource): class EmailCodeLoginSendEmailApi(Resource):
@setup_required @setup_required
@only_edition_enterprise
def post(self): def post(self):
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("email", type=email, required=True, location="json") parser.add_argument("email", type=email, required=True, location="json")
@@ -78,6 +72,7 @@ class EmailCodeLoginSendEmailApi(Resource):
class EmailCodeLoginApi(Resource): class EmailCodeLoginApi(Resource):
@setup_required @setup_required
@only_edition_enterprise
def post(self): def post(self):
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("email", type=str, required=True, location="json") parser.add_argument("email", type=str, required=True, location="json")
@@ -86,9 +81,6 @@ class EmailCodeLoginApi(Resource):
args = parser.parse_args() args = parser.parse_args()
user_email = args["email"] user_email = args["email"]
app_code = request.headers.get("X-App-Code")
if app_code is None:
raise BadRequest("X-App-Code header is missing.")
token_data = WebAppAuthService.get_email_code_login_data(args["token"]) token_data = WebAppAuthService.get_email_code_login_data(args["token"])
if token_data is None: if token_data is None:
@@ -105,16 +97,12 @@ class EmailCodeLoginApi(Resource):
if not account: if not account:
raise AccountNotFound() raise AccountNotFound()
WebAppAuthService._validate_user_accessibility(account=account, app_code=app_code) token = WebAppAuthService.login(account=account)
end_user = WebAppAuthService.create_end_user(email=user_email, app_code=app_code)
token = WebAppAuthService.login(account=account, app_code=app_code, end_user_id=end_user.id)
AccountService.reset_login_error_rate_limit(args["email"]) AccountService.reset_login_error_rate_limit(args["email"])
return {"result": "success", "token": token} return {"result": "success", "data": {"access_token": token}}
# api.add_resource(LoginApi, "/login") api.add_resource(LoginApi, "/login")
# api.add_resource(LogoutApi, "/logout") # api.add_resource(LogoutApi, "/logout")
# api.add_resource(EmailCodeLoginSendEmailApi, "/email-code-login") api.add_resource(EmailCodeLoginSendEmailApi, "/email-code-login")
# api.add_resource(EmailCodeLoginApi, "/email-code-login/validity") api.add_resource(EmailCodeLoginApi, "/email-code-login/validity")

View File

@@ -1,9 +1,11 @@
import uuid import uuid
from datetime import UTC, datetime, timedelta
from flask import request from flask import request
from flask_restful import Resource from flask_restful import Resource
from werkzeug.exceptions import NotFound, Unauthorized from werkzeug.exceptions import NotFound, Unauthorized
from configs import dify_config
from controllers.web import api from controllers.web import api
from controllers.web.error import WebAppAuthRequiredError from controllers.web.error import WebAppAuthRequiredError
from extensions.ext_database import db from extensions.ext_database import db
@@ -11,6 +13,7 @@ from libs.passport import PassportService
from models.model import App, EndUser, Site from models.model import App, EndUser, Site
from services.enterprise.enterprise_service import EnterpriseService from services.enterprise.enterprise_service import EnterpriseService
from services.feature_service import FeatureService from services.feature_service import FeatureService
from services.webapp_auth_service import WebAppAuthService, WebAppAuthType
class PassportResource(Resource): class PassportResource(Resource):
@@ -20,10 +23,19 @@ class PassportResource(Resource):
system_features = FeatureService.get_system_features() system_features = FeatureService.get_system_features()
app_code = request.headers.get("X-App-Code") app_code = request.headers.get("X-App-Code")
user_id = request.args.get("user_id") user_id = request.args.get("user_id")
web_app_access_token = request.args.get("web_app_access_token")
if app_code is None: if app_code is None:
raise Unauthorized("X-App-Code header is missing.") raise Unauthorized("X-App-Code header is missing.")
# exchange token for enterprise logined web user
enterprise_user_decoded = decode_enterprise_webapp_user_id(web_app_access_token)
if enterprise_user_decoded:
# a web user has already logged in, exchange a token for this app without redirecting to the login page
return exchange_token_for_existing_web_user(
app_code=app_code, enterprise_user_decoded=enterprise_user_decoded
)
if system_features.webapp_auth.enabled: if system_features.webapp_auth.enabled:
app_settings = EnterpriseService.WebAppAuth.get_app_access_mode_by_code(app_code=app_code) app_settings = EnterpriseService.WebAppAuth.get_app_access_mode_by_code(app_code=app_code)
if not app_settings or not app_settings.access_mode == "public": if not app_settings or not app_settings.access_mode == "public":
@@ -84,6 +96,128 @@ class PassportResource(Resource):
api.add_resource(PassportResource, "/passport") api.add_resource(PassportResource, "/passport")
def decode_enterprise_webapp_user_id(jwt_token: str | None):
"""
Decode the enterprise user session from the Authorization header.
"""
if not jwt_token:
return None
decoded = PassportService().verify(jwt_token)
source = decoded.get("token_source")
if not source or source != "webapp_login_token":
raise Unauthorized("Invalid token source. Expected 'webapp_login_token'.")
return decoded
def exchange_token_for_existing_web_user(app_code: str, enterprise_user_decoded: dict):
"""
Exchange a token for an existing web user session.
"""
user_id = enterprise_user_decoded.get("user_id")
end_user_id = enterprise_user_decoded.get("end_user_id")
session_id = enterprise_user_decoded.get("session_id")
user_auth_type = enterprise_user_decoded.get("auth_type")
if not user_auth_type:
raise Unauthorized("Missing auth_type in the token.")
site = db.session.query(Site).filter(Site.code == app_code, Site.status == "normal").first()
if not site:
raise NotFound()
app_model = db.session.query(App).filter(App.id == site.app_id).first()
if not app_model or app_model.status != "normal" or not app_model.enable_site:
raise NotFound()
app_auth_type = WebAppAuthService.get_app_auth_type(app_code=app_code)
if app_auth_type == WebAppAuthType.PUBLIC:
return _exchange_for_public_app_token(app_model, site, enterprise_user_decoded)
elif app_auth_type == WebAppAuthType.EXTERNAL and user_auth_type != "external":
raise WebAppAuthRequiredError("Please login as external user.")
elif app_auth_type == WebAppAuthType.INTERNAL and user_auth_type != "internal":
raise WebAppAuthRequiredError("Please login as internal user.")
end_user = None
if end_user_id:
end_user = db.session.query(EndUser).filter(EndUser.id == end_user_id).first()
if session_id:
end_user = (
db.session.query(EndUser)
.filter(
EndUser.session_id == session_id,
EndUser.tenant_id == app_model.tenant_id,
EndUser.app_id == app_model.id,
)
.first()
)
if not end_user:
if not session_id:
raise NotFound("Missing session_id for existing web user.")
end_user = EndUser(
tenant_id=app_model.tenant_id,
app_id=app_model.id,
type="browser",
is_anonymous=True,
session_id=session_id,
)
db.session.add(end_user)
db.session.commit()
exp_dt = datetime.now(UTC) + timedelta(hours=dify_config.ACCESS_TOKEN_EXPIRE_MINUTES * 24)
exp = int(exp_dt.timestamp())
payload = {
"iss": site.id,
"sub": "Web API Passport",
"app_id": site.app_id,
"app_code": site.code,
"user_id": user_id,
"end_user_id": end_user.id,
"auth_type": user_auth_type,
"granted_at": int(datetime.now(UTC).timestamp()),
"token_source": "webapp",
"exp": exp,
}
token: str = PassportService().issue(payload)
return {
"access_token": token,
}
def _exchange_for_public_app_token(app_model, site, token_decoded):
user_id = token_decoded.get("user_id")
end_user = None
if user_id:
end_user = (
db.session.query(EndUser).filter(EndUser.app_id == app_model.id, EndUser.session_id == user_id).first()
)
if not end_user:
end_user = EndUser(
tenant_id=app_model.tenant_id,
app_id=app_model.id,
type="browser",
is_anonymous=True,
session_id=generate_session_id(),
)
db.session.add(end_user)
db.session.commit()
payload = {
"iss": site.app_id,
"sub": "Web API Passport",
"app_id": site.app_id,
"app_code": site.code,
"end_user_id": end_user.id,
}
tk = PassportService().issue(payload)
return {
"access_token": tk,
}
def generate_session_id(): def generate_session_id():
""" """
Generate a unique session ID. Generate a unique session ID.

View File

@@ -1,3 +1,4 @@
from datetime import UTC, datetime
from functools import wraps from functools import wraps
from flask import request from flask import request
@@ -8,8 +9,9 @@ from controllers.web.error import WebAppAuthAccessDeniedError, WebAppAuthRequire
from extensions.ext_database import db from extensions.ext_database import db
from libs.passport import PassportService from libs.passport import PassportService
from models.model import App, EndUser, Site from models.model import App, EndUser, Site
from services.enterprise.enterprise_service import EnterpriseService from services.enterprise.enterprise_service import EnterpriseService, WebAppSettings
from services.feature_service import FeatureService from services.feature_service import FeatureService
from services.webapp_auth_service import WebAppAuthService
def validate_jwt_token(view=None): def validate_jwt_token(view=None):
@@ -45,7 +47,8 @@ def decode_jwt_token():
raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.") raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.")
decoded = PassportService().verify(tk) decoded = PassportService().verify(tk)
app_code = decoded.get("app_code") app_code = decoded.get("app_code")
app_model = db.session.query(App).filter(App.id == decoded["app_id"]).first() app_id = decoded.get("app_id")
app_model = db.session.query(App).filter(App.id == app_id).first()
site = db.session.query(Site).filter(Site.code == app_code).first() site = db.session.query(Site).filter(Site.code == app_code).first()
if not app_model: if not app_model:
raise NotFound() raise NotFound()
@@ -53,23 +56,30 @@ def decode_jwt_token():
raise BadRequest("Site URL is no longer valid.") raise BadRequest("Site URL is no longer valid.")
if app_model.enable_site is False: if app_model.enable_site is False:
raise BadRequest("Site is disabled.") raise BadRequest("Site is disabled.")
end_user = db.session.query(EndUser).filter(EndUser.id == decoded["end_user_id"]).first() end_user_id = decoded.get("end_user_id")
end_user = db.session.query(EndUser).filter(EndUser.id == end_user_id).first()
if not end_user: if not end_user:
raise NotFound() raise NotFound()
# for enterprise webapp auth # for enterprise webapp auth
app_web_auth_enabled = False app_web_auth_enabled = False
webapp_settings = None
if system_features.webapp_auth.enabled: if system_features.webapp_auth.enabled:
app_web_auth_enabled = ( webapp_settings = EnterpriseService.WebAppAuth.get_app_access_mode_by_code(app_code=app_code)
EnterpriseService.WebAppAuth.get_app_access_mode_by_code(app_code=app_code).access_mode != "public" if not webapp_settings:
) raise NotFound("Web app settings not found.")
app_web_auth_enabled = webapp_settings.access_mode != "public"
_validate_webapp_token(decoded, app_web_auth_enabled, system_features.webapp_auth.enabled) _validate_webapp_token(decoded, app_web_auth_enabled, system_features.webapp_auth.enabled)
_validate_user_accessibility(decoded, app_code, app_web_auth_enabled, system_features.webapp_auth.enabled) _validate_user_accessibility(
decoded, app_code, app_web_auth_enabled, system_features.webapp_auth.enabled, webapp_settings
)
return app_model, end_user return app_model, end_user
except Unauthorized as e: except Unauthorized as e:
if system_features.webapp_auth.enabled: if system_features.webapp_auth.enabled:
if not app_code:
raise Unauthorized("Please re-login to access the web app.")
app_web_auth_enabled = ( app_web_auth_enabled = (
EnterpriseService.WebAppAuth.get_app_access_mode_by_code(app_code=str(app_code)).access_mode != "public" EnterpriseService.WebAppAuth.get_app_access_mode_by_code(app_code=str(app_code)).access_mode != "public"
) )
@@ -95,15 +105,41 @@ def _validate_webapp_token(decoded, app_web_auth_enabled: bool, system_webapp_au
raise Unauthorized("webapp token expired.") raise Unauthorized("webapp token expired.")
def _validate_user_accessibility(decoded, app_code, app_web_auth_enabled: bool, system_webapp_auth_enabled: bool): def _validate_user_accessibility(
decoded,
app_code,
app_web_auth_enabled: bool,
system_webapp_auth_enabled: bool,
webapp_settings: WebAppSettings | None,
):
if system_webapp_auth_enabled and app_web_auth_enabled: if system_webapp_auth_enabled and app_web_auth_enabled:
# Check if the user is allowed to access the web app # Check if the user is allowed to access the web app
user_id = decoded.get("user_id") user_id = decoded.get("user_id")
if not user_id: if not user_id:
raise WebAppAuthRequiredError() raise WebAppAuthRequiredError()
if not EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp(user_id, app_code=app_code): if not webapp_settings:
raise WebAppAuthAccessDeniedError() raise WebAppAuthRequiredError("Web app settings not found.")
if WebAppAuthService.is_app_require_permission_check(access_mode=webapp_settings.access_mode):
if not EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp(user_id, app_code=app_code):
raise WebAppAuthAccessDeniedError()
auth_type = decoded.get("auth_type")
granted_at = decoded.get("granted_at")
if not auth_type:
raise WebAppAuthAccessDeniedError("Missing auth_type in the token.")
if not granted_at:
raise WebAppAuthAccessDeniedError("Missing granted_at in the token.")
# check if sso has been updated
if auth_type == "external":
last_update_time = EnterpriseService.get_app_sso_settings_last_update_time()
if granted_at and datetime.fromtimestamp(granted_at, tz=UTC) < last_update_time:
raise WebAppAuthAccessDeniedError("SSO settings have been updated. Please re-login.")
elif auth_type == "internal":
last_update_time = EnterpriseService.get_workspace_sso_settings_last_update_time()
if granted_at and datetime.fromtimestamp(granted_at, tz=UTC) < last_update_time:
raise WebAppAuthAccessDeniedError("SSO settings have been updated. Please re-login.")
class WebApiResource(Resource): class WebApiResource(Resource):

View File

@@ -57,6 +57,9 @@ def load_user_from_request(request_from_flask_login):
raise Unauthorized("Invalid Authorization token.") raise Unauthorized("Invalid Authorization token.")
decoded = PassportService().verify(auth_token) decoded = PassportService().verify(auth_token)
user_id = decoded.get("user_id") user_id = decoded.get("user_id")
source = decoded.get("token_source")
if source:
raise Unauthorized("Invalid Authorization token.")
if not user_id: if not user_id:
raise Unauthorized("Invalid Authorization token.") raise Unauthorized("Invalid Authorization token.")

View File

@@ -395,3 +395,15 @@ class AppService:
if not site: if not site:
raise ValueError(f"App with id {app_id} not found") raise ValueError(f"App with id {app_id} not found")
return str(site.code) return str(site.code)
@staticmethod
def get_app_id_by_code(app_code: str) -> str:
"""
Get app id by app code
:param app_code: app code
:return: app id
"""
site = db.session.query(Site).filter(Site.code == app_code).first()
if not site:
raise ValueError(f"App with code {app_code} not found")
return str(site.app_id)

View File

@@ -1,3 +1,5 @@
from datetime import datetime
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from services.enterprise.base import EnterpriseRequest from services.enterprise.base import EnterpriseRequest
@@ -5,7 +7,7 @@ from services.enterprise.base import EnterpriseRequest
class WebAppSettings(BaseModel): class WebAppSettings(BaseModel):
access_mode: str = Field( access_mode: str = Field(
description="Access mode for the web app. Can be 'public' or 'private'", description="Access mode for the web app. Can be 'public', 'private', 'private_all', 'sso_verified'",
default="private", default="private",
alias="accessMode", alias="accessMode",
) )
@@ -20,6 +22,28 @@ class EnterpriseService:
def get_workspace_info(cls, tenant_id: str): def get_workspace_info(cls, tenant_id: str):
return EnterpriseRequest.send_request("GET", f"/workspace/{tenant_id}/info") return EnterpriseRequest.send_request("GET", f"/workspace/{tenant_id}/info")
@classmethod
def get_app_sso_settings_last_update_time(cls) -> datetime:
data = EnterpriseRequest.send_request("GET", "/sso/app/last-update-time")
if not data:
raise ValueError("No data found.")
try:
# parse the UTC timestamp from the response
return datetime.fromisoformat(data.replace("Z", "+00:00"))
except ValueError as e:
raise ValueError(f"Invalid date format: {data}") from e
@classmethod
def get_workspace_sso_settings_last_update_time(cls) -> datetime:
data = EnterpriseRequest.send_request("GET", "/sso/workspace/last-update-time")
if not data:
raise ValueError("No data found.")
try:
# parse the UTC timestamp from the response
return datetime.fromisoformat(data.replace("Z", "+00:00"))
except ValueError as e:
raise ValueError(f"Invalid date format: {data}") from e
class WebAppAuth: class WebAppAuth:
@classmethod @classmethod
def is_user_allowed_to_access_webapp(cls, user_id: str, app_code: str): def is_user_allowed_to_access_webapp(cls, user_id: str, app_code: str):

View File

@@ -1,3 +1,4 @@
import enum
import secrets import secrets
from datetime import UTC, datetime, timedelta from datetime import UTC, datetime, timedelta
from typing import Any, Optional, cast from typing import Any, Optional, cast
@@ -5,27 +6,33 @@ from typing import Any, Optional, cast
from werkzeug.exceptions import NotFound, Unauthorized from werkzeug.exceptions import NotFound, Unauthorized
from configs import dify_config from configs import dify_config
from controllers.web.error import WebAppAuthAccessDeniedError
from extensions.ext_database import db from extensions.ext_database import db
from libs.helper import TokenManager from libs.helper import TokenManager
from libs.passport import PassportService from libs.passport import PassportService
from libs.password import compare_password from libs.password import compare_password
from models.account import Account, AccountStatus from models.account import Account, AccountStatus
from models.model import App, EndUser, Site from models.model import App, EndUser, Site
from services.app_service import AppService
from services.enterprise.enterprise_service import EnterpriseService from services.enterprise.enterprise_service import EnterpriseService
from services.errors.account import AccountLoginError, AccountNotFoundError, AccountPasswordError from services.errors.account import AccountLoginError, AccountNotFoundError, AccountPasswordError
from services.feature_service import FeatureService
from tasks.mail_email_code_login import send_email_code_login_mail_task from tasks.mail_email_code_login import send_email_code_login_mail_task
class WebAppAuthType(enum.StrEnum):
"""Enum for web app authentication types."""
PUBLIC = "public"
INTERNAL = "internal"
EXTERNAL = "external"
class WebAppAuthService: class WebAppAuthService:
"""Service for web app authentication.""" """Service for web app authentication."""
@staticmethod @staticmethod
def authenticate(email: str, password: str) -> Account: def authenticate(email: str, password: str) -> Account:
"""authenticate account with email and password""" """authenticate account with email and password"""
account = db.session.query(Account).filter_by(email=email).first()
account = Account.query.filter_by(email=email).first()
if not account: if not account:
raise AccountNotFoundError() raise AccountNotFoundError()
@@ -38,12 +45,8 @@ class WebAppAuthService:
return cast(Account, account) return cast(Account, account)
@classmethod @classmethod
def login(cls, account: Account, app_code: str, end_user_id: str) -> str: def login(cls, account: Account) -> str:
site = db.session.query(Site).filter(Site.code == app_code).first() access_token = cls._get_account_jwt_token(account=account)
if not site:
raise NotFound("Site not found.")
access_token = cls._get_account_jwt_token(account=account, site=site, end_user_id=end_user_id)
return access_token return access_token
@@ -68,7 +71,7 @@ class WebAppAuthService:
code = "".join([str(secrets.randbelow(exclusive_upper_bound=10)) for _ in range(6)]) code = "".join([str(secrets.randbelow(exclusive_upper_bound=10)) for _ in range(6)])
token = TokenManager.generate_token( token = TokenManager.generate_token(
account=account, email=email, token_type="webapp_email_code_login", additional_data={"code": code} account=account, email=email, token_type="email_code_login", additional_data={"code": code}
) )
send_email_code_login_mail_task.delay( send_email_code_login_mail_task.delay(
language=language, language=language,
@@ -80,11 +83,11 @@ class WebAppAuthService:
@classmethod @classmethod
def get_email_code_login_data(cls, token: str) -> Optional[dict[str, Any]]: def get_email_code_login_data(cls, token: str) -> Optional[dict[str, Any]]:
return TokenManager.get_token_data(token, "webapp_email_code_login") return TokenManager.get_token_data(token, "email_code_login")
@classmethod @classmethod
def revoke_email_code_login_token(cls, token: str): def revoke_email_code_login_token(cls, token: str):
TokenManager.revoke_token(token, "webapp_email_code_login") TokenManager.revoke_token(token, "email_code_login")
@classmethod @classmethod
def create_end_user(cls, app_code, email) -> EndUser: def create_end_user(cls, app_code, email) -> EndUser:
@@ -109,33 +112,67 @@ class WebAppAuthService:
return end_user return end_user
@classmethod @classmethod
def _validate_user_accessibility(cls, account: Account, app_code: str): def _get_account_jwt_token(cls, account: Account) -> str:
"""Check if the user is allowed to access the app."""
system_features = FeatureService.get_system_features()
if system_features.webapp_auth.enabled:
app_settings = EnterpriseService.WebAppAuth.get_app_access_mode_by_code(app_code=app_code)
if (
app_settings.access_mode != "public"
and not EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp(account.id, app_code=app_code)
):
raise WebAppAuthAccessDeniedError()
@classmethod
def _get_account_jwt_token(cls, account: Account, site: Site, end_user_id: str) -> str:
exp_dt = datetime.now(UTC) + timedelta(hours=dify_config.ACCESS_TOKEN_EXPIRE_MINUTES * 24) exp_dt = datetime.now(UTC) + timedelta(hours=dify_config.ACCESS_TOKEN_EXPIRE_MINUTES * 24)
exp = int(exp_dt.timestamp()) exp = int(exp_dt.timestamp())
payload = { payload = {
"iss": site.id,
"sub": "Web API Passport", "sub": "Web API Passport",
"app_id": site.app_id,
"app_code": site.code,
"user_id": account.id, "user_id": account.id,
"end_user_id": end_user_id, "session_id": account.email,
"token_source": "webapp", "token_source": "webapp_login_token",
"auth_type": "internal",
"exp": exp, "exp": exp,
} }
token: str = PassportService().issue(payload) token: str = PassportService().issue(payload)
return token return token
@classmethod
def is_app_require_permission_check(
cls, app_code: Optional[str] = None, app_id: Optional[str] = None, access_mode: Optional[str] = None
) -> bool:
"""
Check if the app requires permission check based on its access mode.
"""
modes_requiring_permission_check = [
"private",
"private_all",
]
if access_mode:
return access_mode in modes_requiring_permission_check
if not app_code and not app_id:
raise ValueError("Either app_code or app_id must be provided.")
if app_code:
app_id = AppService.get_app_id_by_code(app_code)
if not app_id:
raise ValueError("App ID could not be determined from the provided app_code.")
webapp_settings = EnterpriseService.WebAppAuth.get_app_access_mode_by_id(app_id)
if webapp_settings and webapp_settings.access_mode in modes_requiring_permission_check:
return True
return False
@classmethod
def get_app_auth_type(cls, app_code: str | None = None, access_mode: str | None = None) -> WebAppAuthType:
"""
Get the authentication type for the app based on its access mode.
"""
if not app_code and not access_mode:
raise ValueError("Either app_code or access_mode must be provided.")
if access_mode:
if access_mode == "public":
return WebAppAuthType.PUBLIC
elif access_mode in ["private", "private_all"]:
return WebAppAuthType.INTERNAL
elif access_mode == "sso_verified":
return WebAppAuthType.EXTERNAL
if app_code:
webapp_settings = EnterpriseService.WebAppAuth.get_app_access_mode_by_code(app_code)
return cls.get_app_auth_type(access_mode=webapp_settings.access_mode)
raise ValueError("Could not determine app authentication type.")