feat: add reasoning format processing to LLMNode for <think> tag handling (#23313)

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
taewoong Kim
2025-09-05 19:15:35 +09:00
committed by GitHub
parent 05cd7e2d8a
commit edf4a1b652
30 changed files with 366 additions and 5 deletions

View File

@@ -156,6 +156,7 @@ class LLMResult(BaseModel):
message: AssistantPromptMessage
usage: LLMUsage
system_fingerprint: Optional[str] = None
reasoning_content: Optional[str] = None
class LLMStructuredOutput(BaseModel):

View File

@@ -30,6 +30,7 @@ class ModelInvokeCompletedEvent(BaseModel):
text: str
usage: LLMUsage
finish_reason: str | None = None
reasoning_content: str | None = None
class RunRetryEvent(BaseModel):

View File

@@ -1,5 +1,5 @@
from collections.abc import Mapping, Sequence
from typing import Any, Optional
from typing import Any, Literal, Optional
from pydantic import BaseModel, Field, field_validator
@@ -68,6 +68,23 @@ class LLMNodeData(BaseNodeData):
structured_output: Mapping[str, Any] | None = None
# We used 'structured_output_enabled' in the past, but it's not a good name.
structured_output_switch_on: bool = Field(False, alias="structured_output_enabled")
reasoning_format: Literal["separated", "tagged"] = Field(
# Keep tagged as default for backward compatibility
default="tagged",
description=(
"""
Strategy for handling model reasoning output.
separated: Return clean text (without <think> tags) + reasoning_content field.
Recommended for new workflows. Enables safe downstream parsing and
workflow variable access: {{#node_id.reasoning_content#}}
tagged : Return original text (with <think> tags) + reasoning_content field.
Maintains full backward compatibility while still providing reasoning_content
for workflow automation. Frontend thinking panels work as before.
"""
),
)
@field_validator("prompt_config", mode="before")
@classmethod

View File

@@ -2,8 +2,9 @@ import base64
import io
import json
import logging
import re
from collections.abc import Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any, Optional, Union
from typing import TYPE_CHECKING, Any, Literal, Optional, Union
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.file import FileType, file_manager
@@ -99,6 +100,9 @@ class LLMNode(BaseNode):
_node_data: LLMNodeData
# Compiled regex for extracting <think> blocks (with compatibility for attributes)
_THINK_PATTERN = re.compile(r"<think[^>]*>(.*?)</think>", re.IGNORECASE | re.DOTALL)
# Instance attributes specific to LLMNode.
# Output variable for file
_file_outputs: list["File"]
@@ -167,6 +171,7 @@ class LLMNode(BaseNode):
result_text = ""
usage = LLMUsage.empty_usage()
finish_reason = None
reasoning_content = None
variable_pool = self.graph_runtime_state.variable_pool
try:
@@ -256,6 +261,7 @@ class LLMNode(BaseNode):
file_saver=self._llm_file_saver,
file_outputs=self._file_outputs,
node_id=self.node_id,
reasoning_format=self._node_data.reasoning_format,
)
structured_output: LLMStructuredOutput | None = None
@@ -264,9 +270,20 @@ class LLMNode(BaseNode):
if isinstance(event, RunStreamChunkEvent):
yield event
elif isinstance(event, ModelInvokeCompletedEvent):
# Raw text
result_text = event.text
usage = event.usage
finish_reason = event.finish_reason
reasoning_content = event.reasoning_content or ""
# For downstream nodes, determine clean text based on reasoning_format
if self._node_data.reasoning_format == "tagged":
# Keep <think> tags for backward compatibility
clean_text = result_text
else:
# Extract clean text from <think> tags
clean_text, _ = LLMNode._split_reasoning(result_text, self._node_data.reasoning_format)
# deduct quota
llm_utils.deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage)
break
@@ -284,7 +301,12 @@ class LLMNode(BaseNode):
"model_name": model_config.model,
}
outputs = {"text": result_text, "usage": jsonable_encoder(usage), "finish_reason": finish_reason}
outputs = {
"text": clean_text,
"reasoning_content": reasoning_content,
"usage": jsonable_encoder(usage),
"finish_reason": finish_reason,
}
if structured_output:
outputs["structured_output"] = structured_output.structured_output
if self._file_outputs is not None:
@@ -338,6 +360,7 @@ class LLMNode(BaseNode):
file_saver: LLMFileSaver,
file_outputs: list["File"],
node_id: str,
reasoning_format: Literal["separated", "tagged"] = "tagged",
) -> Generator[NodeEvent | LLMStructuredOutput, None, None]:
model_schema = model_instance.model_type_instance.get_model_schema(
node_data_model.name, model_instance.credentials
@@ -374,6 +397,7 @@ class LLMNode(BaseNode):
file_saver=file_saver,
file_outputs=file_outputs,
node_id=node_id,
reasoning_format=reasoning_format,
)
@staticmethod
@@ -383,6 +407,7 @@ class LLMNode(BaseNode):
file_saver: LLMFileSaver,
file_outputs: list["File"],
node_id: str,
reasoning_format: Literal["separated", "tagged"] = "tagged",
) -> Generator[NodeEvent | LLMStructuredOutput, None, None]:
# For blocking mode
if isinstance(invoke_result, LLMResult):
@@ -390,6 +415,7 @@ class LLMNode(BaseNode):
invoke_result=invoke_result,
saver=file_saver,
file_outputs=file_outputs,
reasoning_format=reasoning_format,
)
yield event
return
@@ -430,13 +456,66 @@ class LLMNode(BaseNode):
except OutputParserError as e:
raise LLMNodeError(f"Failed to parse structured output: {e}")
yield ModelInvokeCompletedEvent(text=full_text_buffer.getvalue(), usage=usage, finish_reason=finish_reason)
# Extract reasoning content from <think> tags in the main text
full_text = full_text_buffer.getvalue()
if reasoning_format == "tagged":
# Keep <think> tags in text for backward compatibility
clean_text = full_text
reasoning_content = ""
else:
# Extract clean text and reasoning from <think> tags
clean_text, reasoning_content = LLMNode._split_reasoning(full_text, reasoning_format)
yield ModelInvokeCompletedEvent(
# Use clean_text for separated mode, full_text for tagged mode
text=clean_text if reasoning_format == "separated" else full_text,
usage=usage,
finish_reason=finish_reason,
# Reasoning content for workflow variables and downstream nodes
reasoning_content=reasoning_content,
)
@staticmethod
def _image_file_to_markdown(file: "File", /):
text_chunk = f"![]({file.generate_url()})"
return text_chunk
@classmethod
def _split_reasoning(
cls, text: str, reasoning_format: Literal["separated", "tagged"] = "tagged"
) -> tuple[str, str]:
"""
Split reasoning content from text based on reasoning_format strategy.
Args:
text: Full text that may contain <think> blocks
reasoning_format: Strategy for handling reasoning content
- "separated": Remove <think> tags and return clean text + reasoning_content field
- "tagged": Keep <think> tags in text, return empty reasoning_content
Returns:
tuple of (clean_text, reasoning_content)
"""
if reasoning_format == "tagged":
return text, ""
# Find all <think>...</think> blocks (case-insensitive)
matches = cls._THINK_PATTERN.findall(text)
# Extract reasoning content from all <think> blocks
reasoning_content = "\n".join(match.strip() for match in matches) if matches else ""
# Remove all <think>...</think> blocks from original text
clean_text = cls._THINK_PATTERN.sub("", text)
# Clean up extra whitespace
clean_text = re.sub(r"\n\s*\n", "\n\n", clean_text).strip()
# Separated mode: always return clean text and reasoning_content
return clean_text, reasoning_content or ""
def _transform_chat_messages(
self, messages: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate, /
) -> Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate:
@@ -964,6 +1043,7 @@ class LLMNode(BaseNode):
invoke_result: LLMResult,
saver: LLMFileSaver,
file_outputs: list["File"],
reasoning_format: Literal["separated", "tagged"] = "tagged",
) -> ModelInvokeCompletedEvent:
buffer = io.StringIO()
for text_part in LLMNode._save_multimodal_output_and_convert_result_to_markdown(
@@ -973,10 +1053,24 @@ class LLMNode(BaseNode):
):
buffer.write(text_part)
# Extract reasoning content from <think> tags in the main text
full_text = buffer.getvalue()
if reasoning_format == "tagged":
# Keep <think> tags in text for backward compatibility
clean_text = full_text
reasoning_content = ""
else:
# Extract clean text and reasoning from <think> tags
clean_text, reasoning_content = LLMNode._split_reasoning(full_text, reasoning_format)
return ModelInvokeCompletedEvent(
text=buffer.getvalue(),
# Use clean_text for separated mode, full_text for tagged mode
text=clean_text if reasoning_format == "separated" else full_text,
usage=invoke_result.usage,
finish_reason=None,
# Reasoning content for workflow variables and downstream nodes
reasoning_content=reasoning_content,
)
@staticmethod

