Initial commit

This commit is contained in:
John Wang
2023-05-15 08:51:32 +08:00
commit db896255d6
744 changed files with 56028 additions and 0 deletions

View File

@@ -0,0 +1,10 @@
# -*- coding:utf-8 -*-
from flask import Blueprint
from libs.external_api import ExternalApi
bp = Blueprint('web', __name__, url_prefix='/api')
api = ExternalApi(bp)
from . import completion, app, conversation, message, site, saved_message

View File

@@ -0,0 +1,42 @@
# -*- coding:utf-8 -*-
from flask_restful import marshal_with, fields
from controllers.web import api
from controllers.web.wraps import WebApiResource
class AppParameterApi(WebApiResource):
"""Resource for app variables."""
variable_fields = {
'key': fields.String,
'name': fields.String,
'description': fields.String,
'type': fields.String,
'default': fields.String,
'max_length': fields.Integer,
'options': fields.List(fields.String)
}
parameters_fields = {
'opening_statement': fields.String,
'suggested_questions': fields.Raw,
'suggested_questions_after_answer': fields.Raw,
'more_like_this': fields.Raw,
'user_input_form': fields.Raw,
}
@marshal_with(parameters_fields)
def get(self, app_model, end_user):
"""Retrieve app parameters."""
app_model_config = app_model.app_model_config
return {
'opening_statement': app_model_config.opening_statement,
'suggested_questions': app_model_config.suggested_questions_list,
'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict,
'more_like_this': app_model_config.more_like_this_dict,
'user_input_form': app_model_config.user_input_form_list
}
api.add_resource(AppParameterApi, '/parameters')

View File

