chore(api/core): apply ruff reformatting (#7624)

This commit is contained in:
Bowen Liang
2024-09-10 17:00:20 +08:00
committed by GitHub
parent 178730266d
commit 2cf1187b32
724 changed files with 21180 additions and 21123 deletions

View File

@@ -16,31 +16,32 @@ _TEXT_COLOR_MAPPING = {
"red": "31;1",
}
def get_colored_text(text: str, color: str) -> str:
"""Get colored text."""
color_str = _TEXT_COLOR_MAPPING[color]
return f"\u001b[{color_str}m\033[1;3m{text}\u001b[0m"
def print_text(
text: str, color: Optional[str] = None, end: str = "", file: Optional[TextIO] = None
) -> None:
def print_text(text: str, color: Optional[str] = None, end: str = "", file: Optional[TextIO] = None) -> None:
"""Print text with highlighting and no end characters."""
text_to_print = get_colored_text(text, color) if color else text
print(text_to_print, end=end, file=file)
if file:
file.flush() # ensure all printed content are written to file
class DifyAgentCallbackHandler(BaseModel):
"""Callback Handler that prints to std out."""
color: Optional[str] = ''
color: Optional[str] = ""
current_loop: int = 1
def __init__(self, color: Optional[str] = None) -> None:
super().__init__()
"""Initialize callback handler."""
# use a specific color is not specified
self.color = color or 'green'
self.color = color or "green"
self.current_loop = 1
def on_tool_start(
@@ -58,7 +59,7 @@ class DifyAgentCallbackHandler(BaseModel):
tool_outputs: Sequence[ToolInvokeMessage],
message_id: Optional[str] = None,
timer: Optional[Any] = None,
trace_manager: Optional[TraceQueueManager] = None
trace_manager: Optional[TraceQueueManager] = None,
) -> None:
"""If not the final action, print out observation."""
print_text("\n[on_tool_end]\n", color=self.color)
@@ -79,26 +80,21 @@ class DifyAgentCallbackHandler(BaseModel):
)
)
def on_tool_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
def on_tool_error(self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any) -> None:
"""Do nothing."""
print_text("\n[on_tool_error] Error: " + str(error) + "\n", color='red')
print_text("\n[on_tool_error] Error: " + str(error) + "\n", color="red")
def on_agent_start(
self, thought: str
) -> None:
def on_agent_start(self, thought: str) -> None:
"""Run on agent start."""
if thought:
print_text("\n[on_agent_start] \nCurrent Loop: " + \
str(self.current_loop) + \
"\nThought: " + thought + "\n", color=self.color)
print_text(
"\n[on_agent_start] \nCurrent Loop: " + str(self.current_loop) + "\nThought: " + thought + "\n",
color=self.color,
)
else:
print_text("\n[on_agent_start] \nCurrent Loop: " + str(self.current_loop) + "\n", color=self.color)
def on_agent_finish(
self, color: Optional[str] = None, **kwargs: Any
) -> None:
def on_agent_finish(self, color: Optional[str] = None, **kwargs: Any) -> None:
"""Run on agent end."""
print_text("\n[on_agent_finish]\n Loop: " + str(self.current_loop) + "\n", color=self.color)
@@ -107,9 +103,9 @@ class DifyAgentCallbackHandler(BaseModel):
@property
def ignore_agent(self) -> bool:
"""Whether to ignore agent callbacks."""
return not os.environ.get("DEBUG") or os.environ.get("DEBUG").lower() != 'true'
return not os.environ.get("DEBUG") or os.environ.get("DEBUG").lower() != "true"
@property
def ignore_chat_model(self) -> bool:
"""Whether to ignore chat model callbacks."""
return not os.environ.get("DEBUG") or os.environ.get("DEBUG").lower() != 'true'
return not os.environ.get("DEBUG") or os.environ.get("DEBUG").lower() != "true"

View File

