Skip to content

Add LLaVA OneVision model support #7693

New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Merged
merged 13 commits into from
Mar 18, 2025
Merged
2 changes: 2 additions & 0 deletions invokeai/app/invocations/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ class UIType(str, Enum, metaclass=MetaEnum):
ControlLoRAModel = "ControlLoRAModelField"
SigLipModel = "SigLipModelField"
FluxReduxModel = "FluxReduxModelField"
LlavaOnevisionModel = "LLaVAModelField"
# endregion

# region Misc Field Types
Expand Down Expand Up @@ -205,6 +206,7 @@ class FieldDescriptions:
freeu_b2 = "Scaling factor for stage 2 to amplify the contributions of backbone features."
instantx_control_mode = "The control mode for InstantX ControlNet union models. Ignored for other ControlNet models. The standard mapping is: canny (0), tile (1), depth (2), blur (3), pose (4), gray (5), low quality (6). Negative values will be treated as 'None'."
flux_redux_conditioning = "FLUX Redux conditioning tensor"
vllm_model = "The VLLM model to use"


class ImageField(BaseModel):
Expand Down
60 changes: 60 additions & 0 deletions invokeai/app/invocations/llava_onevision_vllm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
from typing import Any

import torch
from PIL.Image import Image
from pydantic import field_validator

from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.fields import FieldDescriptions, ImageField, InputField, UIComponent, UIType
from invokeai.app.invocations.model import ModelIdentifierField
from invokeai.app.invocations.primitives import StringOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.llava_onevision_model import LlavaOnevisionModel
from invokeai.backend.util.devices import TorchDevice


@invocation("llava_onevision_vllm", title="LLaVA OneVision VLLM", tags=["vllm"], category="vllm", version="1.0.0")
class LlavaOnevisionVllmInvocation(BaseInvocation):
"""Run a LLaVA OneVision VLLM model."""

images: list[ImageField] | ImageField | None = InputField(default=None, max_length=3, description="Input image.")
prompt: str = InputField(
default="",
description="Input text prompt.",
ui_component=UIComponent.Textarea,
)
vllm_model: ModelIdentifierField = InputField(
title="LLaVA Model Type",
description=FieldDescriptions.vllm_model,
ui_type=UIType.LlavaOnevisionModel,
)

@field_validator("images", mode="before")
def listify_images(cls, v: Any) -> list:
if v is None:
return v
if not isinstance(v, list):
return [v]
return v

def _get_images(self, context: InvocationContext) -> list[Image]:
if self.images is None:
return []

image_fields = self.images if isinstance(self.images, list) else [self.images]
return [context.images.get_pil(image_field.image_name, "RGB") for image_field in image_fields]

@torch.no_grad()
def invoke(self, context: InvocationContext) -> StringOutput:
images = self._get_images(context)

with context.models.load(self.vllm_model) as vllm_model:
assert isinstance(vllm_model, LlavaOnevisionModel)
output = vllm_model.run(
prompt=self.prompt,
images=images,
device=TorchDevice.choose_torch_device(),
dtype=TorchDevice.choose_torch_dtype(),
)

return StringOutput(value=output)
49 changes: 49 additions & 0 deletions invokeai/backend/llava_onevision_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from pathlib import Path
from typing import Optional

import torch
from PIL.Image import Image
from transformers import AutoProcessor, LlavaOnevisionForConditionalGeneration, LlavaOnevisionProcessor

from invokeai.backend.raw_model import RawModel


class LlavaOnevisionModel(RawModel):
def __init__(self, vllm_model: LlavaOnevisionForConditionalGeneration, processor: LlavaOnevisionProcessor):
self._vllm_model = vllm_model
self._processor = processor

@classmethod
def load_from_path(cls, path: str | Path):
vllm_model = LlavaOnevisionForConditionalGeneration.from_pretrained(path, local_files_only=True)
assert isinstance(vllm_model, LlavaOnevisionForConditionalGeneration)
processor = AutoProcessor.from_pretrained(path, local_files_only=True)
assert isinstance(processor, LlavaOnevisionProcessor)
return cls(vllm_model, processor)

def run(self, prompt: str, images: list[Image], device: torch.device, dtype: torch.dtype) -> str:
# TODO(ryand): Tune the max number of images that are useful for the model.
if len(images) > 3:
raise ValueError(
f"{len(images)} images were provided as input to the LLaVA OneVision model. "
"Pass <=3 images for good performance."
)

# Define a chat history and use `apply_chat_template` to get correctly formatted prompt.
# "content" is a list of dicts with types "text" or "image".
content = [{"type": "text", "text": prompt}]
# Add the correct number of images.
for _ in images:
content.append({"type": "image"})

conversation = [{"role": "user", "content": content}]
prompt = self._processor.apply_chat_template(conversation, add_generation_prompt=True)
inputs = self._processor(images=images or None, text=prompt, return_tensors="pt").to(device=device, dtype=dtype)
output = self._vllm_model.generate(**inputs, max_new_tokens=400, do_sample=False)
output_str: str = self._processor.decode(output[0][2:], skip_special_tokens=True)
# The output_str will include the prompt, so we extract the response.
response = output_str.split("assistant\n", 1)[1].strip()
return response

