feat: mypy for all type check (#10921)

This commit is contained in:
yihong
2024-12-24 18:38:51 +08:00
committed by GitHub
parent c91e8b1737
commit 56e15d09a9
584 changed files with 3975 additions and 2826 deletions

View File

@@ -1,3 +1,4 @@
from collections.abc import Mapping
from datetime import datetime
from enum import StrEnum
from typing import Any, Optional, Union
@@ -38,8 +39,8 @@ class WorkflowTraceInfo(BaseTraceInfo):
workflow_run_id: str
workflow_run_elapsed_time: Union[int, float]
workflow_run_status: str
workflow_run_inputs: dict[str, Any]
workflow_run_outputs: dict[str, Any]
workflow_run_inputs: Mapping[str, Any]
workflow_run_outputs: Mapping[str, Any]
workflow_run_version: str
error: Optional[str] = None
total_tokens: int

View File

@@ -77,8 +77,8 @@ class LangFuseDataTrace(BaseTraceInstance):
id=trace_id,
user_id=user_id,
name=name,
input=trace_info.workflow_run_inputs,
output=trace_info.workflow_run_outputs,
input=dict(trace_info.workflow_run_inputs),
output=dict(trace_info.workflow_run_outputs),
metadata=metadata,
session_id=trace_info.conversation_id,
tags=["message", "workflow"],
@@ -87,8 +87,8 @@ class LangFuseDataTrace(BaseTraceInstance):
workflow_span_data = LangfuseSpan(
id=trace_info.workflow_run_id,
name=TraceTaskName.WORKFLOW_TRACE.value,
input=trace_info.workflow_run_inputs,
output=trace_info.workflow_run_outputs,
input=dict(trace_info.workflow_run_inputs),
output=dict(trace_info.workflow_run_outputs),
trace_id=trace_id,
start_time=trace_info.start_time,
end_time=trace_info.end_time,
@@ -102,8 +102,8 @@ class LangFuseDataTrace(BaseTraceInstance):
id=trace_id,
user_id=user_id,
name=TraceTaskName.WORKFLOW_TRACE.value,
input=trace_info.workflow_run_inputs,
output=trace_info.workflow_run_outputs,
input=dict(trace_info.workflow_run_inputs),
output=dict(trace_info.workflow_run_outputs),
metadata=metadata,
session_id=trace_info.conversation_id,
tags=["workflow"],

View File

@@ -49,7 +49,6 @@ class LangSmithRunModel(LangSmithTokenUsage, LangSmithMultiModel):
reference_example_id: Optional[str] = Field(None, description="Reference example ID associated with the run")
input_attachments: Optional[dict[str, Any]] = Field(None, description="Input attachments of the run")
output_attachments: Optional[dict[str, Any]] = Field(None, description="Output attachments of the run")
dotted_order: Optional[str] = Field(None, description="Dotted order of the run")
@field_validator("inputs", "outputs")
@classmethod

View File

@@ -3,6 +3,7 @@ import logging
import os
import uuid
from datetime import datetime, timedelta
from typing import Optional, cast
from langsmith import Client
from langsmith.schemas import RunBase
@@ -63,6 +64,8 @@ class LangSmithDataTrace(BaseTraceInstance):
def workflow_trace(self, trace_info: WorkflowTraceInfo):
trace_id = trace_info.message_id or trace_info.workflow_run_id
if trace_info.start_time is None:
trace_info.start_time = datetime.now()
message_dotted_order = (
generate_dotted_order(trace_info.message_id, trace_info.start_time) if trace_info.message_id else None
)
@@ -78,8 +81,8 @@ class LangSmithDataTrace(BaseTraceInstance):
message_run = LangSmithRunModel(
id=trace_info.message_id,
name=TraceTaskName.MESSAGE_TRACE.value,
inputs=trace_info.workflow_run_inputs,
outputs=trace_info.workflow_run_outputs,
inputs=dict(trace_info.workflow_run_inputs),
outputs=dict(trace_info.workflow_run_outputs),
run_type=LangSmithRunType.chain,
start_time=trace_info.start_time,
end_time=trace_info.end_time,
@@ -90,6 +93,15 @@ class LangSmithDataTrace(BaseTraceInstance):
error=trace_info.error,
trace_id=trace_id,
dotted_order=message_dotted_order,
file_list=[],
serialized=None,
parent_run_id=None,
events=[],
session_id=None,
session_name=None,
reference_example_id=None,
input_attachments={},
output_attachments={},
)
self.add_run(message_run)
@@ -98,11 +110,11 @@ class LangSmithDataTrace(BaseTraceInstance):
total_tokens=trace_info.total_tokens,
id=trace_info.workflow_run_id,
name=TraceTaskName.WORKFLOW_TRACE.value,
inputs=trace_info.workflow_run_inputs,
inputs=dict(trace_info.workflow_run_inputs),
run_type=LangSmithRunType.tool,
start_time=trace_info.workflow_data.created_at,
end_time=trace_info.workflow_data.finished_at,
outputs=trace_info.workflow_run_outputs,
outputs=dict(trace_info.workflow_run_outputs),
extra={
"metadata": metadata,
},
@@ -111,6 +123,13 @@ class LangSmithDataTrace(BaseTraceInstance):
parent_run_id=trace_info.message_id or None,
trace_id=trace_id,
dotted_order=workflow_dotted_order,
serialized=None,
events=[],
session_id=None,
session_name=None,
reference_example_id=None,
input_attachments={},
output_attachments={},
)
self.add_run(langsmith_run)
@@ -211,25 +230,35 @@ class LangSmithDataTrace(BaseTraceInstance):
id=node_execution_id,
trace_id=trace_id,
dotted_order=node_dotted_order,
error="",
serialized=None,
events=[],
session_id=None,
session_name=None,
reference_example_id=None,
input_attachments={},
output_attachments={},
)
self.add_run(langsmith_run)
def message_trace(self, trace_info: MessageTraceInfo):
# get message file data
file_list = trace_info.file_list
message_file_data: MessageFile = trace_info.message_file_data
file_list = cast(list[str], trace_info.file_list) or []
message_file_data: Optional[MessageFile] = trace_info.message_file_data
file_url = f"{self.file_base_url}/{message_file_data.url}" if message_file_data else ""
file_list.append(file_url)
metadata = trace_info.metadata
message_data = trace_info.message_data
if message_data is None:
return
message_id = message_data.id
user_id = message_data.from_account_id
metadata["user_id"] = user_id
if message_data.from_end_user_id:
end_user_data: EndUser = (
end_user_data: Optional[EndUser] = (
db.session.query(EndUser).filter(EndUser.id == message_data.from_end_user_id).first()
)
if end_user_data is not None:
@@ -247,12 +276,20 @@ class LangSmithDataTrace(BaseTraceInstance):
start_time=trace_info.start_time,
end_time=trace_info.end_time,
outputs=message_data.answer,
extra={
"metadata": metadata,
},
extra={"metadata": metadata},
tags=["message", str(trace_info.conversation_mode)],
error=trace_info.error,
file_list=file_list,
serialized=None,
events=[],
session_id=None,
session_name=None,
reference_example_id=None,
input_attachments={},
output_attachments={},
trace_id=None,
dotted_order=None,
parent_run_id=None,
)
self.add_run(message_run)
@@ -267,17 +304,27 @@ class LangSmithDataTrace(BaseTraceInstance):
start_time=trace_info.start_time,
end_time=trace_info.end_time,
outputs=message_data.answer,
extra={
"metadata": metadata,
},
extra={"metadata": metadata},
parent_run_id=message_id,
tags=["llm", str(trace_info.conversation_mode)],
error=trace_info.error,
file_list=file_list,
serialized=None,
events=[],
session_id=None,
session_name=None,
reference_example_id=None,
input_attachments={},
output_attachments={},
trace_id=None,
dotted_order=None,
id=str(uuid.uuid4()),
)
self.add_run(llm_run)
def moderation_trace(self, trace_info: ModerationTraceInfo):
if trace_info.message_data is None:
return
langsmith_run = LangSmithRunModel(
name=TraceTaskName.MODERATION_TRACE.value,
inputs=trace_info.inputs,
@@ -288,48 +335,82 @@ class LangSmithDataTrace(BaseTraceInstance):
"inputs": trace_info.inputs,
},
run_type=LangSmithRunType.tool,
extra={
"metadata": trace_info.metadata,
},
extra={"metadata": trace_info.metadata},
tags=["moderation"],
parent_run_id=trace_info.message_id,
start_time=trace_info.start_time or trace_info.message_data.created_at,
end_time=trace_info.end_time or trace_info.message_data.updated_at,
id=str(uuid.uuid4()),
serialized=None,
events=[],
session_id=None,
session_name=None,
reference_example_id=None,
input_attachments={},
output_attachments={},
trace_id=None,
dotted_order=None,
error="",
file_list=[],
)
self.add_run(langsmith_run)
def suggested_question_trace(self, trace_info: SuggestedQuestionTraceInfo):
message_data = trace_info.message_data
if message_data is None:
return
suggested_question_run = LangSmithRunModel(
name=TraceTaskName.SUGGESTED_QUESTION_TRACE.value,
inputs=trace_info.inputs,
outputs=trace_info.suggested_question,
run_type=LangSmithRunType.tool,
extra={
"metadata": trace_info.metadata,
},
extra={"metadata": trace_info.metadata},
tags=["suggested_question"],
parent_run_id=trace_info.message_id,
start_time=trace_info.start_time or message_data.created_at,
end_time=trace_info.end_time or message_data.updated_at,
id=str(uuid.uuid4()),
serialized=None,
events=[],
session_id=None,
session_name=None,
reference_example_id=None,
input_attachments={},
output_attachments={},
trace_id=None,
dotted_order=None,
error="",
file_list=[],
)
self.add_run(suggested_question_run)
def dataset_retrieval_trace(self, trace_info: DatasetRetrievalTraceInfo):
if trace_info.message_data is None:
return
dataset_retrieval_run = LangSmithRunModel(
name=TraceTaskName.DATASET_RETRIEVAL_TRACE.value,
inputs=trace_info.inputs,
outputs={"documents": trace_info.documents},
run_type=LangSmithRunType.retriever,
extra={
"metadata": trace_info.metadata,
},
extra={"metadata": trace_info.metadata},
tags=["dataset_retrieval"],
parent_run_id=trace_info.message_id,
start_time=trace_info.start_time or trace_info.message_data.created_at,
end_time=trace_info.end_time or trace_info.message_data.updated_at,
id=str(uuid.uuid4()),
serialized=None,
events=[],
session_id=None,
session_name=None,
reference_example_id=None,
input_attachments={},
output_attachments={},
trace_id=None,
dotted_order=None,
error="",
file_list=[],
)
self.add_run(dataset_retrieval_run)
@@ -347,7 +428,18 @@ class LangSmithDataTrace(BaseTraceInstance):
parent_run_id=trace_info.message_id,
start_time=trace_info.start_time,
end_time=trace_info.end_time,
file_list=[trace_info.file_url],
file_list=[cast(str, trace_info.file_url)],
id=str(uuid.uuid4()),
serialized=None,
events=[],
session_id=None,
session_name=None,
reference_example_id=None,
input_attachments={},
output_attachments={},
trace_id=None,
dotted_order=None,
error=trace_info.error or "",
)
self.add_run(tool_run)
@@ -358,12 +450,23 @@ class LangSmithDataTrace(BaseTraceInstance):
inputs=trace_info.inputs,
outputs=trace_info.outputs,
run_type=LangSmithRunType.tool,
extra={
"metadata": trace_info.metadata,
},
extra={"metadata": trace_info.metadata},
tags=["generate_name"],
start_time=trace_info.start_time or datetime.now(),
end_time=trace_info.end_time or datetime.now(),
id=str(uuid.uuid4()),
serialized=None,
events=[],
session_id=None,
session_name=None,
reference_example_id=None,
input_attachments={},
output_attachments={},
trace_id=None,
dotted_order=None,
error="",
file_list=[],
parent_run_id=None,
)
self.add_run(name_run)

