feat: claude api support (#572)

This commit is contained in:
John Wang
2023-07-17 00:14:19 +08:00
committed by GitHub
parent 510389909c
commit 7599f79a17
52 changed files with 637 additions and 349 deletions

View File

@@ -118,6 +118,7 @@ class Completion:
prompt, stop_words = cls.get_main_llm_prompt(
mode=mode,
llm=final_llm,
model=app_model_config.model_dict,
pre_prompt=app_model_config.pre_prompt,
query=query,
inputs=inputs,
@@ -129,6 +130,7 @@ class Completion:
cls.recale_llm_max_tokens(
final_llm=final_llm,
model=app_model_config.model_dict,
prompt=prompt,
mode=mode
)
@@ -138,7 +140,8 @@ class Completion:
return response
@classmethod
def get_main_llm_prompt(cls, mode: str, llm: BaseLanguageModel, pre_prompt: str, query: str, inputs: dict,
def get_main_llm_prompt(cls, mode: str, llm: BaseLanguageModel, model: dict,
pre_prompt: str, query: str, inputs: dict,
chain_output: Optional[str],
memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory]) -> \
Tuple[Union[str | List[BaseMessage]], Optional[List[str]]]:
@@ -151,10 +154,11 @@ class Completion:
if mode == 'completion':
prompt_template = JinjaPromptTemplate.from_template(
template=("""Use the following CONTEXT as your learned knowledge:
[CONTEXT]
template=("""Use the following context as your learned knowledge, inside <context></context> XML tags.
<context>
{{context}}
[END CONTEXT]
</context>
When answer to user:
- If you don't know, just say that you don't know.
@@ -204,10 +208,11 @@ And answer according to the language of the user's question.
if chain_output:
human_inputs['context'] = chain_output
human_message_prompt += """Use the following CONTEXT as your learned knowledge.
[CONTEXT]
human_message_prompt += """Use the following context as your learned knowledge, inside <context></context> XML tags.
<context>
{{context}}
[END CONTEXT]
</context>
When answer to user:
- If you don't know, just say that you don't know.
@@ -219,7 +224,7 @@ And answer according to the language of the user's question.
if pre_prompt:
human_message_prompt += pre_prompt
query_prompt = "\nHuman: {{query}}\nAI: "
query_prompt = "\n\nHuman: {{query}}\n\nAssistant: "
if memory:
# append chat histories
@@ -228,9 +233,11 @@ And answer according to the language of the user's question.
inputs=human_inputs
)
curr_message_tokens = memory.llm.get_messages_tokens([tmp_human_message])
rest_tokens = llm_constant.max_context_token_length[memory.llm.model_name] \
- memory.llm.max_tokens - curr_message_tokens
curr_message_tokens = memory.llm.get_num_tokens_from_messages([tmp_human_message])
model_name = model['name']
max_tokens = model.get("completion_params").get('max_tokens')
rest_tokens = llm_constant.max_context_token_length[model_name] \
- max_tokens - curr_message_tokens
rest_tokens = max(rest_tokens, 0)
histories = cls.get_history_messages_from_memory(memory, rest_tokens)
@@ -241,7 +248,10 @@ And answer according to the language of the user's question.
# if histories_param not in human_inputs:
# human_inputs[histories_param] = '{{' + histories_param + '}}'
human_message_prompt += "\n\n" + histories
human_message_prompt += "\n\n" if human_message_prompt else ""
human_message_prompt += "Here is the chat histories between human and assistant, " \
"inside <histories></histories> XML tags.\n\n<histories>"
human_message_prompt += histories + "</histories>"
human_message_prompt += query_prompt
@@ -307,13 +317,15 @@ And answer according to the language of the user's question.
model=app_model_config.model_dict
)
model_limited_tokens = llm_constant.max_context_token_length[llm.model_name]
max_tokens = llm.max_tokens
model_name = app_model_config.model_dict.get("name")
model_limited_tokens = llm_constant.max_context_token_length[model_name]
max_tokens = app_model_config.model_dict.get("completion_params").get('max_tokens')
# get prompt without memory and context
prompt, _ = cls.get_main_llm_prompt(
mode=mode,
llm=llm,
model=app_model_config.model_dict,
pre_prompt=app_model_config.pre_prompt,
query=query,
inputs=inputs,
@@ -332,16 +344,17 @@ And answer according to the language of the user's question.
return rest_tokens
@classmethod
def recale_llm_max_tokens(cls, final_llm: Union[StreamableOpenAI, StreamableChatOpenAI],
def recale_llm_max_tokens(cls, final_llm: BaseLanguageModel, model: dict,
prompt: Union[str, List[BaseMessage]], mode: str):
# recalc max_tokens if sum(prompt_token + max_tokens) over model token limit
model_limited_tokens = llm_constant.max_context_token_length[final_llm.model_name]
max_tokens = final_llm.max_tokens
model_name = model.get("name")
model_limited_tokens = llm_constant.max_context_token_length[model_name]
max_tokens = model.get("completion_params").get('max_tokens')
if mode == 'completion' and isinstance(final_llm, BaseLLM):
prompt_tokens = final_llm.get_num_tokens(prompt)
else:
prompt_tokens = final_llm.get_messages_tokens(prompt)
prompt_tokens = final_llm.get_num_tokens_from_messages(prompt)
if prompt_tokens + max_tokens > model_limited_tokens:
max_tokens = max(model_limited_tokens - prompt_tokens, 16)
@@ -350,9 +363,10 @@ And answer according to the language of the user's question.
@classmethod
def generate_more_like_this(cls, task_id: str, app: App, message: Message, pre_prompt: str,
app_model_config: AppModelConfig, user: Account, streaming: bool):
llm: StreamableOpenAI = LLMBuilder.to_llm(
llm = LLMBuilder.to_llm_from_model(
tenant_id=app.tenant_id,
model_name='gpt-3.5-turbo',
model=app_model_config.model_dict,
streaming=streaming
)
@@ -360,6 +374,7 @@ And answer according to the language of the user's question.
original_prompt, _ = cls.get_main_llm_prompt(
mode="completion",
llm=llm,
model=app_model_config.model_dict,
pre_prompt=pre_prompt,
query=message.query,
inputs=message.inputs,
@@ -390,6 +405,7 @@ And answer according to the language of the user's question.
cls.recale_llm_max_tokens(
final_llm=llm,
model=app_model_config.model_dict,
prompt=prompt,
mode='completion'
)