def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
self._vllm_model.to(device=device, dtype=dtype)
13 changes: 13 additions & 0 deletions invokeai/backend/model_manager/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ class ModelType(str, Enum):
SpandrelImageToImage = "spandrel_image_to_image"
SigLIP = "siglip"
FluxRedux = "flux_redux"
LlavaOnevision = "llava_onevision"


class SubModelType(str, Enum):
Expand Down Expand Up @@ -552,6 +553,17 @@ def get_tag() -> Tag:
return Tag(f"{ModelType.FluxRedux.value}.{ModelFormat.Checkpoint.value}")


class LlavaOnevisionConfig(DiffusersConfigBase):
"""Model config for Llava Onevision models."""

type: Literal[ModelType.LlavaOnevision] = ModelType.LlavaOnevision
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers

@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.LlavaOnevision.value}.{ModelFormat.Diffusers.value}")


def get_model_discriminator_value(v: Any) -> str:
"""
Computes the discriminator value for a model config.
Expand Down Expand Up @@ -601,6 +613,7 @@ def get_model_discriminator_value(v: Any) -> str:
Annotated[CLIPGEmbedDiffusersConfig, CLIPGEmbedDiffusersConfig.get_tag()],
Annotated[SigLIPConfig, SigLIPConfig.get_tag()],
Annotated[FluxReduxConfig, FluxReduxConfig.get_tag()],
Annotated[LlavaOnevisionConfig, LlavaOnevisionConfig.get_tag()],
],
Discriminator(get_model_discriminator_value),
]
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from pathlib import Path
from typing import Optional

from invokeai.backend.llava_onevision_model import LlavaOnevisionModel
from invokeai.backend.model_manager.config import (
AnyModel,
AnyModelConfig,
BaseModelType,
ModelFormat,
ModelType,
SubModelType,
)
from invokeai.backend.model_manager.load.load_default import ModelLoader
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry


@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.LlavaOnevision, format=ModelFormat.Diffusers)
class LlavaOnevisionModelLoader(ModelLoader):
"""Class for loading LLaVA Onevision VLLM models."""

def _load_model(
self,
config: AnyModelConfig,
submodel_type: Optional[SubModelType] = None,
) -> AnyModel:
if submodel_type is not None:
raise ValueError("Unexpected submodel requested for LLaVA OneVision model.")

model_path = Path(config.path)
model = LlavaOnevisionModel.load_from_path(model_path)
model.to(dtype=self._torch_dtype)
return model
13 changes: 13 additions & 0 deletions invokeai/backend/model_manager/probe.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ class ModelProbe(object):
"SD3Transformer2DModel": ModelType.Main,
"CLIPTextModelWithProjection": ModelType.CLIPEmbed,
"SiglipModel": ModelType.SigLIP,
"LlavaOnevisionForConditionalGeneration": ModelType.LlavaOnevision,
}

TYPE2VARIANT: Dict[ModelType, Callable[[str], Optional[AnyVariant]]] = {ModelType.CLIPEmbed: get_clip_variant_type}
Expand Down Expand Up @@ -767,6 +768,11 @@ def get_base_type(self) -> BaseModelType:
return BaseModelType.Flux


class LlavaOnevisionCheckpointProbe(CheckpointProbeBase):
def get_base_type(self) -> BaseModelType:
raise NotImplementedError()


########################################################
# classes for probing folders
#######################################################
Expand Down Expand Up @@ -1047,6 +1053,11 @@ def get_base_type(self) -> BaseModelType:
raise NotImplementedError()


class LlaveOnevisionFolderProbe(FolderProbeBase):
def get_base_type(self) -> BaseModelType:
return BaseModelType.Any


class T2IAdapterFolderProbe(FolderProbeBase):
def get_base_type(self) -> BaseModelType:
config_file = self.model_path / "config.json"
Expand Down Expand Up @@ -1082,6 +1093,7 @@ def get_base_type(self) -> BaseModelType:
ModelProbe.register_probe("diffusers", ModelType.SpandrelImageToImage, SpandrelImageToImageFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.SigLIP, SigLIPFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.FluxRedux, FluxReduxFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.LlavaOnevision, LlaveOnevisionFolderProbe)

ModelProbe.register_probe("checkpoint", ModelType.Main, PipelineCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.VAE, VaeCheckpointProbe)
Expand All @@ -1095,5 +1107,6 @@ def get_base_type(self) -> BaseModelType:
ModelProbe.register_probe("checkpoint", ModelType.SpandrelImageToImage, SpandrelImageToImageCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.SigLIP, SigLIPCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.FluxRedux, FluxReduxCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.LlavaOnevision, LlavaOnevisionCheckpointProbe)

ModelProbe.register_probe("onnx", ModelType.ONNX, ONNXFolderProbe)
11 changes: 11 additions & 0 deletions invokeai/backend/model_manager/starter_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -614,6 +614,16 @@ class StarterModelBundles(BaseModel):
)
# endregion

# region LlavaOnevisionModel
llava_onevision = StarterModel(
name="LLaVA Onevision Qwen2 0.5B",
base=BaseModelType.Any,
source="llava-hf/llava-onevision-qwen2-0.5b-ov-hf",
description="LLaVA Onevision VLLM model",
type=ModelType.LlavaOnevision,
)
# endregion

# List of starter models, displayed on the frontend.
# The order/sort of this list is not changed by the frontend - set it how you want it here.
STARTER_MODELS: list[StarterModel] = [
Expand Down Expand Up @@ -683,6 +693,7 @@ class StarterModelBundles(BaseModel):
clip_l_encoder,
siglip,
flux_redux,
llava_onevision,
]

sd1_bundle: list[StarterModel] = [
Expand Down
1 change: 1 addition & 0 deletions invokeai/frontend/web/public/locales/en.json
Original file line number Diff line number Diff line change
Expand Up @@ -846,6 +846,7 @@
"starterModels": "Starter Models",
"starterModelsInModelManager": "Starter Models can be found in Model Manager",
"controlLora": "Control LoRA",
"llavaOnevision": "LLaVA OneVision",
"syncModels": "Sync Models",
"textualInversions": "Textual Inversions",
"triggerPhrases": "Trigger Phrases",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import {
useEmbeddingModels,
useFluxReduxModels,
useIPAdapterModels,
useLLaVAModels,
useLoRAModels,
useMainModels,
useRefinerModels,
Expand Down Expand Up @@ -126,6 +127,12 @@ const ModelList = () => {
[fluxReduxModels, searchTerm, filteredModelType]
);

const [llavaOneVisionModels, { isLoading: isLoadingLlavaOneVisionModels }] = useLLaVAModels();
const filteredLlavaOneVisionModels = useMemo(
() => modelsFilter(llavaOneVisionModels, searchTerm, filteredModelType),
[llavaOneVisionModels, searchTerm, filteredModelType]
);

const totalFilteredModels = useMemo(() => {
return (
filteredMainModels.length +
Expand Down Expand Up @@ -236,6 +243,17 @@ const ModelList = () => {
{!isLoadingClipEmbedModels && filteredClipEmbedModels.length > 0 && (
<ModelListWrapper title={t('modelManager.clipEmbed')} modelList={filteredClipEmbedModels} key="clip-embed" />
)}

{/* LLaVA OneVision List */}
{isLoadingLlavaOneVisionModels && <FetchingModelsLoader loadingMessage="Loading LLaVA OneVision Models..." />}
{!isLoadingLlavaOneVisionModels && filteredLlavaOneVisionModels.length > 0 && (
<ModelListWrapper
title={t('modelManager.llavaOnevision')}
modelList={filteredLlavaOneVisionModels}
key="llava-onevision"
/>
)}

{/* Spandrel Image to Image List */}
{isLoadingSpandrelImageToImageModels && (
<FetchingModelsLoader loadingMessage="Loading Image-to-Image Models..." />
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ export const ModelTypeFilter = memo(() => {
control_lora: t('modelManager.controlLora'),
siglip: t('modelManager.siglip'),
flux_redux: t('modelManager.fluxRedux'),
llava_onevision: t('modelManager.llavaOnevision'),
}),
[t]
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ import {
isIntegerGeneratorFieldInputTemplate,
isIPAdapterModelFieldInputInstance,
isIPAdapterModelFieldInputTemplate,
isLLaVAModelFieldInputInstance,
isLLaVAModelFieldInputTemplate,
isLoRAModelFieldInputInstance,
isLoRAModelFieldInputTemplate,
isMainModelFieldInputInstance,
Expand Down Expand Up @@ -112,6 +114,7 @@ import FluxReduxModelFieldInputComponent from './inputs/FluxReduxModelFieldInput
import FluxVAEModelFieldInputComponent from './inputs/FluxVAEModelFieldInputComponent';
import ImageFieldInputComponent from './inputs/ImageFieldInputComponent';
import IPAdapterModelFieldInputComponent from './inputs/IPAdapterModelFieldInputComponent';
import LLaVAModelFieldInputComponent from './inputs/LLaVAModelFieldInputComponent';
import LoRAModelFieldInputComponent from './inputs/LoRAModelFieldInputComponent';
import MainModelFieldInputComponent from './inputs/MainModelFieldInputComponent';
import RefinerModelFieldInputComponent from './inputs/RefinerModelFieldInputComponent';
Expand Down Expand Up @@ -322,6 +325,13 @@ export const InputFieldRenderer = memo(({ nodeId, fieldName, settings }: Props)
return <ControlLoRAModelFieldInputComponent nodeId={nodeId} field={field} fieldTemplate={template} />;
}

if (isLLaVAModelFieldInputTemplate(template)) {
if (!isLLaVAModelFieldInputInstance(field)) {
return null;
}
return <LLaVAModelFieldInputComponent nodeId={nodeId} field={field} fieldTemplate={template} />;
}

if (isFluxVAEModelFieldInputTemplate(template)) {
if (!isFluxVAEModelFieldInputInstance(field)) {
return null;
Expand Down
Loading