Annotation management (#1767)

Co-authored-by: jyong <jyong@dify.ai>
This commit is contained in:
Jyong
2023-12-18 13:10:05 +08:00
committed by GitHub
parent a9b942981d
commit a71f2863ac
41 changed files with 1871 additions and 67 deletions

View File

@@ -12,8 +12,10 @@ from core.callback_handler.main_chain_gather_callback_handler import MainChainGa
from core.callback_handler.llm_callback_handler import LLMCallbackHandler
from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException, \
ConversationTaskInterruptException
from core.embedding.cached_embedding import CacheEmbedding
from core.external_data_tool.factory import ExternalDataToolFactory
from core.file.file_obj import FileObj
from core.index.vector_index.vector_index import VectorIndex
from core.model_providers.error import LLMBadRequestError
from core.memory.read_only_conversation_token_db_buffer_shared_memory import \
ReadOnlyConversationTokenDBBufferSharedMemory
@@ -23,9 +25,12 @@ from core.model_providers.models.llm.base import BaseLLM
from core.orchestrator_rule_parser import OrchestratorRuleParser
from core.prompt.prompt_template import PromptTemplateParser
from core.prompt.prompt_transform import PromptTransform
from models.dataset import Dataset
from models.model import App, AppModelConfig, Account, Conversation, EndUser
from core.moderation.base import ModerationException, ModerationAction
from core.moderation.factory import ModerationFactory
from services.annotation_service import AppAnnotationService
from services.dataset_service import DatasetCollectionBindingService
class Completion:
@@ -33,7 +38,7 @@ class Completion:
def generate(cls, task_id: str, app: App, app_model_config: AppModelConfig, query: str, inputs: dict,
files: List[FileObj], user: Union[Account, EndUser], conversation: Optional[Conversation],
streaming: bool, is_override: bool = False, retriever_from: str = 'dev',
auto_generate_name: bool = True):
auto_generate_name: bool = True, from_source: str = 'console'):
"""
errors: ProviderTokenNotInitError
"""
@@ -109,7 +114,10 @@ class Completion:
fake_response=str(e)
)
return
# check annotation reply
annotation_reply = cls.query_app_annotations_to_reply(conversation_message_task, from_source)
if annotation_reply:
return
# fill in variable inputs from external data tools if exists
external_data_tools = app_model_config.external_data_tools_list
if external_data_tools:
@@ -166,17 +174,18 @@ class Completion:
except ChunkedEncodingError as e:
# Interrupt by LLM (like OpenAI), handle it.
logging.warning(f'ChunkedEncodingError: {e}')
conversation_message_task.end()
return
@classmethod
def moderation_for_inputs(cls, app_id: str, tenant_id: str, app_model_config: AppModelConfig, inputs: dict, query: str):
def moderation_for_inputs(cls, app_id: str, tenant_id: str, app_model_config: AppModelConfig, inputs: dict,
query: str):
if not app_model_config.sensitive_word_avoidance_dict['enabled']:
return inputs, query
type = app_model_config.sensitive_word_avoidance_dict['type']
moderation = ModerationFactory(type, app_id, tenant_id, app_model_config.sensitive_word_avoidance_dict['config'])
moderation = ModerationFactory(type, app_id, tenant_id,
app_model_config.sensitive_word_avoidance_dict['config'])
moderation_result = moderation.moderation_for_inputs(inputs, query)
if not moderation_result.flagged:
@@ -324,6 +333,76 @@ class Completion:
external_context = memory.load_memory_variables({})
return external_context[memory_key]
@classmethod
def query_app_annotations_to_reply(cls, conversation_message_task: ConversationMessageTask,
from_source: str) -> bool:
"""Get memory messages."""
app_model_config = conversation_message_task.app_model_config
app = conversation_message_task.app
annotation_reply = app_model_config.annotation_reply_dict
if annotation_reply['enabled']:
score_threshold = annotation_reply.get('score_threshold', 1)
embedding_provider_name = annotation_reply['embedding_model']['embedding_provider_name']
embedding_model_name = annotation_reply['embedding_model']['embedding_model_name']
# get embedding model
embedding_model = ModelFactory.get_embedding_model(
tenant_id=app.tenant_id,
model_provider_name=embedding_provider_name,
model_name=embedding_model_name
)
embeddings = CacheEmbedding(embedding_model)
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
embedding_provider_name,
embedding_model_name,
'annotation'
)
dataset = Dataset(
id=app.id,
tenant_id=app.tenant_id,
indexing_technique='high_quality',
embedding_model_provider=embedding_provider_name,
embedding_model=embedding_model_name,
collection_binding_id=dataset_collection_binding.id
)
vector_index = VectorIndex(
dataset=dataset,
config=current_app.config,
embeddings=embeddings
)
documents = vector_index.search(
conversation_message_task.query,
search_type='similarity_score_threshold',
search_kwargs={
'k': 1,
'score_threshold': score_threshold,
'filter': {
'group_id': [dataset.id]
}
}
)
if documents:
annotation_id = documents[0].metadata['annotation_id']
score = documents[0].metadata['score']
annotation = AppAnnotationService.get_annotation_by_id(annotation_id)
if annotation:
conversation_message_task.annotation_end(annotation.content, annotation.id, annotation.account.name)
# insert annotation history
AppAnnotationService.add_annotation_history(annotation.id,
app.id,
annotation.question,
annotation.content,
conversation_message_task.query,
conversation_message_task.user.id,
conversation_message_task.message.id,
from_source,
score)
return True
return False
@classmethod
def get_memory_from_conversation(cls, tenant_id: str, app_model_config: AppModelConfig,
conversation: Conversation,

View File

@@ -319,6 +319,10 @@ class ConversationMessageTask:
self._pub_handler.pub_message_end(self.retriever_resource)
self._pub_handler.pub_end()
def annotation_end(self, text: str, annotation_id: str, annotation_author_name: str):
self._pub_handler.pub_annotation(text, annotation_id, annotation_author_name, self.start_at)
self._pub_handler.pub_end()
class PubHandler:
def __init__(self, user: Union[Account, EndUser], task_id: str,
@@ -435,7 +439,7 @@ class PubHandler:
'task_id': self._task_id,
'message_id': self._message.id,
'mode': self._conversation.mode,
'conversation_id': self._conversation.id
'conversation_id': self._conversation.id,
}
}
if retriever_resource:
@@ -446,6 +450,30 @@ class PubHandler:
self.pub_end()
raise ConversationTaskStoppedException()
def pub_annotation(self, text: str, annotation_id: str, annotation_author_name: str, start_at: float):
content = {
'event': 'annotation',
'data': {
'task_id': self._task_id,
'message_id': self._message.id,
'mode': self._conversation.mode,
'conversation_id': self._conversation.id,
'text': text,
'annotation_id': annotation_id,
'annotation_author_name': annotation_author_name
}
}
self._message.answer = text
self._message.provider_response_latency = time.perf_counter() - start_at
db.session.commit()
redis_client.publish(self._channel, json.dumps(content))
if self._is_stopped():
self.pub_end()
raise ConversationTaskStoppedException()
def pub_end(self):
content = {
'event': 'end',

View File

@@ -32,6 +32,10 @@ class BaseIndex(ABC):
def delete_by_ids(self, ids: list[str]) -> None:
raise NotImplementedError
@abstractmethod
def delete_by_metadata_field(self, key: str, value: str) -> None:
raise NotImplementedError
@abstractmethod
def delete_by_group_id(self, group_id: str) -> None:
raise NotImplementedError

View File

@@ -107,6 +107,9 @@ class KeywordTableIndex(BaseIndex):
self._save_dataset_keyword_table(keyword_table)
def delete_by_metadata_field(self, key: str, value: str):
pass
def get_retriever(self, **kwargs: Any) -> BaseRetriever:
return KeywordTableRetriever(index=self, **kwargs)

View File

@@ -121,6 +121,16 @@ class MilvusVectorIndex(BaseVectorIndex):
'filter': f'id in {ids}'
})
def delete_by_metadata_field(self, key: str, value: str):
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
ids = vector_store.get_ids_by_metadata_field(key, value)
if ids:
vector_store.del_texts({
'filter': f'id in {ids}'
})
def delete_by_ids(self, doc_ids: list[str]) -> None:
vector_store = self._get_vector_store()

