Initial commit
This commit is contained in:
120
api/core/generator/llm_generator.py
Normal file
120
api/core/generator/llm_generator.py
Normal file
@@ -0,0 +1,120 @@
|
||||
import logging
|
||||
|
||||
from langchain.chat_models.base import BaseChatModel
|
||||
from langchain.schema import HumanMessage
|
||||
|
||||
from core.constant import llm_constant
|
||||
from core.llm.llm_builder import LLMBuilder
|
||||
from core.llm.streamable_open_ai import StreamableOpenAI
|
||||
from core.llm.token_calculator import TokenCalculator
|
||||
|
||||
from core.prompt.output_parser.suggested_questions_after_answer import SuggestedQuestionsAfterAnswerOutputParser
|
||||
from core.prompt.prompt_template import OutLinePromptTemplate
|
||||
from core.prompt.prompts import CONVERSATION_TITLE_PROMPT, CONVERSATION_SUMMARY_PROMPT, INTRODUCTION_GENERATE_PROMPT
|
||||
|
||||
|
||||
# gpt-3.5-turbo works not well
|
||||
generate_base_model = 'text-davinci-003'
|
||||
|
||||
|
||||
class LLMGenerator:
|
||||
@classmethod
|
||||
def generate_conversation_name(cls, tenant_id: str, query, answer):
|
||||
prompt = CONVERSATION_TITLE_PROMPT
|
||||
prompt = prompt.format(query=query, answer=answer)
|
||||
llm: StreamableOpenAI = LLMBuilder.to_llm(
|
||||
tenant_id=tenant_id,
|
||||
model_name=generate_base_model,
|
||||
max_tokens=50
|
||||
)
|
||||
|
||||
if isinstance(llm, BaseChatModel):
|
||||
prompt = [HumanMessage(content=prompt)]
|
||||
|
||||
response = llm.generate([prompt])
|
||||
answer = response.generations[0][0].text
|
||||
return answer.strip()
|
||||
|
||||
@classmethod
|
||||
def generate_conversation_summary(cls, tenant_id: str, messages):
|
||||
max_tokens = 200
|
||||
|
||||
prompt = CONVERSATION_SUMMARY_PROMPT
|
||||
prompt_with_empty_context = prompt.format(context='')
|
||||
prompt_tokens = TokenCalculator.get_num_tokens(generate_base_model, prompt_with_empty_context)
|
||||
rest_tokens = llm_constant.max_context_token_length[generate_base_model] - prompt_tokens - max_tokens
|
||||
|
||||
context = ''
|
||||
for message in messages:
|
||||
if not message.answer:
|
||||
continue
|
||||
|
||||
message_qa_text = "Human:" + message.query + "\nAI:" + message.answer + "\n"
|
||||
if rest_tokens - TokenCalculator.get_num_tokens(generate_base_model, context + message_qa_text) > 0:
|
||||
context += message_qa_text
|
||||
|
||||
prompt = prompt.format(context=context)
|
||||
|
||||
llm: StreamableOpenAI = LLMBuilder.to_llm(
|
||||
tenant_id=tenant_id,
|
||||
model_name=generate_base_model,
|
||||
max_tokens=max_tokens
|
||||
)
|
||||
|
||||
if isinstance(llm, BaseChatModel):
|
||||
prompt = [HumanMessage(content=prompt)]
|
||||
|
||||
response = llm.generate([prompt])
|
||||
answer = response.generations[0][0].text
|
||||
return answer.strip()
|
||||
|
||||
@classmethod
|
||||
def generate_introduction(cls, tenant_id: str, pre_prompt: str):
|
||||
prompt = INTRODUCTION_GENERATE_PROMPT
|
||||
prompt = prompt.format(prompt=pre_prompt)
|
||||
|
||||
llm: StreamableOpenAI = LLMBuilder.to_llm(
|
||||
tenant_id=tenant_id,
|
||||
model_name=generate_base_model,
|
||||
)
|
||||
|
||||
if isinstance(llm, BaseChatModel):
|
||||
prompt = [HumanMessage(content=prompt)]
|
||||
|
||||
response = llm.generate([prompt])
|
||||
answer = response.generations[0][0].text
|
||||
return answer.strip()
|
||||
|
||||
@classmethod
|
||||
def generate_suggested_questions_after_answer(cls, tenant_id: str, histories: str):
|
||||
output_parser = SuggestedQuestionsAfterAnswerOutputParser()
|
||||
format_instructions = output_parser.get_format_instructions()
|
||||
|
||||
prompt = OutLinePromptTemplate(
|
||||
template="{histories}\n{format_instructions}\nquestions:\n",
|
||||
input_variables=["histories"],
|
||||
partial_variables={"format_instructions": format_instructions}
|
||||
)
|
||||
|
||||
_input = prompt.format_prompt(histories=histories)
|
||||
|
||||
llm: StreamableOpenAI = LLMBuilder.to_llm(
|
||||
tenant_id=tenant_id,
|
||||
model_name=generate_base_model,
|
||||
temperature=0,
|
||||
max_tokens=256
|
||||
)
|
||||
|
||||
if isinstance(llm, BaseChatModel):
|
||||
query = [HumanMessage(content=_input.to_string())]
|
||||
else:
|
||||
query = _input.to_string()
|
||||
|
||||
try:
|
||||
output = llm(query)
|
||||
questions = output_parser.parse(output)
|
||||
except Exception:
|
||||
logging.exception("Error generating suggested questions after answer")
|
||||
questions = []
|
||||
|
||||
return questions
|
Reference in New Issue
Block a user