feat: add hosted moderation (#1158)

This commit is contained in:
takatost
2023-09-12 10:26:12 +08:00
committed by GitHub
parent 983834cd52
commit f9082104ed
15 changed files with 240 additions and 69 deletions

View File

@@ -6,7 +6,7 @@ from typing import Any, Dict, List, Union, Optional
from langchain.agents import openai_functions_agent, openai_functions_multi_agent
from langchain.callbacks.base import BaseCallbackHandler
from langchain.schema import AgentAction, AgentFinish, LLMResult, ChatGeneration
from langchain.schema import AgentAction, AgentFinish, LLMResult, ChatGeneration, BaseMessage
from core.callback_handler.entity.agent_loop import AgentLoop
from core.conversation_message_task import ConversationMessageTask
@@ -18,9 +18,9 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
"""Callback Handler that prints to std out."""
raise_error: bool = True
def __init__(self, model_instant: BaseLLM, conversation_message_task: ConversationMessageTask) -> None:
def __init__(self, model_instance: BaseLLM, conversation_message_task: ConversationMessageTask) -> None:
"""Initialize callback handler."""
self.model_instant = model_instant
self.model_instance = model_instance
self.conversation_message_task = conversation_message_task
self._agent_loops = []
self._current_loop = None
@@ -46,6 +46,21 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
"""Whether to ignore chain callbacks."""
return True
def on_chat_model_start(
self,
serialized: Dict[str, Any],
messages: List[List[BaseMessage]],
**kwargs: Any
) -> Any:
if not self._current_loop:
# Agent start with a LLM query
self._current_loop = AgentLoop(
position=len(self._agent_loops) + 1,
prompt="\n".join([message.content for message in messages[0]]),
status='llm_started',
started_at=time.perf_counter()
)
def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None:
@@ -70,7 +85,7 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
if response.llm_output:
self._current_loop.prompt_tokens = response.llm_output['token_usage']['prompt_tokens']
else:
self._current_loop.prompt_tokens = self.model_instant.get_num_tokens(
self._current_loop.prompt_tokens = self.model_instance.get_num_tokens(
[PromptMessage(content=self._current_loop.prompt)]
)
completion_generation = response.generations[0][0]
@@ -87,7 +102,7 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
if response.llm_output:
self._current_loop.completion_tokens = response.llm_output['token_usage']['completion_tokens']
else:
self._current_loop.completion_tokens = self.model_instant.get_num_tokens(
self._current_loop.completion_tokens = self.model_instance.get_num_tokens(
[PromptMessage(content=self._current_loop.completion)]
)
@@ -162,7 +177,7 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
self._current_loop.latency = self._current_loop.completed_at - self._current_loop.started_at
self.conversation_message_task.on_agent_end(
self._message_agent_thought, self.model_instant, self._current_loop
self._message_agent_thought, self.model_instance, self._current_loop
)
self._agent_loops.append(self._current_loop)
@@ -193,7 +208,7 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
)
self.conversation_message_task.on_agent_end(
self._message_agent_thought, self.model_instant, self._current_loop
self._message_agent_thought, self.model_instance, self._current_loop
)
self._agent_loops.append(self._current_loop)

View File

@@ -6,4 +6,3 @@ class LLMMessage(BaseModel):
prompt_tokens: int = 0
completion: str = ''
completion_tokens: int = 0
latency: float = 0.0

View File

@@ -1,5 +1,4 @@
import logging
import time
from typing import Any, Dict, List, Union
from langchain.callbacks.base import BaseCallbackHandler
@@ -32,7 +31,6 @@ class LLMCallbackHandler(BaseCallbackHandler):
messages: List[List[BaseMessage]],
**kwargs: Any
) -> Any:
self.start_at = time.perf_counter()
real_prompts = []
for message in messages[0]:
if message.type == 'human':
@@ -53,8 +51,6 @@ class LLMCallbackHandler(BaseCallbackHandler):
def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None:
self.start_at = time.perf_counter()
self.llm_message.prompt = [{
"role": 'user',
"text": prompts[0]
@@ -63,9 +59,6 @@ class LLMCallbackHandler(BaseCallbackHandler):
self.llm_message.prompt_tokens = self.model_instance.get_num_tokens([PromptMessage(content=prompts[0])])
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
end_at = time.perf_counter()
self.llm_message.latency = end_at - self.start_at
if not self.conversation_message_task.streaming:
self.conversation_message_task.append_message_text(response.generations[0][0].text)
self.llm_message.completion = response.generations[0][0].text
@@ -89,8 +82,6 @@ class LLMCallbackHandler(BaseCallbackHandler):
"""Do nothing."""
if isinstance(error, ConversationTaskStoppedException):
if self.conversation_message_task.streaming:
end_at = time.perf_counter()
self.llm_message.latency = end_at - self.start_at
self.llm_message.completion_tokens = self.model_instance.get_num_tokens(
[PromptMessage(content=self.llm_message.completion)]
)