fix: Update prompt message content types to use Literal and add union type for content (#17136)
Co-authored-by: 朱庆超 <zhuqingchao@xiaomi.com> Co-authored-by: crazywoola <427733928@qq.com>
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
from collections.abc import Sequence
|
||||
from enum import Enum, StrEnum
|
||||
from typing import Any, Optional, Union
|
||||
from typing import Annotated, Any, Literal, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, Field, field_serializer, field_validator
|
||||
|
||||
@@ -61,11 +61,7 @@ class PromptMessageContentType(StrEnum):
|
||||
|
||||
|
||||
class PromptMessageContent(BaseModel):
|
||||
"""
|
||||
Model class for prompt message content.
|
||||
"""
|
||||
|
||||
type: PromptMessageContentType
|
||||
pass
|
||||
|
||||
|
||||
class TextPromptMessageContent(PromptMessageContent):
|
||||
@@ -73,7 +69,7 @@ class TextPromptMessageContent(PromptMessageContent):
|
||||
Model class for text prompt message content.
|
||||
"""
|
||||
|
||||
type: PromptMessageContentType = PromptMessageContentType.TEXT
|
||||
type: Literal[PromptMessageContentType.TEXT] = PromptMessageContentType.TEXT
|
||||
data: str
|
||||
|
||||
|
||||
@@ -82,7 +78,6 @@ class MultiModalPromptMessageContent(PromptMessageContent):
|
||||
Model class for multi-modal prompt message content.
|
||||
"""
|
||||
|
||||
type: PromptMessageContentType
|
||||
format: str = Field(default=..., description="the format of multi-modal file")
|
||||
base64_data: str = Field(default="", description="the base64 data of multi-modal file")
|
||||
url: str = Field(default="", description="the url of multi-modal file")
|
||||
@@ -94,11 +89,11 @@ class MultiModalPromptMessageContent(PromptMessageContent):
|
||||
|
||||
|
||||
class VideoPromptMessageContent(MultiModalPromptMessageContent):
|
||||
type: PromptMessageContentType = PromptMessageContentType.VIDEO
|
||||
type: Literal[PromptMessageContentType.VIDEO] = PromptMessageContentType.VIDEO
|
||||
|
||||
|
||||
class AudioPromptMessageContent(MultiModalPromptMessageContent):
|
||||
type: PromptMessageContentType = PromptMessageContentType.AUDIO
|
||||
type: Literal[PromptMessageContentType.AUDIO] = PromptMessageContentType.AUDIO
|
||||
|
||||
|
||||
class ImagePromptMessageContent(MultiModalPromptMessageContent):
|
||||
@@ -110,12 +105,24 @@ class ImagePromptMessageContent(MultiModalPromptMessageContent):
|
||||
LOW = "low"
|
||||
HIGH = "high"
|
||||
|
||||
type: PromptMessageContentType = PromptMessageContentType.IMAGE
|
||||
type: Literal[PromptMessageContentType.IMAGE] = PromptMessageContentType.IMAGE
|
||||
detail: DETAIL = DETAIL.LOW
|
||||
|
||||
|
||||
class DocumentPromptMessageContent(MultiModalPromptMessageContent):
|
||||
type: PromptMessageContentType = PromptMessageContentType.DOCUMENT
|
||||
type: Literal[PromptMessageContentType.DOCUMENT] = PromptMessageContentType.DOCUMENT
|
||||
|
||||
|
||||
PromptMessageContentUnionTypes = Annotated[
|
||||
Union[
|
||||
TextPromptMessageContent,
|
||||
ImagePromptMessageContent,
|
||||
DocumentPromptMessageContent,
|
||||
AudioPromptMessageContent,
|
||||
VideoPromptMessageContent,
|
||||
],
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
|
||||
|
||||
class PromptMessage(BaseModel):
|
||||
@@ -124,7 +131,7 @@ class PromptMessage(BaseModel):
|
||||
"""
|
||||
|
||||
role: PromptMessageRole
|
||||
content: Optional[str | Sequence[PromptMessageContent]] = None
|
||||
content: Optional[str | list[PromptMessageContentUnionTypes]] = None
|
||||
name: Optional[str] = None
|
||||
|
||||
def is_empty(self) -> bool:
|
||||
|
Reference in New Issue
Block a user