diff --git a/api/controllers/console/datasets/website.py b/api/controllers/console/datasets/website.py index 4200a5170..fcdc91ec6 100644 --- a/api/controllers/console/datasets/website.py +++ b/api/controllers/console/datasets/website.py @@ -4,7 +4,7 @@ from controllers.console import api from controllers.console.datasets.error import WebsiteCrawlError from controllers.console.wraps import account_initialization_required, setup_required from libs.login import login_required -from services.website_service import WebsiteService +from services.website_service import WebsiteCrawlApiRequest, WebsiteCrawlStatusApiRequest, WebsiteService class WebsiteCrawlApi(Resource): @@ -24,10 +24,16 @@ class WebsiteCrawlApi(Resource): parser.add_argument("url", type=str, required=True, nullable=True, location="json") parser.add_argument("options", type=dict, required=True, nullable=True, location="json") args = parser.parse_args() - WebsiteService.document_create_args_validate(args) - # crawl url + + # Create typed request and validate try: - result = WebsiteService.crawl_url(args) + api_request = WebsiteCrawlApiRequest.from_args(args) + except ValueError as e: + raise WebsiteCrawlError(str(e)) + + # Crawl URL using typed request + try: + result = WebsiteService.crawl_url(api_request) except Exception as e: raise WebsiteCrawlError(str(e)) return result, 200 @@ -43,9 +49,16 @@ class WebsiteCrawlStatusApi(Resource): "provider", type=str, choices=["firecrawl", "watercrawl", "jinareader"], required=True, location="args" ) args = parser.parse_args() - # get crawl status + + # Create typed request and validate try: - result = WebsiteService.get_crawl_status(job_id, args["provider"]) + api_request = WebsiteCrawlStatusApiRequest.from_args(args, job_id) + except ValueError as e: + raise WebsiteCrawlError(str(e)) + + # Get crawl status using typed request + try: + result = WebsiteService.get_crawl_status_typed(api_request) except Exception as e: raise WebsiteCrawlError(str(e)) return result, 200 diff --git a/api/core/helper/encrypter.py b/api/core/helper/encrypter.py index 744fce1cf..1e40997a8 100644 --- a/api/core/helper/encrypter.py +++ b/api/core/helper/encrypter.py @@ -21,7 +21,7 @@ def encrypt_token(tenant_id: str, token: str): return base64.b64encode(encrypted_token).decode() -def decrypt_token(tenant_id: str, token: str): +def decrypt_token(tenant_id: str, token: str) -> str: return rsa.decrypt(base64.b64decode(token), tenant_id) diff --git a/api/libs/rsa.py b/api/libs/rsa.py index 637bcc4a1..da279eb32 100644 --- a/api/libs/rsa.py +++ b/api/libs/rsa.py @@ -1,4 +1,5 @@ import hashlib +from typing import Union from Crypto.Cipher import AES from Crypto.PublicKey import RSA @@ -9,7 +10,7 @@ from extensions.ext_storage import storage from libs import gmpy2_pkcs10aep_cipher -def generate_key_pair(tenant_id): +def generate_key_pair(tenant_id: str) -> str: private_key = RSA.generate(2048) public_key = private_key.publickey() @@ -26,7 +27,7 @@ def generate_key_pair(tenant_id): prefix_hybrid = b"HYBRID:" -def encrypt(text, public_key): +def encrypt(text: str, public_key: Union[str, bytes]) -> bytes: if isinstance(public_key, str): public_key = public_key.encode() @@ -38,14 +39,14 @@ def encrypt(text, public_key): rsa_key = RSA.import_key(public_key) cipher_rsa = gmpy2_pkcs10aep_cipher.new(rsa_key) - enc_aes_key = cipher_rsa.encrypt(aes_key) + enc_aes_key: bytes = cipher_rsa.encrypt(aes_key) encrypted_data = enc_aes_key + cipher_aes.nonce + tag + ciphertext return prefix_hybrid + encrypted_data -def get_decrypt_decoding(tenant_id): +def get_decrypt_decoding(tenant_id: str) -> tuple[RSA.RsaKey, object]: filepath = "privkeys/{tenant_id}".format(tenant_id=tenant_id) + "/private.pem" cache_key = "tenant_privkey:{hash}".format(hash=hashlib.sha3_256(filepath.encode()).hexdigest()) @@ -64,7 +65,7 @@ def get_decrypt_decoding(tenant_id): return rsa_key, cipher_rsa -def decrypt_token_with_decoding(encrypted_text, rsa_key, cipher_rsa): +def decrypt_token_with_decoding(encrypted_text: bytes, rsa_key: RSA.RsaKey, cipher_rsa) -> str: if encrypted_text.startswith(prefix_hybrid): encrypted_text = encrypted_text[len(prefix_hybrid) :] @@ -83,10 +84,10 @@ def decrypt_token_with_decoding(encrypted_text, rsa_key, cipher_rsa): return decrypted_text.decode() -def decrypt(encrypted_text, tenant_id): +def decrypt(encrypted_text: bytes, tenant_id: str) -> str: rsa_key, cipher_rsa = get_decrypt_decoding(tenant_id) - return decrypt_token_with_decoding(encrypted_text, rsa_key, cipher_rsa) + return decrypt_token_with_decoding(encrypted_text=encrypted_text, rsa_key=rsa_key, cipher_rsa=cipher_rsa) class PrivkeyNotFoundError(Exception): diff --git a/api/models/account.py b/api/models/account.py index 7ffeefa98..1af571bc0 100644 --- a/api/models/account.py +++ b/api/models/account.py @@ -196,7 +196,7 @@ class Tenant(Base): __tablename__ = "tenants" __table_args__ = (db.PrimaryKeyConstraint("id", name="tenant_pkey"),) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id: Mapped[str] = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) name = db.Column(db.String(255), nullable=False) encrypt_public_key = db.Column(db.Text) plan = db.Column(db.String(255), nullable=False, server_default=db.text("'basic'::character varying")) diff --git a/api/services/tools/tools_transform_service.py b/api/services/tools/tools_transform_service.py index 36b892e20..2d192e6f7 100644 --- a/api/services/tools/tools_transform_service.py +++ b/api/services/tools/tools_transform_service.py @@ -334,21 +334,33 @@ class ToolTransformService: ) # get tool parameters - parameters = tool.entity.parameters or [] + base_parameters = tool.entity.parameters or [] # get tool runtime parameters runtime_parameters = tool.get_runtime_parameters() - # override parameters - current_parameters = parameters.copy() - for runtime_parameter in runtime_parameters: - found = False - for index, parameter in enumerate(current_parameters): - if parameter.name == runtime_parameter.name and parameter.form == runtime_parameter.form: - current_parameters[index] = runtime_parameter - found = True - break - if not found and runtime_parameter.form == ToolParameter.ToolParameterForm.FORM: - current_parameters.append(runtime_parameter) + # merge parameters using a functional approach to avoid type issues + merged_parameters: list[ToolParameter] = [] + + # create a mapping of runtime parameters for quick lookup + runtime_param_map = {(rp.name, rp.form): rp for rp in runtime_parameters} + + # process base parameters, replacing with runtime versions if they exist + for base_param in base_parameters: + key = (base_param.name, base_param.form) + if key in runtime_param_map: + merged_parameters.append(runtime_param_map[key]) + else: + merged_parameters.append(base_param) + + # add any runtime parameters that weren't in base parameters + for runtime_parameter in runtime_parameters: + if runtime_parameter.form == ToolParameter.ToolParameterForm.FORM: + # check if this parameter is already in merged_parameters + already_exists = any( + p.name == runtime_parameter.name and p.form == runtime_parameter.form for p in merged_parameters + ) + if not already_exists: + merged_parameters.append(runtime_parameter) return ToolApiEntity( author=tool.entity.identity.author, @@ -356,10 +368,10 @@ class ToolTransformService: label=tool.entity.identity.label, description=tool.entity.description.human if tool.entity.description else I18nObject(en_US=""), output_schema=tool.entity.output_schema, - parameters=current_parameters, + parameters=merged_parameters, labels=labels or [], ) - if isinstance(tool, ApiToolBundle): + elif isinstance(tool, ApiToolBundle): return ToolApiEntity( author=tool.author, name=tool.operation_id or "", @@ -368,6 +380,9 @@ class ToolTransformService: parameters=tool.parameters, labels=labels or [], ) + else: + # Handle WorkflowTool case + raise ValueError(f"Unsupported tool type: {type(tool)}") @staticmethod def convert_builtin_provider_to_credential_entity( diff --git a/api/services/website_service.py b/api/services/website_service.py index 6720932a3..991b66973 100644 --- a/api/services/website_service.py +++ b/api/services/website_service.py @@ -1,6 +1,7 @@ import datetime import json -from typing import Any +from dataclasses import dataclass +from typing import Any, Optional import requests from flask_login import current_user @@ -13,241 +14,392 @@ from extensions.ext_storage import storage from services.auth.api_key_auth_service import ApiKeyAuthService -class WebsiteService: - @classmethod - def document_create_args_validate(cls, args: dict): - if "url" not in args or not args["url"]: - raise ValueError("url is required") - if "options" not in args or not args["options"]: - raise ValueError("options is required") - if "limit" not in args["options"] or not args["options"]["limit"]: - raise ValueError("limit is required") +@dataclass +class CrawlOptions: + """Options for crawling operations.""" + + limit: int = 1 + crawl_sub_pages: bool = False + only_main_content: bool = False + includes: Optional[str] = None + excludes: Optional[str] = None + max_depth: Optional[int] = None + use_sitemap: bool = True + + def get_include_paths(self) -> list[str]: + """Get list of include paths from comma-separated string.""" + return self.includes.split(",") if self.includes else [] + + def get_exclude_paths(self) -> list[str]: + """Get list of exclude paths from comma-separated string.""" + return self.excludes.split(",") if self.excludes else [] + + +@dataclass +class CrawlRequest: + """Request container for crawling operations.""" + + url: str + provider: str + options: CrawlOptions + + +@dataclass +class ScrapeRequest: + """Request container for scraping operations.""" + + provider: str + url: str + tenant_id: str + only_main_content: bool + + +@dataclass +class WebsiteCrawlApiRequest: + """Request container for website crawl API arguments.""" + + provider: str + url: str + options: dict[str, Any] + + def to_crawl_request(self) -> CrawlRequest: + """Convert API request to internal CrawlRequest.""" + options = CrawlOptions( + limit=self.options.get("limit", 1), + crawl_sub_pages=self.options.get("crawl_sub_pages", False), + only_main_content=self.options.get("only_main_content", False), + includes=self.options.get("includes"), + excludes=self.options.get("excludes"), + max_depth=self.options.get("max_depth"), + use_sitemap=self.options.get("use_sitemap", True), + ) + return CrawlRequest(url=self.url, provider=self.provider, options=options) @classmethod - def crawl_url(cls, args: dict) -> dict: - provider = args.get("provider", "") + def from_args(cls, args: dict) -> "WebsiteCrawlApiRequest": + """Create from Flask-RESTful parsed arguments.""" + provider = args.get("provider") url = args.get("url") - options = args.get("options", "") - credentials = ApiKeyAuthService.get_auth_credentials(current_user.current_tenant_id, "website", provider) - if provider == "firecrawl": - # decrypt api_key - api_key = encrypter.decrypt_token( - tenant_id=current_user.current_tenant_id, token=credentials.get("config").get("api_key") - ) - firecrawl_app = FirecrawlApp(api_key=api_key, base_url=credentials.get("config").get("base_url", None)) - crawl_sub_pages = options.get("crawl_sub_pages", False) - only_main_content = options.get("only_main_content", False) - if not crawl_sub_pages: - params = { - "includePaths": [], - "excludePaths": [], - "limit": 1, - "scrapeOptions": {"onlyMainContent": only_main_content}, - } - else: - includes = options.get("includes").split(",") if options.get("includes") else [] - excludes = options.get("excludes").split(",") if options.get("excludes") else [] - params = { - "includePaths": includes, - "excludePaths": excludes, - "limit": options.get("limit", 1), - "scrapeOptions": {"onlyMainContent": only_main_content}, - } - if options.get("max_depth"): - params["maxDepth"] = options.get("max_depth") - job_id = firecrawl_app.crawl_url(url, params) - website_crawl_time_cache_key = f"website_crawl_{job_id}" - time = str(datetime.datetime.now().timestamp()) - redis_client.setex(website_crawl_time_cache_key, 3600, time) - return {"status": "active", "job_id": job_id} - elif provider == "watercrawl": - # decrypt api_key - api_key = encrypter.decrypt_token( - tenant_id=current_user.current_tenant_id, token=credentials.get("config").get("api_key") - ) - return WaterCrawlProvider(api_key, credentials.get("config").get("base_url", None)).crawl_url(url, options) + options = args.get("options", {}) - elif provider == "jinareader": - api_key = encrypter.decrypt_token( - tenant_id=current_user.current_tenant_id, token=credentials.get("config").get("api_key") - ) - crawl_sub_pages = options.get("crawl_sub_pages", False) - if not crawl_sub_pages: - response = requests.get( - f"https://r.jina.ai/{url}", - headers={"Accept": "application/json", "Authorization": f"Bearer {api_key}"}, - ) - if response.json().get("code") != 200: - raise ValueError("Failed to crawl") - return {"status": "active", "data": response.json().get("data")} - else: - response = requests.post( - "https://adaptivecrawl-kir3wx7b3a-uc.a.run.app", - json={ - "url": url, - "maxPages": options.get("limit", 1), - "useSitemap": options.get("use_sitemap", True), - }, - headers={ - "Content-Type": "application/json", - "Authorization": f"Bearer {api_key}", - }, - ) - if response.json().get("code") != 200: - raise ValueError("Failed to crawl") - return {"status": "active", "job_id": response.json().get("data", {}).get("taskId")} + if not provider: + raise ValueError("Provider is required") + if not url: + raise ValueError("URL is required") + if not options: + raise ValueError("Options are required") + + return cls(provider=provider, url=url, options=options) + + +@dataclass +class WebsiteCrawlStatusApiRequest: + """Request container for website crawl status API arguments.""" + + provider: str + job_id: str + + @classmethod + def from_args(cls, args: dict, job_id: str) -> "WebsiteCrawlStatusApiRequest": + """Create from Flask-RESTful parsed arguments.""" + provider = args.get("provider") + + if not provider: + raise ValueError("Provider is required") + if not job_id: + raise ValueError("Job ID is required") + + return cls(provider=provider, job_id=job_id) + + +class WebsiteService: + """Service class for website crawling operations using different providers.""" + + @classmethod + def _get_credentials_and_config(cls, tenant_id: str, provider: str) -> tuple[dict, dict]: + """Get and validate credentials for a provider.""" + credentials = ApiKeyAuthService.get_auth_credentials(tenant_id, "website", provider) + if not credentials or "config" not in credentials: + raise ValueError("No valid credentials found for the provider") + return credentials, credentials["config"] + + @classmethod + def _get_decrypted_api_key(cls, tenant_id: str, config: dict) -> str: + """Decrypt and return the API key from config.""" + api_key = config.get("api_key") + if not api_key: + raise ValueError("API key not found in configuration") + return encrypter.decrypt_token(tenant_id=tenant_id, token=api_key) + + @classmethod + def document_create_args_validate(cls, args: dict) -> None: + """Validate arguments for document creation.""" + try: + WebsiteCrawlApiRequest.from_args(args) + except ValueError as e: + raise ValueError(f"Invalid arguments: {e}") + + @classmethod + def crawl_url(cls, api_request: WebsiteCrawlApiRequest) -> dict[str, Any]: + """Crawl a URL using the specified provider with typed request.""" + request = api_request.to_crawl_request() + + _, config = cls._get_credentials_and_config(current_user.current_tenant_id, request.provider) + api_key = cls._get_decrypted_api_key(current_user.current_tenant_id, config) + + if request.provider == "firecrawl": + return cls._crawl_with_firecrawl(request=request, api_key=api_key, config=config) + elif request.provider == "watercrawl": + return cls._crawl_with_watercrawl(request=request, api_key=api_key, config=config) + elif request.provider == "jinareader": + return cls._crawl_with_jinareader(request=request, api_key=api_key) else: raise ValueError("Invalid provider") @classmethod - def get_crawl_status(cls, job_id: str, provider: str) -> dict: - credentials = ApiKeyAuthService.get_auth_credentials(current_user.current_tenant_id, "website", provider) - if provider == "firecrawl": - # decrypt api_key - api_key = encrypter.decrypt_token( - tenant_id=current_user.current_tenant_id, token=credentials.get("config").get("api_key") - ) - firecrawl_app = FirecrawlApp(api_key=api_key, base_url=credentials.get("config").get("base_url", None)) - result = firecrawl_app.check_crawl_status(job_id) - crawl_status_data = { - "status": result.get("status", "active"), - "job_id": job_id, - "total": result.get("total", 0), - "current": result.get("current", 0), - "data": result.get("data", []), + def _crawl_with_firecrawl(cls, request: CrawlRequest, api_key: str, config: dict) -> dict[str, Any]: + firecrawl_app = FirecrawlApp(api_key=api_key, base_url=config.get("base_url")) + + if not request.options.crawl_sub_pages: + params = { + "includePaths": [], + "excludePaths": [], + "limit": 1, + "scrapeOptions": {"onlyMainContent": request.options.only_main_content}, } - if crawl_status_data["status"] == "completed": - website_crawl_time_cache_key = f"website_crawl_{job_id}" - start_time = redis_client.get(website_crawl_time_cache_key) - if start_time: - end_time = datetime.datetime.now().timestamp() - time_consuming = abs(end_time - float(start_time)) - crawl_status_data["time_consuming"] = f"{time_consuming:.2f}" - redis_client.delete(website_crawl_time_cache_key) - elif provider == "watercrawl": - # decrypt api_key - api_key = encrypter.decrypt_token( - tenant_id=current_user.current_tenant_id, token=credentials.get("config").get("api_key") + else: + params = { + "includePaths": request.options.get_include_paths(), + "excludePaths": request.options.get_exclude_paths(), + "limit": request.options.limit, + "scrapeOptions": {"onlyMainContent": request.options.only_main_content}, + } + if request.options.max_depth: + params["maxDepth"] = request.options.max_depth + + job_id = firecrawl_app.crawl_url(request.url, params) + website_crawl_time_cache_key = f"website_crawl_{job_id}" + time = str(datetime.datetime.now().timestamp()) + redis_client.setex(website_crawl_time_cache_key, 3600, time) + return {"status": "active", "job_id": job_id} + + @classmethod + def _crawl_with_watercrawl(cls, request: CrawlRequest, api_key: str, config: dict) -> dict[str, Any]: + # Convert CrawlOptions back to dict format for WaterCrawlProvider + options = { + "limit": request.options.limit, + "crawl_sub_pages": request.options.crawl_sub_pages, + "only_main_content": request.options.only_main_content, + "includes": request.options.includes, + "excludes": request.options.excludes, + "max_depth": request.options.max_depth, + "use_sitemap": request.options.use_sitemap, + } + return WaterCrawlProvider(api_key=api_key, base_url=config.get("base_url")).crawl_url( + url=request.url, options=options + ) + + @classmethod + def _crawl_with_jinareader(cls, request: CrawlRequest, api_key: str) -> dict[str, Any]: + if not request.options.crawl_sub_pages: + response = requests.get( + f"https://r.jina.ai/{request.url}", + headers={"Accept": "application/json", "Authorization": f"Bearer {api_key}"}, ) - crawl_status_data = WaterCrawlProvider( - api_key, credentials.get("config").get("base_url", None) - ).get_crawl_status(job_id) - elif provider == "jinareader": - api_key = encrypter.decrypt_token( - tenant_id=current_user.current_tenant_id, token=credentials.get("config").get("api_key") + if response.json().get("code") != 200: + raise ValueError("Failed to crawl") + return {"status": "active", "data": response.json().get("data")} + else: + response = requests.post( + "https://adaptivecrawl-kir3wx7b3a-uc.a.run.app", + json={ + "url": request.url, + "maxPages": request.options.limit, + "useSitemap": request.options.use_sitemap, + }, + headers={ + "Content-Type": "application/json", + "Authorization": f"Bearer {api_key}", + }, ) + if response.json().get("code") != 200: + raise ValueError("Failed to crawl") + return {"status": "active", "job_id": response.json().get("data", {}).get("taskId")} + + @classmethod + def get_crawl_status(cls, job_id: str, provider: str) -> dict[str, Any]: + """Get crawl status using string parameters.""" + api_request = WebsiteCrawlStatusApiRequest(provider=provider, job_id=job_id) + return cls.get_crawl_status_typed(api_request) + + @classmethod + def get_crawl_status_typed(cls, api_request: WebsiteCrawlStatusApiRequest) -> dict[str, Any]: + """Get crawl status using typed request.""" + _, config = cls._get_credentials_and_config(current_user.current_tenant_id, api_request.provider) + api_key = cls._get_decrypted_api_key(current_user.current_tenant_id, config) + + if api_request.provider == "firecrawl": + return cls._get_firecrawl_status(api_request.job_id, api_key, config) + elif api_request.provider == "watercrawl": + return cls._get_watercrawl_status(api_request.job_id, api_key, config) + elif api_request.provider == "jinareader": + return cls._get_jinareader_status(api_request.job_id, api_key) + else: + raise ValueError("Invalid provider") + + @classmethod + def _get_firecrawl_status(cls, job_id: str, api_key: str, config: dict) -> dict[str, Any]: + firecrawl_app = FirecrawlApp(api_key=api_key, base_url=config.get("base_url")) + result = firecrawl_app.check_crawl_status(job_id) + crawl_status_data = { + "status": result.get("status", "active"), + "job_id": job_id, + "total": result.get("total", 0), + "current": result.get("current", 0), + "data": result.get("data", []), + } + if crawl_status_data["status"] == "completed": + website_crawl_time_cache_key = f"website_crawl_{job_id}" + start_time = redis_client.get(website_crawl_time_cache_key) + if start_time: + end_time = datetime.datetime.now().timestamp() + time_consuming = abs(end_time - float(start_time)) + crawl_status_data["time_consuming"] = f"{time_consuming:.2f}" + redis_client.delete(website_crawl_time_cache_key) + return crawl_status_data + + @classmethod + def _get_watercrawl_status(cls, job_id: str, api_key: str, config: dict) -> dict[str, Any]: + return WaterCrawlProvider(api_key, config.get("base_url")).get_crawl_status(job_id) + + @classmethod + def _get_jinareader_status(cls, job_id: str, api_key: str) -> dict[str, Any]: + response = requests.post( + "https://adaptivecrawlstatus-kir3wx7b3a-uc.a.run.app", + headers={"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"}, + json={"taskId": job_id}, + ) + data = response.json().get("data", {}) + crawl_status_data = { + "status": data.get("status", "active"), + "job_id": job_id, + "total": len(data.get("urls", [])), + "current": len(data.get("processed", [])) + len(data.get("failed", [])), + "data": [], + "time_consuming": data.get("duration", 0) / 1000, + } + + if crawl_status_data["status"] == "completed": response = requests.post( "https://adaptivecrawlstatus-kir3wx7b3a-uc.a.run.app", headers={"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"}, - json={"taskId": job_id}, + json={"taskId": job_id, "urls": list(data.get("processed", {}).keys())}, ) data = response.json().get("data", {}) - crawl_status_data = { - "status": data.get("status", "active"), - "job_id": job_id, - "total": len(data.get("urls", [])), - "current": len(data.get("processed", [])) + len(data.get("failed", [])), - "data": [], - "time_consuming": data.get("duration", 0) / 1000, - } - - if crawl_status_data["status"] == "completed": - response = requests.post( - "https://adaptivecrawlstatus-kir3wx7b3a-uc.a.run.app", - headers={"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"}, - json={"taskId": job_id, "urls": list(data.get("processed", {}).keys())}, - ) - data = response.json().get("data", {}) - formatted_data = [ - { - "title": item.get("data", {}).get("title"), - "source_url": item.get("data", {}).get("url"), - "description": item.get("data", {}).get("description"), - "markdown": item.get("data", {}).get("content"), - } - for item in data.get("processed", {}).values() - ] - crawl_status_data["data"] = formatted_data - else: - raise ValueError("Invalid provider") + formatted_data = [ + { + "title": item.get("data", {}).get("title"), + "source_url": item.get("data", {}).get("url"), + "description": item.get("data", {}).get("description"), + "markdown": item.get("data", {}).get("content"), + } + for item in data.get("processed", {}).values() + ] + crawl_status_data["data"] = formatted_data return crawl_status_data @classmethod def get_crawl_url_data(cls, job_id: str, provider: str, url: str, tenant_id: str) -> dict[str, Any] | None: - credentials = ApiKeyAuthService.get_auth_credentials(tenant_id, "website", provider) - # decrypt api_key - api_key = encrypter.decrypt_token(tenant_id=tenant_id, token=credentials.get("config").get("api_key")) + _, config = cls._get_credentials_and_config(tenant_id, provider) + api_key = cls._get_decrypted_api_key(tenant_id, config) if provider == "firecrawl": - crawl_data: list[dict[str, Any]] | None = None - file_key = "website_files/" + job_id + ".txt" - if storage.exists(file_key): - stored_data = storage.load_once(file_key) - if stored_data: - crawl_data = json.loads(stored_data.decode("utf-8")) - else: - firecrawl_app = FirecrawlApp(api_key=api_key, base_url=credentials.get("config").get("base_url", None)) - result = firecrawl_app.check_crawl_status(job_id) - if result.get("status") != "completed": - raise ValueError("Crawl job is not completed") - crawl_data = result.get("data") - - if crawl_data: - for item in crawl_data: - if item.get("source_url") == url: - return dict(item) - return None + return cls._get_firecrawl_url_data(job_id, url, api_key, config) elif provider == "watercrawl": - api_key = encrypter.decrypt_token(tenant_id=tenant_id, token=credentials.get("config").get("api_key")) - return WaterCrawlProvider(api_key, credentials.get("config").get("base_url", None)).get_crawl_url_data( - job_id, url - ) + return cls._get_watercrawl_url_data(job_id, url, api_key, config) elif provider == "jinareader": - if not job_id: - response = requests.get( - f"https://r.jina.ai/{url}", - headers={"Accept": "application/json", "Authorization": f"Bearer {api_key}"}, - ) - if response.json().get("code") != 200: - raise ValueError("Failed to crawl") - return dict(response.json().get("data", {})) - else: - # Get crawl status first - status_response = requests.post( - "https://adaptivecrawlstatus-kir3wx7b3a-uc.a.run.app", - headers={"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"}, - json={"taskId": job_id}, - ) - status_data = status_response.json().get("data", {}) - if status_data.get("status") != "completed": - raise ValueError("Crawl job is not completed") - - # Get processed data - data_response = requests.post( - "https://adaptivecrawlstatus-kir3wx7b3a-uc.a.run.app", - headers={"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"}, - json={"taskId": job_id, "urls": list(status_data.get("processed", {}).keys())}, - ) - processed_data = data_response.json().get("data", {}) - for item in processed_data.get("processed", {}).values(): - if item.get("data", {}).get("url") == url: - return dict(item.get("data", {})) - return None + return cls._get_jinareader_url_data(job_id, url, api_key) else: raise ValueError("Invalid provider") @classmethod - def get_scrape_url_data(cls, provider: str, url: str, tenant_id: str, only_main_content: bool) -> dict: - credentials = ApiKeyAuthService.get_auth_credentials(tenant_id, "website", provider) - if provider == "firecrawl": - # decrypt api_key - api_key = encrypter.decrypt_token(tenant_id=tenant_id, token=credentials.get("config").get("api_key")) - firecrawl_app = FirecrawlApp(api_key=api_key, base_url=credentials.get("config").get("base_url", None)) - params = {"onlyMainContent": only_main_content} - result = firecrawl_app.scrape_url(url, params) - return result - elif provider == "watercrawl": - api_key = encrypter.decrypt_token(tenant_id=tenant_id, token=credentials.get("config").get("api_key")) - return WaterCrawlProvider(api_key, credentials.get("config").get("base_url", None)).scrape_url(url) + def _get_firecrawl_url_data(cls, job_id: str, url: str, api_key: str, config: dict) -> dict[str, Any] | None: + crawl_data: list[dict[str, Any]] | None = None + file_key = "website_files/" + job_id + ".txt" + if storage.exists(file_key): + stored_data = storage.load_once(file_key) + if stored_data: + crawl_data = json.loads(stored_data.decode("utf-8")) + else: + firecrawl_app = FirecrawlApp(api_key=api_key, base_url=config.get("base_url")) + result = firecrawl_app.check_crawl_status(job_id) + if result.get("status") != "completed": + raise ValueError("Crawl job is not completed") + crawl_data = result.get("data") + + if crawl_data: + for item in crawl_data: + if item.get("source_url") == url: + return dict(item) + return None + + @classmethod + def _get_watercrawl_url_data(cls, job_id: str, url: str, api_key: str, config: dict) -> dict[str, Any] | None: + return WaterCrawlProvider(api_key, config.get("base_url")).get_crawl_url_data(job_id, url) + + @classmethod + def _get_jinareader_url_data(cls, job_id: str, url: str, api_key: str) -> dict[str, Any] | None: + if not job_id: + response = requests.get( + f"https://r.jina.ai/{url}", + headers={"Accept": "application/json", "Authorization": f"Bearer {api_key}"}, + ) + if response.json().get("code") != 200: + raise ValueError("Failed to crawl") + return dict(response.json().get("data", {})) + else: + # Get crawl status first + status_response = requests.post( + "https://adaptivecrawlstatus-kir3wx7b3a-uc.a.run.app", + headers={"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"}, + json={"taskId": job_id}, + ) + status_data = status_response.json().get("data", {}) + if status_data.get("status") != "completed": + raise ValueError("Crawl job is not completed") + + # Get processed data + data_response = requests.post( + "https://adaptivecrawlstatus-kir3wx7b3a-uc.a.run.app", + headers={"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"}, + json={"taskId": job_id, "urls": list(status_data.get("processed", {}).keys())}, + ) + processed_data = data_response.json().get("data", {}) + for item in processed_data.get("processed", {}).values(): + if item.get("data", {}).get("url") == url: + return dict(item.get("data", {})) + return None + + @classmethod + def get_scrape_url_data(cls, provider: str, url: str, tenant_id: str, only_main_content: bool) -> dict[str, Any]: + request = ScrapeRequest(provider=provider, url=url, tenant_id=tenant_id, only_main_content=only_main_content) + + _, config = cls._get_credentials_and_config(tenant_id=request.tenant_id, provider=request.provider) + api_key = cls._get_decrypted_api_key(tenant_id=request.tenant_id, config=config) + + if request.provider == "firecrawl": + return cls._scrape_with_firecrawl(request=request, api_key=api_key, config=config) + elif request.provider == "watercrawl": + return cls._scrape_with_watercrawl(request=request, api_key=api_key, config=config) else: raise ValueError("Invalid provider") + + @classmethod + def _scrape_with_firecrawl(cls, request: ScrapeRequest, api_key: str, config: dict) -> dict[str, Any]: + firecrawl_app = FirecrawlApp(api_key=api_key, base_url=config.get("base_url")) + params = {"onlyMainContent": request.only_main_content} + return firecrawl_app.scrape_url(url=request.url, params=params) + + @classmethod + def _scrape_with_watercrawl(cls, request: ScrapeRequest, api_key: str, config: dict) -> dict[str, Any]: + return WaterCrawlProvider(api_key=api_key, base_url=config.get("base_url")).scrape_url(request.url) diff --git a/api/services/workspace_service.py b/api/services/workspace_service.py index 125e0c1b1..bb35645c5 100644 --- a/api/services/workspace_service.py +++ b/api/services/workspace_service.py @@ -31,7 +31,7 @@ class WorkspaceService: assert tenant_account_join is not None, "TenantAccountJoin not found" tenant_info["role"] = tenant_account_join.role - can_replace_logo = FeatureService.get_features(tenant_info["id"]).can_replace_logo + can_replace_logo = FeatureService.get_features(tenant.id).can_replace_logo if can_replace_logo and TenantService.has_roles(tenant, [TenantAccountRole.OWNER, TenantAccountRole.ADMIN]): base_url = dify_config.FILES_URL diff --git a/api/tests/unit_tests/services/tools/__init__.py b/api/tests/unit_tests/services/tools/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/api/tests/unit_tests/services/tools/test_tools_transform_service.py b/api/tests/unit_tests/services/tools/test_tools_transform_service.py new file mode 100644 index 000000000..549ad018e --- /dev/null +++ b/api/tests/unit_tests/services/tools/test_tools_transform_service.py @@ -0,0 +1,301 @@ +from unittest.mock import Mock + +from core.tools.__base.tool import Tool +from core.tools.entities.api_entities import ToolApiEntity +from core.tools.entities.common_entities import I18nObject +from core.tools.entities.tool_entities import ToolParameter +from services.tools.tools_transform_service import ToolTransformService + + +class TestToolTransformService: + """Test cases for ToolTransformService.convert_tool_entity_to_api_entity method""" + + def test_convert_tool_with_parameter_override(self): + """Test that runtime parameters correctly override base parameters""" + # Create mock base parameters + base_param1 = Mock(spec=ToolParameter) + base_param1.name = "param1" + base_param1.form = ToolParameter.ToolParameterForm.FORM + base_param1.type = "string" + base_param1.label = "Base Param 1" + + base_param2 = Mock(spec=ToolParameter) + base_param2.name = "param2" + base_param2.form = ToolParameter.ToolParameterForm.FORM + base_param2.type = "string" + base_param2.label = "Base Param 2" + + # Create mock runtime parameters that override base parameters + runtime_param1 = Mock(spec=ToolParameter) + runtime_param1.name = "param1" + runtime_param1.form = ToolParameter.ToolParameterForm.FORM + runtime_param1.type = "string" + runtime_param1.label = "Runtime Param 1" # Different label to verify override + + # Create mock tool + mock_tool = Mock(spec=Tool) + mock_tool.entity = Mock() + mock_tool.entity.parameters = [base_param1, base_param2] + mock_tool.entity.identity = Mock() + mock_tool.entity.identity.author = "test_author" + mock_tool.entity.identity.name = "test_tool" + mock_tool.entity.identity.label = I18nObject(en_US="Test Tool") + mock_tool.entity.description = Mock() + mock_tool.entity.description.human = I18nObject(en_US="Test description") + mock_tool.entity.output_schema = {} + mock_tool.get_runtime_parameters.return_value = [runtime_param1] + + # Mock fork_tool_runtime to return the same tool + mock_tool.fork_tool_runtime.return_value = mock_tool + + # Call the method + result = ToolTransformService.convert_tool_entity_to_api_entity(mock_tool, "test_tenant", None) + + # Verify the result + assert isinstance(result, ToolApiEntity) + assert result.author == "test_author" + assert result.name == "test_tool" + assert result.parameters is not None + assert len(result.parameters) == 2 + + # Find the overridden parameter + overridden_param = next((p for p in result.parameters if p.name == "param1"), None) + assert overridden_param is not None + assert overridden_param.label == "Runtime Param 1" # Should be runtime version + + # Find the non-overridden parameter + original_param = next((p for p in result.parameters if p.name == "param2"), None) + assert original_param is not None + assert original_param.label == "Base Param 2" # Should be base version + + def test_convert_tool_with_additional_runtime_parameters(self): + """Test that additional runtime parameters are added to the final list""" + # Create mock base parameters + base_param1 = Mock(spec=ToolParameter) + base_param1.name = "param1" + base_param1.form = ToolParameter.ToolParameterForm.FORM + base_param1.type = "string" + base_param1.label = "Base Param 1" + + # Create mock runtime parameters - one that overrides and one that's new + runtime_param1 = Mock(spec=ToolParameter) + runtime_param1.name = "param1" + runtime_param1.form = ToolParameter.ToolParameterForm.FORM + runtime_param1.type = "string" + runtime_param1.label = "Runtime Param 1" + + runtime_param2 = Mock(spec=ToolParameter) + runtime_param2.name = "runtime_only" + runtime_param2.form = ToolParameter.ToolParameterForm.FORM + runtime_param2.type = "string" + runtime_param2.label = "Runtime Only Param" + + # Create mock tool + mock_tool = Mock(spec=Tool) + mock_tool.entity = Mock() + mock_tool.entity.parameters = [base_param1] + mock_tool.entity.identity = Mock() + mock_tool.entity.identity.author = "test_author" + mock_tool.entity.identity.name = "test_tool" + mock_tool.entity.identity.label = I18nObject(en_US="Test Tool") + mock_tool.entity.description = Mock() + mock_tool.entity.description.human = I18nObject(en_US="Test description") + mock_tool.entity.output_schema = {} + mock_tool.get_runtime_parameters.return_value = [runtime_param1, runtime_param2] + + # Mock fork_tool_runtime to return the same tool + mock_tool.fork_tool_runtime.return_value = mock_tool + + # Call the method + result = ToolTransformService.convert_tool_entity_to_api_entity(mock_tool, "test_tenant", None) + + # Verify the result + assert isinstance(result, ToolApiEntity) + assert result.parameters is not None + assert len(result.parameters) == 2 + + # Check that both parameters are present + param_names = [p.name for p in result.parameters] + assert "param1" in param_names + assert "runtime_only" in param_names + + # Verify the overridden parameter has runtime version + overridden_param = next((p for p in result.parameters if p.name == "param1"), None) + assert overridden_param is not None + assert overridden_param.label == "Runtime Param 1" + + # Verify the new runtime parameter is included + new_param = next((p for p in result.parameters if p.name == "runtime_only"), None) + assert new_param is not None + assert new_param.label == "Runtime Only Param" + + def test_convert_tool_with_non_form_runtime_parameters(self): + """Test that non-FORM runtime parameters are not added as new parameters""" + # Create mock base parameters + base_param1 = Mock(spec=ToolParameter) + base_param1.name = "param1" + base_param1.form = ToolParameter.ToolParameterForm.FORM + base_param1.type = "string" + base_param1.label = "Base Param 1" + + # Create mock runtime parameters with different forms + runtime_param1 = Mock(spec=ToolParameter) + runtime_param1.name = "param1" + runtime_param1.form = ToolParameter.ToolParameterForm.FORM + runtime_param1.type = "string" + runtime_param1.label = "Runtime Param 1" + + runtime_param2 = Mock(spec=ToolParameter) + runtime_param2.name = "llm_param" + runtime_param2.form = ToolParameter.ToolParameterForm.LLM + runtime_param2.type = "string" + runtime_param2.label = "LLM Param" + + # Create mock tool + mock_tool = Mock(spec=Tool) + mock_tool.entity = Mock() + mock_tool.entity.parameters = [base_param1] + mock_tool.entity.identity = Mock() + mock_tool.entity.identity.author = "test_author" + mock_tool.entity.identity.name = "test_tool" + mock_tool.entity.identity.label = I18nObject(en_US="Test Tool") + mock_tool.entity.description = Mock() + mock_tool.entity.description.human = I18nObject(en_US="Test description") + mock_tool.entity.output_schema = {} + mock_tool.get_runtime_parameters.return_value = [runtime_param1, runtime_param2] + + # Mock fork_tool_runtime to return the same tool + mock_tool.fork_tool_runtime.return_value = mock_tool + + # Call the method + result = ToolTransformService.convert_tool_entity_to_api_entity(mock_tool, "test_tenant", None) + + # Verify the result + assert isinstance(result, ToolApiEntity) + assert result.parameters is not None + assert len(result.parameters) == 1 # Only the FORM parameter should be present + + # Check that only the FORM parameter is present + param_names = [p.name for p in result.parameters] + assert "param1" in param_names + assert "llm_param" not in param_names + + def test_convert_tool_with_empty_parameters(self): + """Test conversion with empty base and runtime parameters""" + # Create mock tool with no parameters + mock_tool = Mock(spec=Tool) + mock_tool.entity = Mock() + mock_tool.entity.parameters = [] + mock_tool.entity.identity = Mock() + mock_tool.entity.identity.author = "test_author" + mock_tool.entity.identity.name = "test_tool" + mock_tool.entity.identity.label = I18nObject(en_US="Test Tool") + mock_tool.entity.description = Mock() + mock_tool.entity.description.human = I18nObject(en_US="Test description") + mock_tool.entity.output_schema = {} + mock_tool.get_runtime_parameters.return_value = [] + + # Mock fork_tool_runtime to return the same tool + mock_tool.fork_tool_runtime.return_value = mock_tool + + # Call the method + result = ToolTransformService.convert_tool_entity_to_api_entity(mock_tool, "test_tenant", None) + + # Verify the result + assert isinstance(result, ToolApiEntity) + assert result.parameters is not None + assert len(result.parameters) == 0 + + def test_convert_tool_with_none_parameters(self): + """Test conversion when base parameters is None""" + # Create mock tool with None parameters + mock_tool = Mock(spec=Tool) + mock_tool.entity = Mock() + mock_tool.entity.parameters = None + mock_tool.entity.identity = Mock() + mock_tool.entity.identity.author = "test_author" + mock_tool.entity.identity.name = "test_tool" + mock_tool.entity.identity.label = I18nObject(en_US="Test Tool") + mock_tool.entity.description = Mock() + mock_tool.entity.description.human = I18nObject(en_US="Test description") + mock_tool.entity.output_schema = {} + mock_tool.get_runtime_parameters.return_value = [] + + # Mock fork_tool_runtime to return the same tool + mock_tool.fork_tool_runtime.return_value = mock_tool + + # Call the method + result = ToolTransformService.convert_tool_entity_to_api_entity(mock_tool, "test_tenant", None) + + # Verify the result + assert isinstance(result, ToolApiEntity) + assert result.parameters is not None + assert len(result.parameters) == 0 + + def test_convert_tool_parameter_order_preserved(self): + """Test that parameter order is preserved correctly""" + # Create mock base parameters in specific order + base_param1 = Mock(spec=ToolParameter) + base_param1.name = "param1" + base_param1.form = ToolParameter.ToolParameterForm.FORM + base_param1.type = "string" + base_param1.label = "Base Param 1" + + base_param2 = Mock(spec=ToolParameter) + base_param2.name = "param2" + base_param2.form = ToolParameter.ToolParameterForm.FORM + base_param2.type = "string" + base_param2.label = "Base Param 2" + + base_param3 = Mock(spec=ToolParameter) + base_param3.name = "param3" + base_param3.form = ToolParameter.ToolParameterForm.FORM + base_param3.type = "string" + base_param3.label = "Base Param 3" + + # Create runtime parameter that overrides middle parameter + runtime_param2 = Mock(spec=ToolParameter) + runtime_param2.name = "param2" + runtime_param2.form = ToolParameter.ToolParameterForm.FORM + runtime_param2.type = "string" + runtime_param2.label = "Runtime Param 2" + + # Create new runtime parameter + runtime_param4 = Mock(spec=ToolParameter) + runtime_param4.name = "param4" + runtime_param4.form = ToolParameter.ToolParameterForm.FORM + runtime_param4.type = "string" + runtime_param4.label = "Runtime Param 4" + + # Create mock tool + mock_tool = Mock(spec=Tool) + mock_tool.entity = Mock() + mock_tool.entity.parameters = [base_param1, base_param2, base_param3] + mock_tool.entity.identity = Mock() + mock_tool.entity.identity.author = "test_author" + mock_tool.entity.identity.name = "test_tool" + mock_tool.entity.identity.label = I18nObject(en_US="Test Tool") + mock_tool.entity.description = Mock() + mock_tool.entity.description.human = I18nObject(en_US="Test description") + mock_tool.entity.output_schema = {} + mock_tool.get_runtime_parameters.return_value = [runtime_param2, runtime_param4] + + # Mock fork_tool_runtime to return the same tool + mock_tool.fork_tool_runtime.return_value = mock_tool + + # Call the method + result = ToolTransformService.convert_tool_entity_to_api_entity(mock_tool, "test_tenant", None) + + # Verify the result + assert isinstance(result, ToolApiEntity) + assert result.parameters is not None + assert len(result.parameters) == 4 + + # Check that order is maintained: base parameters first, then new runtime parameters + param_names = [p.name for p in result.parameters] + assert param_names == ["param1", "param2", "param3", "param4"] + + # Verify that param2 was overridden with runtime version + param2 = result.parameters[1] + assert param2.name == "param2" + assert param2.label == "Runtime Param 2"