chore(api/core): apply ruff reformatting (#7624)
This commit is contained in:
@@ -23,4 +23,4 @@ class BaseTraceInstance(ABC):
|
||||
Abstract method to trace activities.
|
||||
Subclasses must implement specific tracing logic for activities.
|
||||
"""
|
||||
...
|
||||
...
|
||||
|
@@ -4,14 +4,15 @@ from pydantic import BaseModel, ValidationInfo, field_validator
|
||||
|
||||
|
||||
class TracingProviderEnum(Enum):
|
||||
LANGFUSE = 'langfuse'
|
||||
LANGSMITH = 'langsmith'
|
||||
LANGFUSE = "langfuse"
|
||||
LANGSMITH = "langsmith"
|
||||
|
||||
|
||||
class BaseTracingConfig(BaseModel):
|
||||
"""
|
||||
Base model class for tracing
|
||||
"""
|
||||
|
||||
...
|
||||
|
||||
|
||||
@@ -19,16 +20,17 @@ class LangfuseConfig(BaseTracingConfig):
|
||||
"""
|
||||
Model class for Langfuse tracing config.
|
||||
"""
|
||||
|
||||
public_key: str
|
||||
secret_key: str
|
||||
host: str = 'https://api.langfuse.com'
|
||||
host: str = "https://api.langfuse.com"
|
||||
|
||||
@field_validator("host")
|
||||
def set_value(cls, v, info: ValidationInfo):
|
||||
if v is None or v == "":
|
||||
v = 'https://api.langfuse.com'
|
||||
if not v.startswith('https://') and not v.startswith('http://'):
|
||||
raise ValueError('host must start with https:// or http://')
|
||||
v = "https://api.langfuse.com"
|
||||
if not v.startswith("https://") and not v.startswith("http://"):
|
||||
raise ValueError("host must start with https:// or http://")
|
||||
|
||||
return v
|
||||
|
||||
@@ -37,15 +39,16 @@ class LangSmithConfig(BaseTracingConfig):
|
||||
"""
|
||||
Model class for Langsmith tracing config.
|
||||
"""
|
||||
|
||||
api_key: str
|
||||
project: str
|
||||
endpoint: str = 'https://api.smith.langchain.com'
|
||||
endpoint: str = "https://api.smith.langchain.com"
|
||||
|
||||
@field_validator("endpoint")
|
||||
def set_value(cls, v, info: ValidationInfo):
|
||||
if v is None or v == "":
|
||||
v = 'https://api.smith.langchain.com'
|
||||
if not v.startswith('https://'):
|
||||
raise ValueError('endpoint must start with https://')
|
||||
v = "https://api.smith.langchain.com"
|
||||
if not v.startswith("https://"):
|
||||
raise ValueError("endpoint must start with https://")
|
||||
|
||||
return v
|
||||
|
@@ -23,6 +23,7 @@ class BaseTraceInfo(BaseModel):
|
||||
else:
|
||||
return ""
|
||||
|
||||
|
||||
class WorkflowTraceInfo(BaseTraceInfo):
|
||||
workflow_data: Any
|
||||
conversation_id: Optional[str] = None
|
||||
@@ -98,23 +99,24 @@ class GenerateNameTraceInfo(BaseTraceInfo):
|
||||
conversation_id: Optional[str] = None
|
||||
tenant_id: str
|
||||
|
||||
|
||||
trace_info_info_map = {
|
||||
'WorkflowTraceInfo': WorkflowTraceInfo,
|
||||
'MessageTraceInfo': MessageTraceInfo,
|
||||
'ModerationTraceInfo': ModerationTraceInfo,
|
||||
'SuggestedQuestionTraceInfo': SuggestedQuestionTraceInfo,
|
||||
'DatasetRetrievalTraceInfo': DatasetRetrievalTraceInfo,
|
||||
'ToolTraceInfo': ToolTraceInfo,
|
||||
'GenerateNameTraceInfo': GenerateNameTraceInfo,
|
||||
"WorkflowTraceInfo": WorkflowTraceInfo,
|
||||
"MessageTraceInfo": MessageTraceInfo,
|
||||
"ModerationTraceInfo": ModerationTraceInfo,
|
||||
"SuggestedQuestionTraceInfo": SuggestedQuestionTraceInfo,
|
||||
"DatasetRetrievalTraceInfo": DatasetRetrievalTraceInfo,
|
||||
"ToolTraceInfo": ToolTraceInfo,
|
||||
"GenerateNameTraceInfo": GenerateNameTraceInfo,
|
||||
}
|
||||
|
||||
|
||||
class TraceTaskName(str, Enum):
|
||||
CONVERSATION_TRACE = 'conversation'
|
||||
WORKFLOW_TRACE = 'workflow'
|
||||
MESSAGE_TRACE = 'message'
|
||||
MODERATION_TRACE = 'moderation'
|
||||
SUGGESTED_QUESTION_TRACE = 'suggested_question'
|
||||
DATASET_RETRIEVAL_TRACE = 'dataset_retrieval'
|
||||
TOOL_TRACE = 'tool'
|
||||
GENERATE_NAME_TRACE = 'generate_conversation_name'
|
||||
CONVERSATION_TRACE = "conversation"
|
||||
WORKFLOW_TRACE = "workflow"
|
||||
MESSAGE_TRACE = "message"
|
||||
MODERATION_TRACE = "moderation"
|
||||
SUGGESTED_QUESTION_TRACE = "suggested_question"
|
||||
DATASET_RETRIEVAL_TRACE = "dataset_retrieval"
|
||||
TOOL_TRACE = "tool"
|
||||
GENERATE_NAME_TRACE = "generate_conversation_name"
|
||||
|
@@ -35,38 +35,20 @@ class LangSmithRunModel(LangSmithTokenUsage, LangSmithMultiModel):
|
||||
run_type: LangSmithRunType = Field(..., description="Type of the run")
|
||||
start_time: Optional[datetime | str] = Field(None, description="Start time of the run")
|
||||
end_time: Optional[datetime | str] = Field(None, description="End time of the run")
|
||||
extra: Optional[dict[str, Any]] = Field(
|
||||
None, description="Extra information of the run"
|
||||
)
|
||||
extra: Optional[dict[str, Any]] = Field(None, description="Extra information of the run")
|
||||
error: Optional[str] = Field(None, description="Error message of the run")
|
||||
serialized: Optional[dict[str, Any]] = Field(
|
||||
None, description="Serialized data of the run"
|
||||
)
|
||||
serialized: Optional[dict[str, Any]] = Field(None, description="Serialized data of the run")
|
||||
parent_run_id: Optional[str] = Field(None, description="Parent run ID")
|
||||
events: Optional[list[dict[str, Any]]] = Field(
|
||||
None, description="Events associated with the run"
|
||||
)
|
||||
events: Optional[list[dict[str, Any]]] = Field(None, description="Events associated with the run")
|
||||
tags: Optional[list[str]] = Field(None, description="Tags associated with the run")
|
||||
trace_id: Optional[str] = Field(
|
||||
None, description="Trace ID associated with the run"
|
||||
)
|
||||
trace_id: Optional[str] = Field(None, description="Trace ID associated with the run")
|
||||
dotted_order: Optional[str] = Field(None, description="Dotted order of the run")
|
||||
id: Optional[str] = Field(None, description="ID of the run")
|
||||
session_id: Optional[str] = Field(
|
||||
None, description="Session ID associated with the run"
|
||||
)
|
||||
session_name: Optional[str] = Field(
|
||||
None, description="Session name associated with the run"
|
||||
)
|
||||
reference_example_id: Optional[str] = Field(
|
||||
None, description="Reference example ID associated with the run"
|
||||
)
|
||||
input_attachments: Optional[dict[str, Any]] = Field(
|
||||
None, description="Input attachments of the run"
|
||||
)
|
||||
output_attachments: Optional[dict[str, Any]] = Field(
|
||||
None, description="Output attachments of the run"
|
||||
)
|
||||
session_id: Optional[str] = Field(None, description="Session ID associated with the run")
|
||||
session_name: Optional[str] = Field(None, description="Session name associated with the run")
|
||||
reference_example_id: Optional[str] = Field(None, description="Reference example ID associated with the run")
|
||||
input_attachments: Optional[dict[str, Any]] = Field(None, description="Input attachments of the run")
|
||||
output_attachments: Optional[dict[str, Any]] = Field(None, description="Output attachments of the run")
|
||||
|
||||
@field_validator("inputs", "outputs")
|
||||
def ensure_dict(cls, v, info: ValidationInfo):
|
||||
@@ -75,9 +57,9 @@ class LangSmithRunModel(LangSmithTokenUsage, LangSmithMultiModel):
|
||||
if v == {} or v is None:
|
||||
return v
|
||||
usage_metadata = {
|
||||
"input_tokens": values.get('input_tokens', 0),
|
||||
"output_tokens": values.get('output_tokens', 0),
|
||||
"total_tokens": values.get('total_tokens', 0),
|
||||
"input_tokens": values.get("input_tokens", 0),
|
||||
"output_tokens": values.get("output_tokens", 0),
|
||||
"total_tokens": values.get("total_tokens", 0),
|
||||
}
|
||||
file_list = values.get("file_list", [])
|
||||
if isinstance(v, str):
|
||||
@@ -143,25 +125,15 @@ class LangSmithRunModel(LangSmithTokenUsage, LangSmithMultiModel):
|
||||
|
||||
class LangSmithRunUpdateModel(BaseModel):
|
||||
run_id: str = Field(..., description="ID of the run")
|
||||
trace_id: Optional[str] = Field(
|
||||
None, description="Trace ID associated with the run"
|
||||
)
|
||||
trace_id: Optional[str] = Field(None, description="Trace ID associated with the run")
|
||||
dotted_order: Optional[str] = Field(None, description="Dotted order of the run")
|
||||
parent_run_id: Optional[str] = Field(None, description="Parent run ID")
|
||||
end_time: Optional[datetime | str] = Field(None, description="End time of the run")
|
||||
error: Optional[str] = Field(None, description="Error message of the run")
|
||||
inputs: Optional[dict[str, Any]] = Field(None, description="Inputs of the run")
|
||||
outputs: Optional[dict[str, Any]] = Field(None, description="Outputs of the run")
|
||||
events: Optional[list[dict[str, Any]]] = Field(
|
||||
None, description="Events associated with the run"
|
||||
)
|
||||
events: Optional[list[dict[str, Any]]] = Field(None, description="Events associated with the run")
|
||||
tags: Optional[list[str]] = Field(None, description="Tags associated with the run")
|
||||
extra: Optional[dict[str, Any]] = Field(
|
||||
None, description="Extra information of the run"
|
||||
)
|
||||
input_attachments: Optional[dict[str, Any]] = Field(
|
||||
None, description="Input attachments of the run"
|
||||
)
|
||||
output_attachments: Optional[dict[str, Any]] = Field(
|
||||
None, description="Output attachments of the run"
|
||||
)
|
||||
extra: Optional[dict[str, Any]] = Field(None, description="Extra information of the run")
|
||||
input_attachments: Optional[dict[str, Any]] = Field(None, description="Input attachments of the run")
|
||||
output_attachments: Optional[dict[str, Any]] = Field(None, description="Output attachments of the run")
|
||||
|
@@ -159,8 +159,8 @@ class LangSmithDataTrace(BaseTraceInstance):
|
||||
run_type = LangSmithRunType.llm
|
||||
metadata.update(
|
||||
{
|
||||
'ls_provider': process_data.get('model_provider', ''),
|
||||
'ls_model_name': process_data.get('model_name', ''),
|
||||
"ls_provider": process_data.get("model_provider", ""),
|
||||
"ls_model_name": process_data.get("model_name", ""),
|
||||
}
|
||||
)
|
||||
elif node_type == "knowledge-retrieval":
|
||||
@@ -385,12 +385,10 @@ class LangSmithDataTrace(BaseTraceInstance):
|
||||
start_time=datetime.now(),
|
||||
)
|
||||
|
||||
project_url = self.langsmith_client.get_run_url(run=run_data,
|
||||
project_id=self.project_id,
|
||||
project_name=self.project_name)
|
||||
return project_url.split('/r/')[0]
|
||||
project_url = self.langsmith_client.get_run_url(
|
||||
run=run_data, project_id=self.project_id, project_name=self.project_name
|
||||
)
|
||||
return project_url.split("/r/")[0]
|
||||
except Exception as e:
|
||||
logger.debug(f"LangSmith get run url failed: {str(e)}")
|
||||
raise ValueError(f"LangSmith get run url failed: {str(e)}")
|
||||
|
||||
|
||||
|
@@ -36,17 +36,17 @@ from tasks.ops_trace_task import process_trace_tasks
|
||||
|
||||
provider_config_map = {
|
||||
TracingProviderEnum.LANGFUSE.value: {
|
||||
'config_class': LangfuseConfig,
|
||||
'secret_keys': ['public_key', 'secret_key'],
|
||||
'other_keys': ['host', 'project_key'],
|
||||
'trace_instance': LangFuseDataTrace
|
||||
"config_class": LangfuseConfig,
|
||||
"secret_keys": ["public_key", "secret_key"],
|
||||
"other_keys": ["host", "project_key"],
|
||||
"trace_instance": LangFuseDataTrace,
|
||||
},
|
||||
TracingProviderEnum.LANGSMITH.value: {
|
||||
'config_class': LangSmithConfig,
|
||||
'secret_keys': ['api_key'],
|
||||
'other_keys': ['project', 'endpoint'],
|
||||
'trace_instance': LangSmithDataTrace
|
||||
}
|
||||
"config_class": LangSmithConfig,
|
||||
"secret_keys": ["api_key"],
|
||||
"other_keys": ["project", "endpoint"],
|
||||
"trace_instance": LangSmithDataTrace,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@@ -64,14 +64,17 @@ class OpsTraceManager:
|
||||
:return: encrypted tracing configuration
|
||||
"""
|
||||
# Get the configuration class and the keys that require encryption
|
||||
config_class, secret_keys, other_keys = provider_config_map[tracing_provider]['config_class'], \
|
||||
provider_config_map[tracing_provider]['secret_keys'], provider_config_map[tracing_provider]['other_keys']
|
||||
config_class, secret_keys, other_keys = (
|
||||
provider_config_map[tracing_provider]["config_class"],
|
||||
provider_config_map[tracing_provider]["secret_keys"],
|
||||
provider_config_map[tracing_provider]["other_keys"],
|
||||
)
|
||||
|
||||
new_config = {}
|
||||
# Encrypt necessary keys
|
||||
for key in secret_keys:
|
||||
if key in tracing_config:
|
||||
if '*' in tracing_config[key]:
|
||||
if "*" in tracing_config[key]:
|
||||
# If the key contains '*', retain the original value from the current config
|
||||
new_config[key] = current_trace_config.get(key, tracing_config[key])
|
||||
else:
|
||||
@@ -94,8 +97,11 @@ class OpsTraceManager:
|
||||
:param tracing_config: tracing config
|
||||
:return:
|
||||
"""
|
||||
config_class, secret_keys, other_keys = provider_config_map[tracing_provider]['config_class'], \
|
||||
provider_config_map[tracing_provider]['secret_keys'], provider_config_map[tracing_provider]['other_keys']
|
||||
config_class, secret_keys, other_keys = (
|
||||
provider_config_map[tracing_provider]["config_class"],
|
||||
provider_config_map[tracing_provider]["secret_keys"],
|
||||
provider_config_map[tracing_provider]["other_keys"],
|
||||
)
|
||||
new_config = {}
|
||||
for key in secret_keys:
|
||||
if key in tracing_config:
|
||||
@@ -114,8 +120,11 @@ class OpsTraceManager:
|
||||
:param decrypt_tracing_config: tracing config
|
||||
:return:
|
||||
"""
|
||||
config_class, secret_keys, other_keys = provider_config_map[tracing_provider]['config_class'], \
|
||||
provider_config_map[tracing_provider]['secret_keys'], provider_config_map[tracing_provider]['other_keys']
|
||||
config_class, secret_keys, other_keys = (
|
||||
provider_config_map[tracing_provider]["config_class"],
|
||||
provider_config_map[tracing_provider]["secret_keys"],
|
||||
provider_config_map[tracing_provider]["other_keys"],
|
||||
)
|
||||
new_config = {}
|
||||
for key in secret_keys:
|
||||
if key in decrypt_tracing_config:
|
||||
@@ -133,9 +142,11 @@ class OpsTraceManager:
|
||||
:param tracing_provider: tracing provider
|
||||
:return:
|
||||
"""
|
||||
trace_config_data: TraceAppConfig = db.session.query(TraceAppConfig).filter(
|
||||
TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider
|
||||
).first()
|
||||
trace_config_data: TraceAppConfig = (
|
||||
db.session.query(TraceAppConfig)
|
||||
.filter(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not trace_config_data:
|
||||
return None
|
||||
@@ -164,21 +175,21 @@ class OpsTraceManager:
|
||||
if app_id is None:
|
||||
return None
|
||||
|
||||
app: App = db.session.query(App).filter(
|
||||
App.id == app_id
|
||||
).first()
|
||||
app: App = db.session.query(App).filter(App.id == app_id).first()
|
||||
app_ops_trace_config = json.loads(app.tracing) if app.tracing else None
|
||||
|
||||
if app_ops_trace_config is not None:
|
||||
tracing_provider = app_ops_trace_config.get('tracing_provider')
|
||||
tracing_provider = app_ops_trace_config.get("tracing_provider")
|
||||
else:
|
||||
return None
|
||||
|
||||
# decrypt_token
|
||||
decrypt_trace_config = cls.get_decrypted_tracing_config(app_id, tracing_provider)
|
||||
if app_ops_trace_config.get('enabled'):
|
||||
trace_instance, config_class = provider_config_map[tracing_provider]['trace_instance'], \
|
||||
provider_config_map[tracing_provider]['config_class']
|
||||
if app_ops_trace_config.get("enabled"):
|
||||
trace_instance, config_class = (
|
||||
provider_config_map[tracing_provider]["trace_instance"],
|
||||
provider_config_map[tracing_provider]["config_class"],
|
||||
)
|
||||
tracing_instance = trace_instance(config_class(**decrypt_trace_config))
|
||||
return tracing_instance
|
||||
|
||||
@@ -192,9 +203,11 @@ class OpsTraceManager:
|
||||
conversation_data = db.session.query(Conversation).filter(Conversation.id == conversation_id).first()
|
||||
|
||||
if conversation_data.app_model_config_id:
|
||||
app_model_config = db.session.query(AppModelConfig).filter(
|
||||
AppModelConfig.id == conversation_data.app_model_config_id
|
||||
).first()
|
||||
app_model_config = (
|
||||
db.session.query(AppModelConfig)
|
||||
.filter(AppModelConfig.id == conversation_data.app_model_config_id)
|
||||
.first()
|
||||
)
|
||||
elif conversation_data.app_model_config_id is None and conversation_data.override_model_configs:
|
||||
app_model_config = conversation_data.override_model_configs
|
||||
|
||||
@@ -231,10 +244,7 @@ class OpsTraceManager:
|
||||
"""
|
||||
app: App = db.session.query(App).filter(App.id == app_id).first()
|
||||
if not app.tracing:
|
||||
return {
|
||||
"enabled": False,
|
||||
"tracing_provider": None
|
||||
}
|
||||
return {"enabled": False, "tracing_provider": None}
|
||||
app_trace_config = json.loads(app.tracing)
|
||||
return app_trace_config
|
||||
|
||||
@@ -246,8 +256,10 @@ class OpsTraceManager:
|
||||
:param tracing_provider: tracing provider
|
||||
:return:
|
||||
"""
|
||||
config_type, trace_instance = provider_config_map[tracing_provider]['config_class'], \
|
||||
provider_config_map[tracing_provider]['trace_instance']
|
||||
config_type, trace_instance = (
|
||||
provider_config_map[tracing_provider]["config_class"],
|
||||
provider_config_map[tracing_provider]["trace_instance"],
|
||||
)
|
||||
tracing_config = config_type(**tracing_config)
|
||||
return trace_instance(tracing_config).api_check()
|
||||
|
||||
@@ -259,8 +271,10 @@ class OpsTraceManager:
|
||||
:param tracing_provider: tracing provider
|
||||
:return:
|
||||
"""
|
||||
config_type, trace_instance = provider_config_map[tracing_provider]['config_class'], \
|
||||
provider_config_map[tracing_provider]['trace_instance']
|
||||
config_type, trace_instance = (
|
||||
provider_config_map[tracing_provider]["config_class"],
|
||||
provider_config_map[tracing_provider]["trace_instance"],
|
||||
)
|
||||
tracing_config = config_type(**tracing_config)
|
||||
return trace_instance(tracing_config).get_project_key()
|
||||
|
||||
@@ -272,8 +286,10 @@ class OpsTraceManager:
|
||||
:param tracing_provider: tracing provider
|
||||
:return:
|
||||
"""
|
||||
config_type, trace_instance = provider_config_map[tracing_provider]['config_class'], \
|
||||
provider_config_map[tracing_provider]['trace_instance']
|
||||
config_type, trace_instance = (
|
||||
provider_config_map[tracing_provider]["config_class"],
|
||||
provider_config_map[tracing_provider]["trace_instance"],
|
||||
)
|
||||
tracing_config = config_type(**tracing_config)
|
||||
return trace_instance(tracing_config).get_project_url()
|
||||
|
||||
@@ -287,7 +303,7 @@ class TraceTask:
|
||||
conversation_id: Optional[str] = None,
|
||||
user_id: Optional[str] = None,
|
||||
timer: Optional[Any] = None,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
):
|
||||
self.trace_type = trace_type
|
||||
self.message_id = message_id
|
||||
@@ -310,9 +326,7 @@ class TraceTask:
|
||||
self.workflow_run, self.conversation_id, self.user_id
|
||||
),
|
||||
TraceTaskName.MESSAGE_TRACE: lambda: self.message_trace(self.message_id),
|
||||
TraceTaskName.MODERATION_TRACE: lambda: self.moderation_trace(
|
||||
self.message_id, self.timer, **self.kwargs
|
||||
),
|
||||
TraceTaskName.MODERATION_TRACE: lambda: self.moderation_trace(self.message_id, self.timer, **self.kwargs),
|
||||
TraceTaskName.SUGGESTED_QUESTION_TRACE: lambda: self.suggested_question_trace(
|
||||
self.message_id, self.timer, **self.kwargs
|
||||
),
|
||||
@@ -337,12 +351,8 @@ class TraceTask:
|
||||
workflow_run_id = workflow_run.id
|
||||
workflow_run_elapsed_time = workflow_run.elapsed_time
|
||||
workflow_run_status = workflow_run.status
|
||||
workflow_run_inputs = (
|
||||
json.loads(workflow_run.inputs) if workflow_run.inputs else {}
|
||||
)
|
||||
workflow_run_outputs = (
|
||||
json.loads(workflow_run.outputs) if workflow_run.outputs else {}
|
||||
)
|
||||
workflow_run_inputs = json.loads(workflow_run.inputs) if workflow_run.inputs else {}
|
||||
workflow_run_outputs = json.loads(workflow_run.outputs) if workflow_run.outputs else {}
|
||||
workflow_run_version = workflow_run.version
|
||||
error = workflow_run.error if workflow_run.error else ""
|
||||
|
||||
@@ -352,17 +362,18 @@ class TraceTask:
|
||||
query = workflow_run_inputs.get("query") or workflow_run_inputs.get("sys.query") or ""
|
||||
|
||||
# get workflow_app_log_id
|
||||
workflow_app_log_data = db.session.query(WorkflowAppLog).filter_by(
|
||||
tenant_id=tenant_id,
|
||||
app_id=workflow_run.app_id,
|
||||
workflow_run_id=workflow_run.id
|
||||
).first()
|
||||
workflow_app_log_data = (
|
||||
db.session.query(WorkflowAppLog)
|
||||
.filter_by(tenant_id=tenant_id, app_id=workflow_run.app_id, workflow_run_id=workflow_run.id)
|
||||
.first()
|
||||
)
|
||||
workflow_app_log_id = str(workflow_app_log_data.id) if workflow_app_log_data else None
|
||||
# get message_id
|
||||
message_data = db.session.query(Message.id).filter_by(
|
||||
conversation_id=conversation_id,
|
||||
workflow_run_id=workflow_run_id
|
||||
).first()
|
||||
message_data = (
|
||||
db.session.query(Message.id)
|
||||
.filter_by(conversation_id=conversation_id, workflow_run_id=workflow_run_id)
|
||||
.first()
|
||||
)
|
||||
message_id = str(message_data.id) if message_data else None
|
||||
|
||||
metadata = {
|
||||
@@ -470,9 +481,9 @@ class TraceTask:
|
||||
# get workflow_app_log_id
|
||||
workflow_app_log_id = None
|
||||
if message_data.workflow_run_id:
|
||||
workflow_app_log_data = db.session.query(WorkflowAppLog).filter_by(
|
||||
workflow_run_id=message_data.workflow_run_id
|
||||
).first()
|
||||
workflow_app_log_data = (
|
||||
db.session.query(WorkflowAppLog).filter_by(workflow_run_id=message_data.workflow_run_id).first()
|
||||
)
|
||||
workflow_app_log_id = str(workflow_app_log_data.id) if workflow_app_log_data else None
|
||||
|
||||
moderation_trace_info = ModerationTraceInfo(
|
||||
@@ -510,9 +521,9 @@ class TraceTask:
|
||||
# get workflow_app_log_id
|
||||
workflow_app_log_id = None
|
||||
if message_data.workflow_run_id:
|
||||
workflow_app_log_data = db.session.query(WorkflowAppLog).filter_by(
|
||||
workflow_run_id=message_data.workflow_run_id
|
||||
).first()
|
||||
workflow_app_log_data = (
|
||||
db.session.query(WorkflowAppLog).filter_by(workflow_run_id=message_data.workflow_run_id).first()
|
||||
)
|
||||
workflow_app_log_id = str(workflow_app_log_data.id) if workflow_app_log_data else None
|
||||
|
||||
suggested_question_trace_info = SuggestedQuestionTraceInfo(
|
||||
@@ -569,9 +580,9 @@ class TraceTask:
|
||||
return dataset_retrieval_trace_info
|
||||
|
||||
def tool_trace(self, message_id, timer, **kwargs):
|
||||
tool_name = kwargs.get('tool_name')
|
||||
tool_inputs = kwargs.get('tool_inputs')
|
||||
tool_outputs = kwargs.get('tool_outputs')
|
||||
tool_name = kwargs.get("tool_name")
|
||||
tool_inputs = kwargs.get("tool_inputs")
|
||||
tool_outputs = kwargs.get("tool_outputs")
|
||||
message_data = get_message_data(message_id)
|
||||
if not message_data:
|
||||
return {}
|
||||
@@ -586,11 +597,11 @@ class TraceTask:
|
||||
if tool_name in agent_thought.tools:
|
||||
created_time = agent_thought.created_at
|
||||
tool_meta_data = agent_thought.tool_meta.get(tool_name, {})
|
||||
tool_config = tool_meta_data.get('tool_config', {})
|
||||
time_cost = tool_meta_data.get('time_cost', 0)
|
||||
tool_config = tool_meta_data.get("tool_config", {})
|
||||
time_cost = tool_meta_data.get("time_cost", 0)
|
||||
end_time = created_time + timedelta(seconds=time_cost)
|
||||
error = tool_meta_data.get('error', "")
|
||||
tool_parameters = tool_meta_data.get('tool_parameters', {})
|
||||
error = tool_meta_data.get("error", "")
|
||||
tool_parameters = tool_meta_data.get("tool_parameters", {})
|
||||
metadata = {
|
||||
"message_id": message_id,
|
||||
"tool_name": tool_name,
|
||||
@@ -715,9 +726,7 @@ class TraceQueueManager:
|
||||
def start_timer(self):
|
||||
global trace_manager_timer
|
||||
if trace_manager_timer is None or not trace_manager_timer.is_alive():
|
||||
trace_manager_timer = threading.Timer(
|
||||
trace_manager_interval, self.run
|
||||
)
|
||||
trace_manager_timer = threading.Timer(trace_manager_interval, self.run)
|
||||
trace_manager_timer.name = f"trace_manager_timer_{time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())}"
|
||||
trace_manager_timer.daemon = False
|
||||
trace_manager_timer.start()
|
||||
|
@@ -20,19 +20,19 @@ def get_message_data(message_id):
|
||||
|
||||
@contextmanager
|
||||
def measure_time():
|
||||
timing_info = {'start': datetime.now(), 'end': None}
|
||||
timing_info = {"start": datetime.now(), "end": None}
|
||||
try:
|
||||
yield timing_info
|
||||
finally:
|
||||
timing_info['end'] = datetime.now()
|
||||
timing_info["end"] = datetime.now()
|
||||
|
||||
|
||||
def replace_text_with_content(data):
|
||||
if isinstance(data, dict):
|
||||
new_data = {}
|
||||
for key, value in data.items():
|
||||
if key == 'text':
|
||||
new_data['content'] = value
|
||||
if key == "text":
|
||||
new_data["content"] = value
|
||||
else:
|
||||
new_data[key] = replace_text_with_content(value)
|
||||
return new_data
|
||||
|
Reference in New Issue
Block a user