Feat/improve vector database logic (#1193)

Co-authored-by: jyong <jyong@dify.ai>
This commit is contained in:
Jyong
2023-09-18 18:15:41 +08:00
committed by GitHub
parent 60e0bbd713
commit 269a465fc4
14 changed files with 463 additions and 46 deletions

View File

@@ -20,7 +20,8 @@ from events.document_event import document_was_deleted
from extensions.ext_database import db
from libs import helper
from models.account import Account
from models.dataset import Dataset, Document, DatasetQuery, DatasetProcessRule, AppDatasetJoin, DocumentSegment
from models.dataset import Dataset, Document, DatasetQuery, DatasetProcessRule, AppDatasetJoin, DocumentSegment, \
DatasetCollectionBinding
from models.model import UploadFile
from models.source import DataSourceBinding
from services.errors.account import NoPermissionError
@@ -147,6 +148,7 @@ class DatasetService:
action = 'remove'
filtered_data['embedding_model'] = None
filtered_data['embedding_model_provider'] = None
filtered_data['collection_binding_id'] = None
elif data['indexing_technique'] == 'high_quality':
action = 'add'
# get embedding model setting
@@ -156,6 +158,11 @@ class DatasetService:
)
filtered_data['embedding_model'] = embedding_model.name
filtered_data['embedding_model_provider'] = embedding_model.model_provider.provider_name
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
embedding_model.model_provider.provider_name,
embedding_model.name
)
filtered_data['collection_binding_id'] = dataset_collection_binding.id
except LLMBadRequestError:
raise ValueError(
f"No Embedding Model available. Please configure a valid provider "
@@ -464,7 +471,11 @@ class DocumentService:
)
dataset.embedding_model = embedding_model.name
dataset.embedding_model_provider = embedding_model.model_provider.provider_name
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
embedding_model.model_provider.provider_name,
embedding_model.name
)
dataset.collection_binding_id = dataset_collection_binding.id
documents = []
batch = time.strftime('%Y%m%d%H%M%S') + str(random.randint(100000, 999999))
@@ -720,10 +731,16 @@ class DocumentService:
if total_count > tenant_document_count:
raise ValueError(f"All your documents have overed limit {tenant_document_count}.")
embedding_model = None
dataset_collection_binding_id = None
if document_data['indexing_technique'] == 'high_quality':
embedding_model = ModelFactory.get_embedding_model(
tenant_id=tenant_id
)
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
embedding_model.model_provider.provider_name,
embedding_model.name
)
dataset_collection_binding_id = dataset_collection_binding.id
# save dataset
dataset = Dataset(
tenant_id=tenant_id,
@@ -732,7 +749,8 @@ class DocumentService:
indexing_technique=document_data["indexing_technique"],
created_by=account.id,
embedding_model=embedding_model.name if embedding_model else None,
embedding_model_provider=embedding_model.model_provider.provider_name if embedding_model else None
embedding_model_provider=embedding_model.model_provider.provider_name if embedding_model else None,
collection_binding_id=dataset_collection_binding_id
)
db.session.add(dataset)
@@ -1069,3 +1087,23 @@ class SegmentService:
delete_segment_from_index_task.delay(segment.id, segment.index_node_id, dataset.id, document.id)
db.session.delete(segment)
db.session.commit()
class DatasetCollectionBindingService:
@classmethod
def get_dataset_collection_binding(cls, provider_name: str, model_name: str) -> DatasetCollectionBinding:
dataset_collection_binding = db.session.query(DatasetCollectionBinding). \
filter(DatasetCollectionBinding.provider_name == provider_name,
DatasetCollectionBinding.model_name == model_name). \
order_by(DatasetCollectionBinding.created_at). \
first()
if not dataset_collection_binding:
dataset_collection_binding = DatasetCollectionBinding(
provider_name=provider_name,
model_name=model_name,
collection_name="Vector_index_" + str(uuid.uuid4()).replace("-", "_") + '_Node'
)
db.session.add(dataset_collection_binding)
db.session.flush()
return dataset_collection_binding