@@ -0,0 +1,175 @@
# -*- coding:utf-8 -*-
import json
import logging
from typing import Generator, Union
from flask import Response, stream_with_context
from flask_restful import reqparse
from werkzeug.exceptions import InternalServerError, NotFound
import services
from controllers.web import api
from controllers.web.error import AppUnavailableError, ConversationCompletedError, \
ProviderNotInitializeError, NotChatAppError, NotCompletionAppError, CompletionRequestError, \
ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError
from controllers.web.wraps import WebApiResource
from core.conversation_message_task import PubHandler
from core.llm.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \
LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
from libs.helper import uuid_value
from services.completion_service import CompletionService
# define completion api for user
class CompletionApi(WebApiResource):
def post(self, app_model, end_user):
if app_model.mode != 'completion':
raise NotCompletionAppError()
parser = reqparse.RequestParser()
parser.add_argument('inputs', type=dict, required=True, location='json')
parser.add_argument('query', type=str, location='json')
parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
args = parser.parse_args()
streaming = args['response_mode'] == 'streaming'
try:
response = CompletionService.completion(
app_model=app_model,
user=end_user,
args=args,
from_source='api',
streaming=streaming
)
return compact_response(response)
except services.errors.conversation.ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")
except services.errors.conversation.ConversationCompletedError:
raise ConversationCompletedError()
except services.errors.app_model_config.AppModelConfigBrokenError:
logging.exception("App model config broken.")
raise AppUnavailableError()
except ProviderTokenNotInitError:
raise ProviderNotInitializeError()
except QuotaExceededError:
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
LLMRateLimitError, LLMAuthorizationError) as e:
raise CompletionRequestError(str(e))
except ValueError as e:
raise e
except Exception as e:
logging.exception("internal server error.")
raise InternalServerError()
class CompletionStopApi(WebApiResource):
def post(self, app_model, end_user, task_id):
if app_model.mode != 'completion':
raise NotCompletionAppError()
PubHandler.stop(end_user, task_id)
return {'result': 'success'}, 200
class ChatApi(WebApiResource):
def post(self, app_model, end_user):
if app_model.mode != 'chat':
raise NotChatAppError()
parser = reqparse.RequestParser()
parser.add_argument('inputs', type=dict, required=True, location='json')
parser.add_argument('query', type=str, required=True, location='json')
parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
parser.add_argument('conversation_id', type=uuid_value, location='json')
args = parser.parse_args()
streaming = args['response_mode'] == 'streaming'
try:
response = CompletionService.completion(
app_model=app_model,
user=end_user,
args=args,
from_source='api',
streaming=streaming
)
return compact_response(response)
except services.errors.conversation.ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")
except services.errors.conversation.ConversationCompletedError:
raise ConversationCompletedError()
except services.errors.app_model_config.AppModelConfigBrokenError:
logging.exception("App model config broken.")
raise AppUnavailableError()
except ProviderTokenNotInitError:
raise ProviderNotInitializeError()
except QuotaExceededError:
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
LLMRateLimitError, LLMAuthorizationError) as e:
raise CompletionRequestError(str(e))
except ValueError as e:
raise e
except Exception as e:
logging.exception("internal server error.")
raise InternalServerError()
class ChatStopApi(WebApiResource):
def post(self, app_model, end_user, task_id):
if app_model.mode != 'chat':
raise NotChatAppError()
PubHandler.stop(end_user, task_id)
return {'result': 'success'}, 200
def compact_response(response: Union[dict | Generator]) -> Response:
if isinstance(response, dict):
return Response(response=json.dumps(response), status=200, mimetype='application/json')
else:
def generate() -> Generator:
try:
for chunk in response:
yield chunk
except services.errors.conversation.ConversationNotExistsError:
yield "data: " + json.dumps(api.handle_error(NotFound("Conversation Not Exists.")).get_json()) + "\n\n"
except services.errors.conversation.ConversationCompletedError:
yield "data: " + json.dumps(api.handle_error(ConversationCompletedError()).get_json()) + "\n\n"
except services.errors.app_model_config.AppModelConfigBrokenError:
logging.exception("App model config broken.")
yield "data: " + json.dumps(api.handle_error(AppUnavailableError()).get_json()) + "\n\n"
except ProviderTokenNotInitError:
yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError()).get_json()) + "\n\n"
except QuotaExceededError:
yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n"
except ModelCurrentlyNotSupportError:
yield "data: " + json.dumps(api.handle_error(ProviderModelCurrentlyNotSupportError()).get_json()) + "\n\n"
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
LLMRateLimitError, LLMAuthorizationError) as e:
yield "data: " + json.dumps(api.handle_error(CompletionRequestError(str(e))).get_json()) + "\n\n"
except ValueError as e:
yield "data: " + json.dumps(api.handle_error(e).get_json()) + "\n\n"
except Exception:
logging.exception("internal server error.")
yield "data: " + json.dumps(api.handle_error(InternalServerError()).get_json()) + "\n\n"
return Response(stream_with_context(generate()), status=200,
mimetype='text/event-stream')
api.add_resource(CompletionApi, '/completion-messages')
api.add_resource(CompletionStopApi, '/completion-messages/<string:task_id>/stop')
api.add_resource(ChatApi, '/chat-messages')
api.add_resource(ChatStopApi, '/chat-messages/<string:task_id>/stop')

View File

