feat: add reasoning format processing to LLMNode for <think> tag handling (#23313)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -156,6 +156,7 @@ class LLMResult(BaseModel):
|
||||
message: AssistantPromptMessage
|
||||
usage: LLMUsage
|
||||
system_fingerprint: Optional[str] = None
|
||||
reasoning_content: Optional[str] = None
|
||||
|
||||
|
||||
class LLMStructuredOutput(BaseModel):
|
||||
|
@@ -30,6 +30,7 @@ class ModelInvokeCompletedEvent(BaseModel):
|
||||
text: str
|
||||
usage: LLMUsage
|
||||
finish_reason: str | None = None
|
||||
reasoning_content: str | None = None
|
||||
|
||||
|
||||
class RunRetryEvent(BaseModel):
|
||||
|
@@ -1,5 +1,5 @@
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, Optional
|
||||
from typing import Any, Literal, Optional
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
@@ -68,6 +68,23 @@ class LLMNodeData(BaseNodeData):
|
||||
structured_output: Mapping[str, Any] | None = None
|
||||
# We used 'structured_output_enabled' in the past, but it's not a good name.
|
||||
structured_output_switch_on: bool = Field(False, alias="structured_output_enabled")
|
||||
reasoning_format: Literal["separated", "tagged"] = Field(
|
||||
# Keep tagged as default for backward compatibility
|
||||
default="tagged",
|
||||
description=(
|
||||
"""
|
||||
Strategy for handling model reasoning output.
|
||||
|
||||
separated: Return clean text (without <think> tags) + reasoning_content field.
|
||||
Recommended for new workflows. Enables safe downstream parsing and
|
||||
workflow variable access: {{#node_id.reasoning_content#}}
|
||||
|
||||
tagged : Return original text (with <think> tags) + reasoning_content field.
|
||||
Maintains full backward compatibility while still providing reasoning_content
|
||||
for workflow automation. Frontend thinking panels work as before.
|
||||
"""
|
||||
),
|
||||
)
|
||||
|
||||
@field_validator("prompt_config", mode="before")
|
||||
@classmethod
|
||||
|
@@ -2,8 +2,9 @@ import base64
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
from typing import TYPE_CHECKING, Any, Literal, Optional, Union
|
||||
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
from core.file import FileType, file_manager
|
||||
@@ -99,6 +100,9 @@ class LLMNode(BaseNode):
|
||||
|
||||
_node_data: LLMNodeData
|
||||
|
||||
# Compiled regex for extracting <think> blocks (with compatibility for attributes)
|
||||
_THINK_PATTERN = re.compile(r"<think[^>]*>(.*?)</think>", re.IGNORECASE | re.DOTALL)
|
||||
|
||||
# Instance attributes specific to LLMNode.
|
||||
# Output variable for file
|
||||
_file_outputs: list["File"]
|
||||
@@ -167,6 +171,7 @@ class LLMNode(BaseNode):
|
||||
result_text = ""
|
||||
usage = LLMUsage.empty_usage()
|
||||
finish_reason = None
|
||||
reasoning_content = None
|
||||
variable_pool = self.graph_runtime_state.variable_pool
|
||||
|
||||
try:
|
||||
@@ -256,6 +261,7 @@ class LLMNode(BaseNode):
|
||||
file_saver=self._llm_file_saver,
|
||||
file_outputs=self._file_outputs,
|
||||
node_id=self.node_id,
|
||||
reasoning_format=self._node_data.reasoning_format,
|
||||
)
|
||||
|
||||
structured_output: LLMStructuredOutput | None = None
|
||||
@@ -264,9 +270,20 @@ class LLMNode(BaseNode):
|
||||
if isinstance(event, RunStreamChunkEvent):
|
||||
yield event
|
||||
elif isinstance(event, ModelInvokeCompletedEvent):
|
||||
# Raw text
|
||||
result_text = event.text
|
||||
usage = event.usage
|
||||
finish_reason = event.finish_reason
|
||||
reasoning_content = event.reasoning_content or ""
|
||||
|
||||
# For downstream nodes, determine clean text based on reasoning_format
|
||||
if self._node_data.reasoning_format == "tagged":
|
||||
# Keep <think> tags for backward compatibility
|
||||
clean_text = result_text
|
||||
else:
|
||||
# Extract clean text from <think> tags
|
||||
clean_text, _ = LLMNode._split_reasoning(result_text, self._node_data.reasoning_format)
|
||||
|
||||
# deduct quota
|
||||
llm_utils.deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage)
|
||||
break
|
||||
@@ -284,7 +301,12 @@ class LLMNode(BaseNode):
|
||||
"model_name": model_config.model,
|
||||
}
|
||||
|
||||
outputs = {"text": result_text, "usage": jsonable_encoder(usage), "finish_reason": finish_reason}
|
||||
outputs = {
|
||||
"text": clean_text,
|
||||
"reasoning_content": reasoning_content,
|
||||
"usage": jsonable_encoder(usage),
|
||||
"finish_reason": finish_reason,
|
||||
}
|
||||
if structured_output:
|
||||
outputs["structured_output"] = structured_output.structured_output
|
||||
if self._file_outputs is not None:
|
||||
@@ -338,6 +360,7 @@ class LLMNode(BaseNode):
|
||||
file_saver: LLMFileSaver,
|
||||
file_outputs: list["File"],
|
||||
node_id: str,
|
||||
reasoning_format: Literal["separated", "tagged"] = "tagged",
|
||||
) -> Generator[NodeEvent | LLMStructuredOutput, None, None]:
|
||||
model_schema = model_instance.model_type_instance.get_model_schema(
|
||||
node_data_model.name, model_instance.credentials
|
||||
@@ -374,6 +397,7 @@ class LLMNode(BaseNode):
|
||||
file_saver=file_saver,
|
||||
file_outputs=file_outputs,
|
||||
node_id=node_id,
|
||||
reasoning_format=reasoning_format,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@@ -383,6 +407,7 @@ class LLMNode(BaseNode):
|
||||
file_saver: LLMFileSaver,
|
||||
file_outputs: list["File"],
|
||||
node_id: str,
|
||||
reasoning_format: Literal["separated", "tagged"] = "tagged",
|
||||
) -> Generator[NodeEvent | LLMStructuredOutput, None, None]:
|
||||
# For blocking mode
|
||||
if isinstance(invoke_result, LLMResult):
|
||||
@@ -390,6 +415,7 @@ class LLMNode(BaseNode):
|
||||
invoke_result=invoke_result,
|
||||
saver=file_saver,
|
||||
file_outputs=file_outputs,
|
||||
reasoning_format=reasoning_format,
|
||||
)
|
||||
yield event
|
||||
return
|
||||
@@ -430,13 +456,66 @@ class LLMNode(BaseNode):
|
||||
except OutputParserError as e:
|
||||
raise LLMNodeError(f"Failed to parse structured output: {e}")
|
||||
|
||||
yield ModelInvokeCompletedEvent(text=full_text_buffer.getvalue(), usage=usage, finish_reason=finish_reason)
|
||||
# Extract reasoning content from <think> tags in the main text
|
||||
full_text = full_text_buffer.getvalue()
|
||||
|
||||
if reasoning_format == "tagged":
|
||||
# Keep <think> tags in text for backward compatibility
|
||||
clean_text = full_text
|
||||
reasoning_content = ""
|
||||
else:
|
||||
# Extract clean text and reasoning from <think> tags
|
||||
clean_text, reasoning_content = LLMNode._split_reasoning(full_text, reasoning_format)
|
||||
|
||||
yield ModelInvokeCompletedEvent(
|
||||
# Use clean_text for separated mode, full_text for tagged mode
|
||||
text=clean_text if reasoning_format == "separated" else full_text,
|
||||
usage=usage,
|
||||
finish_reason=finish_reason,
|
||||
# Reasoning content for workflow variables and downstream nodes
|
||||
reasoning_content=reasoning_content,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _image_file_to_markdown(file: "File", /):
|
||||
text_chunk = f"})"
|
||||
return text_chunk
|
||||
|
||||
@classmethod
|
||||
def _split_reasoning(
|
||||
cls, text: str, reasoning_format: Literal["separated", "tagged"] = "tagged"
|
||||
) -> tuple[str, str]:
|
||||
"""
|
||||
Split reasoning content from text based on reasoning_format strategy.
|
||||
|
||||
Args:
|
||||
text: Full text that may contain <think> blocks
|
||||
reasoning_format: Strategy for handling reasoning content
|
||||
- "separated": Remove <think> tags and return clean text + reasoning_content field
|
||||
- "tagged": Keep <think> tags in text, return empty reasoning_content
|
||||
|
||||
Returns:
|
||||
tuple of (clean_text, reasoning_content)
|
||||
"""
|
||||
|
||||
if reasoning_format == "tagged":
|
||||
return text, ""
|
||||
|
||||
# Find all <think>...</think> blocks (case-insensitive)
|
||||
matches = cls._THINK_PATTERN.findall(text)
|
||||
|
||||
# Extract reasoning content from all <think> blocks
|
||||
reasoning_content = "\n".join(match.strip() for match in matches) if matches else ""
|
||||
|
||||
# Remove all <think>...</think> blocks from original text
|
||||
clean_text = cls._THINK_PATTERN.sub("", text)
|
||||
|
||||
# Clean up extra whitespace
|
||||
clean_text = re.sub(r"\n\s*\n", "\n\n", clean_text).strip()
|
||||
|
||||
# Separated mode: always return clean text and reasoning_content
|
||||
return clean_text, reasoning_content or ""
|
||||
|
||||
def _transform_chat_messages(
|
||||
self, messages: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate, /
|
||||
) -> Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate:
|
||||
@@ -964,6 +1043,7 @@ class LLMNode(BaseNode):
|
||||
invoke_result: LLMResult,
|
||||
saver: LLMFileSaver,
|
||||
file_outputs: list["File"],
|
||||
reasoning_format: Literal["separated", "tagged"] = "tagged",
|
||||
) -> ModelInvokeCompletedEvent:
|
||||
buffer = io.StringIO()
|
||||
for text_part in LLMNode._save_multimodal_output_and_convert_result_to_markdown(
|
||||
@@ -973,10 +1053,24 @@ class LLMNode(BaseNode):
|
||||
):
|
||||
buffer.write(text_part)
|
||||
|
||||
# Extract reasoning content from <think> tags in the main text
|
||||
full_text = buffer.getvalue()
|
||||
|
||||
if reasoning_format == "tagged":
|
||||
# Keep <think> tags in text for backward compatibility
|
||||
clean_text = full_text
|
||||
reasoning_content = ""
|
||||
else:
|
||||
# Extract clean text and reasoning from <think> tags
|
||||
clean_text, reasoning_content = LLMNode._split_reasoning(full_text, reasoning_format)
|
||||
|
||||
return ModelInvokeCompletedEvent(
|
||||
text=buffer.getvalue(),
|
||||
# Use clean_text for separated mode, full_text for tagged mode
|
||||
text=clean_text if reasoning_format == "separated" else full_text,
|
||||
usage=invoke_result.usage,
|
||||
finish_reason=None,
|
||||
# Reasoning content for workflow variables and downstream nodes
|
||||
reasoning_content=reasoning_content,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
|
@@ -69,6 +69,7 @@ def llm_node_data() -> LLMNodeData:
|
||||
detail=ImagePromptMessageContent.DETAIL.HIGH,
|
||||
),
|
||||
),
|
||||
reasoning_format="tagged",
|
||||
)
|
||||
|
||||
|
||||
@@ -689,3 +690,66 @@ class TestSaveMultimodalOutputAndConvertResultToMarkdown:
|
||||
assert list(gen) == []
|
||||
mock_file_saver.save_binary_string.assert_not_called()
|
||||
mock_file_saver.save_remote_url.assert_not_called()
|
||||
|
||||
|
||||
class TestReasoningFormat:
|
||||
"""Test cases for reasoning_format functionality"""
|
||||
|
||||
def test_split_reasoning_separated_mode(self):
|
||||
"""Test separated mode: tags are removed and content is extracted"""
|
||||
|
||||
text_with_think = """
|
||||
<think>I need to explain what Dify is. It's an open source AI platform.
|
||||
</think>Dify is an open source AI platform.
|
||||
"""
|
||||
|
||||
clean_text, reasoning_content = LLMNode._split_reasoning(text_with_think, "separated")
|
||||
|
||||
assert clean_text == "Dify is an open source AI platform."
|
||||
assert reasoning_content == "I need to explain what Dify is. It's an open source AI platform."
|
||||
|
||||
def test_split_reasoning_tagged_mode(self):
|
||||
"""Test tagged mode: original text is preserved"""
|
||||
|
||||
text_with_think = """
|
||||
<think>I need to explain what Dify is. It's an open source AI platform.
|
||||
</think>Dify is an open source AI platform.
|
||||
"""
|
||||
|
||||
clean_text, reasoning_content = LLMNode._split_reasoning(text_with_think, "tagged")
|
||||
|
||||
# Original text unchanged
|
||||
assert clean_text == text_with_think
|
||||
# Empty reasoning content in tagged mode
|
||||
assert reasoning_content == ""
|
||||
|
||||
def test_split_reasoning_no_think_blocks(self):
|
||||
"""Test behavior when no <think> tags are present"""
|
||||
|
||||
text_without_think = "This is a simple answer without any thinking blocks."
|
||||
|
||||
clean_text, reasoning_content = LLMNode._split_reasoning(text_without_think, "separated")
|
||||
|
||||
assert clean_text == text_without_think
|
||||
assert reasoning_content == ""
|
||||
|
||||
def test_reasoning_format_default_value(self):
|
||||
"""Test that reasoning_format defaults to 'tagged' for backward compatibility"""
|
||||
|
||||
node_data = LLMNodeData(
|
||||
title="Test LLM",
|
||||
model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode="chat", completion_params={}),
|
||||
prompt_template=[],
|
||||
context=ContextConfig(enabled=False),
|
||||
)
|
||||
|
||||
assert node_data.reasoning_format == "tagged"
|
||||
|
||||
text_with_think = """
|
||||
<think>I need to explain what Dify is. It's an open source AI platform.
|
||||
</think>Dify is an open source AI platform.
|
||||
"""
|
||||
clean_text, reasoning_content = LLMNode._split_reasoning(text_with_think, node_data.reasoning_format)
|
||||
|
||||
assert clean_text == text_with_think
|
||||
assert reasoning_content == ""
|
||||
|
Reference in New Issue
Block a user