Model Runtime (#1858)

Co-authored-by: StyleZhang <jasonapring2015@outlook.com>
Co-authored-by: Garfield Dai <dai.hai@foxmail.com>
Co-authored-by: chenhe <guchenhe@gmail.com>
Co-authored-by: jyong <jyong@dify.ai>
Co-authored-by: Joel <iamjoel007@gmail.com>
Co-authored-by: Yeuoly <admin@srmxy.cn>
This commit is contained in:
takatost
2024-01-02 23:42:00 +08:00
committed by GitHub
parent e91dd28a76
commit d069c668f8
807 changed files with 171310 additions and 23806 deletions

View File

@@ -4,14 +4,16 @@ import datetime
import time
import random
import uuid
from typing import Optional, List
from typing import Optional, List, cast
from flask import current_app
from sqlalchemy import func
from core.index.index import IndexBuilder
from core.model_providers.error import LLMBadRequestError, ProviderTokenNotInitError
from core.model_providers.model_factory import ModelFactory
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
from extensions.ext_redis import redis_client
from flask_login import current_user
@@ -92,16 +94,18 @@ class DatasetService:
f'Dataset with name {name} already exists.')
embedding_model = None
if indexing_technique == 'high_quality':
embedding_model = ModelFactory.get_embedding_model(
tenant_id=tenant_id
model_manager = ModelManager()
embedding_model = model_manager.get_default_model_instance(
tenant_id=tenant_id,
model_type=ModelType.TEXT_EMBEDDING
)
dataset = Dataset(name=name, indexing_technique=indexing_technique)
# dataset = Dataset(name=name, provider=provider, config=config)
dataset.created_by = account.id
dataset.updated_by = account.id
dataset.tenant_id = tenant_id
dataset.embedding_model_provider = embedding_model.model_provider.provider_name if embedding_model else None
dataset.embedding_model = embedding_model.name if embedding_model else None
dataset.embedding_model_provider = embedding_model.provider if embedding_model else None
dataset.embedding_model = embedding_model.model if embedding_model else None
db.session.add(dataset)
db.session.commit()
return dataset
@@ -120,10 +124,12 @@ class DatasetService:
def check_dataset_model_setting(dataset):
if dataset.indexing_technique == 'high_quality':
try:
ModelFactory.get_embedding_model(
model_manager = ModelManager()
model_manager.get_model_instance(
tenant_id=dataset.tenant_id,
model_provider_name=dataset.embedding_model_provider,
model_name=dataset.embedding_model
provider=dataset.embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING,
model=dataset.embedding_model
)
except LLMBadRequestError:
raise ValueError(
@@ -150,14 +156,16 @@ class DatasetService:
action = 'add'
# get embedding model setting
try:
embedding_model = ModelFactory.get_embedding_model(
tenant_id=current_user.current_tenant_id
model_manager = ModelManager()
embedding_model = model_manager.get_default_model_instance(
tenant_id=current_user.current_tenant_id,
model_type=ModelType.TEXT_EMBEDDING
)
filtered_data['embedding_model'] = embedding_model.name
filtered_data['embedding_model_provider'] = embedding_model.model_provider.provider_name
filtered_data['embedding_model'] = embedding_model.model
filtered_data['embedding_model_provider'] = embedding_model.provider
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
embedding_model.model_provider.provider_name,
embedding_model.name
embedding_model.provider,
embedding_model.model
)
filtered_data['collection_binding_id'] = dataset_collection_binding.id
except LLMBadRequestError:
@@ -458,14 +466,16 @@ class DocumentService:
dataset.indexing_technique = document_data["indexing_technique"]
if document_data["indexing_technique"] == 'high_quality':
embedding_model = ModelFactory.get_embedding_model(
tenant_id=dataset.tenant_id
model_manager = ModelManager()
embedding_model = model_manager.get_default_model_instance(
tenant_id=current_user.current_tenant_id,
model_type=ModelType.TEXT_EMBEDDING
)
dataset.embedding_model = embedding_model.name
dataset.embedding_model_provider = embedding_model.model_provider.provider_name
dataset.embedding_model = embedding_model.model
dataset.embedding_model_provider = embedding_model.provider
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
embedding_model.model_provider.provider_name,
embedding_model.name
embedding_model.provider,
embedding_model.model
)
dataset.collection_binding_id = dataset_collection_binding.id
if not dataset.retrieval_model:
@@ -737,12 +747,14 @@ class DocumentService:
dataset_collection_binding_id = None
retrieval_model = None
if document_data['indexing_technique'] == 'high_quality':
embedding_model = ModelFactory.get_embedding_model(
tenant_id=tenant_id
model_manager = ModelManager()
embedding_model = model_manager.get_default_model_instance(
tenant_id=current_user.current_tenant_id,
model_type=ModelType.TEXT_EMBEDDING
)
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
embedding_model.model_provider.provider_name,
embedding_model.name
embedding_model.provider,
embedding_model.model
)
dataset_collection_binding_id = dataset_collection_binding.id
if 'retrieval_model' in document_data and document_data['retrieval_model']:
@@ -766,8 +778,8 @@ class DocumentService:
data_source_type=document_data["data_source"]["type"],
indexing_technique=document_data["indexing_technique"],
created_by=account.id,
embedding_model=embedding_model.name if embedding_model else None,
embedding_model_provider=embedding_model.model_provider.provider_name if embedding_model else None,
embedding_model=embedding_model.model if embedding_model else None,
embedding_model_provider=embedding_model.provider if embedding_model else None,
collection_binding_id=dataset_collection_binding_id,
retrieval_model=retrieval_model
)
@@ -989,13 +1001,20 @@ class SegmentService:
segment_hash = helper.generate_text_hash(content)
tokens = 0
if dataset.indexing_technique == 'high_quality':
embedding_model = ModelFactory.get_embedding_model(
tenant_id=dataset.tenant_id,
model_provider_name=dataset.embedding_model_provider,
model_name=dataset.embedding_model
model_manager = ModelManager()
embedding_model = model_manager.get_model_instance(
tenant_id=current_user.current_tenant_id,
provider=dataset.embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING,
model=dataset.embedding_model
)
# calc embedding use tokens
tokens = embedding_model.get_num_tokens(content)
model_type_instance = cast(TextEmbeddingModel, embedding_model.model_type_instance)
tokens = model_type_instance.get_num_tokens(
model=embedding_model.model,
credentials=embedding_model.credentials,
texts=[content]
)
max_position = db.session.query(func.max(DocumentSegment.position)).filter(
DocumentSegment.document_id == document.id
).scalar()
@@ -1037,10 +1056,12 @@ class SegmentService:
def multi_create_segment(cls, segments: list, document: Document, dataset: Dataset):
embedding_model = None
if dataset.indexing_technique == 'high_quality':
embedding_model = ModelFactory.get_embedding_model(
tenant_id=dataset.tenant_id,
model_provider_name=dataset.embedding_model_provider,
model_name=dataset.embedding_model
model_manager = ModelManager()
embedding_model = model_manager.get_model_instance(
tenant_id=current_user.current_tenant_id,
provider=dataset.embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING,
model=dataset.embedding_model
)
max_position = db.session.query(func.max(DocumentSegment.position)).filter(
DocumentSegment.document_id == document.id
@@ -1054,7 +1075,12 @@ class SegmentService:
tokens = 0
if dataset.indexing_technique == 'high_quality' and embedding_model:
# calc embedding use tokens
tokens = embedding_model.get_num_tokens(content)
model_type_instance = cast(TextEmbeddingModel, embedding_model.model_type_instance)
tokens = model_type_instance.get_num_tokens(
model=embedding_model.model,
credentials=embedding_model.credentials,
texts=[content]
)
segment_document = DocumentSegment(
tenant_id=current_user.current_tenant_id,
dataset_id=document.dataset_id,
@@ -1121,14 +1147,21 @@ class SegmentService:
segment_hash = helper.generate_text_hash(content)
tokens = 0
if dataset.indexing_technique == 'high_quality':
embedding_model = ModelFactory.get_embedding_model(
tenant_id=dataset.tenant_id,
model_provider_name=dataset.embedding_model_provider,
model_name=dataset.embedding_model
model_manager = ModelManager()
embedding_model = model_manager.get_model_instance(
tenant_id=current_user.current_tenant_id,
provider=dataset.embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING,
model=dataset.embedding_model
)
# calc embedding use tokens
tokens = embedding_model.get_num_tokens(content)
model_type_instance = cast(TextEmbeddingModel, embedding_model.model_type_instance)
tokens = model_type_instance.get_num_tokens(
model=embedding_model.model,
credentials=embedding_model.credentials,
texts=[content]
)
segment.content = content
segment.index_node_hash = segment_hash
segment.word_count = len(content)