feat(plugin): Add API endpoint for invoking LLM with structured output (#21624)

This commit is contained in:
Yeuoly
2025-06-27 15:57:44 +08:00
committed by GitHub
parent 7236c4895b
commit 87efe45240
2 changed files with 87 additions and 0 deletions

View File

@@ -17,6 +17,7 @@ from core.plugin.entities.request import (
RequestInvokeApp, RequestInvokeApp,
RequestInvokeEncrypt, RequestInvokeEncrypt,
RequestInvokeLLM, RequestInvokeLLM,
RequestInvokeLLMWithStructuredOutput,
RequestInvokeModeration, RequestInvokeModeration,
RequestInvokeParameterExtractorNode, RequestInvokeParameterExtractorNode,
RequestInvokeQuestionClassifierNode, RequestInvokeQuestionClassifierNode,
@@ -47,6 +48,21 @@ class PluginInvokeLLMApi(Resource):
return length_prefixed_response(0xF, generator()) return length_prefixed_response(0xF, generator())
class PluginInvokeLLMWithStructuredOutputApi(Resource):
@setup_required
@plugin_inner_api_only
@get_user_tenant
@plugin_data(payload_type=RequestInvokeLLMWithStructuredOutput)
def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeLLMWithStructuredOutput):
def generator():
response = PluginModelBackwardsInvocation.invoke_llm_with_structured_output(
user_model.id, tenant_model, payload
)
return PluginModelBackwardsInvocation.convert_to_event_stream(response)
return length_prefixed_response(0xF, generator())
class PluginInvokeTextEmbeddingApi(Resource): class PluginInvokeTextEmbeddingApi(Resource):
@setup_required @setup_required
@plugin_inner_api_only @plugin_inner_api_only
@@ -291,6 +307,7 @@ class PluginFetchAppInfoApi(Resource):
api.add_resource(PluginInvokeLLMApi, "/invoke/llm") api.add_resource(PluginInvokeLLMApi, "/invoke/llm")
api.add_resource(PluginInvokeLLMWithStructuredOutputApi, "/invoke/llm/structured-output")
api.add_resource(PluginInvokeTextEmbeddingApi, "/invoke/text-embedding") api.add_resource(PluginInvokeTextEmbeddingApi, "/invoke/text-embedding")
api.add_resource(PluginInvokeRerankApi, "/invoke/rerank") api.add_resource(PluginInvokeRerankApi, "/invoke/rerank")
api.add_resource(PluginInvokeTTSApi, "/invoke/tts") api.add_resource(PluginInvokeTTSApi, "/invoke/tts")

View File

@@ -2,11 +2,14 @@ import tempfile
from binascii import hexlify, unhexlify from binascii import hexlify, unhexlify
from collections.abc import Generator from collections.abc import Generator
from core.llm_generator.output_parser.structured_output import invoke_llm_with_structured_output
from core.model_manager import ModelManager from core.model_manager import ModelManager
from core.model_runtime.entities.llm_entities import ( from core.model_runtime.entities.llm_entities import (
LLMResult, LLMResult,
LLMResultChunk, LLMResultChunk,
LLMResultChunkDelta, LLMResultChunkDelta,
LLMResultChunkWithStructuredOutput,
LLMResultWithStructuredOutput,
) )
from core.model_runtime.entities.message_entities import ( from core.model_runtime.entities.message_entities import (
PromptMessage, PromptMessage,
@@ -16,6 +19,7 @@ from core.model_runtime.entities.message_entities import (
from core.plugin.backwards_invocation.base import BaseBackwardsInvocation from core.plugin.backwards_invocation.base import BaseBackwardsInvocation
from core.plugin.entities.request import ( from core.plugin.entities.request import (
RequestInvokeLLM, RequestInvokeLLM,
RequestInvokeLLMWithStructuredOutput,
RequestInvokeModeration, RequestInvokeModeration,
RequestInvokeRerank, RequestInvokeRerank,
RequestInvokeSpeech2Text, RequestInvokeSpeech2Text,
@@ -85,6 +89,72 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
return handle_non_streaming(response) return handle_non_streaming(response)
@classmethod
def invoke_llm_with_structured_output(
cls, user_id: str, tenant: Tenant, payload: RequestInvokeLLMWithStructuredOutput
):
"""
invoke llm with structured output
"""
model_instance = ModelManager().get_model_instance(
tenant_id=tenant.id,
provider=payload.provider,
model_type=payload.model_type,
model=payload.model,
)
model_schema = model_instance.model_type_instance.get_model_schema(payload.model, model_instance.credentials)
if not model_schema:
raise ValueError(f"Model schema not found for {payload.model}")
response = invoke_llm_with_structured_output(
provider=payload.provider,
model_schema=model_schema,
model_instance=model_instance,
prompt_messages=payload.prompt_messages,
json_schema=payload.structured_output_schema,
tools=payload.tools,
stop=payload.stop,
stream=True if payload.stream is None else payload.stream,
user=user_id,
model_parameters=payload.completion_params,
)
if isinstance(response, Generator):
def handle() -> Generator[LLMResultChunkWithStructuredOutput, None, None]:
for chunk in response:
if chunk.delta.usage:
llm_utils.deduct_llm_quota(
tenant_id=tenant.id, model_instance=model_instance, usage=chunk.delta.usage
)
chunk.prompt_messages = []
yield chunk
return handle()
else:
if response.usage:
llm_utils.deduct_llm_quota(tenant_id=tenant.id, model_instance=model_instance, usage=response.usage)
def handle_non_streaming(
response: LLMResultWithStructuredOutput,
) -> Generator[LLMResultChunkWithStructuredOutput, None, None]:
yield LLMResultChunkWithStructuredOutput(
model=response.model,
prompt_messages=[],
system_fingerprint=response.system_fingerprint,
structured_output=response.structured_output,
delta=LLMResultChunkDelta(
index=0,
message=response.message,
usage=response.usage,
finish_reason="",
),
)
return handle_non_streaming(response)
@classmethod @classmethod
def invoke_text_embedding(cls, user_id: str, tenant: Tenant, payload: RequestInvokeTextEmbedding): def invoke_text_embedding(cls, user_id: str, tenant: Tenant, payload: RequestInvokeTextEmbedding):
""" """