chore(api/core): apply ruff reformatting (#7624)

This commit is contained in:
Bowen Liang
2024-09-10 17:00:20 +08:00
committed by GitHub
parent 178730266d
commit 2cf1187b32
724 changed files with 21180 additions and 21123 deletions

View File

@@ -23,4 +23,4 @@ class BaseTraceInstance(ABC):
Abstract method to trace activities.
Subclasses must implement specific tracing logic for activities.
"""
...
...

View File

@@ -4,14 +4,15 @@ from pydantic import BaseModel, ValidationInfo, field_validator
class TracingProviderEnum(Enum):
LANGFUSE = 'langfuse'
LANGSMITH = 'langsmith'
LANGFUSE = "langfuse"
LANGSMITH = "langsmith"
class BaseTracingConfig(BaseModel):
"""
Base model class for tracing
"""
...
@@ -19,16 +20,17 @@ class LangfuseConfig(BaseTracingConfig):
"""
Model class for Langfuse tracing config.
"""
public_key: str
secret_key: str
host: str = 'https://api.langfuse.com'
host: str = "https://api.langfuse.com"
@field_validator("host")
def set_value(cls, v, info: ValidationInfo):
if v is None or v == "":
v = 'https://api.langfuse.com'
if not v.startswith('https://') and not v.startswith('http://'):
raise ValueError('host must start with https:// or http://')
v = "https://api.langfuse.com"
if not v.startswith("https://") and not v.startswith("http://"):
raise ValueError("host must start with https:// or http://")
return v
@@ -37,15 +39,16 @@ class LangSmithConfig(BaseTracingConfig):
"""
Model class for Langsmith tracing config.
"""
api_key: str
project: str
endpoint: str = 'https://api.smith.langchain.com'
endpoint: str = "https://api.smith.langchain.com"
@field_validator("endpoint")
def set_value(cls, v, info: ValidationInfo):
if v is None or v == "":
v = 'https://api.smith.langchain.com'
if not v.startswith('https://'):
raise ValueError('endpoint must start with https://')
v = "https://api.smith.langchain.com"
if not v.startswith("https://"):
raise ValueError("endpoint must start with https://")
return v

View File

@@ -23,6 +23,7 @@ class BaseTraceInfo(BaseModel):
else:
return ""
class WorkflowTraceInfo(BaseTraceInfo):
workflow_data: Any
conversation_id: Optional[str] = None
@@ -98,23 +99,24 @@ class GenerateNameTraceInfo(BaseTraceInfo):
conversation_id: Optional[str] = None
tenant_id: str
trace_info_info_map = {
'WorkflowTraceInfo': WorkflowTraceInfo,
'MessageTraceInfo': MessageTraceInfo,
'ModerationTraceInfo': ModerationTraceInfo,
'SuggestedQuestionTraceInfo': SuggestedQuestionTraceInfo,
'DatasetRetrievalTraceInfo': DatasetRetrievalTraceInfo,
'ToolTraceInfo': ToolTraceInfo,
'GenerateNameTraceInfo': GenerateNameTraceInfo,
"WorkflowTraceInfo": WorkflowTraceInfo,
"MessageTraceInfo": MessageTraceInfo,
"ModerationTraceInfo": ModerationTraceInfo,
"SuggestedQuestionTraceInfo": SuggestedQuestionTraceInfo,
"DatasetRetrievalTraceInfo": DatasetRetrievalTraceInfo,
"ToolTraceInfo": ToolTraceInfo,
"GenerateNameTraceInfo": GenerateNameTraceInfo,
}
class TraceTaskName(str, Enum):
CONVERSATION_TRACE = 'conversation'
WORKFLOW_TRACE = 'workflow'
MESSAGE_TRACE = 'message'
MODERATION_TRACE = 'moderation'
SUGGESTED_QUESTION_TRACE = 'suggested_question'
DATASET_RETRIEVAL_TRACE = 'dataset_retrieval'
TOOL_TRACE = 'tool'
GENERATE_NAME_TRACE = 'generate_conversation_name'
CONVERSATION_TRACE = "conversation"
WORKFLOW_TRACE = "workflow"
MESSAGE_TRACE = "message"
MODERATION_TRACE = "moderation"
SUGGESTED_QUESTION_TRACE = "suggested_question"
DATASET_RETRIEVAL_TRACE = "dataset_retrieval"
TOOL_TRACE = "tool"
GENERATE_NAME_TRACE = "generate_conversation_name"

View File

@@ -35,38 +35,20 @@ class LangSmithRunModel(LangSmithTokenUsage, LangSmithMultiModel):
run_type: LangSmithRunType = Field(..., description="Type of the run")
start_time: Optional[datetime | str] = Field(None, description="Start time of the run")
end_time: Optional[datetime | str] = Field(None, description="End time of the run")
extra: Optional[dict[str, Any]] = Field(
None, description="Extra information of the run"
)
extra: Optional[dict[str, Any]] = Field(None, description="Extra information of the run")
error: Optional[str] = Field(None, description="Error message of the run")
serialized: Optional[dict[str, Any]] = Field(
None, description="Serialized data of the run"
)
serialized: Optional[dict[str, Any]] = Field(None, description="Serialized data of the run")
parent_run_id: Optional[str] = Field(None, description="Parent run ID")
events: Optional[list[dict[str, Any]]] = Field(
None, description="Events associated with the run"
)
events: Optional[list[dict[str, Any]]] = Field(None, description="Events associated with the run")
tags: Optional[list[str]] = Field(None, description="Tags associated with the run")
trace_id: Optional[str] = Field(
None, description="Trace ID associated with the run"
)
trace_id: Optional[str] = Field(None, description="Trace ID associated with the run")
dotted_order: Optional[str] = Field(None, description="Dotted order of the run")
id: Optional[str] = Field(None, description="ID of the run")
session_id: Optional[str] = Field(
None, description="Session ID associated with the run"
)
session_name: Optional[str] = Field(
None, description="Session name associated with the run"
)
reference_example_id: Optional[str] = Field(
None, description="Reference example ID associated with the run"
)
input_attachments: Optional[dict[str, Any]] = Field(
None, description="Input attachments of the run"
)
output_attachments: Optional[dict[str, Any]] = Field(
None, description="Output attachments of the run"
)
session_id: Optional[str] = Field(None, description="Session ID associated with the run")
session_name: Optional[str] = Field(None, description="Session name associated with the run")
reference_example_id: Optional[str] = Field(None, description="Reference example ID associated with the run")
input_attachments: Optional[dict[str, Any]] = Field(None, description="Input attachments of the run")
output_attachments: Optional[dict[str, Any]] = Field(None, description="Output attachments of the run")
@field_validator("inputs", "outputs")
def ensure_dict(cls, v, info: ValidationInfo):
@@ -75,9 +57,9 @@ class LangSmithRunModel(LangSmithTokenUsage, LangSmithMultiModel):
if v == {} or v is None:
return v
usage_metadata = {
"input_tokens": values.get('input_tokens', 0),
"output_tokens": values.get('output_tokens', 0),
"total_tokens": values.get('total_tokens', 0),
"input_tokens": values.get("input_tokens", 0),
"output_tokens": values.get("output_tokens", 0),
"total_tokens": values.get("total_tokens", 0),
}
file_list = values.get("file_list", [])
if isinstance(v, str):
@@ -143,25 +125,15 @@ class LangSmithRunModel(LangSmithTokenUsage, LangSmithMultiModel):
class LangSmithRunUpdateModel(BaseModel):
run_id: str = Field(..., description="ID of the run")
trace_id: Optional[str] = Field(
None, description="Trace ID associated with the run"
)
trace_id: Optional[str] = Field(None, description="Trace ID associated with the run")
dotted_order: Optional[str] = Field(None, description="Dotted order of the run")
parent_run_id: Optional[str] = Field(None, description="Parent run ID")
end_time: Optional[datetime | str] = Field(None, description="End time of the run")
error: Optional[str] = Field(None, description="Error message of the run")
inputs: Optional[dict[str, Any]] = Field(None, description="Inputs of the run")
outputs: Optional[dict[str, Any]] = Field(None, description="Outputs of the run")
events: Optional[list[dict[str, Any]]] = Field(
None, description="Events associated with the run"
)
events: Optional[list[dict[str, Any]]] = Field(None, description="Events associated with the run")
tags: Optional[list[str]] = Field(None, description="Tags associated with the run")
extra: Optional[dict[str, Any]] = Field(
None, description="Extra information of the run"
)
input_attachments: Optional[dict[str, Any]] = Field(
None, description="Input attachments of the run"
)
output_attachments: Optional[dict[str, Any]] = Field(
None, description="Output attachments of the run"
)
extra: Optional[dict[str, Any]] = Field(None, description="Extra information of the run")
input_attachments: Optional[dict[str, Any]] = Field(None, description="Input attachments of the run")
output_attachments: Optional[dict[str, Any]] = Field(None, description="Output attachments of the run")

View File

@@ -159,8 +159,8 @@ class LangSmithDataTrace(BaseTraceInstance):
run_type = LangSmithRunType.llm
metadata.update(
{
'ls_provider': process_data.get('model_provider', ''),
'ls_model_name': process_data.get('model_name', ''),
"ls_provider": process_data.get("model_provider", ""),
"ls_model_name": process_data.get("model_name", ""),
}
)
elif node_type == "knowledge-retrieval":
@@ -385,12 +385,10 @@ class LangSmithDataTrace(BaseTraceInstance):
start_time=datetime.now(),
)
project_url = self.langsmith_client.get_run_url(run=run_data,
project_id=self.project_id,
project_name=self.project_name)
return project_url.split('/r/')[0]
project_url = self.langsmith_client.get_run_url(
run=run_data, project_id=self.project_id, project_name=self.project_name
)
return project_url.split("/r/")[0]
except Exception as e:
logger.debug(f"LangSmith get run url failed: {str(e)}")
raise ValueError(f"LangSmith get run url failed: {str(e)}")

View File

@@ -36,17 +36,17 @@ from tasks.ops_trace_task import process_trace_tasks
provider_config_map = {
TracingProviderEnum.LANGFUSE.value: {
'config_class': LangfuseConfig,
'secret_keys': ['public_key', 'secret_key'],
'other_keys': ['host', 'project_key'],
'trace_instance': LangFuseDataTrace
"config_class": LangfuseConfig,
"secret_keys": ["public_key", "secret_key"],
"other_keys": ["host", "project_key"],
"trace_instance": LangFuseDataTrace,
},
TracingProviderEnum.LANGSMITH.value: {
'config_class': LangSmithConfig,
'secret_keys': ['api_key'],
'other_keys': ['project', 'endpoint'],
'trace_instance': LangSmithDataTrace
}
"config_class": LangSmithConfig,
"secret_keys": ["api_key"],
"other_keys": ["project", "endpoint"],
"trace_instance": LangSmithDataTrace,
},
}
@@ -64,14 +64,17 @@ class OpsTraceManager:
:return: encrypted tracing configuration
"""
# Get the configuration class and the keys that require encryption
config_class, secret_keys, other_keys = provider_config_map[tracing_provider]['config_class'], \
provider_config_map[tracing_provider]['secret_keys'], provider_config_map[tracing_provider]['other_keys']
config_class, secret_keys, other_keys = (
provider_config_map[tracing_provider]["config_class"],
provider_config_map[tracing_provider]["secret_keys"],
provider_config_map[tracing_provider]["other_keys"],
)
new_config = {}
# Encrypt necessary keys
for key in secret_keys:
if key in tracing_config:
if '*' in tracing_config[key]:
if "*" in tracing_config[key]:
# If the key contains '*', retain the original value from the current config
new_config[key] = current_trace_config.get(key, tracing_config[key])
else:
@@ -94,8 +97,11 @@ class OpsTraceManager:
:param tracing_config: tracing config
:return:
"""
config_class, secret_keys, other_keys = provider_config_map[tracing_provider]['config_class'], \
provider_config_map[tracing_provider]['secret_keys'], provider_config_map[tracing_provider]['other_keys']
config_class, secret_keys, other_keys = (
provider_config_map[tracing_provider]["config_class"],
provider_config_map[tracing_provider]["secret_keys"],
provider_config_map[tracing_provider]["other_keys"],
)
new_config = {}
for key in secret_keys:
if key in tracing_config:
@@ -114,8 +120,11 @@ class OpsTraceManager:
:param decrypt_tracing_config: tracing config
:return:
"""
config_class, secret_keys, other_keys = provider_config_map[tracing_provider]['config_class'], \
provider_config_map[tracing_provider]['secret_keys'], provider_config_map[tracing_provider]['other_keys']
config_class, secret_keys, other_keys = (
provider_config_map[tracing_provider]["config_class"],
provider_config_map[tracing_provider]["secret_keys"],
provider_config_map[tracing_provider]["other_keys"],
)
new_config = {}
for key in secret_keys:
if key in decrypt_tracing_config:
@@ -133,9 +142,11 @@ class OpsTraceManager:
:param tracing_provider: tracing provider
:return:
"""
trace_config_data: TraceAppConfig = db.session.query(TraceAppConfig).filter(
TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider
).first()
trace_config_data: TraceAppConfig = (
db.session.query(TraceAppConfig)
.filter(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider)
.first()
)
if not trace_config_data:
return None
@@ -164,21 +175,21 @@ class OpsTraceManager:
if app_id is None:
return None
app: App = db.session.query(App).filter(
App.id == app_id
).first()
app: App = db.session.query(App).filter(App.id == app_id).first()
app_ops_trace_config = json.loads(app.tracing) if app.tracing else None
if app_ops_trace_config is not None:
tracing_provider = app_ops_trace_config.get('tracing_provider')
tracing_provider = app_ops_trace_config.get("tracing_provider")
else:
return None
# decrypt_token
decrypt_trace_config = cls.get_decrypted_tracing_config(app_id, tracing_provider)
if app_ops_trace_config.get('enabled'):
trace_instance, config_class = provider_config_map[tracing_provider]['trace_instance'], \
provider_config_map[tracing_provider]['config_class']
if app_ops_trace_config.get("enabled"):
trace_instance, config_class = (
provider_config_map[tracing_provider]["trace_instance"],
provider_config_map[tracing_provider]["config_class"],
)
tracing_instance = trace_instance(config_class(**decrypt_trace_config))
return tracing_instance
@@ -192,9 +203,11 @@ class OpsTraceManager:
conversation_data = db.session.query(Conversation).filter(Conversation.id == conversation_id).first()
if conversation_data.app_model_config_id:
app_model_config = db.session.query(AppModelConfig).filter(
AppModelConfig.id == conversation_data.app_model_config_id
).first()
app_model_config = (
db.session.query(AppModelConfig)
.filter(AppModelConfig.id == conversation_data.app_model_config_id)
.first()
)
elif conversation_data.app_model_config_id is None and conversation_data.override_model_configs:
app_model_config = conversation_data.override_model_configs
@@ -231,10 +244,7 @@ class OpsTraceManager:
"""
app: App = db.session.query(App).filter(App.id == app_id).first()
if not app.tracing:
return {
"enabled": False,
"tracing_provider": None
}
return {"enabled": False, "tracing_provider": None}
app_trace_config = json.loads(app.tracing)
return app_trace_config
@@ -246,8 +256,10 @@ class OpsTraceManager:
:param tracing_provider: tracing provider
:return:
"""
config_type, trace_instance = provider_config_map[tracing_provider]['config_class'], \
provider_config_map[tracing_provider]['trace_instance']
config_type, trace_instance = (
provider_config_map[tracing_provider]["config_class"],
provider_config_map[tracing_provider]["trace_instance"],
)
tracing_config = config_type(**tracing_config)
return trace_instance(tracing_config).api_check()
@@ -259,8 +271,10 @@ class OpsTraceManager:
:param tracing_provider: tracing provider
:return:
"""
config_type, trace_instance = provider_config_map[tracing_provider]['config_class'], \
provider_config_map[tracing_provider]['trace_instance']
config_type, trace_instance = (
provider_config_map[tracing_provider]["config_class"],
provider_config_map[tracing_provider]["trace_instance"],
)
tracing_config = config_type(**tracing_config)
return trace_instance(tracing_config).get_project_key()
@@ -272,8 +286,10 @@ class OpsTraceManager:
:param tracing_provider: tracing provider
:return:
"""
config_type, trace_instance = provider_config_map[tracing_provider]['config_class'], \
provider_config_map[tracing_provider]['trace_instance']
config_type, trace_instance = (
provider_config_map[tracing_provider]["config_class"],
provider_config_map[tracing_provider]["trace_instance"],
)
tracing_config = config_type(**tracing_config)
return trace_instance(tracing_config).get_project_url()
@@ -287,7 +303,7 @@ class TraceTask:
conversation_id: Optional[str] = None,
user_id: Optional[str] = None,
timer: Optional[Any] = None,
**kwargs
**kwargs,
):
self.trace_type = trace_type
self.message_id = message_id
@@ -310,9 +326,7 @@ class TraceTask:
self.workflow_run, self.conversation_id, self.user_id
),
TraceTaskName.MESSAGE_TRACE: lambda: self.message_trace(self.message_id),
TraceTaskName.MODERATION_TRACE: lambda: self.moderation_trace(
self.message_id, self.timer, **self.kwargs
),
TraceTaskName.MODERATION_TRACE: lambda: self.moderation_trace(self.message_id, self.timer, **self.kwargs),
TraceTaskName.SUGGESTED_QUESTION_TRACE: lambda: self.suggested_question_trace(
self.message_id, self.timer, **self.kwargs
),
@@ -337,12 +351,8 @@ class TraceTask:
workflow_run_id = workflow_run.id
workflow_run_elapsed_time = workflow_run.elapsed_time
workflow_run_status = workflow_run.status
workflow_run_inputs = (
json.loads(workflow_run.inputs) if workflow_run.inputs else {}
)
workflow_run_outputs = (
json.loads(workflow_run.outputs) if workflow_run.outputs else {}
)
workflow_run_inputs = json.loads(workflow_run.inputs) if workflow_run.inputs else {}
workflow_run_outputs = json.loads(workflow_run.outputs) if workflow_run.outputs else {}
workflow_run_version = workflow_run.version
error = workflow_run.error if workflow_run.error else ""
@@ -352,17 +362,18 @@ class TraceTask:
query = workflow_run_inputs.get("query") or workflow_run_inputs.get("sys.query") or ""
# get workflow_app_log_id
workflow_app_log_data = db.session.query(WorkflowAppLog).filter_by(
tenant_id=tenant_id,
app_id=workflow_run.app_id,
workflow_run_id=workflow_run.id
).first()
workflow_app_log_data = (
db.session.query(WorkflowAppLog)
.filter_by(tenant_id=tenant_id, app_id=workflow_run.app_id, workflow_run_id=workflow_run.id)
.first()
)
workflow_app_log_id = str(workflow_app_log_data.id) if workflow_app_log_data else None
# get message_id
message_data = db.session.query(Message.id).filter_by(
conversation_id=conversation_id,
workflow_run_id=workflow_run_id
).first()
message_data = (
db.session.query(Message.id)
.filter_by(conversation_id=conversation_id, workflow_run_id=workflow_run_id)
.first()
)
message_id = str(message_data.id) if message_data else None
metadata = {
@@ -470,9 +481,9 @@ class TraceTask:
# get workflow_app_log_id
workflow_app_log_id = None
if message_data.workflow_run_id:
workflow_app_log_data = db.session.query(WorkflowAppLog).filter_by(
workflow_run_id=message_data.workflow_run_id
).first()
workflow_app_log_data = (
db.session.query(WorkflowAppLog).filter_by(workflow_run_id=message_data.workflow_run_id).first()
)
workflow_app_log_id = str(workflow_app_log_data.id) if workflow_app_log_data else None
moderation_trace_info = ModerationTraceInfo(
@@ -510,9 +521,9 @@ class TraceTask:
# get workflow_app_log_id
workflow_app_log_id = None
if message_data.workflow_run_id:
workflow_app_log_data = db.session.query(WorkflowAppLog).filter_by(
workflow_run_id=message_data.workflow_run_id
).first()
workflow_app_log_data = (
db.session.query(WorkflowAppLog).filter_by(workflow_run_id=message_data.workflow_run_id).first()
)
workflow_app_log_id = str(workflow_app_log_data.id) if workflow_app_log_data else None
suggested_question_trace_info = SuggestedQuestionTraceInfo(
@@ -569,9 +580,9 @@ class TraceTask:
return dataset_retrieval_trace_info
def tool_trace(self, message_id, timer, **kwargs):
tool_name = kwargs.get('tool_name')
tool_inputs = kwargs.get('tool_inputs')
tool_outputs = kwargs.get('tool_outputs')
tool_name = kwargs.get("tool_name")
tool_inputs = kwargs.get("tool_inputs")
tool_outputs = kwargs.get("tool_outputs")
message_data = get_message_data(message_id)
if not message_data:
return {}
@@ -586,11 +597,11 @@ class TraceTask:
if tool_name in agent_thought.tools:
created_time = agent_thought.created_at
tool_meta_data = agent_thought.tool_meta.get(tool_name, {})
tool_config = tool_meta_data.get('tool_config', {})
time_cost = tool_meta_data.get('time_cost', 0)
tool_config = tool_meta_data.get("tool_config", {})
time_cost = tool_meta_data.get("time_cost", 0)
end_time = created_time + timedelta(seconds=time_cost)
error = tool_meta_data.get('error', "")
tool_parameters = tool_meta_data.get('tool_parameters', {})
error = tool_meta_data.get("error", "")
tool_parameters = tool_meta_data.get("tool_parameters", {})
metadata = {
"message_id": message_id,
"tool_name": tool_name,
@@ -715,9 +726,7 @@ class TraceQueueManager:
def start_timer(self):
global trace_manager_timer
if trace_manager_timer is None or not trace_manager_timer.is_alive():
trace_manager_timer = threading.Timer(
trace_manager_interval, self.run
)
trace_manager_timer = threading.Timer(trace_manager_interval, self.run)
trace_manager_timer.name = f"trace_manager_timer_{time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())}"
trace_manager_timer.daemon = False
trace_manager_timer.start()

View File

@@ -20,19 +20,19 @@ def get_message_data(message_id):
@contextmanager
def measure_time():
timing_info = {'start': datetime.now(), 'end': None}
timing_info = {"start": datetime.now(), "end": None}
try:
yield timing_info
finally:
timing_info['end'] = datetime.now()
timing_info["end"] = datetime.now()
def replace_text_with_content(data):
if isinstance(data, dict):
new_data = {}
for key, value in data.items():
if key == 'text':
new_data['content'] = value
if key == "text":
new_data["content"] = value
else:
new_data[key] = replace_text_with_content(value)
return new_data