feat(plugin): Add API endpoint for invoking LLM with structured output (#21624)
This commit is contained in:
@@ -17,6 +17,7 @@ from core.plugin.entities.request import (
|
|||||||
RequestInvokeApp,
|
RequestInvokeApp,
|
||||||
RequestInvokeEncrypt,
|
RequestInvokeEncrypt,
|
||||||
RequestInvokeLLM,
|
RequestInvokeLLM,
|
||||||
|
RequestInvokeLLMWithStructuredOutput,
|
||||||
RequestInvokeModeration,
|
RequestInvokeModeration,
|
||||||
RequestInvokeParameterExtractorNode,
|
RequestInvokeParameterExtractorNode,
|
||||||
RequestInvokeQuestionClassifierNode,
|
RequestInvokeQuestionClassifierNode,
|
||||||
@@ -47,6 +48,21 @@ class PluginInvokeLLMApi(Resource):
|
|||||||
return length_prefixed_response(0xF, generator())
|
return length_prefixed_response(0xF, generator())
|
||||||
|
|
||||||
|
|
||||||
|
class PluginInvokeLLMWithStructuredOutputApi(Resource):
|
||||||
|
@setup_required
|
||||||
|
@plugin_inner_api_only
|
||||||
|
@get_user_tenant
|
||||||
|
@plugin_data(payload_type=RequestInvokeLLMWithStructuredOutput)
|
||||||
|
def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeLLMWithStructuredOutput):
|
||||||
|
def generator():
|
||||||
|
response = PluginModelBackwardsInvocation.invoke_llm_with_structured_output(
|
||||||
|
user_model.id, tenant_model, payload
|
||||||
|
)
|
||||||
|
return PluginModelBackwardsInvocation.convert_to_event_stream(response)
|
||||||
|
|
||||||
|
return length_prefixed_response(0xF, generator())
|
||||||
|
|
||||||
|
|
||||||
class PluginInvokeTextEmbeddingApi(Resource):
|
class PluginInvokeTextEmbeddingApi(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
@plugin_inner_api_only
|
@plugin_inner_api_only
|
||||||
@@ -291,6 +307,7 @@ class PluginFetchAppInfoApi(Resource):
|
|||||||
|
|
||||||
|
|
||||||
api.add_resource(PluginInvokeLLMApi, "/invoke/llm")
|
api.add_resource(PluginInvokeLLMApi, "/invoke/llm")
|
||||||
|
api.add_resource(PluginInvokeLLMWithStructuredOutputApi, "/invoke/llm/structured-output")
|
||||||
api.add_resource(PluginInvokeTextEmbeddingApi, "/invoke/text-embedding")
|
api.add_resource(PluginInvokeTextEmbeddingApi, "/invoke/text-embedding")
|
||||||
api.add_resource(PluginInvokeRerankApi, "/invoke/rerank")
|
api.add_resource(PluginInvokeRerankApi, "/invoke/rerank")
|
||||||
api.add_resource(PluginInvokeTTSApi, "/invoke/tts")
|
api.add_resource(PluginInvokeTTSApi, "/invoke/tts")
|
||||||
|
@@ -2,11 +2,14 @@ import tempfile
|
|||||||
from binascii import hexlify, unhexlify
|
from binascii import hexlify, unhexlify
|
||||||
from collections.abc import Generator
|
from collections.abc import Generator
|
||||||
|
|
||||||
|
from core.llm_generator.output_parser.structured_output import invoke_llm_with_structured_output
|
||||||
from core.model_manager import ModelManager
|
from core.model_manager import ModelManager
|
||||||
from core.model_runtime.entities.llm_entities import (
|
from core.model_runtime.entities.llm_entities import (
|
||||||
LLMResult,
|
LLMResult,
|
||||||
LLMResultChunk,
|
LLMResultChunk,
|
||||||
LLMResultChunkDelta,
|
LLMResultChunkDelta,
|
||||||
|
LLMResultChunkWithStructuredOutput,
|
||||||
|
LLMResultWithStructuredOutput,
|
||||||
)
|
)
|
||||||
from core.model_runtime.entities.message_entities import (
|
from core.model_runtime.entities.message_entities import (
|
||||||
PromptMessage,
|
PromptMessage,
|
||||||
@@ -16,6 +19,7 @@ from core.model_runtime.entities.message_entities import (
|
|||||||
from core.plugin.backwards_invocation.base import BaseBackwardsInvocation
|
from core.plugin.backwards_invocation.base import BaseBackwardsInvocation
|
||||||
from core.plugin.entities.request import (
|
from core.plugin.entities.request import (
|
||||||
RequestInvokeLLM,
|
RequestInvokeLLM,
|
||||||
|
RequestInvokeLLMWithStructuredOutput,
|
||||||
RequestInvokeModeration,
|
RequestInvokeModeration,
|
||||||
RequestInvokeRerank,
|
RequestInvokeRerank,
|
||||||
RequestInvokeSpeech2Text,
|
RequestInvokeSpeech2Text,
|
||||||
@@ -85,6 +89,72 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
|
|||||||
|
|
||||||
return handle_non_streaming(response)
|
return handle_non_streaming(response)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def invoke_llm_with_structured_output(
|
||||||
|
cls, user_id: str, tenant: Tenant, payload: RequestInvokeLLMWithStructuredOutput
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
invoke llm with structured output
|
||||||
|
"""
|
||||||
|
model_instance = ModelManager().get_model_instance(
|
||||||
|
tenant_id=tenant.id,
|
||||||
|
provider=payload.provider,
|
||||||
|
model_type=payload.model_type,
|
||||||
|
model=payload.model,
|
||||||
|
)
|
||||||
|
|
||||||
|
model_schema = model_instance.model_type_instance.get_model_schema(payload.model, model_instance.credentials)
|
||||||
|
|
||||||
|
if not model_schema:
|
||||||
|
raise ValueError(f"Model schema not found for {payload.model}")
|
||||||
|
|
||||||
|
response = invoke_llm_with_structured_output(
|
||||||
|
provider=payload.provider,
|
||||||
|
model_schema=model_schema,
|
||||||
|
model_instance=model_instance,
|
||||||
|
prompt_messages=payload.prompt_messages,
|
||||||
|
json_schema=payload.structured_output_schema,
|
||||||
|
tools=payload.tools,
|
||||||
|
stop=payload.stop,
|
||||||
|
stream=True if payload.stream is None else payload.stream,
|
||||||
|
user=user_id,
|
||||||
|
model_parameters=payload.completion_params,
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(response, Generator):
|
||||||
|
|
||||||
|
def handle() -> Generator[LLMResultChunkWithStructuredOutput, None, None]:
|
||||||
|
for chunk in response:
|
||||||
|
if chunk.delta.usage:
|
||||||
|
llm_utils.deduct_llm_quota(
|
||||||
|
tenant_id=tenant.id, model_instance=model_instance, usage=chunk.delta.usage
|
||||||
|
)
|
||||||
|
chunk.prompt_messages = []
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
return handle()
|
||||||
|
else:
|
||||||
|
if response.usage:
|
||||||
|
llm_utils.deduct_llm_quota(tenant_id=tenant.id, model_instance=model_instance, usage=response.usage)
|
||||||
|
|
||||||
|
def handle_non_streaming(
|
||||||
|
response: LLMResultWithStructuredOutput,
|
||||||
|
) -> Generator[LLMResultChunkWithStructuredOutput, None, None]:
|
||||||
|
yield LLMResultChunkWithStructuredOutput(
|
||||||
|
model=response.model,
|
||||||
|
prompt_messages=[],
|
||||||
|
system_fingerprint=response.system_fingerprint,
|
||||||
|
structured_output=response.structured_output,
|
||||||
|
delta=LLMResultChunkDelta(
|
||||||
|
index=0,
|
||||||
|
message=response.message,
|
||||||
|
usage=response.usage,
|
||||||
|
finish_reason="",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
return handle_non_streaming(response)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def invoke_text_embedding(cls, user_id: str, tenant: Tenant, payload: RequestInvokeTextEmbedding):
|
def invoke_text_embedding(cls, user_id: str, tenant: Tenant, payload: RequestInvokeTextEmbedding):
|
||||||
"""
|
"""
|
||||||
|
Reference in New Issue
Block a user