Add APIs for Knowledge Base Tag Management and Dataset Binding (#20023)

Co-authored-by: lizb <lizb@sugon.com>
This commit is contained in:
Ganondorf
2025-05-30 14:48:00 +08:00
committed by GitHub
parent 1ea4459d9f
commit 51f64797cd
5 changed files with 1073 additions and 2 deletions

View File

@@ -1,19 +1,21 @@
from flask import request
from flask_restful import marshal, reqparse
from flask_restful import marshal, marshal_with, reqparse
from werkzeug.exceptions import Forbidden, NotFound
import services.dataset_service
from controllers.service_api import api
from controllers.service_api.dataset.error import DatasetInUseError, DatasetNameDuplicateError
from controllers.service_api.wraps import DatasetApiResource
from controllers.service_api.wraps import DatasetApiResource, validate_dataset_token
from core.model_runtime.entities.model_entities import ModelType
from core.plugin.entities.plugin import ModelProviderID
from core.provider_manager import ProviderManager
from fields.dataset_fields import dataset_detail_fields
from fields.tag_fields import tag_fields
from libs.login import current_user
from models.dataset import Dataset, DatasetPermissionEnum
from services.dataset_service import DatasetPermissionService, DatasetService
from services.entities.knowledge_entities.knowledge_entities import RetrievalModel
from services.tag_service import TagService
def _validate_name(name):
@@ -320,5 +322,134 @@ class DatasetApi(DatasetApiResource):
raise DatasetInUseError()
class DatasetTagsApi(DatasetApiResource):
@validate_dataset_token
@marshal_with(tag_fields)
def get(self, _, dataset_id):
"""Get all knowledge type tags."""
tags = TagService.get_tags("knowledge", current_user.current_tenant_id)
return tags, 200
@validate_dataset_token
def post(self, _, dataset_id):
"""Add a knowledge type tag."""
if not (current_user.is_editor or current_user.is_dataset_editor):
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument(
"name",
nullable=False,
required=True,
help="Name must be between 1 to 50 characters.",
type=DatasetTagsApi._validate_tag_name,
)
args = parser.parse_args()
args["type"] = "knowledge"
tag = TagService.save_tags(args)
response = {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": 0}
return response, 200
@validate_dataset_token
def patch(self, _, dataset_id):
if not (current_user.is_editor or current_user.is_dataset_editor):
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument(
"name",
nullable=False,
required=True,
help="Name must be between 1 to 50 characters.",
type=DatasetTagsApi._validate_tag_name,
)
parser.add_argument("tag_id", nullable=False, required=True, help="Id of a tag.", type=str)
args = parser.parse_args()
tag = TagService.update_tags(args, args.get("tag_id"))
binding_count = TagService.get_tag_binding_count(args.get("tag_id"))
response = {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": binding_count}
return response, 200
@validate_dataset_token
def delete(self, _, dataset_id):
"""Delete a knowledge type tag."""
if not current_user.is_editor:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("tag_id", nullable=False, required=True, help="Id of a tag.", type=str)
args = parser.parse_args()
TagService.delete_tag(args.get("tag_id"))
return 204
@staticmethod
def _validate_tag_name(name):
if not name or len(name) < 1 or len(name) > 50:
raise ValueError("Name must be between 1 to 50 characters.")
return name
class DatasetTagBindingApi(DatasetApiResource):
@validate_dataset_token
def post(self, _, dataset_id):
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
if not (current_user.is_editor or current_user.is_dataset_editor):
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument(
"tag_ids", type=list, nullable=False, required=True, location="json", help="Tag IDs is required."
)
parser.add_argument(
"target_id", type=str, nullable=False, required=True, location="json", help="Target Dataset ID is required."
)
args = parser.parse_args()
args["type"] = "knowledge"
TagService.save_tag_binding(args)
return 204
class DatasetTagUnbindingApi(DatasetApiResource):
@validate_dataset_token
def post(self, _, dataset_id):
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
if not (current_user.is_editor or current_user.is_dataset_editor):
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("tag_id", type=str, nullable=False, required=True, help="Tag ID is required.")
parser.add_argument("target_id", type=str, nullable=False, required=True, help="Target ID is required.")
args = parser.parse_args()
args["type"] = "knowledge"
TagService.delete_tag_binding(args)
return 204
class DatasetTagsBindingStatusApi(DatasetApiResource):
@validate_dataset_token
def get(self, _, *args, **kwargs):
"""Get all knowledge type tags."""
dataset_id = kwargs.get("dataset_id")
tags = TagService.get_tags_by_target_id("knowledge", current_user.current_tenant_id, str(dataset_id))
tags_list = [{"id": tag.id, "name": tag.name} for tag in tags]
response = {"data": tags_list, "total": len(tags)}
return response, 200
api.add_resource(DatasetListApi, "/datasets")
api.add_resource(DatasetApi, "/datasets/<uuid:dataset_id>")
api.add_resource(DatasetTagsApi, "/datasets/tags")
api.add_resource(DatasetTagBindingApi, "/datasets/tags/binding")
api.add_resource(DatasetTagUnbindingApi, "/datasets/tags/unbinding")
api.add_resource(DatasetTagsBindingStatusApi, "/datasets/<uuid:dataset_id>/tags")

View File

@@ -44,6 +44,17 @@ class TagService:
results = [tag_binding.target_id for tag_binding in tag_bindings]
return results
@staticmethod
def get_tag_by_tag_name(tag_type: str, current_tenant_id: str, tag_name: str) -> list:
tags = (
db.session.query(Tag)
.filter(Tag.name == tag_name, Tag.tenant_id == current_tenant_id, Tag.type == tag_type)
.all()
)
if not tags:
return []
return tags
@staticmethod
def get_tags_by_target_id(tag_type: str, current_tenant_id: str, target_id: str) -> list:
tags = (
@@ -62,6 +73,8 @@ class TagService:
@staticmethod
def save_tags(args: dict) -> Tag:
if TagService.get_tag_by_tag_name(args["type"], current_user.current_tenant_id, args["name"]):
raise ValueError("Tag name already exists")
tag = Tag(
id=str(uuid.uuid4()),
name=args["name"],
@@ -75,6 +88,8 @@ class TagService:
@staticmethod
def update_tags(args: dict, tag_id: str) -> Tag:
if TagService.get_tag_by_tag_name(args["type"], current_user.current_tenant_id, args["name"]):
raise ValueError("Tag name already exists")
tag = db.session.query(Tag).filter(Tag.id == tag_id).first()
if not tag:
raise NotFound("Tag not found")