Fix: correctly match http/https URLs in image upload file (#24180)
This commit is contained in:
@@ -30,7 +30,7 @@ from core.rag.splitter.fixed_text_splitter import (
|
|||||||
FixedRecursiveCharacterTextSplitter,
|
FixedRecursiveCharacterTextSplitter,
|
||||||
)
|
)
|
||||||
from core.rag.splitter.text_splitter import TextSplitter
|
from core.rag.splitter.text_splitter import TextSplitter
|
||||||
from core.tools.utils.rag_web_reader import get_image_upload_file_ids
|
from core.tools.utils.web_reader_tool import get_image_upload_file_ids
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from extensions.ext_redis import redis_client
|
from extensions.ext_redis import redis_client
|
||||||
from extensions.ext_storage import storage
|
from extensions.ext_storage import storage
|
||||||
|
@@ -1,17 +0,0 @@
|
|||||||
import re
|
|
||||||
|
|
||||||
|
|
||||||
def get_image_upload_file_ids(content):
|
|
||||||
pattern = r"!\[image\]\((http?://.*?(file-preview|image-preview))\)"
|
|
||||||
matches = re.findall(pattern, content)
|
|
||||||
image_upload_file_ids = []
|
|
||||||
for match in matches:
|
|
||||||
if match[1] == "file-preview":
|
|
||||||
content_pattern = r"files/([^/]+)/file-preview"
|
|
||||||
else:
|
|
||||||
content_pattern = r"files/([^/]+)/image-preview"
|
|
||||||
content_match = re.search(content_pattern, match[0])
|
|
||||||
if content_match:
|
|
||||||
image_upload_file_id = content_match.group(1)
|
|
||||||
image_upload_file_ids.append(image_upload_file_id)
|
|
||||||
return image_upload_file_ids
|
|
@@ -80,14 +80,14 @@ def get_url(url: str, user_agent: Optional[str] = None) -> str:
|
|||||||
else:
|
else:
|
||||||
content = response.text
|
content = response.text
|
||||||
|
|
||||||
article = extract_using_readabilipy(content)
|
article = extract_using_readability(content)
|
||||||
|
|
||||||
if not article.text:
|
if not article.text:
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
res = FULL_TEMPLATE.format(
|
res = FULL_TEMPLATE.format(
|
||||||
title=article.title,
|
title=article.title,
|
||||||
author=article.auther,
|
author=article.author,
|
||||||
text=article.text,
|
text=article.text,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -97,15 +97,15 @@ def get_url(url: str, user_agent: Optional[str] = None) -> str:
|
|||||||
@dataclass
|
@dataclass
|
||||||
class Article:
|
class Article:
|
||||||
title: str
|
title: str
|
||||||
auther: str
|
author: str
|
||||||
text: Sequence[dict]
|
text: Sequence[dict]
|
||||||
|
|
||||||
|
|
||||||
def extract_using_readabilipy(html: str):
|
def extract_using_readability(html: str):
|
||||||
json_article: dict[str, Any] = simple_json_from_html_string(html, use_readability=True)
|
json_article: dict[str, Any] = simple_json_from_html_string(html, use_readability=True)
|
||||||
article = Article(
|
article = Article(
|
||||||
title=json_article.get("title") or "",
|
title=json_article.get("title") or "",
|
||||||
auther=json_article.get("byline") or "",
|
author=json_article.get("byline") or "",
|
||||||
text=json_article.get("plain_text") or [],
|
text=json_article.get("plain_text") or [],
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -113,7 +113,7 @@ def extract_using_readabilipy(html: str):
|
|||||||
|
|
||||||
|
|
||||||
def get_image_upload_file_ids(content):
|
def get_image_upload_file_ids(content):
|
||||||
pattern = r"!\[image\]\((http?://.*?(file-preview|image-preview))\)"
|
pattern = r"!\[image\]\((https?://.*?(file-preview|image-preview))\)"
|
||||||
matches = re.findall(pattern, content)
|
matches = re.findall(pattern, content)
|
||||||
image_upload_file_ids = []
|
image_upload_file_ids = []
|
||||||
for match in matches:
|
for match in matches:
|
||||||
|
@@ -5,7 +5,7 @@ import click
|
|||||||
from celery import shared_task # type: ignore
|
from celery import shared_task # type: ignore
|
||||||
|
|
||||||
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
|
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
|
||||||
from core.tools.utils.rag_web_reader import get_image_upload_file_ids
|
from core.tools.utils.web_reader_tool import get_image_upload_file_ids
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from extensions.ext_storage import storage
|
from extensions.ext_storage import storage
|
||||||
from models.dataset import (
|
from models.dataset import (
|
||||||
|
@@ -6,7 +6,7 @@ import click
|
|||||||
from celery import shared_task # type: ignore
|
from celery import shared_task # type: ignore
|
||||||
|
|
||||||
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
|
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
|
||||||
from core.tools.utils.rag_web_reader import get_image_upload_file_ids
|
from core.tools.utils.web_reader_tool import get_image_upload_file_ids
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from extensions.ext_storage import storage
|
from extensions.ext_storage import storage
|
||||||
from models.dataset import Dataset, DatasetMetadataBinding, DocumentSegment
|
from models.dataset import Dataset, DatasetMetadataBinding, DocumentSegment
|
||||||
|
@@ -0,0 +1,25 @@
|
|||||||
|
from core.tools.utils.web_reader_tool import get_image_upload_file_ids
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_image_upload_file_ids():
|
||||||
|
# should extract id from https + file-preview
|
||||||
|
content = ""
|
||||||
|
assert get_image_upload_file_ids(content) == ["abc123"]
|
||||||
|
|
||||||
|
# should extract id from http + image-preview
|
||||||
|
content = ""
|
||||||
|
assert get_image_upload_file_ids(content) == ["xyz789"]
|
||||||
|
|
||||||
|
# should not match invalid scheme 'htt://'
|
||||||
|
content = ""
|
||||||
|
assert get_image_upload_file_ids(content) == []
|
||||||
|
|
||||||
|
# should extract multiple ids in order
|
||||||
|
content = """
|
||||||
|
some text
|
||||||
|

|
||||||
|
middle
|
||||||
|

|
||||||
|
end
|
||||||
|
"""
|
||||||
|
assert get_image_upload_file_ids(content) == ["id1", "id2"]
|
Reference in New Issue
Block a user