@@ -1,4 +1,3 @@
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.entities.queue_entities import QueueRetrieverResourcesEvent
@@ -11,11 +10,9 @@ from models.model import DatasetRetrieverResource
class DatasetIndexToolCallbackHandler:
"""Callback handler for dataset tool."""
def __init__(self, queue_manager: AppQueueManager,
app_id: str,
message_id: str,
user_id: str,
invoke_from: InvokeFrom) -> None:
def __init__(
self, queue_manager: AppQueueManager, app_id: str, message_id: str, user_id: str, invoke_from: InvokeFrom
) -> None:
self._queue_manager = queue_manager
self._app_id = app_id
self._message_id = message_id
@@ -29,11 +26,12 @@ class DatasetIndexToolCallbackHandler:
dataset_query = DatasetQuery(
dataset_id=dataset_id,
content=query,
source='app',
source="app",
source_app_id=self._app_id,
created_by_role=('account'
if self._invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else 'end_user'),
created_by=self._user_id
created_by_role=(
"account" if self._invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else "end_user"
),
created_by=self._user_id,
)
db.session.add(dataset_query)
@@ -43,18 +41,15 @@ class DatasetIndexToolCallbackHandler:
"""Handle tool end."""
for document in documents:
query = db.session.query(DocumentSegment).filter(
DocumentSegment.index_node_id == document.metadata['doc_id']
DocumentSegment.index_node_id == document.metadata["doc_id"]
)
# if 'dataset_id' in document.metadata:
if 'dataset_id' in document.metadata:
query = query.filter(DocumentSegment.dataset_id == document.metadata['dataset_id'])
if "dataset_id" in document.metadata:
query = query.filter(DocumentSegment.dataset_id == document.metadata["dataset_id"])
# add hit count to document segment
query.update(
{DocumentSegment.hit_count: DocumentSegment.hit_count + 1},
synchronize_session=False
)
query.update({DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, synchronize_session=False)
db.session.commit()
@@ -64,26 +59,25 @@ class DatasetIndexToolCallbackHandler:
for item in resource:
dataset_retriever_resource = DatasetRetrieverResource(
message_id=self._message_id,
position=item.get('position'),
dataset_id=item.get('dataset_id'),
dataset_name=item.get('dataset_name'),
document_id=item.get('document_id'),
document_name=item.get('document_name'),
data_source_type=item.get('data_source_type'),
segment_id=item.get('segment_id'),
score=item.get('score') if 'score' in item else None,
hit_count=item.get('hit_count') if 'hit_count' else None,
word_count=item.get('word_count') if 'word_count' in item else None,
segment_position=item.get('segment_position') if 'segment_position' in item else None,
index_node_hash=item.get('index_node_hash') if 'index_node_hash' in item else None,
content=item.get('content'),
retriever_from=item.get('retriever_from'),
created_by=self._user_id
position=item.get("position"),
dataset_id=item.get("dataset_id"),
dataset_name=item.get("dataset_name"),
document_id=item.get("document_id"),
document_name=item.get("document_name"),
data_source_type=item.get("data_source_type"),
segment_id=item.get("segment_id"),
score=item.get("score") if "score" in item else None,
hit_count=item.get("hit_count") if "hit_count" else None,
word_count=item.get("word_count") if "word_count" in item else None,
segment_position=item.get("segment_position") if "segment_position" in item else None,
index_node_hash=item.get("index_node_hash") if "index_node_hash" in item else None,
content=item.get("content"),
retriever_from=item.get("retriever_from"),
created_by=self._user_id,
)
db.session.add(dataset_retriever_resource)
db.session.commit()
self._queue_manager.publish(
QueueRetrieverResourcesEvent(retriever_resources=resource),
PublishFrom.APPLICATION_MANAGER
QueueRetrieverResourcesEvent(retriever_resources=resource), PublishFrom.APPLICATION_MANAGER
)

View File

@@ -2,4 +2,4 @@ from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackH
class DifyWorkflowCallbackHandler(DifyAgentCallbackHandler):
"""Callback Handler that prints to std out."""
"""Callback Handler that prints to std out."""