View File

@@ -138,6 +138,22 @@ class QdrantVectorIndex(BaseVectorIndex):
],
))
def delete_by_metadata_field(self, key: str, value: str):
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
from qdrant_client.http import models
vector_store.del_texts(models.Filter(
must=[
models.FieldCondition(
key=f"metadata.{key}",
match=models.MatchValue(value=value),
),
],
))
def delete_by_ids(self, ids: list[str]) -> None:
vector_store = self._get_vector_store()

View File

@@ -141,6 +141,17 @@ class WeaviateVectorIndex(BaseVectorIndex):
"valueText": document_id
})
def delete_by_metadata_field(self, key: str, value: str):
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
vector_store.del_texts({
"operator": "Equal",
"path": [key],
"valueText": value
})
def delete_by_group_id(self, group_id: str):
if self._is_origin():
self.recreate_dataset(self.dataset)

View File

@@ -30,6 +30,16 @@ class MilvusVectorStore(Milvus):
else:
return None
def get_ids_by_metadata_field(self, key: str, value: str):
result = self.col.query(
expr=f'metadata["{key}"] == "{value}"',
output_fields=["id"]
)
if result:
return [item["id"] for item in result]
else:
return None
def get_ids_by_doc_ids(self, doc_ids: list):
result = self.col.query(
expr=f'metadata["doc_id"] in {doc_ids}',