View File

@@ -33,11 +33,11 @@ from core.ops.langsmith_trace.langsmith_trace import LangSmithDataTrace
from core.ops.utils import get_message_data
from extensions.ext_database import db
from extensions.ext_storage import storage
from models.model import App, AppModelConfig, Conversation, Message, MessageAgentThought, MessageFile, TraceAppConfig
from models.model import App, AppModelConfig, Conversation, Message, MessageFile, TraceAppConfig
from models.workflow import WorkflowAppLog, WorkflowRun
from tasks.ops_trace_task import process_trace_tasks
provider_config_map = {
provider_config_map: dict[str, dict[str, Any]] = {
TracingProviderEnum.LANGFUSE.value: {
"config_class": LangfuseConfig,
"secret_keys": ["public_key", "secret_key"],
@@ -145,7 +145,7 @@ class OpsTraceManager:
:param tracing_provider: tracing provider
:return:
"""
trace_config_data: TraceAppConfig = (
trace_config_data: Optional[TraceAppConfig] = (
db.session.query(TraceAppConfig)
.filter(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider)
.first()
@@ -155,7 +155,11 @@ class OpsTraceManager:
return None
# decrypt_token
tenant_id = db.session.query(App).filter(App.id == app_id).first().tenant_id
app = db.session.query(App).filter(App.id == app_id).first()
if not app:
raise ValueError("App not found")
tenant_id = app.tenant_id
decrypt_tracing_config = cls.decrypt_tracing_config(
tenant_id, tracing_provider, trace_config_data.tracing_config
)
@@ -178,7 +182,7 @@ class OpsTraceManager:
if app_id is None:
return None
app: App = db.session.query(App).filter(App.id == app_id).first()
app: Optional[App] = db.session.query(App).filter(App.id == app_id).first()
if app is None:
return None
@@ -209,8 +213,12 @@ class OpsTraceManager:
def get_app_config_through_message_id(cls, message_id: str):
app_model_config = None
message_data = db.session.query(Message).filter(Message.id == message_id).first()
if not message_data:
return None
conversation_id = message_data.conversation_id
conversation_data = db.session.query(Conversation).filter(Conversation.id == conversation_id).first()
if not conversation_data:
return None
if conversation_data.app_model_config_id:
app_model_config = (
@@ -236,7 +244,9 @@ class OpsTraceManager:
if tracing_provider not in provider_config_map and tracing_provider is not None:
raise ValueError(f"Invalid tracing provider: {tracing_provider}")
app_config: App = db.session.query(App).filter(App.id == app_id).first()
app_config: Optional[App] = db.session.query(App).filter(App.id == app_id).first()
if not app_config:
raise ValueError("App not found")
app_config.tracing = json.dumps(
{
"enabled": enabled,
@@ -252,7 +262,9 @@ class OpsTraceManager:
:param app_id: app id
:return:
"""
app: App = db.session.query(App).filter(App.id == app_id).first()
app: Optional[App] = db.session.query(App).filter(App.id == app_id).first()
if not app:
raise ValueError("App not found")
if not app.tracing:
return {"enabled": False, "tracing_provider": None}
app_trace_config = json.loads(app.tracing)
@@ -483,6 +495,8 @@ class TraceTask:
def moderation_trace(self, message_id, timer, **kwargs):
moderation_result = kwargs.get("moderation_result")
if not moderation_result:
return {}
inputs = kwargs.get("inputs")
message_data = get_message_data(message_id)
if not message_data:
@@ -518,7 +532,7 @@ class TraceTask:
return moderation_trace_info
def suggested_question_trace(self, message_id, timer, **kwargs):
suggested_question = kwargs.get("suggested_question")
suggested_question = kwargs.get("suggested_question", [])
message_data = get_message_data(message_id)
if not message_data:
return {}
@@ -586,7 +600,7 @@ class TraceTask:
dataset_retrieval_trace_info = DatasetRetrievalTraceInfo(
message_id=message_id,
inputs=message_data.query or message_data.inputs,
documents=[doc.model_dump() for doc in documents],
documents=[doc.model_dump() for doc in documents] if documents else [],
start_time=timer.get("start"),
end_time=timer.get("end"),
metadata=metadata,
@@ -596,9 +610,9 @@ class TraceTask:
return dataset_retrieval_trace_info
def tool_trace(self, message_id, timer, **kwargs):
tool_name = kwargs.get("tool_name")
tool_inputs = kwargs.get("tool_inputs")
tool_outputs = kwargs.get("tool_outputs")
tool_name = kwargs.get("tool_name", "")
tool_inputs = kwargs.get("tool_inputs", {})
tool_outputs = kwargs.get("tool_outputs", {})
message_data = get_message_data(message_id)
if not message_data:
return {}
@@ -608,7 +622,7 @@ class TraceTask:
tool_parameters = {}
created_time = message_data.created_at
end_time = message_data.updated_at
agent_thoughts: list[MessageAgentThought] = message_data.agent_thoughts
agent_thoughts = message_data.agent_thoughts
for agent_thought in agent_thoughts:
if tool_name in agent_thought.tools:
created_time = agent_thought.created_at
@@ -672,6 +686,8 @@ class TraceTask:
generate_conversation_name = kwargs.get("generate_conversation_name")
inputs = kwargs.get("inputs")
tenant_id = kwargs.get("tenant_id")
if not tenant_id:
return {}
start_time = timer.get("start")
end_time = timer.get("end")
@@ -693,8 +709,8 @@ class TraceTask:
return generate_name_trace_info
trace_manager_timer = None
trace_manager_queue = queue.Queue()
trace_manager_timer: Optional[threading.Timer] = None
trace_manager_queue: queue.Queue = queue.Queue()
trace_manager_interval = int(os.getenv("TRACE_QUEUE_MANAGER_INTERVAL", 5))
trace_manager_batch_size = int(os.getenv("TRACE_QUEUE_MANAGER_BATCH_SIZE", 100))
@@ -706,7 +722,7 @@ class TraceQueueManager:
self.app_id = app_id
self.user_id = user_id
self.trace_instance = OpsTraceManager.get_ops_trace_instance(app_id)
self.flask_app = current_app._get_current_object()
self.flask_app = current_app._get_current_object() # type: ignore
if trace_manager_timer is None:
self.start_timer()
@@ -723,7 +739,7 @@ class TraceQueueManager:
def collect_tasks(self):
global trace_manager_queue
tasks = []
tasks: list[TraceTask] = []
while len(tasks) < trace_manager_batch_size and not trace_manager_queue.empty():
task = trace_manager_queue.get_nowait()
tasks.append(task)
@@ -749,6 +765,8 @@ class TraceQueueManager:
def send_to_celery(self, tasks: list[TraceTask]):
with self.flask_app.app_context():
for task in tasks:
if task.app_id is None:
continue
file_id = uuid4().hex
trace_info = task.execute()
task_data = TaskData(