diff --git a/api/core/ops/entities/config_entity.py b/api/core/ops/entities/config_entity.py index f2d1bd305..c988bf48d 100644 --- a/api/core/ops/entities/config_entity.py +++ b/api/core/ops/entities/config_entity.py @@ -98,6 +98,7 @@ class WeaveConfig(BaseTracingConfig): entity: str | None = None project: str endpoint: str = "https://trace.wandb.ai" + host: str | None = None @field_validator("endpoint") @classmethod @@ -109,6 +110,14 @@ class WeaveConfig(BaseTracingConfig): return v + @field_validator("host") + @classmethod + def validate_host(cls, v, info: ValidationInfo): + if v is not None and v != "": + if not v.startswith(("https://", "http://")): + raise ValueError("host must start with https:// or http://") + return v + OPS_FILE_PATH = "ops_trace/" OPS_TRACE_FAILED_KEY = "FAILED_OPS_TRACE" diff --git a/api/core/ops/ops_trace_manager.py b/api/core/ops/ops_trace_manager.py index dc4cfc48d..e0dfe0c31 100644 --- a/api/core/ops/ops_trace_manager.py +++ b/api/core/ops/ops_trace_manager.py @@ -81,7 +81,7 @@ class OpsTraceProviderConfigMap(dict[str, dict[str, Any]]): return { "config_class": WeaveConfig, "secret_keys": ["api_key"], - "other_keys": ["project", "entity", "endpoint"], + "other_keys": ["project", "entity", "endpoint", "host"], "trace_instance": WeaveDataTrace, } diff --git a/api/core/ops/weave_trace/weave_trace.py b/api/core/ops/weave_trace/weave_trace.py index cfc8a505b..3917348a9 100644 --- a/api/core/ops/weave_trace/weave_trace.py +++ b/api/core/ops/weave_trace/weave_trace.py @@ -40,9 +40,14 @@ class WeaveDataTrace(BaseTraceInstance): self.weave_api_key = weave_config.api_key self.project_name = weave_config.project self.entity = weave_config.entity + self.host = weave_config.host + + # Login with API key first, including host if provided + if self.host: + login_status = wandb.login(key=self.weave_api_key, verify=True, relogin=True, host=self.host) + else: + login_status = wandb.login(key=self.weave_api_key, verify=True, relogin=True) - # Login with API key first - login_status = wandb.login(key=self.weave_api_key, verify=True, relogin=True) if not login_status: logger.error("Failed to login to Weights & Biases with the provided API key") raise ValueError("Weave login failed") @@ -386,7 +391,11 @@ class WeaveDataTrace(BaseTraceInstance): def api_check(self): try: - login_status = wandb.login(key=self.weave_api_key, verify=True, relogin=True) + if self.host: + login_status = wandb.login(key=self.weave_api_key, verify=True, relogin=True, host=self.host) + else: + login_status = wandb.login(key=self.weave_api_key, verify=True, relogin=True) + if not login_status: raise ValueError("Weave login failed") else: diff --git a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/provider-config-modal.tsx b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/provider-config-modal.tsx index c0b52a9b1..b6c97add4 100644 --- a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/provider-config-modal.tsx +++ b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/provider-config-modal.tsx @@ -55,6 +55,7 @@ const weaveConfigTemplate = { entity: '', project: '', endpoint: '', + host: '', } const ProviderConfigModal: FC = ({ @@ -226,6 +227,13 @@ const ProviderConfigModal: FC = ({ onChange={handleConfigChange('endpoint')} placeholder={'https://trace.wandb.ai/'} /> + )} {type === TracingProvider.langSmith && ( diff --git a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/type.ts b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/type.ts index 386c58974..ed468caf6 100644 --- a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/type.ts +++ b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/type.ts @@ -29,4 +29,5 @@ export type WeaveConfig = { entity: string project: string endpoint: string + host: string }