Chore: optimize the code of PromptTransform (#16143)
This commit is contained in:
@@ -93,7 +93,7 @@ class SimplePromptTransform(PromptTransform):
|
|||||||
|
|
||||||
return prompt_messages, stops
|
return prompt_messages, stops
|
||||||
|
|
||||||
def get_prompt_str_and_rules(
|
def _get_prompt_str_and_rules(
|
||||||
self,
|
self,
|
||||||
app_mode: AppMode,
|
app_mode: AppMode,
|
||||||
model_config: ModelConfigWithCredentialsEntity,
|
model_config: ModelConfigWithCredentialsEntity,
|
||||||
@@ -184,7 +184,7 @@ class SimplePromptTransform(PromptTransform):
|
|||||||
prompt_messages: list[PromptMessage] = []
|
prompt_messages: list[PromptMessage] = []
|
||||||
|
|
||||||
# get prompt
|
# get prompt
|
||||||
prompt, _ = self.get_prompt_str_and_rules(
|
prompt, _ = self._get_prompt_str_and_rules(
|
||||||
app_mode=app_mode,
|
app_mode=app_mode,
|
||||||
model_config=model_config,
|
model_config=model_config,
|
||||||
pre_prompt=pre_prompt,
|
pre_prompt=pre_prompt,
|
||||||
@@ -209,9 +209,9 @@ class SimplePromptTransform(PromptTransform):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if query:
|
if query:
|
||||||
prompt_messages.append(self.get_last_user_message(query, files, image_detail_config))
|
prompt_messages.append(self._get_last_user_message(query, files, image_detail_config))
|
||||||
else:
|
else:
|
||||||
prompt_messages.append(self.get_last_user_message(prompt, files, image_detail_config))
|
prompt_messages.append(self._get_last_user_message(prompt, files, image_detail_config))
|
||||||
|
|
||||||
return prompt_messages, None
|
return prompt_messages, None
|
||||||
|
|
||||||
@@ -228,7 +228,7 @@ class SimplePromptTransform(PromptTransform):
|
|||||||
image_detail_config: Optional[ImagePromptMessageContent.DETAIL] = None,
|
image_detail_config: Optional[ImagePromptMessageContent.DETAIL] = None,
|
||||||
) -> tuple[list[PromptMessage], Optional[list[str]]]:
|
) -> tuple[list[PromptMessage], Optional[list[str]]]:
|
||||||
# get prompt
|
# get prompt
|
||||||
prompt, prompt_rules = self.get_prompt_str_and_rules(
|
prompt, prompt_rules = self._get_prompt_str_and_rules(
|
||||||
app_mode=app_mode,
|
app_mode=app_mode,
|
||||||
model_config=model_config,
|
model_config=model_config,
|
||||||
pre_prompt=pre_prompt,
|
pre_prompt=pre_prompt,
|
||||||
@@ -254,7 +254,7 @@ class SimplePromptTransform(PromptTransform):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# get prompt
|
# get prompt
|
||||||
prompt, prompt_rules = self.get_prompt_str_and_rules(
|
prompt, prompt_rules = self._get_prompt_str_and_rules(
|
||||||
app_mode=app_mode,
|
app_mode=app_mode,
|
||||||
model_config=model_config,
|
model_config=model_config,
|
||||||
pre_prompt=pre_prompt,
|
pre_prompt=pre_prompt,
|
||||||
@@ -268,9 +268,9 @@ class SimplePromptTransform(PromptTransform):
|
|||||||
if stops is not None and len(stops) == 0:
|
if stops is not None and len(stops) == 0:
|
||||||
stops = None
|
stops = None
|
||||||
|
|
||||||
return [self.get_last_user_message(prompt, files, image_detail_config)], stops
|
return [self._get_last_user_message(prompt, files, image_detail_config)], stops
|
||||||
|
|
||||||
def get_last_user_message(
|
def _get_last_user_message(
|
||||||
self,
|
self,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
files: Sequence["File"],
|
files: Sequence["File"],
|
||||||
|
@@ -64,12 +64,10 @@ def test_get_prompt():
|
|||||||
transform._calculate_rest_token = MagicMock(return_value=max_token_limit)
|
transform._calculate_rest_token = MagicMock(return_value=max_token_limit)
|
||||||
result = transform.get_prompt()
|
result = transform.get_prompt()
|
||||||
|
|
||||||
assert len(result) <= max_token_limit
|
|
||||||
assert len(result) == 4
|
assert len(result) == 4
|
||||||
|
|
||||||
max_token_limit = 20
|
max_token_limit = 20
|
||||||
transform._calculate_rest_token = MagicMock(return_value=max_token_limit)
|
transform._calculate_rest_token = MagicMock(return_value=max_token_limit)
|
||||||
result = transform.get_prompt()
|
result = transform.get_prompt()
|
||||||
|
|
||||||
assert len(result) <= max_token_limit
|
|
||||||
assert len(result) == 12
|
assert len(result) == 12
|
||||||
|
@@ -84,7 +84,6 @@ def test_get_baichuan_completion_app_prompt_template_with_pcq():
|
|||||||
query_in_prompt=True,
|
query_in_prompt=True,
|
||||||
with_memory_prompt=False,
|
with_memory_prompt=False,
|
||||||
)
|
)
|
||||||
print(prompt_template["prompt_template"].template)
|
|
||||||
prompt_rules = prompt_template["prompt_rules"]
|
prompt_rules = prompt_template["prompt_rules"]
|
||||||
assert prompt_template["prompt_template"].template == (
|
assert prompt_template["prompt_template"].template == (
|
||||||
prompt_rules["context_prompt"] + pre_prompt + "\n" + prompt_rules["query_prompt"]
|
prompt_rules["context_prompt"] + pre_prompt + "\n" + prompt_rules["query_prompt"]
|
||||||
|
Reference in New Issue
Block a user