Initial commit

This commit is contained in:
John Wang
2023-05-15 08:51:32 +08:00
commit db896255d6
744 changed files with 56028 additions and 0 deletions

View File

@@ -0,0 +1,263 @@
# -*- coding:utf-8 -*-
from datetime import datetime
import pytz
from flask import current_app, request
from flask_login import login_required, current_user
from flask_restful import Resource, reqparse, fields, marshal_with
from controllers.console import api
from controllers.console.setup import setup_required
from controllers.console.workspace.error import AccountAlreadyInitedError, InvalidInvitationCodeError, \
RepeatPasswordNotMatchError
from controllers.console.wraps import account_initialization_required
from libs.helper import TimestampField, supported_language, timezone
from extensions.ext_database import db
from models.account import InvitationCode, AccountIntegrate
from services.account_service import AccountService
account_fields = {
'id': fields.String,
'name': fields.String,
'avatar': fields.String,
'email': fields.String,
'interface_language': fields.String,
'interface_theme': fields.String,
'timezone': fields.String,
'last_login_at': TimestampField,
'last_login_ip': fields.String,
'created_at': TimestampField
}
class AccountInitApi(Resource):
@setup_required
@login_required
def post(self):
account = current_user
if account.status == 'active':
raise AccountAlreadyInitedError()
parser = reqparse.RequestParser()
if current_app.config['EDITION'] == 'CLOUD':
parser.add_argument('invitation_code', type=str, location='json')
parser.add_argument(
'interface_language', type=supported_language, required=True, location='json')
parser.add_argument('timezone', type=timezone,
required=True, location='json')
args = parser.parse_args()
if current_app.config['EDITION'] == 'CLOUD':
if not args['invitation_code']:
raise ValueError('invitation_code is required')
# check invitation code
invitation_code = db.session.query(InvitationCode).filter(
InvitationCode.code == args['invitation_code'],
InvitationCode.status == 'unused',
).first()
if not invitation_code:
raise InvalidInvitationCodeError()
invitation_code.status = 'used'
invitation_code.used_at = datetime.utcnow()
invitation_code.used_by_tenant_id = account.current_tenant_id
invitation_code.used_by_account_id = account.id
account.interface_language = args['interface_language']
account.timezone = args['timezone']
account.interface_theme = 'light'
account.status = 'active'
account.initialized_at = datetime.utcnow()
db.session.commit()
return {'result': 'success'}
class AccountProfileApi(Resource):
@setup_required
@login_required
@account_initialization_required
@marshal_with(account_fields)
def get(self):
return current_user
class AccountNameApi(Resource):
@setup_required
@login_required
@account_initialization_required
@marshal_with(account_fields)
def post(self):
parser = reqparse.RequestParser()
parser.add_argument('name', type=str, required=True, location='json')
args = parser.parse_args()
# Validate account name length
if len(args['name']) < 3 or len(args['name']) > 30:
raise ValueError(
"Account name must be between 3 and 30 characters.")
updated_account = AccountService.update_account(current_user, name=args['name'])
return updated_account
class AccountAvatarApi(Resource):
@setup_required
@login_required
@account_initialization_required
@marshal_with(account_fields)
def post(self):
parser = reqparse.RequestParser()
parser.add_argument('avatar', type=str, required=True, location='json')
args = parser.parse_args()
updated_account = AccountService.update_account(current_user, avatar=args['avatar'])
return updated_account
class AccountInterfaceLanguageApi(Resource):
@setup_required
@login_required
@account_initialization_required
@marshal_with(account_fields)
def post(self):
parser = reqparse.RequestParser()
parser.add_argument(
'interface_language', type=supported_language, required=True, location='json')
args = parser.parse_args()
updated_account = AccountService.update_account(current_user, interface_language=args['interface_language'])
return updated_account
class AccountInterfaceThemeApi(Resource):
@setup_required
@login_required
@account_initialization_required
@marshal_with(account_fields)
def post(self):
parser = reqparse.RequestParser()
parser.add_argument('interface_theme', type=str, choices=[
'light', 'dark'], required=True, location='json')
args = parser.parse_args()
updated_account = AccountService.update_account(current_user, interface_theme=args['interface_theme'])
return updated_account
class AccountTimezoneApi(Resource):
@setup_required
@login_required
@account_initialization_required
@marshal_with(account_fields)
def post(self):
parser = reqparse.RequestParser()
parser.add_argument('timezone', type=str,
required=True, location='json')
args = parser.parse_args()
# Validate timezone string, e.g. America/New_York, Asia/Shanghai
if args['timezone'] not in pytz.all_timezones:
raise ValueError("Invalid timezone string.")
updated_account = AccountService.update_account(current_user, timezone=args['timezone'])
return updated_account
class AccountPasswordApi(Resource):
@setup_required
@login_required
@account_initialization_required
@marshal_with(account_fields)
def post(self):
parser = reqparse.RequestParser()
parser.add_argument('password', type=str,
required=False, location='json')
parser.add_argument('new_password', type=str,
required=True, location='json')
parser.add_argument('repeat_new_password', type=str,
required=True, location='json')
args = parser.parse_args()
if args['new_password'] != args['repeat_new_password']:
raise RepeatPasswordNotMatchError()
AccountService.update_account_password(
current_user, args['password'], args['new_password'])
return {"result": "success"}
class AccountIntegrateApi(Resource):
integrate_fields = {
'provider': fields.String,
'created_at': TimestampField,
'is_bound': fields.Boolean,
'link': fields.String
}
integrate_list_fields = {
'data': fields.List(fields.Nested(integrate_fields)),
}
@setup_required
@login_required
@account_initialization_required
@marshal_with(integrate_list_fields)
def get(self):
account = current_user
account_integrates = db.session.query(AccountIntegrate).filter(
AccountIntegrate.account_id == account.id).all()
base_url = request.url_root.rstrip('/')
oauth_base_path = "/console/api/oauth/login"
providers = ["github", "google"]
integrate_data = []
for provider in providers:
existing_integrate = next((ai for ai in account_integrates if ai.provider == provider), None)
if existing_integrate:
integrate_data.append({
'id': existing_integrate.id,
'provider': provider,
'created_at': existing_integrate.created_at,
'is_bound': True,
'link': None
})
else:
integrate_data.append({
'id': None,
'provider': provider,
'created_at': None,
'is_bound': False,
'link': f'{base_url}{oauth_base_path}/{provider}'
})
return {'data': integrate_data}
# Register API resources
api.add_resource(AccountInitApi, '/account/init')
api.add_resource(AccountProfileApi, '/account/profile')
api.add_resource(AccountNameApi, '/account/name')
api.add_resource(AccountAvatarApi, '/account/avatar')
api.add_resource(AccountInterfaceLanguageApi, '/account/interface-language')
api.add_resource(AccountInterfaceThemeApi, '/account/interface-theme')
api.add_resource(AccountTimezoneApi, '/account/timezone')
api.add_resource(AccountPasswordApi, '/account/password')
api.add_resource(AccountIntegrateApi, '/account/integrates')
# api.add_resource(AccountEmailApi, '/account/email')
# api.add_resource(AccountEmailVerifyApi, '/account/email-verify')