@@ -0,0 +1,121 @@
# -*- coding:utf-8 -*-
from flask_restful import fields, reqparse, marshal_with
from flask_restful.inputs import int_range
from werkzeug.exceptions import NotFound
from controllers.web import api
from controllers.web.error import NotChatAppError
from controllers.web.wraps import WebApiResource
from libs.helper import TimestampField, uuid_value
from services.conversation_service import ConversationService
from services.errors.conversation import LastConversationNotExistsError, ConversationNotExistsError
from services.web_conversation_service import WebConversationService
conversation_fields = {
'id': fields.String,
'name': fields.String,
'inputs': fields.Raw,
'status': fields.String,
'introduction': fields.String,
'created_at': TimestampField
}
conversation_infinite_scroll_pagination_fields = {
'limit': fields.Integer,
'has_more': fields.Boolean,
'data': fields.List(fields.Nested(conversation_fields))
}
class ConversationListApi(WebApiResource):
@marshal_with(conversation_infinite_scroll_pagination_fields)
def get(self, app_model, end_user):
if app_model.mode != 'chat':
raise NotChatAppError()
parser = reqparse.RequestParser()
parser.add_argument('last_id', type=uuid_value, location='args')
parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args')
parser.add_argument('pinned', type=str, choices=['true', 'false', None], location='args')
args = parser.parse_args()
pinned = None
if 'pinned' in args and args['pinned'] is not None:
pinned = True if args['pinned'] == 'true' else False
try:
return WebConversationService.pagination_by_last_id(
app_model=app_model,
end_user=end_user,
last_id=args['last_id'],
limit=args['limit'],
pinned=pinned
)
except LastConversationNotExistsError:
raise NotFound("Last Conversation Not Exists.")
class ConversationApi(WebApiResource):
def delete(self, app_model, end_user, c_id):
if app_model.mode != 'chat':
raise NotChatAppError()
conversation_id = str(c_id)
ConversationService.delete(app_model, conversation_id, end_user)
WebConversationService.unpin(app_model, conversation_id, end_user)
return {"result": "success"}, 204
class ConversationRenameApi(WebApiResource):
@marshal_with(conversation_fields)
def post(self, app_model, end_user, c_id):
if app_model.mode != 'chat':
raise NotChatAppError()
conversation_id = str(c_id)
parser = reqparse.RequestParser()
parser.add_argument('name', type=str, required=True, location='json')
args = parser.parse_args()
try:
return ConversationService.rename(app_model, conversation_id, end_user, args['name'])
except ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")
class ConversationPinApi(WebApiResource):
def patch(self, app_model, end_user, c_id):
if app_model.mode != 'chat':
raise NotChatAppError()
conversation_id = str(c_id)
try:
WebConversationService.pin(app_model, conversation_id, end_user)
except ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")
return {"result": "success"}
class ConversationUnPinApi(WebApiResource):
def patch(self, app_model, end_user, c_id):
if app_model.mode != 'chat':
raise NotChatAppError()
conversation_id = str(c_id)
WebConversationService.unpin(app_model, conversation_id, end_user)
return {"result": "success"}
api.add_resource(ConversationRenameApi, '/conversations/<uuid:c_id>/name', endpoint='web_conversation_name')
api.add_resource(ConversationListApi, '/conversations')
api.add_resource(ConversationApi, '/conversations/<uuid:c_id>')
api.add_resource(ConversationPinApi, '/conversations/<uuid:c_id>/pin')
api.add_resource(ConversationUnPinApi, '/conversations/<uuid:c_id>/unpin')

View File

@@ -0,0 +1,62 @@
# -*- coding:utf-8 -*-
from libs.exception import BaseHTTPException
class AppUnavailableError(BaseHTTPException):
error_code = 'app_unavailable'
description = "App unavailable."
code = 400
class NotCompletionAppError(BaseHTTPException):
error_code = 'not_completion_app'
description = "Not Completion App"
code = 400
class NotChatAppError(BaseHTTPException):
error_code = 'not_chat_app'
description = "Not Chat App"
code = 400
class ConversationCompletedError(BaseHTTPException):
error_code = 'conversation_completed'
description = "Conversation Completed."
code = 400
class ProviderNotInitializeError(BaseHTTPException):
error_code = 'provider_not_initialize'
description = "Provider Token not initialize."
code = 400
class ProviderQuotaExceededError(BaseHTTPException):
error_code = 'provider_quota_exceeded'
description = "Provider quota exceeded."
code = 400
class ProviderModelCurrentlyNotSupportError(BaseHTTPException):
error_code = 'model_currently_not_support'
description = "GPT-4 currently not support."
code = 400
class CompletionRequestError(BaseHTTPException):
error_code = 'completion_request_error'
description = "Completion request failed."
code = 400
class AppMoreLikeThisDisabledError(BaseHTTPException):
error_code = 'app_more_like_this_disabled'
description = "More like this disabled."
code = 403
class AppSuggestedQuestionsAfterAnswerDisabledError(BaseHTTPException):
error_code = 'app_suggested_questions_after_answer_disabled'
description = "Function Suggested questions after answer disabled."
code = 403

View File