View File

@@ -69,6 +69,7 @@ def llm_node_data() -> LLMNodeData:
detail=ImagePromptMessageContent.DETAIL.HIGH,
),
),
reasoning_format="tagged",
)
@@ -689,3 +690,66 @@ class TestSaveMultimodalOutputAndConvertResultToMarkdown:
assert list(gen) == []
mock_file_saver.save_binary_string.assert_not_called()
mock_file_saver.save_remote_url.assert_not_called()
class TestReasoningFormat:
"""Test cases for reasoning_format functionality"""
def test_split_reasoning_separated_mode(self):
"""Test separated mode: tags are removed and content is extracted"""
text_with_think = """
<think>I need to explain what Dify is. It's an open source AI platform.
</think>Dify is an open source AI platform.
"""
clean_text, reasoning_content = LLMNode._split_reasoning(text_with_think, "separated")
assert clean_text == "Dify is an open source AI platform."
assert reasoning_content == "I need to explain what Dify is. It's an open source AI platform."
def test_split_reasoning_tagged_mode(self):
"""Test tagged mode: original text is preserved"""
text_with_think = """
<think>I need to explain what Dify is. It's an open source AI platform.
</think>Dify is an open source AI platform.
"""
clean_text, reasoning_content = LLMNode._split_reasoning(text_with_think, "tagged")
# Original text unchanged
assert clean_text == text_with_think
# Empty reasoning content in tagged mode
assert reasoning_content == ""
def test_split_reasoning_no_think_blocks(self):
"""Test behavior when no <think> tags are present"""
text_without_think = "This is a simple answer without any thinking blocks."
clean_text, reasoning_content = LLMNode._split_reasoning(text_without_think, "separated")
assert clean_text == text_without_think
assert reasoning_content == ""
def test_reasoning_format_default_value(self):
"""Test that reasoning_format defaults to 'tagged' for backward compatibility"""
node_data = LLMNodeData(
title="Test LLM",
model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode="chat", completion_params={}),
prompt_template=[],
context=ContextConfig(enabled=False),
)
assert node_data.reasoning_format == "tagged"
text_with_think = """
<think>I need to explain what Dify is. It's an open source AI platform.
</think>Dify is an open source AI platform.
"""
clean_text, reasoning_content = LLMNode._split_reasoning(text_with_think, node_data.reasoning_format)
assert clean_text == text_with_think
assert reasoning_content == ""