Feat/improve vector database logic (#1193)
Co-authored-by: jyong <jyong@dify.ai>
This commit is contained in:
@@ -20,7 +20,8 @@ from events.document_event import document_was_deleted
|
||||
from extensions.ext_database import db
|
||||
from libs import helper
|
||||
from models.account import Account
|
||||
from models.dataset import Dataset, Document, DatasetQuery, DatasetProcessRule, AppDatasetJoin, DocumentSegment
|
||||
from models.dataset import Dataset, Document, DatasetQuery, DatasetProcessRule, AppDatasetJoin, DocumentSegment, \
|
||||
DatasetCollectionBinding
|
||||
from models.model import UploadFile
|
||||
from models.source import DataSourceBinding
|
||||
from services.errors.account import NoPermissionError
|
||||
@@ -147,6 +148,7 @@ class DatasetService:
|
||||
action = 'remove'
|
||||
filtered_data['embedding_model'] = None
|
||||
filtered_data['embedding_model_provider'] = None
|
||||
filtered_data['collection_binding_id'] = None
|
||||
elif data['indexing_technique'] == 'high_quality':
|
||||
action = 'add'
|
||||
# get embedding model setting
|
||||
@@ -156,6 +158,11 @@ class DatasetService:
|
||||
)
|
||||
filtered_data['embedding_model'] = embedding_model.name
|
||||
filtered_data['embedding_model_provider'] = embedding_model.model_provider.provider_name
|
||||
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
|
||||
embedding_model.model_provider.provider_name,
|
||||
embedding_model.name
|
||||
)
|
||||
filtered_data['collection_binding_id'] = dataset_collection_binding.id
|
||||
except LLMBadRequestError:
|
||||
raise ValueError(
|
||||
f"No Embedding Model available. Please configure a valid provider "
|
||||
@@ -464,7 +471,11 @@ class DocumentService:
|
||||
)
|
||||
dataset.embedding_model = embedding_model.name
|
||||
dataset.embedding_model_provider = embedding_model.model_provider.provider_name
|
||||
|
||||
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
|
||||
embedding_model.model_provider.provider_name,
|
||||
embedding_model.name
|
||||
)
|
||||
dataset.collection_binding_id = dataset_collection_binding.id
|
||||
|
||||
documents = []
|
||||
batch = time.strftime('%Y%m%d%H%M%S') + str(random.randint(100000, 999999))
|
||||
@@ -720,10 +731,16 @@ class DocumentService:
|
||||
if total_count > tenant_document_count:
|
||||
raise ValueError(f"All your documents have overed limit {tenant_document_count}.")
|
||||
embedding_model = None
|
||||
dataset_collection_binding_id = None
|
||||
if document_data['indexing_technique'] == 'high_quality':
|
||||
embedding_model = ModelFactory.get_embedding_model(
|
||||
tenant_id=tenant_id
|
||||
)
|
||||
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
|
||||
embedding_model.model_provider.provider_name,
|
||||
embedding_model.name
|
||||
)
|
||||
dataset_collection_binding_id = dataset_collection_binding.id
|
||||
# save dataset
|
||||
dataset = Dataset(
|
||||
tenant_id=tenant_id,
|
||||
@@ -732,7 +749,8 @@ class DocumentService:
|
||||
indexing_technique=document_data["indexing_technique"],
|
||||
created_by=account.id,
|
||||
embedding_model=embedding_model.name if embedding_model else None,
|
||||
embedding_model_provider=embedding_model.model_provider.provider_name if embedding_model else None
|
||||
embedding_model_provider=embedding_model.model_provider.provider_name if embedding_model else None,
|
||||
collection_binding_id=dataset_collection_binding_id
|
||||
)
|
||||
|
||||
db.session.add(dataset)
|
||||
@@ -1069,3 +1087,23 @@ class SegmentService:
|
||||
delete_segment_from_index_task.delay(segment.id, segment.index_node_id, dataset.id, document.id)
|
||||
db.session.delete(segment)
|
||||
db.session.commit()
|
||||
|
||||
|
||||
class DatasetCollectionBindingService:
|
||||
@classmethod
|
||||
def get_dataset_collection_binding(cls, provider_name: str, model_name: str) -> DatasetCollectionBinding:
|
||||
dataset_collection_binding = db.session.query(DatasetCollectionBinding). \
|
||||
filter(DatasetCollectionBinding.provider_name == provider_name,
|
||||
DatasetCollectionBinding.model_name == model_name). \
|
||||
order_by(DatasetCollectionBinding.created_at). \
|
||||
first()
|
||||
|
||||
if not dataset_collection_binding:
|
||||
dataset_collection_binding = DatasetCollectionBinding(
|
||||
provider_name=provider_name,
|
||||
model_name=model_name,
|
||||
collection_name="Vector_index_" + str(uuid.uuid4()).replace("-", "_") + '_Node'
|
||||
)
|
||||
db.session.add(dataset_collection_binding)
|
||||
db.session.flush()
|
||||
return dataset_collection_binding
|
||||
|
Reference in New Issue
Block a user