@@ -0,0 +1,189 @@
# -*- coding:utf-8 -*-
import json
import logging
from typing import Generator, Union
from flask import stream_with_context, Response
from flask_restful import reqparse, fields, marshal_with
from flask_restful.inputs import int_range
from werkzeug.exceptions import NotFound, InternalServerError
import services
from controllers.web import api
from controllers.web.error import NotChatAppError, CompletionRequestError, ProviderNotInitializeError, \
AppMoreLikeThisDisabledError, NotCompletionAppError, AppSuggestedQuestionsAfterAnswerDisabledError, \
ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError
from controllers.web.wraps import WebApiResource
from core.llm.error import LLMRateLimitError, LLMBadRequestError, LLMAuthorizationError, LLMAPIConnectionError, \
ProviderTokenNotInitError, LLMAPIUnavailableError, QuotaExceededError, ModelCurrentlyNotSupportError
from libs.helper import uuid_value, TimestampField
from services.completion_service import CompletionService
from services.errors.app import MoreLikeThisDisabledError
from services.errors.conversation import ConversationNotExistsError
from services.errors.message import MessageNotExistsError, SuggestedQuestionsAfterAnswerDisabledError
from services.message_service import MessageService
class MessageListApi(WebApiResource):
feedback_fields = {
'rating': fields.String
}
message_fields = {
'id': fields.String,
'conversation_id': fields.String,
'inputs': fields.Raw,
'query': fields.String,
'answer': fields.String,
'feedback': fields.Nested(feedback_fields, attribute='user_feedback', allow_null=True),
'created_at': TimestampField
}
message_infinite_scroll_pagination_fields = {
'limit': fields.Integer,
'has_more': fields.Boolean,
'data': fields.List(fields.Nested(message_fields))
}
@marshal_with(message_infinite_scroll_pagination_fields)
def get(self, app_model, end_user):
if app_model.mode != 'chat':
raise NotChatAppError()
parser = reqparse.RequestParser()
parser.add_argument('conversation_id', required=True, type=uuid_value, location='args')
parser.add_argument('first_id', type=uuid_value, location='args')
parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args')
args = parser.parse_args()
try:
return MessageService.pagination_by_first_id(app_model, end_user,
args['conversation_id'], args['first_id'], args['limit'])
except services.errors.conversation.ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")
except services.errors.message.FirstMessageNotExistsError:
raise NotFound("First Message Not Exists.")
class MessageFeedbackApi(WebApiResource):
def post(self, app_model, end_user, message_id):
message_id = str(message_id)
parser = reqparse.RequestParser()
parser.add_argument('rating', type=str, choices=['like', 'dislike', None], location='json')
args = parser.parse_args()
try:
MessageService.create_feedback(app_model, message_id, end_user, args['rating'])
except services.errors.message.MessageNotExistsError:
raise NotFound("Message Not Exists.")
return {'result': 'success'}
class MessageMoreLikeThisApi(WebApiResource):
def get(self, app_model, end_user, message_id):
if app_model.mode != 'completion':
raise NotCompletionAppError()
message_id = str(message_id)
parser = reqparse.RequestParser()
parser.add_argument('response_mode', type=str, required=True, choices=['blocking', 'streaming'], location='args')
args = parser.parse_args()
streaming = args['response_mode'] == 'streaming'
try:
response = CompletionService.generate_more_like_this(app_model, end_user, message_id, streaming)
return compact_response(response)
except MessageNotExistsError:
raise NotFound("Message Not Exists.")
except MoreLikeThisDisabledError:
raise AppMoreLikeThisDisabledError()
except ProviderTokenNotInitError:
raise ProviderNotInitializeError()
except QuotaExceededError:
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
LLMRateLimitError, LLMAuthorizationError) as e:
raise CompletionRequestError(str(e))
except ValueError as e:
raise e
except Exception:
logging.exception("internal server error.")
raise InternalServerError()
def compact_response(response: Union[dict | Generator]) -> Response:
if isinstance(response, dict):
return Response(response=json.dumps(response), status=200, mimetype='application/json')
else:
def generate() -> Generator:
try:
for chunk in response:
yield chunk
except MessageNotExistsError:
yield "data: " + json.dumps(api.handle_error(NotFound("Message Not Exists.")).get_json()) + "\n\n"
except MoreLikeThisDisabledError:
yield "data: " + json.dumps(api.handle_error(AppMoreLikeThisDisabledError()).get_json()) + "\n\n"
except ProviderTokenNotInitError:
yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError()).get_json()) + "\n\n"
except QuotaExceededError:
yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n"
except ModelCurrentlyNotSupportError:
yield "data: " + json.dumps(api.handle_error(ProviderModelCurrentlyNotSupportError()).get_json()) + "\n\n"
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
LLMRateLimitError, LLMAuthorizationError) as e:
yield "data: " + json.dumps(api.handle_error(CompletionRequestError(str(e))).get_json()) + "\n\n"
except ValueError as e:
yield "data: " + json.dumps(api.handle_error(e).get_json()) + "\n\n"
except Exception:
logging.exception("internal server error.")
yield "data: " + json.dumps(api.handle_error(InternalServerError()).get_json()) + "\n\n"
return Response(stream_with_context(generate()), status=200,
mimetype='text/event-stream')
class MessageSuggestedQuestionApi(WebApiResource):
def get(self, app_model, end_user, message_id):
if app_model.mode != 'chat':
raise NotCompletionAppError()
message_id = str(message_id)
try:
questions = MessageService.get_suggested_questions_after_answer(
app_model=app_model,
user=end_user,
message_id=message_id
)
except MessageNotExistsError:
raise NotFound("Message not found")
except ConversationNotExistsError:
raise NotFound("Conversation not found")
except SuggestedQuestionsAfterAnswerDisabledError:
raise AppSuggestedQuestionsAfterAnswerDisabledError()
except ProviderTokenNotInitError:
raise ProviderNotInitializeError()
except QuotaExceededError:
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
LLMRateLimitError, LLMAuthorizationError) as e:
raise CompletionRequestError(str(e))
except Exception:
logging.exception("internal server error.")
raise InternalServerError()
return {'data': questions}
api.add_resource(MessageListApi, '/messages')
api.add_resource(MessageFeedbackApi, '/messages/<uuid:message_id>/feedbacks')
api.add_resource(MessageMoreLikeThisApi, '/messages/<uuid:message_id>/more-like-this')
api.add_resource(MessageSuggestedQuestionApi, '/messages/<uuid:message_id>/suggested-questions')

