refactor: Fix some type error (#22594)
Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
@@ -4,7 +4,7 @@ from controllers.console import api
|
||||
from controllers.console.datasets.error import WebsiteCrawlError
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from libs.login import login_required
|
||||
from services.website_service import WebsiteService
|
||||
from services.website_service import WebsiteCrawlApiRequest, WebsiteCrawlStatusApiRequest, WebsiteService
|
||||
|
||||
|
||||
class WebsiteCrawlApi(Resource):
|
||||
@@ -24,10 +24,16 @@ class WebsiteCrawlApi(Resource):
|
||||
parser.add_argument("url", type=str, required=True, nullable=True, location="json")
|
||||
parser.add_argument("options", type=dict, required=True, nullable=True, location="json")
|
||||
args = parser.parse_args()
|
||||
WebsiteService.document_create_args_validate(args)
|
||||
# crawl url
|
||||
|
||||
# Create typed request and validate
|
||||
try:
|
||||
result = WebsiteService.crawl_url(args)
|
||||
api_request = WebsiteCrawlApiRequest.from_args(args)
|
||||
except ValueError as e:
|
||||
raise WebsiteCrawlError(str(e))
|
||||
|
||||
# Crawl URL using typed request
|
||||
try:
|
||||
result = WebsiteService.crawl_url(api_request)
|
||||
except Exception as e:
|
||||
raise WebsiteCrawlError(str(e))
|
||||
return result, 200
|
||||
@@ -43,9 +49,16 @@ class WebsiteCrawlStatusApi(Resource):
|
||||
"provider", type=str, choices=["firecrawl", "watercrawl", "jinareader"], required=True, location="args"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
# get crawl status
|
||||
|
||||
# Create typed request and validate
|
||||
try:
|
||||
result = WebsiteService.get_crawl_status(job_id, args["provider"])
|
||||
api_request = WebsiteCrawlStatusApiRequest.from_args(args, job_id)
|
||||
except ValueError as e:
|
||||
raise WebsiteCrawlError(str(e))
|
||||
|
||||
# Get crawl status using typed request
|
||||
try:
|
||||
result = WebsiteService.get_crawl_status_typed(api_request)
|
||||
except Exception as e:
|
||||
raise WebsiteCrawlError(str(e))
|
||||
return result, 200
|
||||
|
@@ -21,7 +21,7 @@ def encrypt_token(tenant_id: str, token: str):
|
||||
return base64.b64encode(encrypted_token).decode()
|
||||
|
||||
|
||||
def decrypt_token(tenant_id: str, token: str):
|
||||
def decrypt_token(tenant_id: str, token: str) -> str:
|
||||
return rsa.decrypt(base64.b64decode(token), tenant_id)
|
||||
|
||||
|
||||
|
@@ -1,4 +1,5 @@
|
||||
import hashlib
|
||||
from typing import Union
|
||||
|
||||
from Crypto.Cipher import AES
|
||||
from Crypto.PublicKey import RSA
|
||||
@@ -9,7 +10,7 @@ from extensions.ext_storage import storage
|
||||
from libs import gmpy2_pkcs10aep_cipher
|
||||
|
||||
|
||||
def generate_key_pair(tenant_id):
|
||||
def generate_key_pair(tenant_id: str) -> str:
|
||||
private_key = RSA.generate(2048)
|
||||
public_key = private_key.publickey()
|
||||
|
||||
@@ -26,7 +27,7 @@ def generate_key_pair(tenant_id):
|
||||
prefix_hybrid = b"HYBRID:"
|
||||
|
||||
|
||||
def encrypt(text, public_key):
|
||||
def encrypt(text: str, public_key: Union[str, bytes]) -> bytes:
|
||||
if isinstance(public_key, str):
|
||||
public_key = public_key.encode()
|
||||
|
||||
@@ -38,14 +39,14 @@ def encrypt(text, public_key):
|
||||
rsa_key = RSA.import_key(public_key)
|
||||
cipher_rsa = gmpy2_pkcs10aep_cipher.new(rsa_key)
|
||||
|
||||
enc_aes_key = cipher_rsa.encrypt(aes_key)
|
||||
enc_aes_key: bytes = cipher_rsa.encrypt(aes_key)
|
||||
|
||||
encrypted_data = enc_aes_key + cipher_aes.nonce + tag + ciphertext
|
||||
|
||||
return prefix_hybrid + encrypted_data
|
||||
|
||||
|
||||
def get_decrypt_decoding(tenant_id):
|
||||
def get_decrypt_decoding(tenant_id: str) -> tuple[RSA.RsaKey, object]:
|
||||
filepath = "privkeys/{tenant_id}".format(tenant_id=tenant_id) + "/private.pem"
|
||||
|
||||
cache_key = "tenant_privkey:{hash}".format(hash=hashlib.sha3_256(filepath.encode()).hexdigest())
|
||||
@@ -64,7 +65,7 @@ def get_decrypt_decoding(tenant_id):
|
||||
return rsa_key, cipher_rsa
|
||||
|
||||
|
||||
def decrypt_token_with_decoding(encrypted_text, rsa_key, cipher_rsa):
|
||||
def decrypt_token_with_decoding(encrypted_text: bytes, rsa_key: RSA.RsaKey, cipher_rsa) -> str:
|
||||
if encrypted_text.startswith(prefix_hybrid):
|
||||
encrypted_text = encrypted_text[len(prefix_hybrid) :]
|
||||
|
||||
@@ -83,10 +84,10 @@ def decrypt_token_with_decoding(encrypted_text, rsa_key, cipher_rsa):
|
||||
return decrypted_text.decode()
|
||||
|
||||
|
||||
def decrypt(encrypted_text, tenant_id):
|
||||
def decrypt(encrypted_text: bytes, tenant_id: str) -> str:
|
||||
rsa_key, cipher_rsa = get_decrypt_decoding(tenant_id)
|
||||
|
||||
return decrypt_token_with_decoding(encrypted_text, rsa_key, cipher_rsa)
|
||||
return decrypt_token_with_decoding(encrypted_text=encrypted_text, rsa_key=rsa_key, cipher_rsa=cipher_rsa)
|
||||
|
||||
|
||||
class PrivkeyNotFoundError(Exception):
|
||||
|
@@ -196,7 +196,7 @@ class Tenant(Base):
|
||||
__tablename__ = "tenants"
|
||||
__table_args__ = (db.PrimaryKeyConstraint("id", name="tenant_pkey"),)
|
||||
|
||||
id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
|
||||
id: Mapped[str] = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
|
||||
name = db.Column(db.String(255), nullable=False)
|
||||
encrypt_public_key = db.Column(db.Text)
|
||||
plan = db.Column(db.String(255), nullable=False, server_default=db.text("'basic'::character varying"))
|
||||
|
@@ -334,21 +334,33 @@ class ToolTransformService:
|
||||
)
|
||||
|
||||
# get tool parameters
|
||||
parameters = tool.entity.parameters or []
|
||||
base_parameters = tool.entity.parameters or []
|
||||
# get tool runtime parameters
|
||||
runtime_parameters = tool.get_runtime_parameters()
|
||||
# override parameters
|
||||
current_parameters = parameters.copy()
|
||||
for runtime_parameter in runtime_parameters:
|
||||
found = False
|
||||
for index, parameter in enumerate(current_parameters):
|
||||
if parameter.name == runtime_parameter.name and parameter.form == runtime_parameter.form:
|
||||
current_parameters[index] = runtime_parameter
|
||||
found = True
|
||||
break
|
||||
|
||||
if not found and runtime_parameter.form == ToolParameter.ToolParameterForm.FORM:
|
||||
current_parameters.append(runtime_parameter)
|
||||
# merge parameters using a functional approach to avoid type issues
|
||||
merged_parameters: list[ToolParameter] = []
|
||||
|
||||
# create a mapping of runtime parameters for quick lookup
|
||||
runtime_param_map = {(rp.name, rp.form): rp for rp in runtime_parameters}
|
||||
|
||||
# process base parameters, replacing with runtime versions if they exist
|
||||
for base_param in base_parameters:
|
||||
key = (base_param.name, base_param.form)
|
||||
if key in runtime_param_map:
|
||||
merged_parameters.append(runtime_param_map[key])
|
||||
else:
|
||||
merged_parameters.append(base_param)
|
||||
|
||||
# add any runtime parameters that weren't in base parameters
|
||||
for runtime_parameter in runtime_parameters:
|
||||
if runtime_parameter.form == ToolParameter.ToolParameterForm.FORM:
|
||||
# check if this parameter is already in merged_parameters
|
||||
already_exists = any(
|
||||
p.name == runtime_parameter.name and p.form == runtime_parameter.form for p in merged_parameters
|
||||
)
|
||||
if not already_exists:
|
||||
merged_parameters.append(runtime_parameter)
|
||||
|
||||
return ToolApiEntity(
|
||||
author=tool.entity.identity.author,
|
||||
@@ -356,10 +368,10 @@ class ToolTransformService:
|
||||
label=tool.entity.identity.label,
|
||||
description=tool.entity.description.human if tool.entity.description else I18nObject(en_US=""),
|
||||
output_schema=tool.entity.output_schema,
|
||||
parameters=current_parameters,
|
||||
parameters=merged_parameters,
|
||||
labels=labels or [],
|
||||
)
|
||||
if isinstance(tool, ApiToolBundle):
|
||||
elif isinstance(tool, ApiToolBundle):
|
||||
return ToolApiEntity(
|
||||
author=tool.author,
|
||||
name=tool.operation_id or "",
|
||||
@@ -368,6 +380,9 @@ class ToolTransformService:
|
||||
parameters=tool.parameters,
|
||||
labels=labels or [],
|
||||
)
|
||||
else:
|
||||
# Handle WorkflowTool case
|
||||
raise ValueError(f"Unsupported tool type: {type(tool)}")
|
||||
|
||||
@staticmethod
|
||||
def convert_builtin_provider_to_credential_entity(
|
||||
|
@@ -1,6 +1,7 @@
|
||||
import datetime
|
||||
import json
|
||||
from typing import Any
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional
|
||||
|
||||
import requests
|
||||
from flask_login import current_user
|
||||
@@ -13,241 +14,392 @@ from extensions.ext_storage import storage
|
||||
from services.auth.api_key_auth_service import ApiKeyAuthService
|
||||
|
||||
|
||||
class WebsiteService:
|
||||
@classmethod
|
||||
def document_create_args_validate(cls, args: dict):
|
||||
if "url" not in args or not args["url"]:
|
||||
raise ValueError("url is required")
|
||||
if "options" not in args or not args["options"]:
|
||||
raise ValueError("options is required")
|
||||
if "limit" not in args["options"] or not args["options"]["limit"]:
|
||||
raise ValueError("limit is required")
|
||||
@dataclass
|
||||
class CrawlOptions:
|
||||
"""Options for crawling operations."""
|
||||
|
||||
limit: int = 1
|
||||
crawl_sub_pages: bool = False
|
||||
only_main_content: bool = False
|
||||
includes: Optional[str] = None
|
||||
excludes: Optional[str] = None
|
||||
max_depth: Optional[int] = None
|
||||
use_sitemap: bool = True
|
||||
|
||||
def get_include_paths(self) -> list[str]:
|
||||
"""Get list of include paths from comma-separated string."""
|
||||
return self.includes.split(",") if self.includes else []
|
||||
|
||||
def get_exclude_paths(self) -> list[str]:
|
||||
"""Get list of exclude paths from comma-separated string."""
|
||||
return self.excludes.split(",") if self.excludes else []
|
||||
|
||||
|
||||
@dataclass
|
||||
class CrawlRequest:
|
||||
"""Request container for crawling operations."""
|
||||
|
||||
url: str
|
||||
provider: str
|
||||
options: CrawlOptions
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScrapeRequest:
|
||||
"""Request container for scraping operations."""
|
||||
|
||||
provider: str
|
||||
url: str
|
||||
tenant_id: str
|
||||
only_main_content: bool
|
||||
|
||||
|
||||
@dataclass
|
||||
class WebsiteCrawlApiRequest:
|
||||
"""Request container for website crawl API arguments."""
|
||||
|
||||
provider: str
|
||||
url: str
|
||||
options: dict[str, Any]
|
||||
|
||||
def to_crawl_request(self) -> CrawlRequest:
|
||||
"""Convert API request to internal CrawlRequest."""
|
||||
options = CrawlOptions(
|
||||
limit=self.options.get("limit", 1),
|
||||
crawl_sub_pages=self.options.get("crawl_sub_pages", False),
|
||||
only_main_content=self.options.get("only_main_content", False),
|
||||
includes=self.options.get("includes"),
|
||||
excludes=self.options.get("excludes"),
|
||||
max_depth=self.options.get("max_depth"),
|
||||
use_sitemap=self.options.get("use_sitemap", True),
|
||||
)
|
||||
return CrawlRequest(url=self.url, provider=self.provider, options=options)
|
||||
|
||||
@classmethod
|
||||
def crawl_url(cls, args: dict) -> dict:
|
||||
provider = args.get("provider", "")
|
||||
def from_args(cls, args: dict) -> "WebsiteCrawlApiRequest":
|
||||
"""Create from Flask-RESTful parsed arguments."""
|
||||
provider = args.get("provider")
|
||||
url = args.get("url")
|
||||
options = args.get("options", "")
|
||||
credentials = ApiKeyAuthService.get_auth_credentials(current_user.current_tenant_id, "website", provider)
|
||||
if provider == "firecrawl":
|
||||
# decrypt api_key
|
||||
api_key = encrypter.decrypt_token(
|
||||
tenant_id=current_user.current_tenant_id, token=credentials.get("config").get("api_key")
|
||||
)
|
||||
firecrawl_app = FirecrawlApp(api_key=api_key, base_url=credentials.get("config").get("base_url", None))
|
||||
crawl_sub_pages = options.get("crawl_sub_pages", False)
|
||||
only_main_content = options.get("only_main_content", False)
|
||||
if not crawl_sub_pages:
|
||||
params = {
|
||||
"includePaths": [],
|
||||
"excludePaths": [],
|
||||
"limit": 1,
|
||||
"scrapeOptions": {"onlyMainContent": only_main_content},
|
||||
}
|
||||
else:
|
||||
includes = options.get("includes").split(",") if options.get("includes") else []
|
||||
excludes = options.get("excludes").split(",") if options.get("excludes") else []
|
||||
params = {
|
||||
"includePaths": includes,
|
||||
"excludePaths": excludes,
|
||||
"limit": options.get("limit", 1),
|
||||
"scrapeOptions": {"onlyMainContent": only_main_content},
|
||||
}
|
||||
if options.get("max_depth"):
|
||||
params["maxDepth"] = options.get("max_depth")
|
||||
job_id = firecrawl_app.crawl_url(url, params)
|
||||
website_crawl_time_cache_key = f"website_crawl_{job_id}"
|
||||
time = str(datetime.datetime.now().timestamp())
|
||||
redis_client.setex(website_crawl_time_cache_key, 3600, time)
|
||||
return {"status": "active", "job_id": job_id}
|
||||
elif provider == "watercrawl":
|
||||
# decrypt api_key
|
||||
api_key = encrypter.decrypt_token(
|
||||
tenant_id=current_user.current_tenant_id, token=credentials.get("config").get("api_key")
|
||||
)
|
||||
return WaterCrawlProvider(api_key, credentials.get("config").get("base_url", None)).crawl_url(url, options)
|
||||
options = args.get("options", {})
|
||||
|
||||
elif provider == "jinareader":
|
||||
api_key = encrypter.decrypt_token(
|
||||
tenant_id=current_user.current_tenant_id, token=credentials.get("config").get("api_key")
|
||||
)
|
||||
crawl_sub_pages = options.get("crawl_sub_pages", False)
|
||||
if not crawl_sub_pages:
|
||||
response = requests.get(
|
||||
f"https://r.jina.ai/{url}",
|
||||
headers={"Accept": "application/json", "Authorization": f"Bearer {api_key}"},
|
||||
)
|
||||
if response.json().get("code") != 200:
|
||||
raise ValueError("Failed to crawl")
|
||||
return {"status": "active", "data": response.json().get("data")}
|
||||
else:
|
||||
response = requests.post(
|
||||
"https://adaptivecrawl-kir3wx7b3a-uc.a.run.app",
|
||||
json={
|
||||
"url": url,
|
||||
"maxPages": options.get("limit", 1),
|
||||
"useSitemap": options.get("use_sitemap", True),
|
||||
},
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
},
|
||||
)
|
||||
if response.json().get("code") != 200:
|
||||
raise ValueError("Failed to crawl")
|
||||
return {"status": "active", "job_id": response.json().get("data", {}).get("taskId")}
|
||||
if not provider:
|
||||
raise ValueError("Provider is required")
|
||||
if not url:
|
||||
raise ValueError("URL is required")
|
||||
if not options:
|
||||
raise ValueError("Options are required")
|
||||
|
||||
return cls(provider=provider, url=url, options=options)
|
||||
|
||||
|
||||
@dataclass
|
||||
class WebsiteCrawlStatusApiRequest:
|
||||
"""Request container for website crawl status API arguments."""
|
||||
|
||||
provider: str
|
||||
job_id: str
|
||||
|
||||
@classmethod
|
||||
def from_args(cls, args: dict, job_id: str) -> "WebsiteCrawlStatusApiRequest":
|
||||
"""Create from Flask-RESTful parsed arguments."""
|
||||
provider = args.get("provider")
|
||||
|
||||
if not provider:
|
||||
raise ValueError("Provider is required")
|
||||
if not job_id:
|
||||
raise ValueError("Job ID is required")
|
||||
|
||||
return cls(provider=provider, job_id=job_id)
|
||||
|
||||
|
||||
class WebsiteService:
|
||||
"""Service class for website crawling operations using different providers."""
|
||||
|
||||
@classmethod
|
||||
def _get_credentials_and_config(cls, tenant_id: str, provider: str) -> tuple[dict, dict]:
|
||||
"""Get and validate credentials for a provider."""
|
||||
credentials = ApiKeyAuthService.get_auth_credentials(tenant_id, "website", provider)
|
||||
if not credentials or "config" not in credentials:
|
||||
raise ValueError("No valid credentials found for the provider")
|
||||
return credentials, credentials["config"]
|
||||
|
||||
@classmethod
|
||||
def _get_decrypted_api_key(cls, tenant_id: str, config: dict) -> str:
|
||||
"""Decrypt and return the API key from config."""
|
||||
api_key = config.get("api_key")
|
||||
if not api_key:
|
||||
raise ValueError("API key not found in configuration")
|
||||
return encrypter.decrypt_token(tenant_id=tenant_id, token=api_key)
|
||||
|
||||
@classmethod
|
||||
def document_create_args_validate(cls, args: dict) -> None:
|
||||
"""Validate arguments for document creation."""
|
||||
try:
|
||||
WebsiteCrawlApiRequest.from_args(args)
|
||||
except ValueError as e:
|
||||
raise ValueError(f"Invalid arguments: {e}")
|
||||
|
||||
@classmethod
|
||||
def crawl_url(cls, api_request: WebsiteCrawlApiRequest) -> dict[str, Any]:
|
||||
"""Crawl a URL using the specified provider with typed request."""
|
||||
request = api_request.to_crawl_request()
|
||||
|
||||
_, config = cls._get_credentials_and_config(current_user.current_tenant_id, request.provider)
|
||||
api_key = cls._get_decrypted_api_key(current_user.current_tenant_id, config)
|
||||
|
||||
if request.provider == "firecrawl":
|
||||
return cls._crawl_with_firecrawl(request=request, api_key=api_key, config=config)
|
||||
elif request.provider == "watercrawl":
|
||||
return cls._crawl_with_watercrawl(request=request, api_key=api_key, config=config)
|
||||
elif request.provider == "jinareader":
|
||||
return cls._crawl_with_jinareader(request=request, api_key=api_key)
|
||||
else:
|
||||
raise ValueError("Invalid provider")
|
||||
|
||||
@classmethod
|
||||
def get_crawl_status(cls, job_id: str, provider: str) -> dict:
|
||||
credentials = ApiKeyAuthService.get_auth_credentials(current_user.current_tenant_id, "website", provider)
|
||||
if provider == "firecrawl":
|
||||
# decrypt api_key
|
||||
api_key = encrypter.decrypt_token(
|
||||
tenant_id=current_user.current_tenant_id, token=credentials.get("config").get("api_key")
|
||||
)
|
||||
firecrawl_app = FirecrawlApp(api_key=api_key, base_url=credentials.get("config").get("base_url", None))
|
||||
result = firecrawl_app.check_crawl_status(job_id)
|
||||
crawl_status_data = {
|
||||
"status": result.get("status", "active"),
|
||||
"job_id": job_id,
|
||||
"total": result.get("total", 0),
|
||||
"current": result.get("current", 0),
|
||||
"data": result.get("data", []),
|
||||
def _crawl_with_firecrawl(cls, request: CrawlRequest, api_key: str, config: dict) -> dict[str, Any]:
|
||||
firecrawl_app = FirecrawlApp(api_key=api_key, base_url=config.get("base_url"))
|
||||
|
||||
if not request.options.crawl_sub_pages:
|
||||
params = {
|
||||
"includePaths": [],
|
||||
"excludePaths": [],
|
||||
"limit": 1,
|
||||
"scrapeOptions": {"onlyMainContent": request.options.only_main_content},
|
||||
}
|
||||
if crawl_status_data["status"] == "completed":
|
||||
website_crawl_time_cache_key = f"website_crawl_{job_id}"
|
||||
start_time = redis_client.get(website_crawl_time_cache_key)
|
||||
if start_time:
|
||||
end_time = datetime.datetime.now().timestamp()
|
||||
time_consuming = abs(end_time - float(start_time))
|
||||
crawl_status_data["time_consuming"] = f"{time_consuming:.2f}"
|
||||
redis_client.delete(website_crawl_time_cache_key)
|
||||
elif provider == "watercrawl":
|
||||
# decrypt api_key
|
||||
api_key = encrypter.decrypt_token(
|
||||
tenant_id=current_user.current_tenant_id, token=credentials.get("config").get("api_key")
|
||||
else:
|
||||
params = {
|
||||
"includePaths": request.options.get_include_paths(),
|
||||
"excludePaths": request.options.get_exclude_paths(),
|
||||
"limit": request.options.limit,
|
||||
"scrapeOptions": {"onlyMainContent": request.options.only_main_content},
|
||||
}
|
||||
if request.options.max_depth:
|
||||
params["maxDepth"] = request.options.max_depth
|
||||
|
||||
job_id = firecrawl_app.crawl_url(request.url, params)
|
||||
website_crawl_time_cache_key = f"website_crawl_{job_id}"
|
||||
time = str(datetime.datetime.now().timestamp())
|
||||
redis_client.setex(website_crawl_time_cache_key, 3600, time)
|
||||
return {"status": "active", "job_id": job_id}
|
||||
|
||||
@classmethod
|
||||
def _crawl_with_watercrawl(cls, request: CrawlRequest, api_key: str, config: dict) -> dict[str, Any]:
|
||||
# Convert CrawlOptions back to dict format for WaterCrawlProvider
|
||||
options = {
|
||||
"limit": request.options.limit,
|
||||
"crawl_sub_pages": request.options.crawl_sub_pages,
|
||||
"only_main_content": request.options.only_main_content,
|
||||
"includes": request.options.includes,
|
||||
"excludes": request.options.excludes,
|
||||
"max_depth": request.options.max_depth,
|
||||
"use_sitemap": request.options.use_sitemap,
|
||||
}
|
||||
return WaterCrawlProvider(api_key=api_key, base_url=config.get("base_url")).crawl_url(
|
||||
url=request.url, options=options
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _crawl_with_jinareader(cls, request: CrawlRequest, api_key: str) -> dict[str, Any]:
|
||||
if not request.options.crawl_sub_pages:
|
||||
response = requests.get(
|
||||
f"https://r.jina.ai/{request.url}",
|
||||
headers={"Accept": "application/json", "Authorization": f"Bearer {api_key}"},
|
||||
)
|
||||
crawl_status_data = WaterCrawlProvider(
|
||||
api_key, credentials.get("config").get("base_url", None)
|
||||
).get_crawl_status(job_id)
|
||||
elif provider == "jinareader":
|
||||
api_key = encrypter.decrypt_token(
|
||||
tenant_id=current_user.current_tenant_id, token=credentials.get("config").get("api_key")
|
||||
if response.json().get("code") != 200:
|
||||
raise ValueError("Failed to crawl")
|
||||
return {"status": "active", "data": response.json().get("data")}
|
||||
else:
|
||||
response = requests.post(
|
||||
"https://adaptivecrawl-kir3wx7b3a-uc.a.run.app",
|
||||
json={
|
||||
"url": request.url,
|
||||
"maxPages": request.options.limit,
|
||||
"useSitemap": request.options.use_sitemap,
|
||||
},
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
},
|
||||
)
|
||||
if response.json().get("code") != 200:
|
||||
raise ValueError("Failed to crawl")
|
||||
return {"status": "active", "job_id": response.json().get("data", {}).get("taskId")}
|
||||
|
||||
@classmethod
|
||||
def get_crawl_status(cls, job_id: str, provider: str) -> dict[str, Any]:
|
||||
"""Get crawl status using string parameters."""
|
||||
api_request = WebsiteCrawlStatusApiRequest(provider=provider, job_id=job_id)
|
||||
return cls.get_crawl_status_typed(api_request)
|
||||
|
||||
@classmethod
|
||||
def get_crawl_status_typed(cls, api_request: WebsiteCrawlStatusApiRequest) -> dict[str, Any]:
|
||||
"""Get crawl status using typed request."""
|
||||
_, config = cls._get_credentials_and_config(current_user.current_tenant_id, api_request.provider)
|
||||
api_key = cls._get_decrypted_api_key(current_user.current_tenant_id, config)
|
||||
|
||||
if api_request.provider == "firecrawl":
|
||||
return cls._get_firecrawl_status(api_request.job_id, api_key, config)
|
||||
elif api_request.provider == "watercrawl":
|
||||
return cls._get_watercrawl_status(api_request.job_id, api_key, config)
|
||||
elif api_request.provider == "jinareader":
|
||||
return cls._get_jinareader_status(api_request.job_id, api_key)
|
||||
else:
|
||||
raise ValueError("Invalid provider")
|
||||
|
||||
@classmethod
|
||||
def _get_firecrawl_status(cls, job_id: str, api_key: str, config: dict) -> dict[str, Any]:
|
||||
firecrawl_app = FirecrawlApp(api_key=api_key, base_url=config.get("base_url"))
|
||||
result = firecrawl_app.check_crawl_status(job_id)
|
||||
crawl_status_data = {
|
||||
"status": result.get("status", "active"),
|
||||
"job_id": job_id,
|
||||
"total": result.get("total", 0),
|
||||
"current": result.get("current", 0),
|
||||
"data": result.get("data", []),
|
||||
}
|
||||
if crawl_status_data["status"] == "completed":
|
||||
website_crawl_time_cache_key = f"website_crawl_{job_id}"
|
||||
start_time = redis_client.get(website_crawl_time_cache_key)
|
||||
if start_time:
|
||||
end_time = datetime.datetime.now().timestamp()
|
||||
time_consuming = abs(end_time - float(start_time))
|
||||
crawl_status_data["time_consuming"] = f"{time_consuming:.2f}"
|
||||
redis_client.delete(website_crawl_time_cache_key)
|
||||
return crawl_status_data
|
||||
|
||||
@classmethod
|
||||
def _get_watercrawl_status(cls, job_id: str, api_key: str, config: dict) -> dict[str, Any]:
|
||||
return WaterCrawlProvider(api_key, config.get("base_url")).get_crawl_status(job_id)
|
||||
|
||||
@classmethod
|
||||
def _get_jinareader_status(cls, job_id: str, api_key: str) -> dict[str, Any]:
|
||||
response = requests.post(
|
||||
"https://adaptivecrawlstatus-kir3wx7b3a-uc.a.run.app",
|
||||
headers={"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"},
|
||||
json={"taskId": job_id},
|
||||
)
|
||||
data = response.json().get("data", {})
|
||||
crawl_status_data = {
|
||||
"status": data.get("status", "active"),
|
||||
"job_id": job_id,
|
||||
"total": len(data.get("urls", [])),
|
||||
"current": len(data.get("processed", [])) + len(data.get("failed", [])),
|
||||
"data": [],
|
||||
"time_consuming": data.get("duration", 0) / 1000,
|
||||
}
|
||||
|
||||
if crawl_status_data["status"] == "completed":
|
||||
response = requests.post(
|
||||
"https://adaptivecrawlstatus-kir3wx7b3a-uc.a.run.app",
|
||||
headers={"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"},
|
||||
json={"taskId": job_id},
|
||||
json={"taskId": job_id, "urls": list(data.get("processed", {}).keys())},
|
||||
)
|
||||
data = response.json().get("data", {})
|
||||
crawl_status_data = {
|
||||
"status": data.get("status", "active"),
|
||||
"job_id": job_id,
|
||||
"total": len(data.get("urls", [])),
|
||||
"current": len(data.get("processed", [])) + len(data.get("failed", [])),
|
||||
"data": [],
|
||||
"time_consuming": data.get("duration", 0) / 1000,
|
||||
}
|
||||
|
||||
if crawl_status_data["status"] == "completed":
|
||||
response = requests.post(
|
||||
"https://adaptivecrawlstatus-kir3wx7b3a-uc.a.run.app",
|
||||
headers={"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"},
|
||||
json={"taskId": job_id, "urls": list(data.get("processed", {}).keys())},
|
||||
)
|
||||
data = response.json().get("data", {})
|
||||
formatted_data = [
|
||||
{
|
||||
"title": item.get("data", {}).get("title"),
|
||||
"source_url": item.get("data", {}).get("url"),
|
||||
"description": item.get("data", {}).get("description"),
|
||||
"markdown": item.get("data", {}).get("content"),
|
||||
}
|
||||
for item in data.get("processed", {}).values()
|
||||
]
|
||||
crawl_status_data["data"] = formatted_data
|
||||
else:
|
||||
raise ValueError("Invalid provider")
|
||||
formatted_data = [
|
||||
{
|
||||
"title": item.get("data", {}).get("title"),
|
||||
"source_url": item.get("data", {}).get("url"),
|
||||
"description": item.get("data", {}).get("description"),
|
||||
"markdown": item.get("data", {}).get("content"),
|
||||
}
|
||||
for item in data.get("processed", {}).values()
|
||||
]
|
||||
crawl_status_data["data"] = formatted_data
|
||||
return crawl_status_data
|
||||
|
||||
@classmethod
|
||||
def get_crawl_url_data(cls, job_id: str, provider: str, url: str, tenant_id: str) -> dict[str, Any] | None:
|
||||
credentials = ApiKeyAuthService.get_auth_credentials(tenant_id, "website", provider)
|
||||
# decrypt api_key
|
||||
api_key = encrypter.decrypt_token(tenant_id=tenant_id, token=credentials.get("config").get("api_key"))
|
||||
_, config = cls._get_credentials_and_config(tenant_id, provider)
|
||||
api_key = cls._get_decrypted_api_key(tenant_id, config)
|
||||
|
||||
if provider == "firecrawl":
|
||||
crawl_data: list[dict[str, Any]] | None = None
|
||||
file_key = "website_files/" + job_id + ".txt"
|
||||
if storage.exists(file_key):
|
||||
stored_data = storage.load_once(file_key)
|
||||
if stored_data:
|
||||
crawl_data = json.loads(stored_data.decode("utf-8"))
|
||||
else:
|
||||
firecrawl_app = FirecrawlApp(api_key=api_key, base_url=credentials.get("config").get("base_url", None))
|
||||
result = firecrawl_app.check_crawl_status(job_id)
|
||||
if result.get("status") != "completed":
|
||||
raise ValueError("Crawl job is not completed")
|
||||
crawl_data = result.get("data")
|
||||
|
||||
if crawl_data:
|
||||
for item in crawl_data:
|
||||
if item.get("source_url") == url:
|
||||
return dict(item)
|
||||
return None
|
||||
return cls._get_firecrawl_url_data(job_id, url, api_key, config)
|
||||
elif provider == "watercrawl":
|
||||
api_key = encrypter.decrypt_token(tenant_id=tenant_id, token=credentials.get("config").get("api_key"))
|
||||
return WaterCrawlProvider(api_key, credentials.get("config").get("base_url", None)).get_crawl_url_data(
|
||||
job_id, url
|
||||
)
|
||||
return cls._get_watercrawl_url_data(job_id, url, api_key, config)
|
||||
elif provider == "jinareader":
|
||||
if not job_id:
|
||||
response = requests.get(
|
||||
f"https://r.jina.ai/{url}",
|
||||
headers={"Accept": "application/json", "Authorization": f"Bearer {api_key}"},
|
||||
)
|
||||
if response.json().get("code") != 200:
|
||||
raise ValueError("Failed to crawl")
|
||||
return dict(response.json().get("data", {}))
|
||||
else:
|
||||
# Get crawl status first
|
||||
status_response = requests.post(
|
||||
"https://adaptivecrawlstatus-kir3wx7b3a-uc.a.run.app",
|
||||
headers={"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"},
|
||||
json={"taskId": job_id},
|
||||
)
|
||||
status_data = status_response.json().get("data", {})
|
||||
if status_data.get("status") != "completed":
|
||||
raise ValueError("Crawl job is not completed")
|
||||
|
||||
# Get processed data
|
||||
data_response = requests.post(
|
||||
"https://adaptivecrawlstatus-kir3wx7b3a-uc.a.run.app",
|
||||
headers={"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"},
|
||||
json={"taskId": job_id, "urls": list(status_data.get("processed", {}).keys())},
|
||||
)
|
||||
processed_data = data_response.json().get("data", {})
|
||||
for item in processed_data.get("processed", {}).values():
|
||||
if item.get("data", {}).get("url") == url:
|
||||
return dict(item.get("data", {}))
|
||||
return None
|
||||
return cls._get_jinareader_url_data(job_id, url, api_key)
|
||||
else:
|
||||
raise ValueError("Invalid provider")
|
||||
|
||||
@classmethod
|
||||
def get_scrape_url_data(cls, provider: str, url: str, tenant_id: str, only_main_content: bool) -> dict:
|
||||
credentials = ApiKeyAuthService.get_auth_credentials(tenant_id, "website", provider)
|
||||
if provider == "firecrawl":
|
||||
# decrypt api_key
|
||||
api_key = encrypter.decrypt_token(tenant_id=tenant_id, token=credentials.get("config").get("api_key"))
|
||||
firecrawl_app = FirecrawlApp(api_key=api_key, base_url=credentials.get("config").get("base_url", None))
|
||||
params = {"onlyMainContent": only_main_content}
|
||||
result = firecrawl_app.scrape_url(url, params)
|
||||
return result
|
||||
elif provider == "watercrawl":
|
||||
api_key = encrypter.decrypt_token(tenant_id=tenant_id, token=credentials.get("config").get("api_key"))
|
||||
return WaterCrawlProvider(api_key, credentials.get("config").get("base_url", None)).scrape_url(url)
|
||||
def _get_firecrawl_url_data(cls, job_id: str, url: str, api_key: str, config: dict) -> dict[str, Any] | None:
|
||||
crawl_data: list[dict[str, Any]] | None = None
|
||||
file_key = "website_files/" + job_id + ".txt"
|
||||
if storage.exists(file_key):
|
||||
stored_data = storage.load_once(file_key)
|
||||
if stored_data:
|
||||
crawl_data = json.loads(stored_data.decode("utf-8"))
|
||||
else:
|
||||
firecrawl_app = FirecrawlApp(api_key=api_key, base_url=config.get("base_url"))
|
||||
result = firecrawl_app.check_crawl_status(job_id)
|
||||
if result.get("status") != "completed":
|
||||
raise ValueError("Crawl job is not completed")
|
||||
crawl_data = result.get("data")
|
||||
|
||||
if crawl_data:
|
||||
for item in crawl_data:
|
||||
if item.get("source_url") == url:
|
||||
return dict(item)
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def _get_watercrawl_url_data(cls, job_id: str, url: str, api_key: str, config: dict) -> dict[str, Any] | None:
|
||||
return WaterCrawlProvider(api_key, config.get("base_url")).get_crawl_url_data(job_id, url)
|
||||
|
||||
@classmethod
|
||||
def _get_jinareader_url_data(cls, job_id: str, url: str, api_key: str) -> dict[str, Any] | None:
|
||||
if not job_id:
|
||||
response = requests.get(
|
||||
f"https://r.jina.ai/{url}",
|
||||
headers={"Accept": "application/json", "Authorization": f"Bearer {api_key}"},
|
||||
)
|
||||
if response.json().get("code") != 200:
|
||||
raise ValueError("Failed to crawl")
|
||||
return dict(response.json().get("data", {}))
|
||||
else:
|
||||
# Get crawl status first
|
||||
status_response = requests.post(
|
||||
"https://adaptivecrawlstatus-kir3wx7b3a-uc.a.run.app",
|
||||
headers={"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"},
|
||||
json={"taskId": job_id},
|
||||
)
|
||||
status_data = status_response.json().get("data", {})
|
||||
if status_data.get("status") != "completed":
|
||||
raise ValueError("Crawl job is not completed")
|
||||
|
||||
# Get processed data
|
||||
data_response = requests.post(
|
||||
"https://adaptivecrawlstatus-kir3wx7b3a-uc.a.run.app",
|
||||
headers={"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"},
|
||||
json={"taskId": job_id, "urls": list(status_data.get("processed", {}).keys())},
|
||||
)
|
||||
processed_data = data_response.json().get("data", {})
|
||||
for item in processed_data.get("processed", {}).values():
|
||||
if item.get("data", {}).get("url") == url:
|
||||
return dict(item.get("data", {}))
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def get_scrape_url_data(cls, provider: str, url: str, tenant_id: str, only_main_content: bool) -> dict[str, Any]:
|
||||
request = ScrapeRequest(provider=provider, url=url, tenant_id=tenant_id, only_main_content=only_main_content)
|
||||
|
||||
_, config = cls._get_credentials_and_config(tenant_id=request.tenant_id, provider=request.provider)
|
||||
api_key = cls._get_decrypted_api_key(tenant_id=request.tenant_id, config=config)
|
||||
|
||||
if request.provider == "firecrawl":
|
||||
return cls._scrape_with_firecrawl(request=request, api_key=api_key, config=config)
|
||||
elif request.provider == "watercrawl":
|
||||
return cls._scrape_with_watercrawl(request=request, api_key=api_key, config=config)
|
||||
else:
|
||||
raise ValueError("Invalid provider")
|
||||
|
||||
@classmethod
|
||||
def _scrape_with_firecrawl(cls, request: ScrapeRequest, api_key: str, config: dict) -> dict[str, Any]:
|
||||
firecrawl_app = FirecrawlApp(api_key=api_key, base_url=config.get("base_url"))
|
||||
params = {"onlyMainContent": request.only_main_content}
|
||||
return firecrawl_app.scrape_url(url=request.url, params=params)
|
||||
|
||||
@classmethod
|
||||
def _scrape_with_watercrawl(cls, request: ScrapeRequest, api_key: str, config: dict) -> dict[str, Any]:
|
||||
return WaterCrawlProvider(api_key=api_key, base_url=config.get("base_url")).scrape_url(request.url)
|
||||
|
@@ -31,7 +31,7 @@ class WorkspaceService:
|
||||
assert tenant_account_join is not None, "TenantAccountJoin not found"
|
||||
tenant_info["role"] = tenant_account_join.role
|
||||
|
||||
can_replace_logo = FeatureService.get_features(tenant_info["id"]).can_replace_logo
|
||||
can_replace_logo = FeatureService.get_features(tenant.id).can_replace_logo
|
||||
|
||||
if can_replace_logo and TenantService.has_roles(tenant, [TenantAccountRole.OWNER, TenantAccountRole.ADMIN]):
|
||||
base_url = dify_config.FILES_URL
|
||||
|
0
api/tests/unit_tests/services/tools/__init__.py
Normal file
0
api/tests/unit_tests/services/tools/__init__.py
Normal file
@@ -0,0 +1,301 @@
|
||||
from unittest.mock import Mock
|
||||
|
||||
from core.tools.__base.tool import Tool
|
||||
from core.tools.entities.api_entities import ToolApiEntity
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import ToolParameter
|
||||
from services.tools.tools_transform_service import ToolTransformService
|
||||
|
||||
|
||||
class TestToolTransformService:
|
||||
"""Test cases for ToolTransformService.convert_tool_entity_to_api_entity method"""
|
||||
|
||||
def test_convert_tool_with_parameter_override(self):
|
||||
"""Test that runtime parameters correctly override base parameters"""
|
||||
# Create mock base parameters
|
||||
base_param1 = Mock(spec=ToolParameter)
|
||||
base_param1.name = "param1"
|
||||
base_param1.form = ToolParameter.ToolParameterForm.FORM
|
||||
base_param1.type = "string"
|
||||
base_param1.label = "Base Param 1"
|
||||
|
||||
base_param2 = Mock(spec=ToolParameter)
|
||||
base_param2.name = "param2"
|
||||
base_param2.form = ToolParameter.ToolParameterForm.FORM
|
||||
base_param2.type = "string"
|
||||
base_param2.label = "Base Param 2"
|
||||
|
||||
# Create mock runtime parameters that override base parameters
|
||||
runtime_param1 = Mock(spec=ToolParameter)
|
||||
runtime_param1.name = "param1"
|
||||
runtime_param1.form = ToolParameter.ToolParameterForm.FORM
|
||||
runtime_param1.type = "string"
|
||||
runtime_param1.label = "Runtime Param 1" # Different label to verify override
|
||||
|
||||
# Create mock tool
|
||||
mock_tool = Mock(spec=Tool)
|
||||
mock_tool.entity = Mock()
|
||||
mock_tool.entity.parameters = [base_param1, base_param2]
|
||||
mock_tool.entity.identity = Mock()
|
||||
mock_tool.entity.identity.author = "test_author"
|
||||
mock_tool.entity.identity.name = "test_tool"
|
||||
mock_tool.entity.identity.label = I18nObject(en_US="Test Tool")
|
||||
mock_tool.entity.description = Mock()
|
||||
mock_tool.entity.description.human = I18nObject(en_US="Test description")
|
||||
mock_tool.entity.output_schema = {}
|
||||
mock_tool.get_runtime_parameters.return_value = [runtime_param1]
|
||||
|
||||
# Mock fork_tool_runtime to return the same tool
|
||||
mock_tool.fork_tool_runtime.return_value = mock_tool
|
||||
|
||||
# Call the method
|
||||
result = ToolTransformService.convert_tool_entity_to_api_entity(mock_tool, "test_tenant", None)
|
||||
|
||||
# Verify the result
|
||||
assert isinstance(result, ToolApiEntity)
|
||||
assert result.author == "test_author"
|
||||
assert result.name == "test_tool"
|
||||
assert result.parameters is not None
|
||||
assert len(result.parameters) == 2
|
||||
|
||||
# Find the overridden parameter
|
||||
overridden_param = next((p for p in result.parameters if p.name == "param1"), None)
|
||||
assert overridden_param is not None
|
||||
assert overridden_param.label == "Runtime Param 1" # Should be runtime version
|
||||
|
||||
# Find the non-overridden parameter
|
||||
original_param = next((p for p in result.parameters if p.name == "param2"), None)
|
||||
assert original_param is not None
|
||||
assert original_param.label == "Base Param 2" # Should be base version
|
||||
|
||||
def test_convert_tool_with_additional_runtime_parameters(self):
|
||||
"""Test that additional runtime parameters are added to the final list"""
|
||||
# Create mock base parameters
|
||||
base_param1 = Mock(spec=ToolParameter)
|
||||
base_param1.name = "param1"
|
||||
base_param1.form = ToolParameter.ToolParameterForm.FORM
|
||||
base_param1.type = "string"
|
||||
base_param1.label = "Base Param 1"
|
||||
|
||||
# Create mock runtime parameters - one that overrides and one that's new
|
||||
runtime_param1 = Mock(spec=ToolParameter)
|
||||
runtime_param1.name = "param1"
|
||||
runtime_param1.form = ToolParameter.ToolParameterForm.FORM
|
||||
runtime_param1.type = "string"
|
||||
runtime_param1.label = "Runtime Param 1"
|
||||
|
||||
runtime_param2 = Mock(spec=ToolParameter)
|
||||
runtime_param2.name = "runtime_only"
|
||||
runtime_param2.form = ToolParameter.ToolParameterForm.FORM
|
||||
runtime_param2.type = "string"
|
||||
runtime_param2.label = "Runtime Only Param"
|
||||
|
||||
# Create mock tool
|
||||
mock_tool = Mock(spec=Tool)
|
||||
mock_tool.entity = Mock()
|
||||
mock_tool.entity.parameters = [base_param1]
|
||||
mock_tool.entity.identity = Mock()
|
||||
mock_tool.entity.identity.author = "test_author"
|
||||
mock_tool.entity.identity.name = "test_tool"
|
||||
mock_tool.entity.identity.label = I18nObject(en_US="Test Tool")
|
||||
mock_tool.entity.description = Mock()
|
||||
mock_tool.entity.description.human = I18nObject(en_US="Test description")
|
||||
mock_tool.entity.output_schema = {}
|
||||
mock_tool.get_runtime_parameters.return_value = [runtime_param1, runtime_param2]
|
||||
|
||||
# Mock fork_tool_runtime to return the same tool
|
||||
mock_tool.fork_tool_runtime.return_value = mock_tool
|
||||
|
||||
# Call the method
|
||||
result = ToolTransformService.convert_tool_entity_to_api_entity(mock_tool, "test_tenant", None)
|
||||
|
||||
# Verify the result
|
||||
assert isinstance(result, ToolApiEntity)
|
||||
assert result.parameters is not None
|
||||
assert len(result.parameters) == 2
|
||||
|
||||
# Check that both parameters are present
|
||||
param_names = [p.name for p in result.parameters]
|
||||
assert "param1" in param_names
|
||||
assert "runtime_only" in param_names
|
||||
|
||||
# Verify the overridden parameter has runtime version
|
||||
overridden_param = next((p for p in result.parameters if p.name == "param1"), None)
|
||||
assert overridden_param is not None
|
||||
assert overridden_param.label == "Runtime Param 1"
|
||||
|
||||
# Verify the new runtime parameter is included
|
||||
new_param = next((p for p in result.parameters if p.name == "runtime_only"), None)
|
||||
assert new_param is not None
|
||||
assert new_param.label == "Runtime Only Param"
|
||||
|
||||
def test_convert_tool_with_non_form_runtime_parameters(self):
|
||||
"""Test that non-FORM runtime parameters are not added as new parameters"""
|
||||
# Create mock base parameters
|
||||
base_param1 = Mock(spec=ToolParameter)
|
||||
base_param1.name = "param1"
|
||||
base_param1.form = ToolParameter.ToolParameterForm.FORM
|
||||
base_param1.type = "string"
|
||||
base_param1.label = "Base Param 1"
|
||||
|
||||
# Create mock runtime parameters with different forms
|
||||
runtime_param1 = Mock(spec=ToolParameter)
|
||||
runtime_param1.name = "param1"
|
||||
runtime_param1.form = ToolParameter.ToolParameterForm.FORM
|
||||
runtime_param1.type = "string"
|
||||
runtime_param1.label = "Runtime Param 1"
|
||||
|
||||
runtime_param2 = Mock(spec=ToolParameter)
|
||||
runtime_param2.name = "llm_param"
|
||||
runtime_param2.form = ToolParameter.ToolParameterForm.LLM
|
||||
runtime_param2.type = "string"
|
||||
runtime_param2.label = "LLM Param"
|
||||
|
||||
# Create mock tool
|
||||
mock_tool = Mock(spec=Tool)
|
||||
mock_tool.entity = Mock()
|
||||
mock_tool.entity.parameters = [base_param1]
|
||||
mock_tool.entity.identity = Mock()
|
||||
mock_tool.entity.identity.author = "test_author"
|
||||
mock_tool.entity.identity.name = "test_tool"
|
||||
mock_tool.entity.identity.label = I18nObject(en_US="Test Tool")
|
||||
mock_tool.entity.description = Mock()
|
||||
mock_tool.entity.description.human = I18nObject(en_US="Test description")
|
||||
mock_tool.entity.output_schema = {}
|
||||
mock_tool.get_runtime_parameters.return_value = [runtime_param1, runtime_param2]
|
||||
|
||||
# Mock fork_tool_runtime to return the same tool
|
||||
mock_tool.fork_tool_runtime.return_value = mock_tool
|
||||
|
||||
# Call the method
|
||||
result = ToolTransformService.convert_tool_entity_to_api_entity(mock_tool, "test_tenant", None)
|
||||
|
||||
# Verify the result
|
||||
assert isinstance(result, ToolApiEntity)
|
||||
assert result.parameters is not None
|
||||
assert len(result.parameters) == 1 # Only the FORM parameter should be present
|
||||
|
||||
# Check that only the FORM parameter is present
|
||||
param_names = [p.name for p in result.parameters]
|
||||
assert "param1" in param_names
|
||||
assert "llm_param" not in param_names
|
||||
|
||||
def test_convert_tool_with_empty_parameters(self):
|
||||
"""Test conversion with empty base and runtime parameters"""
|
||||
# Create mock tool with no parameters
|
||||
mock_tool = Mock(spec=Tool)
|
||||
mock_tool.entity = Mock()
|
||||
mock_tool.entity.parameters = []
|
||||
mock_tool.entity.identity = Mock()
|
||||
mock_tool.entity.identity.author = "test_author"
|
||||
mock_tool.entity.identity.name = "test_tool"
|
||||
mock_tool.entity.identity.label = I18nObject(en_US="Test Tool")
|
||||
mock_tool.entity.description = Mock()
|
||||
mock_tool.entity.description.human = I18nObject(en_US="Test description")
|
||||
mock_tool.entity.output_schema = {}
|
||||
mock_tool.get_runtime_parameters.return_value = []
|
||||
|
||||
# Mock fork_tool_runtime to return the same tool
|
||||
mock_tool.fork_tool_runtime.return_value = mock_tool
|
||||
|
||||
# Call the method
|
||||
result = ToolTransformService.convert_tool_entity_to_api_entity(mock_tool, "test_tenant", None)
|
||||
|
||||
# Verify the result
|
||||
assert isinstance(result, ToolApiEntity)
|
||||
assert result.parameters is not None
|
||||
assert len(result.parameters) == 0
|
||||
|
||||
def test_convert_tool_with_none_parameters(self):
|
||||
"""Test conversion when base parameters is None"""
|
||||
# Create mock tool with None parameters
|
||||
mock_tool = Mock(spec=Tool)
|
||||
mock_tool.entity = Mock()
|
||||
mock_tool.entity.parameters = None
|
||||
mock_tool.entity.identity = Mock()
|
||||
mock_tool.entity.identity.author = "test_author"
|
||||
mock_tool.entity.identity.name = "test_tool"
|
||||
mock_tool.entity.identity.label = I18nObject(en_US="Test Tool")
|
||||
mock_tool.entity.description = Mock()
|
||||
mock_tool.entity.description.human = I18nObject(en_US="Test description")
|
||||
mock_tool.entity.output_schema = {}
|
||||
mock_tool.get_runtime_parameters.return_value = []
|
||||
|
||||
# Mock fork_tool_runtime to return the same tool
|
||||
mock_tool.fork_tool_runtime.return_value = mock_tool
|
||||
|
||||
# Call the method
|
||||
result = ToolTransformService.convert_tool_entity_to_api_entity(mock_tool, "test_tenant", None)
|
||||
|
||||
# Verify the result
|
||||
assert isinstance(result, ToolApiEntity)
|
||||
assert result.parameters is not None
|
||||
assert len(result.parameters) == 0
|
||||
|
||||
def test_convert_tool_parameter_order_preserved(self):
|
||||
"""Test that parameter order is preserved correctly"""
|
||||
# Create mock base parameters in specific order
|
||||
base_param1 = Mock(spec=ToolParameter)
|
||||
base_param1.name = "param1"
|
||||
base_param1.form = ToolParameter.ToolParameterForm.FORM
|
||||
base_param1.type = "string"
|
||||
base_param1.label = "Base Param 1"
|
||||
|
||||
base_param2 = Mock(spec=ToolParameter)
|
||||
base_param2.name = "param2"
|
||||
base_param2.form = ToolParameter.ToolParameterForm.FORM
|
||||
base_param2.type = "string"
|
||||
base_param2.label = "Base Param 2"
|
||||
|
||||
base_param3 = Mock(spec=ToolParameter)
|
||||
base_param3.name = "param3"
|
||||
base_param3.form = ToolParameter.ToolParameterForm.FORM
|
||||
base_param3.type = "string"
|
||||
base_param3.label = "Base Param 3"
|
||||
|
||||
# Create runtime parameter that overrides middle parameter
|
||||
runtime_param2 = Mock(spec=ToolParameter)
|
||||
runtime_param2.name = "param2"
|
||||
runtime_param2.form = ToolParameter.ToolParameterForm.FORM
|
||||
runtime_param2.type = "string"
|
||||
runtime_param2.label = "Runtime Param 2"
|
||||
|
||||
# Create new runtime parameter
|
||||
runtime_param4 = Mock(spec=ToolParameter)
|
||||
runtime_param4.name = "param4"
|
||||
runtime_param4.form = ToolParameter.ToolParameterForm.FORM
|
||||
runtime_param4.type = "string"
|
||||
runtime_param4.label = "Runtime Param 4"
|
||||
|
||||
# Create mock tool
|
||||
mock_tool = Mock(spec=Tool)
|
||||
mock_tool.entity = Mock()
|
||||
mock_tool.entity.parameters = [base_param1, base_param2, base_param3]
|
||||
mock_tool.entity.identity = Mock()
|
||||
mock_tool.entity.identity.author = "test_author"
|
||||
mock_tool.entity.identity.name = "test_tool"
|
||||
mock_tool.entity.identity.label = I18nObject(en_US="Test Tool")
|
||||
mock_tool.entity.description = Mock()
|
||||
mock_tool.entity.description.human = I18nObject(en_US="Test description")
|
||||
mock_tool.entity.output_schema = {}
|
||||
mock_tool.get_runtime_parameters.return_value = [runtime_param2, runtime_param4]
|
||||
|
||||
# Mock fork_tool_runtime to return the same tool
|
||||
mock_tool.fork_tool_runtime.return_value = mock_tool
|
||||
|
||||
# Call the method
|
||||
result = ToolTransformService.convert_tool_entity_to_api_entity(mock_tool, "test_tenant", None)
|
||||
|
||||
# Verify the result
|
||||
assert isinstance(result, ToolApiEntity)
|
||||
assert result.parameters is not None
|
||||
assert len(result.parameters) == 4
|
||||
|
||||
# Check that order is maintained: base parameters first, then new runtime parameters
|
||||
param_names = [p.name for p in result.parameters]
|
||||
assert param_names == ["param1", "param2", "param3", "param4"]
|
||||
|
||||
# Verify that param2 was overridden with runtime version
|
||||
param2 = result.parameters[1]
|
||||
assert param2.name == "param2"
|
||||
assert param2.label == "Runtime Param 2"
|
Reference in New Issue
Block a user