fix dataset operator (#6064)

Co-authored-by: JzoNg <jzongcode@gmail.com>
This commit is contained in:
Joe
2024-07-09 17:47:54 +08:00
committed by GitHub
parent 3b14939d66
commit ce930f19b9
46 changed files with 1072 additions and 290 deletions

View File

@@ -21,11 +21,12 @@ from events.document_event import document_was_deleted
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs import helper
from models.account import Account
from models.account import Account, TenantAccountRole
from models.dataset import (
AppDatasetJoin,
Dataset,
DatasetCollectionBinding,
DatasetPermission,
DatasetProcessRule,
DatasetQuery,
Document,
@@ -56,22 +57,55 @@ class DatasetService:
@staticmethod
def get_datasets(page, per_page, provider="vendor", tenant_id=None, user=None, search=None, tag_ids=None):
query = Dataset.query.filter(Dataset.provider == provider, Dataset.tenant_id == tenant_id).order_by(
Dataset.created_at.desc()
)
if user:
permission_filter = db.or_(Dataset.created_by == user.id,
Dataset.permission == 'all_team_members')
# get permitted dataset ids
dataset_permission = DatasetPermission.query.filter_by(
account_id=user.id,
tenant_id=tenant_id
).all()
permitted_dataset_ids = {dp.dataset_id for dp in dataset_permission} if dataset_permission else None
if user.current_role == TenantAccountRole.DATASET_OPERATOR:
# only show datasets that the user has permission to access
if permitted_dataset_ids:
query = query.filter(Dataset.id.in_(permitted_dataset_ids))
else:
return [], 0
else:
# show all datasets that the user has permission to access
if permitted_dataset_ids:
query = query.filter(
db.or_(
Dataset.permission == 'all_team_members',
db.and_(Dataset.permission == 'only_me', Dataset.created_by == user.id),
db.and_(Dataset.permission == 'partial_members', Dataset.id.in_(permitted_dataset_ids))
)
)
else:
query = query.filter(
db.or_(
Dataset.permission == 'all_team_members',
db.and_(Dataset.permission == 'only_me', Dataset.created_by == user.id)
)
)
else:
permission_filter = Dataset.permission == 'all_team_members'
query = Dataset.query.filter(
db.and_(Dataset.provider == provider, Dataset.tenant_id == tenant_id, permission_filter)) \
.order_by(Dataset.created_at.desc())
# if no user, only show datasets that are shared with all team members
query = query.filter(Dataset.permission == 'all_team_members')
if search:
query = query.filter(db.and_(Dataset.name.ilike(f'%{search}%')))
query = query.filter(Dataset.name.ilike(f'%{search}%'))
if tag_ids:
target_ids = TagService.get_target_ids_by_tag_ids('knowledge', tenant_id, tag_ids)
if target_ids:
query = query.filter(db.and_(Dataset.id.in_(target_ids)))
query = query.filter(Dataset.id.in_(target_ids))
else:
return [], 0
datasets = query.paginate(
page=page,
per_page=per_page,
@@ -102,9 +136,12 @@ class DatasetService:
@staticmethod
def get_datasets_by_ids(ids, tenant_id):
datasets = Dataset.query.filter(Dataset.id.in_(ids),
Dataset.tenant_id == tenant_id).paginate(
page=1, per_page=len(ids), max_per_page=len(ids), error_out=False)
datasets = Dataset.query.filter(
Dataset.id.in_(ids),
Dataset.tenant_id == tenant_id
).paginate(
page=1, per_page=len(ids), max_per_page=len(ids), error_out=False
)
return datasets.items, datasets.total
@staticmethod
@@ -112,7 +149,8 @@ class DatasetService:
# check if dataset name already exists
if Dataset.query.filter_by(name=name, tenant_id=tenant_id).first():
raise DatasetNameDuplicateError(
f'Dataset with name {name} already exists.')
f'Dataset with name {name} already exists.'
)
embedding_model = None
if indexing_technique == 'high_quality':
model_manager = ModelManager()
@@ -151,13 +189,17 @@ class DatasetService:
except LLMBadRequestError:
raise ValueError(
"No Embedding Model available. Please configure a valid provider "
"in the Settings -> Model Provider.")
"in the Settings -> Model Provider."
)
except ProviderTokenNotInitError as ex:
raise ValueError(f"The dataset in unavailable, due to: "
f"{ex.description}")
raise ValueError(
f"The dataset in unavailable, due to: "
f"{ex.description}"
)
@staticmethod
def update_dataset(dataset_id, data, user):
data.pop('partial_member_list', None)
filtered_data = {k: v for k, v in data.items() if v is not None or k == 'description'}
dataset = DatasetService.get_dataset(dataset_id)
DatasetService.check_dataset_permission(dataset, user)
@@ -190,12 +232,13 @@ class DatasetService:
except LLMBadRequestError:
raise ValueError(
"No Embedding Model available. Please configure a valid provider "
"in the Settings -> Model Provider.")
"in the Settings -> Model Provider."
)
except ProviderTokenNotInitError as ex:
raise ValueError(ex.description)
else:
if data['embedding_model_provider'] != dataset.embedding_model_provider or \
data['embedding_model'] != dataset.embedding_model:
data['embedding_model'] != dataset.embedding_model:
action = 'update'
try:
model_manager = ModelManager()
@@ -215,7 +258,8 @@ class DatasetService:
except LLMBadRequestError:
raise ValueError(
"No Embedding Model available. Please configure a valid provider "
"in the Settings -> Model Provider.")
"in the Settings -> Model Provider."
)
except ProviderTokenNotInitError as ex:
raise ValueError(ex.description)
@@ -259,14 +303,41 @@ class DatasetService:
def check_dataset_permission(dataset, user):
if dataset.tenant_id != user.current_tenant_id:
logging.debug(
f'User {user.id} does not have permission to access dataset {dataset.id}')
f'User {user.id} does not have permission to access dataset {dataset.id}'
)
raise NoPermissionError(
'You do not have permission to access this dataset.')
'You do not have permission to access this dataset.'
)
if dataset.permission == 'only_me' and dataset.created_by != user.id:
logging.debug(
f'User {user.id} does not have permission to access dataset {dataset.id}')
f'User {user.id} does not have permission to access dataset {dataset.id}'
)
raise NoPermissionError(
'You do not have permission to access this dataset.')
'You do not have permission to access this dataset.'
)
if dataset.permission == 'partial_members':
user_permission = DatasetPermission.query.filter_by(
dataset_id=dataset.id, account_id=user.id
).first()
if not user_permission and dataset.tenant_id != user.current_tenant_id and dataset.created_by != user.id:
logging.debug(
f'User {user.id} does not have permission to access dataset {dataset.id}'
)
raise NoPermissionError(
'You do not have permission to access this dataset.'
)
@staticmethod
def check_dataset_operator_permission(user: Account = None, dataset: Dataset = None):
if dataset.permission == 'only_me':
if dataset.created_by != user.id:
raise NoPermissionError('You do not have permission to access this dataset.')
elif dataset.permission == 'partial_members':
if not any(
dp.dataset_id == dataset.id for dp in DatasetPermission.query.filter_by(account_id=user.id).all()
):
raise NoPermissionError('You do not have permission to access this dataset.')
@staticmethod
def get_dataset_queries(dataset_id: str, page: int, per_page: int):
@@ -547,6 +618,7 @@ class DocumentService:
redis_client.setex(sync_indexing_cache_key, 600, 1)
sync_website_document_indexing_task.delay(dataset_id, document.id)
@staticmethod
def get_documents_position(dataset_id):
document = Document.query.filter_by(dataset_id=dataset_id).order_by(Document.position.desc()).first()
@@ -556,9 +628,11 @@ class DocumentService:
return 1
@staticmethod
def save_document_with_dataset_id(dataset: Dataset, document_data: dict,
account: Account, dataset_process_rule: Optional[DatasetProcessRule] = None,
created_from: str = 'web'):
def save_document_with_dataset_id(
dataset: Dataset, document_data: dict,
account: Account, dataset_process_rule: Optional[DatasetProcessRule] = None,
created_from: str = 'web'
):
# check document limit
features = FeatureService.get_features(current_user.current_tenant_id)
@@ -588,7 +662,7 @@ class DocumentService:
if not dataset.indexing_technique:
if 'indexing_technique' not in document_data \
or document_data['indexing_technique'] not in Dataset.INDEXING_TECHNIQUE_LIST:
or document_data['indexing_technique'] not in Dataset.INDEXING_TECHNIQUE_LIST:
raise ValueError("Indexing technique is required")
dataset.indexing_technique = document_data["indexing_technique"]
@@ -618,7 +692,8 @@ class DocumentService:
}
dataset.retrieval_model = document_data.get('retrieval_model') if document_data.get(
'retrieval_model') else default_retrieval_model
'retrieval_model'
) else default_retrieval_model
documents = []
batch = time.strftime('%Y%m%d%H%M%S') + str(random.randint(100000, 999999))
@@ -686,12 +761,14 @@ class DocumentService:
documents.append(document)
duplicate_document_ids.append(document.id)
continue
document = DocumentService.build_document(dataset, dataset_process_rule.id,
document_data["data_source"]["type"],
document_data["doc_form"],
document_data["doc_language"],
data_source_info, created_from, position,
account, file_name, batch)
document = DocumentService.build_document(
dataset, dataset_process_rule.id,
document_data["data_source"]["type"],
document_data["doc_form"],
document_data["doc_language"],
data_source_info, created_from, position,
account, file_name, batch
)
db.session.add(document)
db.session.flush()
document_ids.append(document.id)
@@ -732,12 +809,14 @@ class DocumentService:
"notion_page_icon": page['page_icon'],
"type": page['type']
}
document = DocumentService.build_document(dataset, dataset_process_rule.id,
document_data["data_source"]["type"],
document_data["doc_form"],
document_data["doc_language"],
data_source_info, created_from, position,
account, page['page_name'], batch)
document = DocumentService.build_document(
dataset, dataset_process_rule.id,
document_data["data_source"]["type"],
document_data["doc_form"],
document_data["doc_language"],
data_source_info, created_from, position,
account, page['page_name'], batch
)
db.session.add(document)
db.session.flush()
document_ids.append(document.id)
@@ -759,12 +838,14 @@ class DocumentService:
'only_main_content': website_info.get('only_main_content', False),
'mode': 'crawl',
}
document = DocumentService.build_document(dataset, dataset_process_rule.id,
document_data["data_source"]["type"],
document_data["doc_form"],
document_data["doc_language"],
data_source_info, created_from, position,
account, url, batch)
document = DocumentService.build_document(
dataset, dataset_process_rule.id,
document_data["data_source"]["type"],
document_data["doc_form"],
document_data["doc_language"],
data_source_info, created_from, position,
account, url, batch
)
db.session.add(document)
db.session.flush()
document_ids.append(document.id)
@@ -785,13 +866,16 @@ class DocumentService:
can_upload_size = features.documents_upload_quota.limit - features.documents_upload_quota.size
if count > can_upload_size:
raise ValueError(
f'You have reached the limit of your subscription. Only {can_upload_size} documents can be uploaded.')
f'You have reached the limit of your subscription. Only {can_upload_size} documents can be uploaded.'
)
@staticmethod
def build_document(dataset: Dataset, process_rule_id: str, data_source_type: str, document_form: str,
document_language: str, data_source_info: dict, created_from: str, position: int,
account: Account,
name: str, batch: str):
def build_document(
dataset: Dataset, process_rule_id: str, data_source_type: str, document_form: str,
document_language: str, data_source_info: dict, created_from: str, position: int,
account: Account,
name: str, batch: str
):
document = Document(
tenant_id=dataset.tenant_id,
dataset_id=dataset.id,
@@ -810,16 +894,20 @@ class DocumentService:
@staticmethod
def get_tenant_documents_count():
documents_count = Document.query.filter(Document.completed_at.isnot(None),
Document.enabled == True,
Document.archived == False,
Document.tenant_id == current_user.current_tenant_id).count()
documents_count = Document.query.filter(
Document.completed_at.isnot(None),
Document.enabled == True,
Document.archived == False,
Document.tenant_id == current_user.current_tenant_id
).count()
return documents_count
@staticmethod
def update_document_with_dataset_id(dataset: Dataset, document_data: dict,
account: Account, dataset_process_rule: Optional[DatasetProcessRule] = None,
created_from: str = 'web'):
def update_document_with_dataset_id(
dataset: Dataset, document_data: dict,
account: Account, dataset_process_rule: Optional[DatasetProcessRule] = None,
created_from: str = 'web'
):
DatasetService.check_dataset_model_setting(dataset)
document = DocumentService.get_document(dataset.id, document_data["original_document_id"])
if document.display_status != 'available':
@@ -1007,7 +1095,7 @@ class DocumentService:
DocumentService.process_rule_args_validate(args)
else:
if ('data_source' not in args and not args['data_source']) \
and ('process_rule' not in args and not args['process_rule']):
and ('process_rule' not in args and not args['process_rule']):
raise ValueError("Data source or Process rule is required")
else:
if args.get('data_source'):
@@ -1069,7 +1157,7 @@ class DocumentService:
raise ValueError("Process rule rules is invalid")
if 'pre_processing_rules' not in args['process_rule']['rules'] \
or args['process_rule']['rules']['pre_processing_rules'] is None:
or args['process_rule']['rules']['pre_processing_rules'] is None:
raise ValueError("Process rule pre_processing_rules is required")
if not isinstance(args['process_rule']['rules']['pre_processing_rules'], list):
@@ -1094,21 +1182,21 @@ class DocumentService:
args['process_rule']['rules']['pre_processing_rules'] = list(unique_pre_processing_rule_dicts.values())
if 'segmentation' not in args['process_rule']['rules'] \
or args['process_rule']['rules']['segmentation'] is None:
or args['process_rule']['rules']['segmentation'] is None:
raise ValueError("Process rule segmentation is required")
if not isinstance(args['process_rule']['rules']['segmentation'], dict):
raise ValueError("Process rule segmentation is invalid")
if 'separator' not in args['process_rule']['rules']['segmentation'] \
or not args['process_rule']['rules']['segmentation']['separator']:
or not args['process_rule']['rules']['segmentation']['separator']:
raise ValueError("Process rule segmentation separator is required")
if not isinstance(args['process_rule']['rules']['segmentation']['separator'], str):
raise ValueError("Process rule segmentation separator is invalid")
if 'max_tokens' not in args['process_rule']['rules']['segmentation'] \
or not args['process_rule']['rules']['segmentation']['max_tokens']:
or not args['process_rule']['rules']['segmentation']['max_tokens']:
raise ValueError("Process rule segmentation max_tokens is required")
if not isinstance(args['process_rule']['rules']['segmentation']['max_tokens'], int):
@@ -1144,7 +1232,7 @@ class DocumentService:
raise ValueError("Process rule rules is invalid")
if 'pre_processing_rules' not in args['process_rule']['rules'] \
or args['process_rule']['rules']['pre_processing_rules'] is None:
or args['process_rule']['rules']['pre_processing_rules'] is None:
raise ValueError("Process rule pre_processing_rules is required")
if not isinstance(args['process_rule']['rules']['pre_processing_rules'], list):
@@ -1169,21 +1257,21 @@ class DocumentService:
args['process_rule']['rules']['pre_processing_rules'] = list(unique_pre_processing_rule_dicts.values())
if 'segmentation' not in args['process_rule']['rules'] \
or args['process_rule']['rules']['segmentation'] is None:
or args['process_rule']['rules']['segmentation'] is None:
raise ValueError("Process rule segmentation is required")
if not isinstance(args['process_rule']['rules']['segmentation'], dict):
raise ValueError("Process rule segmentation is invalid")
if 'separator' not in args['process_rule']['rules']['segmentation'] \
or not args['process_rule']['rules']['segmentation']['separator']:
or not args['process_rule']['rules']['segmentation']['separator']:
raise ValueError("Process rule segmentation separator is required")
if not isinstance(args['process_rule']['rules']['segmentation']['separator'], str):
raise ValueError("Process rule segmentation separator is invalid")
if 'max_tokens' not in args['process_rule']['rules']['segmentation'] \
or not args['process_rule']['rules']['segmentation']['max_tokens']:
or not args['process_rule']['rules']['segmentation']['max_tokens']:
raise ValueError("Process rule segmentation max_tokens is required")
if not isinstance(args['process_rule']['rules']['segmentation']['max_tokens'], int):
@@ -1437,12 +1525,16 @@ class SegmentService:
class DatasetCollectionBindingService:
@classmethod
def get_dataset_collection_binding(cls, provider_name: str, model_name: str,
collection_type: str = 'dataset') -> DatasetCollectionBinding:
def get_dataset_collection_binding(
cls, provider_name: str, model_name: str,
collection_type: str = 'dataset'
) -> DatasetCollectionBinding:
dataset_collection_binding = db.session.query(DatasetCollectionBinding). \
filter(DatasetCollectionBinding.provider_name == provider_name,
DatasetCollectionBinding.model_name == model_name,
DatasetCollectionBinding.type == collection_type). \
filter(
DatasetCollectionBinding.provider_name == provider_name,
DatasetCollectionBinding.model_name == model_name,
DatasetCollectionBinding.type == collection_type
). \
order_by(DatasetCollectionBinding.created_at). \
first()
@@ -1458,12 +1550,77 @@ class DatasetCollectionBindingService:
return dataset_collection_binding
@classmethod
def get_dataset_collection_binding_by_id_and_type(cls, collection_binding_id: str,
collection_type: str = 'dataset') -> DatasetCollectionBinding:
def get_dataset_collection_binding_by_id_and_type(
cls, collection_binding_id: str,
collection_type: str = 'dataset'
) -> DatasetCollectionBinding:
dataset_collection_binding = db.session.query(DatasetCollectionBinding). \
filter(DatasetCollectionBinding.id == collection_binding_id,
DatasetCollectionBinding.type == collection_type). \
filter(
DatasetCollectionBinding.id == collection_binding_id,
DatasetCollectionBinding.type == collection_type
). \
order_by(DatasetCollectionBinding.created_at). \
first()
return dataset_collection_binding
class DatasetPermissionService:
@classmethod
def get_dataset_partial_member_list(cls, dataset_id):
user_list_query = db.session.query(
DatasetPermission.account_id,
).filter(
DatasetPermission.dataset_id == dataset_id
).all()
user_list = []
for user in user_list_query:
user_list.append(user.account_id)
return user_list
@classmethod
def update_partial_member_list(cls, tenant_id, dataset_id, user_list):
try:
db.session.query(DatasetPermission).filter(DatasetPermission.dataset_id == dataset_id).delete()
permissions = []
for user in user_list:
permission = DatasetPermission(
tenant_id=tenant_id,
dataset_id=dataset_id,
account_id=user['user_id'],
)
permissions.append(permission)
db.session.add_all(permissions)
db.session.commit()
except Exception as e:
db.session.rollback()
raise e
@classmethod
def check_permission(cls, user, dataset, requested_permission, requested_partial_member_list):
if not user.is_dataset_editor:
raise NoPermissionError('User does not have permission to edit this dataset.')
if user.is_dataset_operator and dataset.permission != requested_permission:
raise NoPermissionError('Dataset operators cannot change the dataset permissions.')
if user.is_dataset_operator and requested_permission == 'partial_members':
if not requested_partial_member_list:
raise ValueError('Partial member list is required when setting to partial members.')
local_member_list = cls.get_dataset_partial_member_list(dataset.id)
request_member_list = [user['user_id'] for user in requested_partial_member_list]
if set(local_member_list) != set(request_member_list):
raise ValueError('Dataset operators cannot change the dataset permissions.')
@classmethod
def clear_partial_member_list(cls, dataset_id):
try:
db.session.query(DatasetPermission).filter(DatasetPermission.dataset_id == dataset_id).delete()
db.session.commit()
except Exception as e:
db.session.rollback()
raise e