View File

@@ -0,0 +1,74 @@
from flask_restful import reqparse, marshal_with, fields
from flask_restful.inputs import int_range
from werkzeug.exceptions import NotFound
from controllers.web import api
from controllers.web.error import NotCompletionAppError
from controllers.web.wraps import WebApiResource
from libs.helper import uuid_value, TimestampField
from services.errors.message import MessageNotExistsError
from services.saved_message_service import SavedMessageService
feedback_fields = {
'rating': fields.String
}
message_fields = {
'id': fields.String,
'inputs': fields.Raw,
'query': fields.String,
'answer': fields.String,
'feedback': fields.Nested(feedback_fields, attribute='user_feedback', allow_null=True),
'created_at': TimestampField
}
class SavedMessageListApi(WebApiResource):
saved_message_infinite_scroll_pagination_fields = {
'limit': fields.Integer,
'has_more': fields.Boolean,
'data': fields.List(fields.Nested(message_fields))
}
@marshal_with(saved_message_infinite_scroll_pagination_fields)
def get(self, app_model, end_user):
if app_model.mode != 'completion':
raise NotCompletionAppError()
parser = reqparse.RequestParser()
parser.add_argument('last_id', type=uuid_value, location='args')
parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args')
args = parser.parse_args()
return SavedMessageService.pagination_by_last_id(app_model, end_user, args['last_id'], args['limit'])
def post(self, app_model, end_user):
if app_model.mode != 'completion':
raise NotCompletionAppError()
parser = reqparse.RequestParser()
parser.add_argument('message_id', type=uuid_value, required=True, location='json')
args = parser.parse_args()
try:
SavedMessageService.save(app_model, end_user, args['message_id'])
except MessageNotExistsError:
raise NotFound("Message Not Exists.")
return {'result': 'success'}
class SavedMessageApi(WebApiResource):
def delete(self, app_model, end_user, message_id):
message_id = str(message_id)
if app_model.mode != 'completion':
raise NotCompletionAppError()
SavedMessageService.delete(app_model, end_user, message_id)
return {'result': 'success'}
api.add_resource(SavedMessageListApi, '/saved-messages')
api.add_resource(SavedMessageApi, '/saved-messages/<uuid:message_id>')

View File