View File

@@ -0,0 +1,31 @@
from libs.exception import BaseHTTPException
class RepeatPasswordNotMatchError(BaseHTTPException):
error_code = 'repeat_password_not_match'
description = "New password and repeat password does not match."
code = 400
class ProviderRequestFailedError(BaseHTTPException):
error_code = 'provider_request_failed'
description = None
code = 400
class InvalidInvitationCodeError(BaseHTTPException):
error_code = 'invalid_invitation_code'
description = "Invalid invitation code."
code = 400
class AccountAlreadyInitedError(BaseHTTPException):
error_code = 'account_already_inited'
description = "Account already inited."
code = 400
class AccountNotInitializedError(BaseHTTPException):
error_code = 'account_not_initialized'
description = "Account not initialized."
code = 400

View File

@@ -0,0 +1,141 @@
# -*- coding:utf-8 -*-
from flask_login import login_required, current_user
from flask_restful import Resource, reqparse, marshal_with, abort, fields, marshal
import services
from controllers.console import api
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from libs.helper import TimestampField
from extensions.ext_database import db
from models.account import Account, TenantAccountJoin
from services.account_service import TenantService, RegisterService
account_fields = {
'id': fields.String,
'name': fields.String,
'avatar': fields.String,
'email': fields.String,
'last_login_at': TimestampField,
'created_at': TimestampField,
'role': fields.String,
'status': fields.String,
}
account_list_fields = {
'accounts': fields.List(fields.Nested(account_fields))
}
class MemberListApi(Resource):
"""List all members of current tenant."""
@setup_required
@login_required
@account_initialization_required
@marshal_with(account_list_fields)
def get(self):
members = TenantService.get_tenant_members(current_user.current_tenant)
return {'result': 'success', 'accounts': members}, 200
class MemberInviteEmailApi(Resource):
"""Invite a new member by email."""
@setup_required
@login_required
@account_initialization_required
def post(self):
parser = reqparse.RequestParser()
parser.add_argument('email', type=str, required=True, location='json')
parser.add_argument('role', type=str, required=True, default='admin', location='json')
args = parser.parse_args()
invitee_email = args['email']
invitee_role = args['role']
if invitee_role not in ['admin', 'normal']:
return {'code': 'invalid-role', 'message': 'Invalid role'}, 400
inviter = current_user
try:
RegisterService.invite_new_member(inviter.current_tenant, invitee_email, role=invitee_role, inviter=inviter)
account = db.session.query(Account, TenantAccountJoin.role).join(
TenantAccountJoin, Account.id == TenantAccountJoin.account_id
).filter(Account.email == args['email']).first()
account, role = account
account = marshal(account, account_fields)
account['role'] = role
except services.errors.account.CannotOperateSelfError as e:
return {'code': 'cannot-operate-self', 'message': str(e)}, 400
except services.errors.account.NoPermissionError as e:
return {'code': 'forbidden', 'message': str(e)}, 403
except services.errors.account.AccountAlreadyInTenantError as e:
return {'code': 'email-taken', 'message': str(e)}, 409
except Exception as e:
return {'code': 'unexpected-error', 'message': str(e)}, 500
# todo:413
return {'result': 'success', 'account': account}, 201
class MemberCancelInviteApi(Resource):
"""Cancel an invitation by member id."""
@setup_required
@login_required
@account_initialization_required
def delete(self, member_id):
member = Account.query.get(str(member_id))
if not member:
abort(404)
try:
TenantService.remove_member_from_tenant(current_user.current_tenant, member, current_user)
except services.errors.account.CannotOperateSelfError as e:
return {'code': 'cannot-operate-self', 'message': str(e)}, 400
except services.errors.account.NoPermissionError as e:
return {'code': 'forbidden', 'message': str(e)}, 403
except services.errors.account.MemberNotInTenantError as e:
return {'code': 'member-not-found', 'message': str(e)}, 404
except Exception as e:
raise ValueError(str(e))
return {'result': 'success'}, 204
class MemberUpdateRoleApi(Resource):
"""Update member role."""
@setup_required
@login_required
@account_initialization_required
def put(self, member_id):
parser = reqparse.RequestParser()
parser.add_argument('role', type=str, required=True, location='json')
args = parser.parse_args()
new_role = args['role']
if new_role not in ['admin', 'normal', 'owner']:
return {'code': 'invalid-role', 'message': 'Invalid role'}, 400
member = Account.query.get(str(member_id))
if not member:
abort(404)
try:
TenantService.update_member_role(current_user.current_tenant, member, new_role, current_user)
except Exception as e:
raise ValueError(str(e))
# todo: 403
return {'result': 'success'}
api.add_resource(MemberListApi, '/workspaces/current/members')
api.add_resource(MemberInviteEmailApi, '/workspaces/current/members/invite-email')
api.add_resource(MemberCancelInviteApi, '/workspaces/current/members/<uuid:member_id>')
api.add_resource(MemberUpdateRoleApi, '/workspaces/current/members/<uuid:member_id>/update-role')

