Feat/dify rag (#2528)

Co-authored-by: jyong <jyong@dify.ai>
This commit is contained in:
Jyong
2024-02-22 23:31:57 +08:00
committed by GitHub
parent 97fe817186
commit 6c4e6bf1d6
119 changed files with 3181 additions and 5892 deletions

View File

@@ -1,13 +1,8 @@
import logging
from typing import Optional
from flask import current_app
from core.embedding.cached_embedding import CacheEmbedding
from core.entities.application_entities import InvokeFrom
from core.index.vector_index.vector_index import VectorIndex
from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
from core.rag.datasource.vdb.vector_factory import Vector
from extensions.ext_database import db
from models.dataset import Dataset
from models.model import App, AppAnnotationSetting, Message, MessageAnnotation
@@ -45,17 +40,6 @@ class AnnotationReplyFeature:
embedding_provider_name = collection_binding_detail.provider_name
embedding_model_name = collection_binding_detail.model_name
model_manager = ModelManager()
model_instance = model_manager.get_model_instance(
tenant_id=app_record.tenant_id,
provider=embedding_provider_name,
model_type=ModelType.TEXT_EMBEDDING,
model=embedding_model_name
)
# get embedding model
embeddings = CacheEmbedding(model_instance)
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
embedding_provider_name,
embedding_model_name,
@@ -71,22 +55,14 @@ class AnnotationReplyFeature:
collection_binding_id=dataset_collection_binding.id
)
vector_index = VectorIndex(
dataset=dataset,
config=current_app.config,
embeddings=embeddings,
attributes=['doc_id', 'annotation_id', 'app_id']
)
vector = Vector(dataset, attributes=['doc_id', 'annotation_id', 'app_id'])
documents = vector_index.search(
documents = vector.search_by_vector(
query=query,
search_type='similarity_score_threshold',
search_kwargs={
'k': 1,
'score_threshold': score_threshold,
'filter': {
'group_id': [dataset.id]
}
k=1,
score_threshold=score_threshold,
filter={
'group_id': [dataset.id]
}
)