@@ -0,0 +1,73 @@
# -*- coding:utf-8 -*-
from flask_restful import fields, marshal_with
from werkzeug.exceptions import Forbidden
from controllers.web import api
from controllers.web.wraps import WebApiResource
from extensions.ext_database import db
from models.model import Site
class AppSiteApi(WebApiResource):
"""Resource for app sites."""
model_config_fields = {
'opening_statement': fields.String,
'suggested_questions': fields.Raw(attribute='suggested_questions_list'),
'suggested_questions_after_answer': fields.Raw(attribute='suggested_questions_after_answer_dict'),
'more_like_this': fields.Raw(attribute='more_like_this_dict'),
'model': fields.Raw(attribute='model_dict'),
'user_input_form': fields.Raw(attribute='user_input_form_list'),
'pre_prompt': fields.String,
}
site_fields = {
'title': fields.String,
'icon': fields.String,
'icon_background': fields.String,
'description': fields.String,
'copyright': fields.String,
'privacy_policy': fields.String,
'default_language': fields.String,
'prompt_public': fields.Boolean
}
app_fields = {
'app_id': fields.String,
'end_user_id': fields.String,
'enable_site': fields.Boolean,
'site': fields.Nested(site_fields),
'model_config': fields.Nested(model_config_fields, allow_null=True),
'plan': fields.String,
}
@marshal_with(app_fields)
def get(self, app_model, end_user):
"""Retrieve app site info."""
# get site
site = db.session.query(Site).filter(Site.app_id == app_model.id).first()
if not site:
raise Forbidden()
return AppSiteInfo(app_model.tenant, app_model, site, end_user.id)
api.add_resource(AppSiteApi, '/site')
class AppSiteInfo:
"""Class to store site information."""
def __init__(self, tenant, app, site, end_user):
"""Initialize AppSiteInfo instance."""
self.app_id = app.id
self.end_user_id = end_user
self.enable_site = app.enable_site
self.site = site
self.model_config = None
self.plan = tenant.plan
if app.enable_site and site.prompt_public:
app_model_config = app.app_model_config
self.model_config = app_model_config

View File

@@ -0,0 +1,107 @@
# -*- coding:utf-8 -*-
import uuid
from functools import wraps
from flask import request, session
from flask_restful import Resource
from werkzeug.exceptions import NotFound, Unauthorized
from extensions.ext_database import db
from models.model import App, Site, EndUser
def validate_token(view=None):
def decorator(view):
@wraps(view)
def decorated(*args, **kwargs):
site = validate_and_get_site()
app_model = db.session.query(App).get(site.app_id)
if not app_model:
raise NotFound()
if app_model.status != 'normal':
raise NotFound()
if not app_model.enable_site:
raise NotFound()
end_user = create_or_update_end_user_for_session(app_model)
return view(app_model, end_user, *args, **kwargs)
return decorated
if view:
return decorator(view)
return decorator
def validate_and_get_site():
"""
Validate and get API token.
"""
auth_header = request.headers.get('Authorization')
if auth_header is None:
raise Unauthorized()
auth_scheme, auth_token = auth_header.split(None, 1)
auth_scheme = auth_scheme.lower()
if auth_scheme != 'bearer':
raise Unauthorized()
site = db.session.query(Site).filter(
Site.code == auth_token,
Site.status == 'normal'
).first()
if not site:
raise NotFound()
return site
def create_or_update_end_user_for_session(app_model):
"""
Create or update session terminal based on session ID.
"""
if 'session_id' not in session:
session['session_id'] = generate_session_id()
session_id = session.get('session_id')
end_user = db.session.query(EndUser) \
.filter(
EndUser.session_id == session_id,
EndUser.type == 'browser'
).first()
if end_user is None:
end_user = EndUser(
tenant_id=app_model.tenant_id,
app_id=app_model.id,
type='browser',
is_anonymous=True,
session_id=session_id
)
db.session.add(end_user)
db.session.commit()
return end_user
def generate_session_id():
"""
Generate a unique session ID.
"""
count = 1
session_id = ''
while count != 0:
session_id = str(uuid.uuid4())
count = db.session.query(EndUser) \
.filter(EndUser.session_id == session_id).count()
return session_id
class WebApiResource(Resource):
method_decorators = [validate_token]