View File

@@ -0,0 +1,246 @@
# -*- coding:utf-8 -*-
import base64
import json
import logging
from flask_login import login_required, current_user
from flask_restful import Resource, reqparse, abort
from werkzeug.exceptions import Forbidden
from controllers.console import api
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from core.llm.provider.errors import ValidateFailedError
from extensions.ext_database import db
from libs import rsa
from models.provider import Provider, ProviderType, ProviderName
from services.provider_service import ProviderService
class ProviderListApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self):
tenant_id = current_user.current_tenant_id
"""
If the type is AZURE_OPENAI, decode and return the four fields of azure_api_type, azure_api_version:,
azure_api_base, azure_api_key as an object, where azure_api_key displays the first 6 bits in plaintext, and the
rest is replaced by * and the last two bits are displayed in plaintext
If the type is other, decode and return the Token field directly, the field displays the first 6 bits in
plaintext, the rest is replaced by * and the last two bits are displayed in plaintext
"""
ProviderService.init_supported_provider(current_user.current_tenant, "cloud")
providers = Provider.query.filter_by(tenant_id=tenant_id).all()
provider_list = [
{
'provider_name': p.provider_name,
'provider_type': p.provider_type,
'is_valid': p.is_valid,
'last_used': p.last_used,
'is_enabled': p.is_enabled,
**({
'quota_type': p.quota_type,
'quota_limit': p.quota_limit,
'quota_used': p.quota_used
} if p.provider_type == ProviderType.SYSTEM.value else {}),
'token': ProviderService.get_obfuscated_api_key(current_user.current_tenant,
ProviderName(p.provider_name))
}
for p in providers
]
return provider_list
class ProviderTokenApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self, provider):
if provider not in [p.value for p in ProviderName]:
abort(404)
# The role of the current user in the ta table must be admin or owner
if current_user.current_tenant.current_role not in ['admin', 'owner']:
logging.log(logging.ERROR,
f'User {current_user.id} is not authorized to update provider token, current_role is {current_user.current_tenant.current_role}')
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument('token', type=ProviderService.get_token_type(
tenant=current_user.current_tenant,
provider_name=ProviderName(provider)
), required=True, nullable=False, location='json')
args = parser.parse_args()
if not args['token']:
raise ValueError('Token is empty')
try:
ProviderService.validate_provider_configs(
tenant=current_user.current_tenant,
provider_name=ProviderName(provider),
configs=args['token']
)
token_is_valid = True
except ValidateFailedError:
token_is_valid = False
tenant = current_user.current_tenant
base64_encrypted_token = ProviderService.get_encrypted_token(
tenant=current_user.current_tenant,
provider_name=ProviderName(provider),
configs=args['token']
)
provider_model = Provider.query.filter_by(tenant_id=tenant.id, provider_name=provider,
provider_type=ProviderType.CUSTOM.value).first()
# Only allow updating token for CUSTOM provider type
if provider_model:
provider_model.encrypted_config = base64_encrypted_token
provider_model.is_valid = token_is_valid
else:
provider_model = Provider(tenant_id=tenant.id, provider_name=provider,
provider_type=ProviderType.CUSTOM.value,
encrypted_config=base64_encrypted_token,
is_valid=token_is_valid)
db.session.add(provider_model)
db.session.commit()
if provider in [ProviderName.ANTHROPIC.value, ProviderName.AZURE_OPENAI.value, ProviderName.COHERE.value,
ProviderName.HUGGINGFACEHUB.value]:
return {'result': 'success', 'warning': 'MOCK: This provider is not supported yet.'}, 201
return {'result': 'success'}, 201
class ProviderTokenValidateApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self, provider):
if provider not in [p.value for p in ProviderName]:
abort(404)
parser = reqparse.RequestParser()
parser.add_argument('token', type=ProviderService.get_token_type(
tenant=current_user.current_tenant,
provider_name=ProviderName(provider)
), required=True, nullable=False, location='json')
args = parser.parse_args()
# todo: remove this when the provider is supported
if provider in [ProviderName.ANTHROPIC.value, ProviderName.AZURE_OPENAI.value, ProviderName.COHERE.value,
ProviderName.HUGGINGFACEHUB.value]:
return {'result': 'success', 'warning': 'MOCK: This provider is not supported yet.'}
result = True
error = None
try:
ProviderService.validate_provider_configs(
tenant=current_user.current_tenant,
provider_name=ProviderName(provider),
configs=args['token']
)
except ValidateFailedError as e:
result = False
error = str(e)
response = {'result': 'success' if result else 'error'}
if not result:
response['error'] = error
return response
class ProviderSystemApi(Resource):
@setup_required
@login_required
@account_initialization_required
def put(self, provider):
if provider not in [p.value for p in ProviderName]:
abort(404)
parser = reqparse.RequestParser()
parser.add_argument('is_enabled', type=bool, required=True, location='json')
args = parser.parse_args()
tenant = current_user.current_tenant_id
provider_model = Provider.query.filter_by(tenant_id=tenant.id, provider_name=provider).first()
if provider_model and provider_model.provider_type == ProviderType.SYSTEM.value:
provider_model.is_valid = args['is_enabled']
db.session.commit()
elif not provider_model:
ProviderService.create_system_provider(tenant, provider, args['is_enabled'])
else:
abort(403)
return {'result': 'success'}
@setup_required
@login_required
@account_initialization_required
def get(self, provider):
if provider not in [p.value for p in ProviderName]:
abort(404)
# The role of the current user in the ta table must be admin or owner
if current_user.current_tenant.current_role not in ['admin', 'owner']:
raise Forbidden()
provider_model = db.session.query(Provider).filter(Provider.tenant_id == current_user.current_tenant_id,
Provider.provider_name == provider,
Provider.provider_type == ProviderType.SYSTEM.value).first()
system_model = None
if provider_model:
system_model = {
'result': 'success',
'provider': {
'provider_name': provider_model.provider_name,
'provider_type': provider_model.provider_type,
'is_valid': provider_model.is_valid,
'last_used': provider_model.last_used,
'is_enabled': provider_model.is_enabled,
'quota_type': provider_model.quota_type,
'quota_limit': provider_model.quota_limit,
'quota_used': provider_model.quota_used
}
}
else:
abort(404)
return system_model
api.add_resource(ProviderTokenApi, '/providers/<provider>/token',
endpoint='current_providers_token') # Deprecated
api.add_resource(ProviderTokenValidateApi, '/providers/<provider>/token-validate',
endpoint='current_providers_token_validate') # Deprecated
api.add_resource(ProviderTokenApi, '/workspaces/current/providers/<provider>/token',
endpoint='workspaces_current_providers_token') # PUT for updating provider token
api.add_resource(ProviderTokenValidateApi, '/workspaces/current/providers/<provider>/token-validate',
endpoint='workspaces_current_providers_token_validate') # POST for validating provider token
api.add_resource(ProviderListApi, '/workspaces/current/providers') # GET for getting providers list
api.add_resource(ProviderSystemApi, '/workspaces/current/providers/<provider>/system',
endpoint='workspaces_current_providers_system') # GET for getting provider quota, PUT for updating provider status

