fix: better memory usage from 800+ to 500+ (#11796)

Signed-off-by: yihong0618 <zouzou0208@gmail.com>
This commit is contained in:
yihong
2024-12-20 14:51:43 +08:00
committed by GitHub
parent 52201d95b1
commit 7b03a0316d
5 changed files with 56 additions and 26 deletions

View File

@@ -4,11 +4,10 @@ import json
import logging
import time
from collections.abc import Generator
from typing import Optional, Union, cast
from typing import TYPE_CHECKING, Optional, Union, cast
import google.auth.transport.requests
import requests
import vertexai.generative_models as glm
from anthropic import AnthropicVertex, Stream
from anthropic.types import (
ContentBlockDeltaEvent,
@@ -19,8 +18,6 @@ from anthropic.types import (
MessageStreamEvent,
)
from google.api_core import exceptions
from google.cloud import aiplatform
from google.oauth2 import service_account
from PIL import Image
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
@@ -47,6 +44,9 @@ from core.model_runtime.errors.invoke import (
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
if TYPE_CHECKING:
import vertexai.generative_models as glm
logger = logging.getLogger(__name__)
@@ -102,6 +102,8 @@ class VertexAiLargeLanguageModel(LargeLanguageModel):
:param stream: is stream response
:return: full response or stream response chunk generator result
"""
from google.oauth2 import service_account
# use Anthropic official SDK references
# - https://github.com/anthropics/anthropic-sdk-python
service_account_key = credentials.get("vertex_service_account_key", "")
@@ -406,13 +408,15 @@ class VertexAiLargeLanguageModel(LargeLanguageModel):
return text.rstrip()
def _convert_tools_to_glm_tool(self, tools: list[PromptMessageTool]) -> glm.Tool:
def _convert_tools_to_glm_tool(self, tools: list[PromptMessageTool]) -> "glm.Tool":
"""
Convert tool messages to glm tools
:param tools: tool messages
:return: glm tools
"""
import vertexai.generative_models as glm
return glm.Tool(
function_declarations=[
glm.FunctionDeclaration(
@@ -473,6 +477,10 @@ class VertexAiLargeLanguageModel(LargeLanguageModel):
:param user: unique user id
:return: full response or stream response chunk generator result
"""
import vertexai.generative_models as glm
from google.cloud import aiplatform
from google.oauth2 import service_account
config_kwargs = model_parameters.copy()
config_kwargs["max_output_tokens"] = config_kwargs.pop("max_tokens_to_sample", None)
@@ -522,7 +530,7 @@ class VertexAiLargeLanguageModel(LargeLanguageModel):
return self._handle_generate_response(model, credentials, response, prompt_messages)
def _handle_generate_response(
self, model: str, credentials: dict, response: glm.GenerationResponse, prompt_messages: list[PromptMessage]
self, model: str, credentials: dict, response: "glm.GenerationResponse", prompt_messages: list[PromptMessage]
) -> LLMResult:
"""
Handle llm response
@@ -554,7 +562,7 @@ class VertexAiLargeLanguageModel(LargeLanguageModel):
return result
def _handle_generate_stream_response(
self, model: str, credentials: dict, response: glm.GenerationResponse, prompt_messages: list[PromptMessage]
self, model: str, credentials: dict, response: "glm.GenerationResponse", prompt_messages: list[PromptMessage]
) -> Generator:
"""
Handle llm stream response
@@ -638,13 +646,15 @@ class VertexAiLargeLanguageModel(LargeLanguageModel):
return message_text
def _format_message_to_glm_content(self, message: PromptMessage) -> glm.Content:
def _format_message_to_glm_content(self, message: PromptMessage) -> "glm.Content":
"""
Format a single message into glm.Content for Google API
:param message: one PromptMessage
:return: glm Content representation of message
"""
import vertexai.generative_models as glm
if isinstance(message, UserPromptMessage):
glm_content = glm.Content(role="user", parts=[])

View File

@@ -2,12 +2,9 @@ import base64
import json
import time
from decimal import Decimal
from typing import Optional
from typing import TYPE_CHECKING, Optional
import tiktoken
from google.cloud import aiplatform
from google.oauth2 import service_account
from vertexai.language_models import TextEmbeddingModel as VertexTextEmbeddingModel
from core.entities.embedding_type import EmbeddingInputType
from core.model_runtime.entities.common_entities import I18nObject
@@ -24,6 +21,11 @@ from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
from core.model_runtime.model_providers.vertex_ai._common import _CommonVertexAi
if TYPE_CHECKING:
from vertexai.language_models import TextEmbeddingModel as VertexTextEmbeddingModel
else:
VertexTextEmbeddingModel = None
class VertexAiTextEmbeddingModel(_CommonVertexAi, TextEmbeddingModel):
"""
@@ -48,6 +50,10 @@ class VertexAiTextEmbeddingModel(_CommonVertexAi, TextEmbeddingModel):
:param input_type: input type
:return: embeddings result
"""
from google.cloud import aiplatform
from google.oauth2 import service_account
from vertexai.language_models import TextEmbeddingModel as VertexTextEmbeddingModel
service_account_key = credentials.get("vertex_service_account_key", "")
project_id = credentials["vertex_project_id"]
location = credentials["vertex_location"]
@@ -100,6 +106,10 @@ class VertexAiTextEmbeddingModel(_CommonVertexAi, TextEmbeddingModel):
:param credentials: model credentials
:return:
"""
from google.cloud import aiplatform
from google.oauth2 import service_account
from vertexai.language_models import TextEmbeddingModel as VertexTextEmbeddingModel
try:
service_account_key = credentials.get("vertex_service_account_key", "")
project_id = credentials["vertex_project_id"]