feat: add api-based extension & external data tool & moderation backend (#1403)

Co-authored-by: takatost <takatost@gmail.com>
This commit is contained in:
Garfield Dai
2023-11-06 19:36:16 +08:00
committed by GitHub
parent 7699621983
commit db43ed6f41
50 changed files with 1624 additions and 273 deletions

View File

@@ -1,13 +1,18 @@
import concurrent
import json
import logging
from typing import Optional, List, Union
from concurrent.futures import ThreadPoolExecutor
from typing import Optional, List, Union, Tuple
from flask import current_app, Flask
from requests.exceptions import ChunkedEncodingError
from core.agent.agent_executor import AgentExecuteResult, PlanningStrategy
from core.callback_handler.main_chain_gather_callback_handler import MainChainGatherCallbackHandler
from core.callback_handler.llm_callback_handler import LLMCallbackHandler
from core.chain.sensitive_word_avoidance_chain import SensitiveWordAvoidanceError
from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException
from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException, \
ConversationTaskInterruptException
from core.external_data_tool.factory import ExternalDataToolFactory
from core.model_providers.error import LLMBadRequestError
from core.memory.read_only_conversation_token_db_buffer_shared_memory import \
ReadOnlyConversationTokenDBBufferSharedMemory
@@ -18,6 +23,8 @@ from core.orchestrator_rule_parser import OrchestratorRuleParser
from core.prompt.prompt_template import PromptTemplateParser
from core.prompt.prompt_transform import PromptTransform
from models.model import App, AppModelConfig, Account, Conversation, EndUser
from core.moderation.base import ModerationException, ModerationAction
from core.moderation.factory import ModerationFactory
class Completion:
@@ -76,26 +83,35 @@ class Completion:
)
try:
# parse sensitive_word_avoidance_chain
chain_callback = MainChainGatherCallbackHandler(conversation_message_task)
sensitive_word_avoidance_chain = orchestrator_rule_parser.to_sensitive_word_avoidance_chain(
final_model_instance, [chain_callback])
if sensitive_word_avoidance_chain:
try:
query = sensitive_word_avoidance_chain.run(query)
except SensitiveWordAvoidanceError as ex:
cls.run_final_llm(
model_instance=final_model_instance,
mode=app.mode,
app_model_config=app_model_config,
query=query,
inputs=inputs,
agent_execute_result=None,
conversation_message_task=conversation_message_task,
memory=memory,
fake_response=ex.message
)
return
try:
# process sensitive_word_avoidance
inputs, query = cls.moderation_for_inputs(app.id, app.tenant_id, app_model_config, inputs, query)
except ModerationException as e:
cls.run_final_llm(
model_instance=final_model_instance,
mode=app.mode,
app_model_config=app_model_config,
query=query,
inputs=inputs,
agent_execute_result=None,
conversation_message_task=conversation_message_task,
memory=memory,
fake_response=str(e)
)
return
# fill in variable inputs from external data tools if exists
external_data_tools = app_model_config.external_data_tools_list
if external_data_tools:
inputs = cls.fill_in_inputs_from_external_data_tools(
tenant_id=app.tenant_id,
app_id=app.id,
external_data_tools=external_data_tools,
inputs=inputs,
query=query
)
# get agent executor
agent_executor = orchestrator_rule_parser.to_agent_executor(
@@ -135,19 +151,110 @@ class Completion:
memory=memory,
fake_response=fake_response
)
except ConversationTaskStoppedException:
except (ConversationTaskInterruptException, ConversationTaskStoppedException):
return
except ChunkedEncodingError as e:
# Interrupt by LLM (like OpenAI), handle it.
logging.warning(f'ChunkedEncodingError: {e}')
conversation_message_task.end()
return
@classmethod
def moderation_for_inputs(cls, app_id: str, tenant_id: str, app_model_config: AppModelConfig, inputs: dict, query: str):
if not app_model_config.sensitive_word_avoidance_dict['enabled']:
return inputs, query
type = app_model_config.sensitive_word_avoidance_dict['type']
moderation = ModerationFactory(type, app_id, tenant_id, app_model_config.sensitive_word_avoidance_dict['config'])
moderation_result = moderation.moderation_for_inputs(inputs, query)
if not moderation_result.flagged:
return inputs, query
if moderation_result.action == ModerationAction.DIRECT_OUTPUT:
raise ModerationException(moderation_result.preset_response)
elif moderation_result.action == ModerationAction.OVERRIDED:
inputs = moderation_result.inputs
query = moderation_result.query
return inputs, query
@classmethod
def fill_in_inputs_from_external_data_tools(cls, tenant_id: str, app_id: str, external_data_tools: list[dict],
inputs: dict, query: str) -> dict:
"""
Fill in variable inputs from external data tools if exists.
:param tenant_id: workspace id
:param app_id: app id
:param external_data_tools: external data tools configs
:param inputs: the inputs
:param query: the query
:return: the filled inputs
"""
# Group tools by type and config
grouped_tools = {}
for tool in external_data_tools:
if not tool.get("enabled"):
continue
tool_key = (tool.get("type"), json.dumps(tool.get("config"), sort_keys=True))
grouped_tools.setdefault(tool_key, []).append(tool)
results = {}
with ThreadPoolExecutor() as executor:
futures = {}
for tools in grouped_tools.values():
# Only query the first tool in each group
first_tool = tools[0]
future = executor.submit(
cls.query_external_data_tool, current_app._get_current_object(), tenant_id, app_id, first_tool,
inputs, query
)
for tool in tools:
futures[future] = tool
for future in concurrent.futures.as_completed(futures):
tool_key, result = future.result()
if tool_key in grouped_tools:
for tool in grouped_tools[tool_key]:
results[tool['variable']] = result
inputs.update(results)
return inputs
@classmethod
def query_external_data_tool(cls, flask_app: Flask, tenant_id: str, app_id: str, external_data_tool: dict,
inputs: dict, query: str) -> Tuple[Optional[str], Optional[str]]:
with flask_app.app_context():
tool_variable = external_data_tool.get("variable")
tool_type = external_data_tool.get("type")
tool_config = external_data_tool.get("config")
external_data_tool_factory = ExternalDataToolFactory(
name=tool_type,
tenant_id=tenant_id,
app_id=app_id,
variable=tool_variable,
config=tool_config
)
# query external data tool
result = external_data_tool_factory.query(
inputs=inputs,
query=query
)
tool_key = (external_data_tool.get("type"), json.dumps(external_data_tool.get("config"), sort_keys=True))
return tool_key, result
@classmethod
def get_query_for_agent(cls, app: App, app_model_config: AppModelConfig, query: str, inputs: dict) -> str:
if app.mode != 'completion':
return query
return inputs.get(app_model_config.dataset_query_variable, "")
@classmethod