View File

@@ -0,0 +1,97 @@
# -*- coding:utf-8 -*-
import logging
from flask import request
from flask_login import login_required, current_user
from flask_restful import Resource, fields, marshal_with, reqparse, marshal
from controllers.console import api
from controllers.console.setup import setup_required
from controllers.console.error import AccountNotLinkTenantError
from controllers.console.wraps import account_initialization_required
from libs.helper import TimestampField
from extensions.ext_database import db
from models.account import Tenant
from services.account_service import TenantService
from services.workspace_service import WorkspaceService
provider_fields = {
'provider_name': fields.String,
'provider_type': fields.String,
'is_valid': fields.Boolean,
'token_is_set': fields.Boolean,
}
tenant_fields = {
'id': fields.String,
'name': fields.String,
'plan': fields.String,
'status': fields.String,
'created_at': TimestampField,
'role': fields.String,
'providers': fields.List(fields.Nested(provider_fields)),
'in_trail': fields.Boolean,
'trial_end_reason': fields.String,
}
tenants_fields = {
'id': fields.String,
'name': fields.String,
'plan': fields.String,
'status': fields.String,
'created_at': TimestampField,
'current': fields.Boolean
}
class TenantListApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self):
tenants = TenantService.get_join_tenants(current_user)
for tenant in tenants:
if tenant.id == current_user.current_tenant_id:
tenant.current = True # Set current=True for current tenant
return {'workspaces': marshal(tenants, tenants_fields)}, 200
class TenantApi(Resource):
@setup_required
@login_required
@account_initialization_required
@marshal_with(tenant_fields)
def get(self):
if request.path == '/info':
logging.warning('Deprecated URL /info was used.')
tenant = current_user.current_tenant
return WorkspaceService.get_tenant_info(tenant), 200
class SwitchWorkspaceApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self):
parser = reqparse.RequestParser()
parser.add_argument('tenant_id', type=str, required=True, location='json')
args = parser.parse_args()
# check if tenant_id is valid, 403 if not
try:
TenantService.switch_tenant(current_user, args['tenant_id'])
except Exception:
raise AccountNotLinkTenantError("Account not link tenant")
new_tenant = db.session.query(Tenant).get(args['tenant_id']) # Get new tenant
return {'result': 'success', 'new_tenant': marshal(WorkspaceService.get_tenant_info(new_tenant), tenant_fields)}
api.add_resource(TenantListApi, '/workspaces') # GET for getting all tenants
api.add_resource(TenantApi, '/workspaces/current', endpoint='workspaces_current') # GET for getting current tenant info
api.add_resource(TenantApi, '/info', endpoint='info') # Deprecated
api.add_resource(SwitchWorkspaceApi, '/workspaces/switch') # POST for switching tenant