Model Runtime (#1858)

Co-authored-by: StyleZhang <jasonapring2015@outlook.com>
Co-authored-by: Garfield Dai <dai.hai@foxmail.com>
Co-authored-by: chenhe <guchenhe@gmail.com>
Co-authored-by: jyong <jyong@dify.ai>
Co-authored-by: Joel <iamjoel007@gmail.com>
Co-authored-by: Yeuoly <admin@srmxy.cn>
This commit is contained in:
takatost
2024-01-02 23:42:00 +08:00
committed by GitHub
parent e91dd28a76
commit d069c668f8
807 changed files with 171310 additions and 23806 deletions

View File

@@ -2,16 +2,16 @@ import datetime
import logging
import time
import uuid
from typing import Optional, List
from typing import List, cast
import click
from celery import shared_task
from sqlalchemy import func
from werkzeug.exceptions import NotFound
from core.index.index import IndexBuilder
from core.indexing_runner import IndexingRunner
from core.model_providers.model_factory import ModelFactory
from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs import helper
@@ -51,18 +51,26 @@ def batch_create_segment_to_index_task(job_id: str, content: List, dataset_id: s
document_segments = []
embedding_model = None
if dataset.indexing_technique == 'high_quality':
embedding_model = ModelFactory.get_embedding_model(
model_manager = ModelManager()
embedding_model = model_manager.get_model_instance(
tenant_id=dataset.tenant_id,
model_provider_name=dataset.embedding_model_provider,
model_name=dataset.embedding_model
provider=dataset.embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING,
model=dataset.embedding_model
)
model_type_instance = embedding_model.model_type_instance
model_type_instance = cast(TextEmbeddingModel, model_type_instance)
for segment in content:
content = segment['content']
doc_id = str(uuid.uuid4())
segment_hash = helper.generate_text_hash(content)
# calc embedding use tokens
tokens = embedding_model.get_num_tokens(content) if embedding_model else 0
tokens = model_type_instance.get_num_tokens(
model=embedding_model.model,
credentials=embedding_model.credentials,
texts=[content]
) if embedding_model else 0
max_position = db.session.query(func.max(DocumentSegment.position)).filter(
DocumentSegment.document_id == dataset_document.id
).scalar()