From 24151d855d66b7b494fbeb80c9e2f65fa5f1ab31 Mon Sep 17 00:00:00 2001 From: haotongl Date: Mon, 23 Dec 2024 20:54:41 +0800 Subject: [PATCH 01/58] add prompt depth anything model by modular transformer --- .../models/prompt_depth_anything/__init__.py | 52 ++ .../configuration_prompt_depth_anything.py | 156 ++++++ .../convert_prompt_depth_anything_to_hf.py | 322 +++++++++++ .../modeling_prompt_depth_anything.py | 523 ++++++++++++++++++ .../modular_prompt_depth_anything.py | 299 ++++++++++ 5 files changed, 1352 insertions(+) create mode 100644 src/transformers/models/prompt_depth_anything/__init__.py create mode 100644 src/transformers/models/prompt_depth_anything/configuration_prompt_depth_anything.py create mode 100644 src/transformers/models/prompt_depth_anything/convert_prompt_depth_anything_to_hf.py create mode 100644 src/transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py create mode 100644 src/transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py diff --git a/src/transformers/models/prompt_depth_anything/__init__.py b/src/transformers/models/prompt_depth_anything/__init__.py new file mode 100644 index 000000000000..5674d28175c5 --- /dev/null +++ b/src/transformers/models/prompt_depth_anything/__init__.py @@ -0,0 +1,52 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...file_utils import _LazyModule, is_torch_available +from ...utils import OptionalDependencyNotAvailable + + +_import_structure = {"configuration_prompt_depth_anything": ["PromptDepthAnythingConfig"]} + +try: + if not is_torch_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + pass +else: + _import_structure["modeling_prompt_depth_anything"] = [ + "PromptDepthAnythingForDepthEstimation", + "PromptDepthAnythingPreTrainedModel", + ] + + +if TYPE_CHECKING: + from .configuration_prompt_depth_anything import PromptDepthAnythingConfig + + try: + if not is_torch_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + pass + else: + from .modeling_prompt_depth_anything import ( + PromptDepthAnythingForDepthEstimation, + PromptDepthAnythingPreTrainedModel, + ) + + +else: + import sys + + sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__) diff --git a/src/transformers/models/prompt_depth_anything/configuration_prompt_depth_anything.py b/src/transformers/models/prompt_depth_anything/configuration_prompt_depth_anything.py new file mode 100644 index 000000000000..15e25015d2a3 --- /dev/null +++ b/src/transformers/models/prompt_depth_anything/configuration_prompt_depth_anything.py @@ -0,0 +1,156 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_prompt_depth_anything.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 + +import copy + +from ...configuration_utils import PretrainedConfig +from ...utils import logging +from ...utils.backbone_utils import verify_backbone_config_arguments +from ..auto.configuration_auto import CONFIG_MAPPING + + +logger = logging.get_logger(__name__) + + +class PromptDepthAnythingConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`PromptDepthAnythingModel`]. It is used to instantiate a PromptDepthAnything + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the PromptDepthAnything + [LiheYoung/depth-anything-small-hf](https://huggingface.co/LiheYoung/depth-anything-small-hf) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + backbone_config (`Union[Dict[str, Any], PretrainedConfig]`, *optional*): + The configuration of the backbone model. Only used in case `is_hybrid` is `True` or in case you want to + leverage the [`AutoBackbone`] API. + backbone (`str`, *optional*): + Name of backbone to use when `backbone_config` is `None`. If `use_pretrained_backbone` is `True`, this + will load the corresponding pretrained weights from the timm or transformers library. If `use_pretrained_backbone` + is `False`, this loads the backbone's config and uses that to initialize the backbone with random weights. + use_pretrained_backbone (`bool`, *optional*, defaults to `False`): + Whether to use pretrained weights for the backbone. + use_timm_backbone (`bool`, *optional*, defaults to `False`): + Whether or not to use the `timm` library for the backbone. If set to `False`, will use the [`AutoBackbone`] + API. + backbone_kwargs (`dict`, *optional*): + Keyword arguments to be passed to AutoBackbone when loading from a checkpoint + e.g. `{'out_indices': (0, 1, 2, 3)}`. Cannot be specified if `backbone_config` is set. + patch_size (`int`, *optional*, defaults to 14): + The size of the patches to extract from the backbone features. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + reassemble_hidden_size (`int`, *optional*, defaults to 384): + The number of input channels of the reassemble layers. + reassemble_factors (`List[int]`, *optional*, defaults to `[4, 2, 1, 0.5]`): + The up/downsampling factors of the reassemble layers. + neck_hidden_sizes (`List[str]`, *optional*, defaults to `[48, 96, 192, 384]`): + The hidden sizes to project to for the feature maps of the backbone. + fusion_hidden_size (`int`, *optional*, defaults to 64): + The number of channels before fusion. + head_in_index (`int`, *optional*, defaults to -1): + The index of the features to use in the depth estimation head. + head_hidden_size (`int`, *optional*, defaults to 32): + The number of output channels in the second convolution of the depth estimation head. + depth_estimation_type (`str`, *optional*, defaults to `"relative"`): + The type of depth estimation to use. Can be one of `["relative", "metric"]`. + max_depth (`float`, *optional*): + The maximum depth to use for the "metric" depth estimation head. 20 should be used for indoor models + and 80 for outdoor models. For "relative" depth estimation, this value is ignored. + + Example: + + ```python + >>> from transformers import PromptDepthAnythingConfig, PromptDepthAnythingForDepthEstimation + + >>> # Initializing a PromptDepthAnything small style configuration + >>> configuration = PromptDepthAnythingConfig() + + >>> # Initializing a model from the PromptDepthAnything small style configuration + >>> model = PromptDepthAnythingForDepthEstimation(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "prompt_depth_anything" + + def __init__( + self, + backbone_config=None, + backbone=None, + use_pretrained_backbone=False, + use_timm_backbone=False, + backbone_kwargs=None, + patch_size=14, + initializer_range=0.02, + reassemble_hidden_size=384, + reassemble_factors=[4, 2, 1, 0.5], + neck_hidden_sizes=[48, 96, 192, 384], + fusion_hidden_size=64, + head_in_index=-1, + head_hidden_size=32, + depth_estimation_type="relative", + max_depth=None, + **kwargs, + ): + super().__init__(**kwargs) + if backbone_config is None and backbone is None: + logger.info("`backbone_config` is `None`. Initializing the config with the default `Dinov2` backbone.") + backbone_config = CONFIG_MAPPING["dinov2"]( + image_size=518, + hidden_size=384, + num_attention_heads=6, + out_indices=[9, 10, 11, 12], + apply_layernorm=True, + reshape_hidden_states=False, + ) + elif isinstance(backbone_config, dict): + backbone_model_type = backbone_config.get("model_type") + config_class = CONFIG_MAPPING[backbone_model_type] + backbone_config = config_class.from_dict(backbone_config) + + verify_backbone_config_arguments( + use_timm_backbone=use_timm_backbone, + use_pretrained_backbone=use_pretrained_backbone, + backbone=backbone, + backbone_config=backbone_config, + backbone_kwargs=backbone_kwargs, + ) + + self.backbone_config = backbone_config + self.backbone = backbone + self.use_pretrained_backbone = use_pretrained_backbone + self.use_timm_backbone = use_timm_backbone + self.backbone_kwargs = backbone_kwargs + self.reassemble_hidden_size = reassemble_hidden_size + self.patch_size = patch_size + self.initializer_range = initializer_range + self.reassemble_factors = reassemble_factors + self.neck_hidden_sizes = neck_hidden_sizes + self.fusion_hidden_size = fusion_hidden_size + self.head_in_index = head_in_index + self.head_hidden_size = head_hidden_size + if depth_estimation_type not in ["relative", "metric"]: + raise ValueError("depth_estimation_type must be one of ['relative', 'metric']") + self.depth_estimation_type = depth_estimation_type + self.max_depth = max_depth if max_depth else 1 + + def to_dict(self): + """ + Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`]. Returns: + `Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance, + """ + output = copy.deepcopy(self.__dict__) + + if output["backbone_config"] is not None: + output["backbone_config"] = self.backbone_config.to_dict() + + output["model_type"] = self.__class__.model_type + return output diff --git a/src/transformers/models/prompt_depth_anything/convert_prompt_depth_anything_to_hf.py b/src/transformers/models/prompt_depth_anything/convert_prompt_depth_anything_to_hf.py new file mode 100644 index 000000000000..9d9f87e3a492 --- /dev/null +++ b/src/transformers/models/prompt_depth_anything/convert_prompt_depth_anything_to_hf.py @@ -0,0 +1,322 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert Depth Anything checkpoints from the original repository. URL: +https://github.com/LiheYoung/Depth-Anything""" + +import argparse +from pathlib import Path + +import requests +import torch +from huggingface_hub import hf_hub_download +from PIL import Image +import numpy as np + +from transformers import PromptDepthAnythingConfig, PromptDepthAnythingForDepthEstimation, Dinov2Config, DPTImageProcessor +from transformers.utils import logging + + +logging.set_verbosity_info() +logger = logging.get_logger(__name__) + + +def get_dpt_config(model_name): + if "small" in model_name or 'vits' in model_name: + out_indices = [3, 6, 9, 12] + backbone_config = Dinov2Config.from_pretrained( + "facebook/dinov2-small", out_indices=out_indices, apply_layernorm=True, reshape_hidden_states=False + ) + fusion_hidden_size = 64 + neck_hidden_sizes = [48, 96, 192, 384] + elif "base" in model_name or 'vitb' in model_name: + out_indices = [3, 6, 9, 12] + backbone_config = Dinov2Config.from_pretrained( + "facebook/dinov2-base", out_indices=out_indices, apply_layernorm=True, reshape_hidden_states=False + ) + fusion_hidden_size = 128 + neck_hidden_sizes = [96, 192, 384, 768] + elif "large" in model_name or 'vitl' in model_name: + out_indices = [5, 12, 18, 24] + backbone_config = Dinov2Config.from_pretrained( + "facebook/dinov2-large", out_indices=out_indices, apply_layernorm=True, reshape_hidden_states=False + ) + fusion_hidden_size = 256 + neck_hidden_sizes = [256, 512, 1024, 1024] + else: + raise NotImplementedError(f"Model not supported: {model_name}") + + depth_estimation_type = "metric" + max_depth = None + + config = PromptDepthAnythingConfig( + reassemble_hidden_size=backbone_config.hidden_size, + patch_size=backbone_config.patch_size, + backbone_config=backbone_config, + fusion_hidden_size=fusion_hidden_size, + neck_hidden_sizes=neck_hidden_sizes, + depth_estimation_type=depth_estimation_type, + max_depth=max_depth, + ) + + return config + + +def create_rename_keys(config): + rename_keys = [] + + # fmt: off + # stem + rename_keys.append(("pretrained.cls_token", "backbone.embeddings.cls_token")) + rename_keys.append(("pretrained.mask_token", "backbone.embeddings.mask_token")) + rename_keys.append(("pretrained.pos_embed", "backbone.embeddings.position_embeddings")) + rename_keys.append(("pretrained.patch_embed.proj.weight", "backbone.embeddings.patch_embeddings.projection.weight")) + rename_keys.append(("pretrained.patch_embed.proj.bias", "backbone.embeddings.patch_embeddings.projection.bias")) + + # Transfomer encoder + for i in range(config.backbone_config.num_hidden_layers): + rename_keys.append((f"pretrained.blocks.{i}.ls1.gamma", f"backbone.encoder.layer.{i}.layer_scale1.lambda1")) + rename_keys.append((f"pretrained.blocks.{i}.ls2.gamma", f"backbone.encoder.layer.{i}.layer_scale2.lambda1")) + rename_keys.append((f"pretrained.blocks.{i}.norm1.weight", f"backbone.encoder.layer.{i}.norm1.weight")) + rename_keys.append((f"pretrained.blocks.{i}.norm1.bias", f"backbone.encoder.layer.{i}.norm1.bias")) + rename_keys.append((f"pretrained.blocks.{i}.norm2.weight", f"backbone.encoder.layer.{i}.norm2.weight")) + rename_keys.append((f"pretrained.blocks.{i}.norm2.bias", f"backbone.encoder.layer.{i}.norm2.bias")) + rename_keys.append((f"pretrained.blocks.{i}.mlp.fc1.weight", f"backbone.encoder.layer.{i}.mlp.fc1.weight")) + rename_keys.append((f"pretrained.blocks.{i}.mlp.fc1.bias", f"backbone.encoder.layer.{i}.mlp.fc1.bias")) + rename_keys.append((f"pretrained.blocks.{i}.mlp.fc2.weight", f"backbone.encoder.layer.{i}.mlp.fc2.weight")) + rename_keys.append((f"pretrained.blocks.{i}.mlp.fc2.bias", f"backbone.encoder.layer.{i}.mlp.fc2.bias")) + rename_keys.append((f"pretrained.blocks.{i}.attn.proj.weight", f"backbone.encoder.layer.{i}.attention.output.dense.weight")) + rename_keys.append((f"pretrained.blocks.{i}.attn.proj.bias", f"backbone.encoder.layer.{i}.attention.output.dense.bias")) + + # Head + rename_keys.append(("pretrained.norm.weight", "backbone.layernorm.weight")) + rename_keys.append(("pretrained.norm.bias", "backbone.layernorm.bias")) + + # activation postprocessing (readout projections + resize blocks) + # Depth Anything does not use CLS token => readout_projects not required + + for i in range(4): + rename_keys.append((f"depth_head.projects.{i}.weight", f"neck.reassemble_stage.layers.{i}.projection.weight")) + rename_keys.append((f"depth_head.projects.{i}.bias", f"neck.reassemble_stage.layers.{i}.projection.bias")) + + if i != 2: + rename_keys.append((f"depth_head.resize_layers.{i}.weight", f"neck.reassemble_stage.layers.{i}.resize.weight")) + rename_keys.append((f"depth_head.resize_layers.{i}.bias", f"neck.reassemble_stage.layers.{i}.resize.bias")) + + # refinenet (tricky here) + mapping = {1:3, 2:2, 3:1, 4:0} + + for i in range(1, 5): + j = mapping[i] + rename_keys.append((f"depth_head.scratch.refinenet{i}.out_conv.weight", f"neck.fusion_stage.layers.{j}.projection.weight")) + rename_keys.append((f"depth_head.scratch.refinenet{i}.out_conv.bias", f"neck.fusion_stage.layers.{j}.projection.bias")) + rename_keys.append((f"depth_head.scratch.refinenet{i}.resConfUnit1.conv1.weight", f"neck.fusion_stage.layers.{j}.residual_layer1.convolution1.weight")) + rename_keys.append((f"depth_head.scratch.refinenet{i}.resConfUnit1.conv1.bias", f"neck.fusion_stage.layers.{j}.residual_layer1.convolution1.bias")) + rename_keys.append((f"depth_head.scratch.refinenet{i}.resConfUnit1.conv2.weight", f"neck.fusion_stage.layers.{j}.residual_layer1.convolution2.weight")) + rename_keys.append((f"depth_head.scratch.refinenet{i}.resConfUnit1.conv2.bias", f"neck.fusion_stage.layers.{j}.residual_layer1.convolution2.bias")) + rename_keys.append((f"depth_head.scratch.refinenet{i}.resConfUnit2.conv1.weight", f"neck.fusion_stage.layers.{j}.residual_layer2.convolution1.weight")) + rename_keys.append((f"depth_head.scratch.refinenet{i}.resConfUnit2.conv1.bias", f"neck.fusion_stage.layers.{j}.residual_layer2.convolution1.bias")) + rename_keys.append((f"depth_head.scratch.refinenet{i}.resConfUnit2.conv2.weight", f"neck.fusion_stage.layers.{j}.residual_layer2.convolution2.weight")) + rename_keys.append((f"depth_head.scratch.refinenet{i}.resConfUnit2.conv2.bias", f"neck.fusion_stage.layers.{j}.residual_layer2.convolution2.bias")) + rename_keys.append((f"depth_head.scratch.refinenet{i}.resConfUnit_depth.0.weight", f"neck.fusion_stage.layers.{j}.residual_layer_depth.convolution1.weight")) + rename_keys.append((f"depth_head.scratch.refinenet{i}.resConfUnit_depth.0.bias", f"neck.fusion_stage.layers.{j}.residual_layer_depth.convolution1.bias")) + rename_keys.append((f"depth_head.scratch.refinenet{i}.resConfUnit_depth.2.weight", f"neck.fusion_stage.layers.{j}.residual_layer_depth.convolution2.weight")) + rename_keys.append((f"depth_head.scratch.refinenet{i}.resConfUnit_depth.2.bias", f"neck.fusion_stage.layers.{j}.residual_layer_depth.convolution2.bias")) + rename_keys.append((f"depth_head.scratch.refinenet{i}.resConfUnit_depth.4.weight", f"neck.fusion_stage.layers.{j}.residual_layer_depth.convolution3.weight")) + rename_keys.append((f"depth_head.scratch.refinenet{i}.resConfUnit_depth.4.bias", f"neck.fusion_stage.layers.{j}.residual_layer_depth.convolution3.bias")) + + # scratch convolutions + for i in range(4): + rename_keys.append((f"depth_head.scratch.layer{i+1}_rn.weight", f"neck.convs.{i}.weight")) + + # head + rename_keys.append(("depth_head.scratch.output_conv1.weight", "head.conv1.weight")) + rename_keys.append(("depth_head.scratch.output_conv1.bias", "head.conv1.bias")) + rename_keys.append(("depth_head.scratch.output_conv2.0.weight", "head.conv2.weight")) + rename_keys.append(("depth_head.scratch.output_conv2.0.bias", "head.conv2.bias")) + rename_keys.append(("depth_head.scratch.output_conv2.2.weight", "head.conv3.weight")) + rename_keys.append(("depth_head.scratch.output_conv2.2.bias", "head.conv3.bias")) + + return rename_keys + + +# we split up the matrix of each encoder layer into queries, keys and values +def read_in_q_k_v(state_dict, config): + hidden_size = config.backbone_config.hidden_size + for i in range(config.backbone_config.num_hidden_layers): + # read in weights + bias of input projection layer (in original implementation, this is a single matrix + bias) + in_proj_weight = state_dict.pop(f"pretrained.blocks.{i}.attn.qkv.weight") + in_proj_bias = state_dict.pop(f"pretrained.blocks.{i}.attn.qkv.bias") + # next, add query, keys and values (in that order) to the state dict + state_dict[f"backbone.encoder.layer.{i}.attention.attention.query.weight"] = in_proj_weight[:hidden_size, :] + state_dict[f"backbone.encoder.layer.{i}.attention.attention.query.bias"] = in_proj_bias[:hidden_size] + state_dict[f"backbone.encoder.layer.{i}.attention.attention.key.weight"] = in_proj_weight[ + hidden_size : hidden_size * 2, : + ] + state_dict[f"backbone.encoder.layer.{i}.attention.attention.key.bias"] = in_proj_bias[ + hidden_size : hidden_size * 2 + ] + state_dict[f"backbone.encoder.layer.{i}.attention.attention.value.weight"] = in_proj_weight[-hidden_size:, :] + state_dict[f"backbone.encoder.layer.{i}.attention.attention.value.bias"] = in_proj_bias[-hidden_size:] + + +def rename_key(dct, old, new): + val = dct.pop(old) + dct[new] = val + + +# We will verify our results on an image of cute cats +def prepare_img(): + url = "http://images.cocodataset.org/val2017/000000039769.jpg" + im = Image.open(requests.get(url, stream=True).raw) + return im + + +name_to_checkpoint = { + "depth-anything-small": "pytorch_model.bin", + "depth-anything-base": "pytorch_model.bin", + "depth-anything-large": "pytorch_model.bin", + "depth-anything-v2-small": "depth_anything_v2_vits.pth", + "depth-anything-v2-base": "depth_anything_v2_vitb.pth", + "depth-anything-v2-large": "depth_anything_v2_vitl.pth", + "depth-anything-v2-metric-indoor-small": "depth_anything_v2_metric_hypersim_vits.pth", + "depth-anything-v2-metric-indoor-base": "depth_anything_v2_metric_hypersim_vitb.pth", + "depth-anything-v2-metric-indoor-large": "depth_anything_v2_metric_hypersim_vitl.pth", + "depth-anything-v2-metric-outdoor-small": "depth_anything_v2_metric_vkitti_vits.pth", + "depth-anything-v2-metric-outdoor-base": "depth_anything_v2_metric_vkitti_vitb.pth", + "depth-anything-v2-metric-outdoor-large": "depth_anything_v2_metric_vkitti_vitl.pth", + # v2-giant pending + "promptda_vits": "model.ckpt" +} + + +@torch.no_grad() +def convert_dpt_checkpoint(model_name, pytorch_dump_folder_path, push_to_hub, verify_logits): + """ + Copy/paste/tweak model's weights to our DPT structure. + """ + + # define DPT configuration + config = get_dpt_config(model_name) + + model_name_to_repo = { + "promptda_vits": "depth-anything/promptda_vits" + } + + # load original state_dict + repo_id = model_name_to_repo[model_name] + filename = name_to_checkpoint[model_name] + filepath = hf_hub_download( + repo_id=repo_id, + filename=f"{filename}", + ) + + state_dict = torch.load(filepath, map_location="cpu")['state_dict'] + state_dict = {key[9:]:state_dict[key] for key in state_dict} + # rename keys + rename_keys = create_rename_keys(config) + for src, dest in rename_keys: + rename_key(state_dict, src, dest) + # read in qkv matrices + read_in_q_k_v(state_dict, config) + + # load HuggingFace model + model = PromptDepthAnythingForDepthEstimation(config) + model.load_state_dict(state_dict, strict=False) + model.eval() + + processor = DPTImageProcessor( + do_resize=True, + size=756, + ensure_multiple_of=14, + keep_aspect_ratio=True, + do_rescale=True, + do_normalize=True, + image_mean=[0.485, 0.456, 0.406], + image_std=[0.229, 0.224, 0.225], + ) + url = "https://github.com/DepthAnything/PromptDA/blob/main/assets/example_images/image.jpg?raw=true" + image = Image.open(requests.get(url, stream=True).raw) + + pixel_values = processor(image, return_tensors="pt").pixel_values + + + prompt_depth_url = "https://github.com/DepthAnything/PromptDA/blob/main/assets/example_images/arkit_depth.png?raw=true" + prompt_depth = Image.open(requests.get(prompt_depth_url, stream=True).raw) + prompt_depth = torch.tensor((np.asarray(prompt_depth) / 1000.0).astype(np.float32)) + prompt_depth = prompt_depth.unsqueeze(0).unsqueeze(0) + + # Verify forward pass + with torch.no_grad(): + outputs = model(pixel_values, prompt_depth=prompt_depth) + predicted_depth = outputs.predicted_depth + + print("Shape of predicted depth:", predicted_depth.shape) + print("First values:", predicted_depth[0, 0, :3, :3]) + + # assert logits + if verify_logits: + expected_shape = torch.Size([1, 1, 756, 1008]) + if model_name == 'promptda_vits': + expected_slice = torch.tensor( + [[3.0100, 3.0016, 3.0219], [3.0046, 3.0137, 3.0275], [3.0083, 3.0191, 3.0292]] + ) + else: + raise ValueError("Not supported") + assert predicted_depth.shape == torch.Size(expected_shape) + assert torch.allclose(predicted_depth[0, 0, :3, :3], expected_slice, atol=1e-3) # 1mm tolerance + print("Looks ok!") + + if pytorch_dump_folder_path is not None: + Path(pytorch_dump_folder_path).mkdir(exist_ok=True) + print(f"Saving model and processor to {pytorch_dump_folder_path}") + model.save_pretrained(pytorch_dump_folder_path) + processor.save_pretrained(pytorch_dump_folder_path) + + if push_to_hub: + print("Pushing model and processor to hub...") + model.push_to_hub(repo_id=f"{model_name.title()}-hf") + processor.push_to_hub(repo_id=f"{model_name.title()}-hf") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # Required parameters + parser.add_argument( + "--model_name", + default="promptda_vits", + type=str, + choices=name_to_checkpoint.keys(), + help="Name of the model you'd like to convert.", + ) + parser.add_argument( + "--pytorch_dump_folder_path", + default=None, + type=str, + help="Path to the output PyTorch model directory.", + ) + parser.add_argument( + "--push_to_hub", + action="store_true", + help="Whether to push the model to the hub after conversion.", + ) + parser.add_argument( + "--verify_logits", + action="store_false", + required=False, + help="Whether to verify the logits after conversion.", + ) + + args = parser.parse_args() + convert_dpt_checkpoint(args.model_name, args.pytorch_dump_folder_path, args.push_to_hub, args.verify_logits) diff --git a/src/transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py b/src/transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py new file mode 100644 index 000000000000..d19b42a28372 --- /dev/null +++ b/src/transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py @@ -0,0 +1,523 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_prompt_depth_anything.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn as nn + +from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward, replace_return_docstrings +from ...modeling_outputs import DepthEstimatorOutput +from ...modeling_utils import PreTrainedModel +from ...utils.backbone_utils import load_backbone +from .configuration_prompt_depth_anything import PromptDepthAnythingConfig + + +# General docstring +_CONFIG_FOR_DOC = "PromptDepthAnythingConfig" + + +class PromptDepthAnythingResidualLayer(nn.Module): + def __init__(self, config): + super().__init__() + self.convolution1 = nn.Conv2d( + 1, + config.fusion_hidden_size, + kernel_size=3, + stride=1, + padding=1, + bias=True, + ) + self.activation1 = nn.ReLU(False) + + self.convolution2 = nn.Conv2d( + config.fusion_hidden_size, + config.fusion_hidden_size, + kernel_size=3, + stride=1, + padding=1, + bias=True, + ) + self.activation2 = nn.ReLU(False) + + self.convolution3 = nn.Conv2d( + config.fusion_hidden_size, + config.fusion_hidden_size, + kernel_size=3, + stride=1, + padding=1, + bias=True, + ) + + def forward(self, prompt_depth: torch.Tensor) -> torch.Tensor: + residual = prompt_depth + residual = self.convolution1(residual) + residual = self.activation1(residual) + residual = self.convolution2(residual) + residual = self.activation2(residual) + residual = self.convolution3(residual) + return residual + + +class PromptDepthAnythingPreActResidualLayer(nn.Module): + """ + ResidualConvUnit, pre-activate residual unit. + + Args: + config (`[PromptDepthAnythingConfig]`): + Model configuration class defining the model architecture. + """ + + def __init__(self, config): + super().__init__() + + self.activation1 = nn.ReLU() + self.convolution1 = nn.Conv2d( + config.fusion_hidden_size, + config.fusion_hidden_size, + kernel_size=3, + stride=1, + padding=1, + bias=True, + ) + + self.activation2 = nn.ReLU() + self.convolution2 = nn.Conv2d( + config.fusion_hidden_size, + config.fusion_hidden_size, + kernel_size=3, + stride=1, + padding=1, + bias=True, + ) + + def forward(self, hidden_state: torch.Tensor) -> torch.Tensor: + residual = hidden_state + hidden_state = self.activation1(hidden_state) + hidden_state = self.convolution1(hidden_state) + hidden_state = self.activation2(hidden_state) + hidden_state = self.convolution2(hidden_state) + + return hidden_state + residual + + +class PromptDepthAnythingFeatureFusionLayer(nn.Module): + """Feature fusion layer, merges feature maps from different stages. + + Args: + config (`[PromptDepthAnythingConfig]`): + Model configuration class defining the model architecture. + """ + + def __init__(self, config): + super().__init__() + + self.projection = nn.Conv2d(config.fusion_hidden_size, config.fusion_hidden_size, kernel_size=1, bias=True) + + self.residual_layer1 = PromptDepthAnythingPreActResidualLayer(config) + self.residual_layer2 = PromptDepthAnythingPreActResidualLayer(config) + self.residual_layer_depth = PromptDepthAnythingResidualLayer(config) + + def forward(self, hidden_state, residual=None, size=None, prompt_depth=None): + if residual is not None: + if hidden_state.shape != residual.shape: + residual = nn.functional.interpolate( + residual, size=(hidden_state.shape[2], hidden_state.shape[3]), mode="bilinear", align_corners=False + ) + hidden_state = hidden_state + self.residual_layer1(residual) + + hidden_state = self.residual_layer2(hidden_state) + + if prompt_depth is not None: + prompt_depth = nn.functional.interpolate( + prompt_depth, hidden_state.shape[2:], mode="bilinear", align_corners=False + ) + res = self.residual_layer_depth(prompt_depth) + hidden_state = hidden_state + res + + modifier = {"scale_factor": 2} if size is None else {"size": size} + + hidden_state = nn.functional.interpolate( + hidden_state, + **modifier, + mode="bilinear", + align_corners=True, + ) + hidden_state = self.projection(hidden_state) + + return hidden_state + + +class PromptDepthAnythingFeatureFusionStage(nn.Module): + def __init__(self, config): + super().__init__() + self.layers = nn.ModuleList() + for _ in range(len(config.neck_hidden_sizes)): + self.layers.append(PromptDepthAnythingFeatureFusionLayer(config)) + + def forward(self, hidden_states, size=None, prompt_depth=None): + # reversing the hidden_states, we start from the last + hidden_states = hidden_states[::-1] + + fused_hidden_states = [] + fused_hidden_state = None + + for idx, (hidden_state, layer) in enumerate(zip(hidden_states, self.layers)): + size = hidden_states[idx + 1].shape[2:] if idx != (len(hidden_states) - 1) else None + + if fused_hidden_state is None: + # first layer only uses the last hidden_state + fused_hidden_state = layer(hidden_state, size=size, prompt_depth=prompt_depth) + else: + fused_hidden_state = layer(fused_hidden_state, hidden_state, size=size, prompt_depth=prompt_depth) + + fused_hidden_states.append(fused_hidden_state) + + return fused_hidden_states + + +class PromptDepthAnythingReassembleLayer(nn.Module): + def __init__(self, config, channels, factor): + super().__init__() + self.projection = nn.Conv2d(in_channels=config.reassemble_hidden_size, out_channels=channels, kernel_size=1) + + # up/down sampling depending on factor + if factor > 1: + self.resize = nn.ConvTranspose2d(channels, channels, kernel_size=factor, stride=factor, padding=0) + elif factor == 1: + self.resize = nn.Identity() + elif factor < 1: + # so should downsample + self.resize = nn.Conv2d(channels, channels, kernel_size=3, stride=int(1 / factor), padding=1) + + def forward(self, hidden_state): + hidden_state = self.projection(hidden_state) + hidden_state = self.resize(hidden_state) + + return hidden_state + + +class PromptDepthAnythingReassembleStage(nn.Module): + """ + This class reassembles the hidden states of the backbone into image-like feature representations at various + resolutions. + + This happens in 3 stages: + 1. Take the patch embeddings and reshape them to image-like feature representations. + 2. Project the channel dimension of the hidden states according to `config.neck_hidden_sizes`. + 3. Resizing the spatial dimensions (height, width). + + Args: + config (`[PromptDepthAnythingConfig]`): + Model configuration class defining the model architecture. + """ + + def __init__(self, config): + super().__init__() + + self.config = config + self.layers = nn.ModuleList() + for channels, factor in zip(config.neck_hidden_sizes, config.reassemble_factors): + self.layers.append(PromptDepthAnythingReassembleLayer(config, channels=channels, factor=factor)) + + def forward(self, hidden_states: List[torch.Tensor], patch_height=None, patch_width=None) -> List[torch.Tensor]: + """ + Args: + hidden_states (`List[torch.FloatTensor]`, each of shape `(batch_size, sequence_length + 1, hidden_size)`): + List of hidden states from the backbone. + """ + out = [] + + for i, hidden_state in enumerate(hidden_states): + # reshape to (batch_size, num_channels, height, width) + hidden_state = hidden_state[:, 1:] + batch_size, _, num_channels = hidden_state.shape + hidden_state = hidden_state.reshape(batch_size, patch_height, patch_width, num_channels) + hidden_state = hidden_state.permute(0, 3, 1, 2).contiguous() + hidden_state = self.layers[i](hidden_state) + out.append(hidden_state) + + return out + + +class PromptDepthAnythingNeck(nn.Module): + """ + PromptDepthAnythingNeck. A neck is a module that is normally used between the backbone and the head. It takes a list of tensors as + input and produces another list of tensors as output. For PromptDepthAnything, it includes 2 stages: + + * PromptDepthAnythingReassembleStage + * PromptDepthAnythingFeatureFusionStage. + + Args: + config (dict): config dict. + """ + + def __init__(self, config): + super().__init__() + self.config = config + + self.reassemble_stage = PromptDepthAnythingReassembleStage(config) + + self.convs = nn.ModuleList() + for channel in config.neck_hidden_sizes: + self.convs.append(nn.Conv2d(channel, config.fusion_hidden_size, kernel_size=3, padding=1, bias=False)) + self.fusion_stage = PromptDepthAnythingFeatureFusionStage(config) + + def forward( + self, hidden_states: List[torch.Tensor], patch_height=None, patch_width=None, prompt_depth=None + ) -> List[torch.Tensor]: + """ + Args: + hidden_states (`List[torch.FloatTensor]`, each of shape `(batch_size, sequence_length, hidden_size)` or `(batch_size, hidden_size, height, width)`): + List of hidden states from the backbone. + """ + if not isinstance(hidden_states, (tuple, list)): + raise TypeError("hidden_states should be a tuple or list of tensors") + + if len(hidden_states) != len(self.config.neck_hidden_sizes): + raise ValueError("The number of hidden states should be equal to the number of neck hidden sizes.") + + # postprocess hidden states + hidden_states = self.reassemble_stage(hidden_states, patch_height, patch_width) + + features = [self.convs[i](feature) for i, feature in enumerate(hidden_states)] + + # fusion blocks + output = self.fusion_stage(features, prompt_depth=prompt_depth) + + return output + + +class PromptDepthAnythingPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = PromptDepthAnythingConfig + base_model_prefix = "prompt_depth_anything" + main_input_name = "pixel_values" + supports_gradient_checkpointing = True + + def _init_weights(self, module): + """Initialize the weights""" + if isinstance(module, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + +class PromptDepthAnythingDepthEstimationHead(nn.Module): + """ + Output head consisting of 3 convolutional layers. It progressively halves the feature dimension and upsamples + the predictions to the input resolution after the first convolutional layer (details can be found in the DPT paper's + supplementary material). The final activation function is either ReLU or Sigmoid, depending on the depth estimation + type (relative or metric). For metric depth estimation, the output is scaled by the maximum depth used during pretraining. + """ + + def __init__(self, config): + super().__init__() + + self.head_in_index = config.head_in_index + self.patch_size = config.patch_size + + features = config.fusion_hidden_size + self.conv1 = nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1) + self.conv2 = nn.Conv2d(features // 2, config.head_hidden_size, kernel_size=3, stride=1, padding=1) + self.activation1 = nn.ReLU() + self.conv3 = nn.Conv2d(config.head_hidden_size, 1, kernel_size=1, stride=1, padding=0) + if config.depth_estimation_type == "relative": + self.activation2 = nn.ReLU() + elif config.depth_estimation_type == "metric": + self.activation2 = nn.Sigmoid() + else: + raise ValueError(f"Unknown depth estimation type: {config.depth_estimation_type}") + self.max_depth = config.max_depth + + def forward(self, hidden_states: List[torch.Tensor], patch_height, patch_width) -> torch.Tensor: + hidden_states = hidden_states[self.head_in_index] + + predicted_depth = self.conv1(hidden_states) + predicted_depth = nn.functional.interpolate( + predicted_depth, + (int(patch_height * self.patch_size), int(patch_width * self.patch_size)), + mode="bilinear", + align_corners=True, + ) + predicted_depth = self.conv2(predicted_depth) + predicted_depth = self.activation1(predicted_depth) + predicted_depth = self.conv3(predicted_depth) + predicted_depth = self.activation2(predicted_depth) * self.max_depth + predicted_depth = predicted_depth.squeeze(dim=1) # shape (batch_size, height, width) + + return predicted_depth + + +PROMPT_DEPTH_ANYTHING_START_DOCSTRING = r""" + This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it + as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and + behavior. + + Parameters: + config ([`PromptDepthAnythingConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +PROMPT_DEPTH_ANYTHING_INPUTS_DOCSTRING = r""" + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See [`DPTImageProcessor.__call__`] + for details. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + """ + Prompt Depth Anything Model with a depth estimation head on top (consisting of 3 convolutional layers) e.g. for KITTI, NYUv2. + """, + PROMPT_DEPTH_ANYTHING_START_DOCSTRING, +) +class PromptDepthAnythingForDepthEstimation(PromptDepthAnythingPreTrainedModel): + _no_split_modules = ["DPTViTEmbeddings"] + + def __init__(self, config): + super().__init__(config) + + self.backbone = load_backbone(config) + self.neck = PromptDepthAnythingNeck(config) + self.head = PromptDepthAnythingDepthEstimationHead(config) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(PROMPT_DEPTH_ANYTHING_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=DepthEstimatorOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + pixel_values: torch.FloatTensor, + labels: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + prompt_depth: Optional[torch.FloatTensor] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.Tensor], DepthEstimatorOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, height, width)`, *optional*): + Ground truth depth estimation maps for computing the loss. + + Returns: + + Examples: + labels (`torch.LongTensor` of shape `(batch_size, height, width)`, *optional*): + Ground truth depth estimation maps for computing the loss. + + Returns: + + Examples: + + ```python + >>> from transformers import AutoImageProcessor, AutoModelForDepthEstimation + >>> import torch + >>> import numpy as np + >>> from PIL import Image + >>> import requests + + >>> url = "https://github.com/DepthAnything/PromptDA/blob/main/assets/example_images/image.jpg?raw=true" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> image_processor = AutoImageProcessor.from_pretrained("depth-anything/promptda_vits_hf") + >>> model = AutoModelForDepthEstimation.from_pretrained("depth-anything/promptda_vits_hf") + + >>> prompt_depth_url = "https://github.com/DepthAnything/PromptDA/blob/main/assets/example_images/arkit_depth.png?raw=true" + >>> prompt_depth = Image.open(requests.get(prompt_depth_url, stream=True).raw) + >>> prompt_depth = torch.tensor((np.asarray(prompt_depth) / 1000.0).astype(np.float32)) + >>> prompt_depth = prompt_depth.unsqueeze(0).unsqueeze(0) + + >>> # prepare image for the model + >>> inputs = image_processor(images=image, return_tensors="pt", prompt_depth=prompt_depth) + + >>> with torch.no_grad(): + ... outputs = model(**inputs) + + >>> # interpolate to original size + >>> post_processed_output = image_processor.post_process_depth_estimation( + ... outputs, + ... target_sizes=[(image.height, image.width)], + ... ) + + >>> # visualize the prediction + >>> predicted_depth = post_processed_output[0]["predicted_depth"] + >>> depth = predicted_depth * 1000 + >>> depth = depth.detach().cpu().numpy() + >>> depth = Image.fromarray(depth.astype("uint16")) # mm + ```""" + loss = None + if labels is not None: + raise NotImplementedError("Training is not implemented yet") + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + + outputs = self.backbone.forward_with_filtered_kwargs( + pixel_values, output_hidden_states=output_hidden_states, output_attentions=output_attentions + ) + hidden_states = outputs.feature_maps + + _, _, height, width = pixel_values.shape + patch_size = self.config.patch_size + patch_height = height // patch_size + patch_width = width // patch_size + + # normalize prompt depth + B = len(prompt_depth) + depth_min, depth_max = ( + torch.min(prompt_depth.reshape(B, -1), dim=1).values, + torch.max(prompt_depth.reshape(B, -1), dim=1).values, + ) + invalid_mask = (depth_max - depth_min) <= 0 + if invalid_mask.any(): + depth_max[invalid_mask] = depth_min[invalid_mask] + 1e-6 + depth_min, depth_max = depth_min.view(B, 1, 1, 1), depth_max.view(B, 1, 1, 1) + prompt_depth = (prompt_depth - depth_min) / (depth_max - depth_min) + # normalize done + + hidden_states = self.neck(hidden_states, patch_height, patch_width, prompt_depth=prompt_depth) + + predicted_depth = self.head(hidden_states, patch_height, patch_width) + # denormalize predicted depth + predicted_depth = predicted_depth * (depth_max - depth_min) + depth_min + # denormalize done + + if not return_dict: + if output_hidden_states: + output = (predicted_depth,) + outputs[1:] + else: + output = (predicted_depth,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return DepthEstimatorOutput( + loss=loss, + predicted_depth=predicted_depth, + hidden_states=outputs.hidden_states if output_hidden_states else None, + attentions=outputs.attentions, + ) diff --git a/src/transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py b/src/transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py new file mode 100644 index 000000000000..b2f75ea14492 --- /dev/null +++ b/src/transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py @@ -0,0 +1,299 @@ +from typing import List, Optional, Tuple, Union +import torch +import torch.nn as nn +from transformers.models.depth_anything.configuration_depth_anything import DepthAnythingConfig +from transformers.models.depth_anything.modeling_depth_anything import DepthAnythingFeatureFusionLayer, DepthAnythingForDepthEstimation, DepthAnythingNeck +from ...file_utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + replace_return_docstrings, +) +from ...modeling_outputs import DepthEstimatorOutput + +_CONFIG_FOR_DOC = "PromptDepthAnythingConfig" + +PROMPT_DEPTH_ANYTHING_START_DOCSTRING = r""" + This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it + as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and + behavior. + + Parameters: + config ([`PromptDepthAnythingConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +PROMPT_DEPTH_ANYTHING_INPUTS_DOCSTRING = r""" + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See [`DPTImageProcessor.__call__`] + for details. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple. +""" + + +class PromptDepthAnythingConfig(DepthAnythingConfig): + model_type = "prompt_depth_anything" + + +class PromptDepthAnythingResidualLayer(nn.Module): + + def __init__(self, config): + super().__init__() + self.convolution1 = nn.Conv2d( + 1, + config.fusion_hidden_size, + kernel_size=3, + stride=1, + padding=1, + bias=True, + ) + self.activation1 = nn.ReLU(False) + + self.convolution2 = nn.Conv2d( + config.fusion_hidden_size, + config.fusion_hidden_size, + kernel_size=3, + stride=1, + padding=1, + bias=True, + ) + self.activation2 = nn.ReLU(False) + + self.convolution3 = nn.Conv2d( + config.fusion_hidden_size, + config.fusion_hidden_size, + kernel_size=3, + stride=1, + padding=1, + bias=True, + ) + + + def forward(self, prompt_depth: torch.Tensor) -> torch.Tensor: + residual = prompt_depth + residual = self.convolution1(residual) + residual = self.activation1(residual) + residual = self.convolution2(residual) + residual = self.activation2(residual) + residual = self.convolution3(residual) + return residual + +class PromptDepthAnythingFeatureFusionLayer(DepthAnythingFeatureFusionLayer): + def __init__(self, config): + super().__init__(config) + self.residual_layer_depth = PromptDepthAnythingResidualLayer(config) + + def forward(self, hidden_state, residual=None, size=None, prompt_depth=None): + if residual is not None: + if hidden_state.shape != residual.shape: + residual = nn.functional.interpolate( + residual, size=(hidden_state.shape[2], hidden_state.shape[3]), mode="bilinear", align_corners=False + ) + hidden_state = hidden_state + self.residual_layer1(residual) + + hidden_state = self.residual_layer2(hidden_state) + + if prompt_depth is not None: + prompt_depth = nn.functional.interpolate( + prompt_depth, + hidden_state.shape[2:], + mode='bilinear', + align_corners=False + ) + res = self.residual_layer_depth(prompt_depth) + hidden_state = hidden_state + res + + modifier = {"scale_factor": 2} if size is None else {"size": size} + + hidden_state = nn.functional.interpolate( + hidden_state, + **modifier, + mode="bilinear", + align_corners=True, + ) + hidden_state = self.projection(hidden_state) + + return hidden_state + + +class PromptDepthAnythingFeatureFusionStage(nn.Module): + def __init__(self, config): + super().__init__() + self.layers = nn.ModuleList() + for _ in range(len(config.neck_hidden_sizes)): + self.layers.append(PromptDepthAnythingFeatureFusionLayer(config)) + + def forward(self, hidden_states, size=None, prompt_depth=None): + # reversing the hidden_states, we start from the last + hidden_states = hidden_states[::-1] + + fused_hidden_states = [] + fused_hidden_state = None + + for idx, (hidden_state, layer) in enumerate(zip(hidden_states, self.layers)): + size = hidden_states[idx + 1].shape[2:] if idx != (len(hidden_states) - 1) else None + + if fused_hidden_state is None: + # first layer only uses the last hidden_state + fused_hidden_state = layer(hidden_state, size=size, prompt_depth=prompt_depth) + else: + fused_hidden_state = layer(fused_hidden_state, hidden_state, size=size, prompt_depth=prompt_depth) + + fused_hidden_states.append(fused_hidden_state) + + return fused_hidden_states + + +class PromptDepthAnythingNeck(DepthAnythingNeck): + def __init__(self, config): + super().__init__(config) + self.fusion_stage = PromptDepthAnythingFeatureFusionStage(config) + + def forward(self, hidden_states: List[torch.Tensor], patch_height=None, patch_width=None, prompt_depth=None) -> List[torch.Tensor]: + """ + Args: + hidden_states (`List[torch.FloatTensor]`, each of shape `(batch_size, sequence_length, hidden_size)` or `(batch_size, hidden_size, height, width)`): + List of hidden states from the backbone. + """ + if not isinstance(hidden_states, (tuple, list)): + raise TypeError("hidden_states should be a tuple or list of tensors") + + if len(hidden_states) != len(self.config.neck_hidden_sizes): + raise ValueError("The number of hidden states should be equal to the number of neck hidden sizes.") + + # postprocess hidden states + hidden_states = self.reassemble_stage(hidden_states, patch_height, patch_width) + + features = [self.convs[i](feature) for i, feature in enumerate(hidden_states)] + + # fusion blocks + output = self.fusion_stage(features, prompt_depth=prompt_depth) + + return output + +@add_start_docstrings( + """ + Prompt Depth Anything Model with a depth estimation head on top (consisting of 3 convolutional layers) e.g. for KITTI, NYUv2. + """, + PROMPT_DEPTH_ANYTHING_START_DOCSTRING, +) +class PromptDepthAnythingForDepthEstimation(DepthAnythingForDepthEstimation): + def __init__(self, config): + super().__init__(config) + self.neck = PromptDepthAnythingNeck(config) + self.post_init() + + @add_start_docstrings_to_model_forward(PROMPT_DEPTH_ANYTHING_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=DepthEstimatorOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + pixel_values: torch.FloatTensor, + labels: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + prompt_depth: Optional[torch.FloatTensor] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.Tensor], DepthEstimatorOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, height, width)`, *optional*): + Ground truth depth estimation maps for computing the loss. + + Returns: + + Examples: + ```python + >>> from transformers import AutoImageProcessor, AutoModelForDepthEstimation + >>> import torch + >>> import numpy as np + >>> from PIL import Image + >>> import requests + + >>> url = "https://github.com/DepthAnything/PromptDA/blob/main/assets/example_images/image.jpg?raw=true" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> image_processor = AutoImageProcessor.from_pretrained("depth-anything/promptda_vits_hf") + >>> model = AutoModelForDepthEstimation.from_pretrained("depth-anything/promptda_vits_hf") + + >>> prompt_depth_url = "https://github.com/DepthAnything/PromptDA/blob/main/assets/example_images/arkit_depth.png?raw=true" + >>> prompt_depth = Image.open(requests.get(prompt_depth_url, stream=True).raw) + >>> prompt_depth = torch.tensor((np.asarray(prompt_depth) / 1000.0).astype(np.float32)) + >>> prompt_depth = prompt_depth.unsqueeze(0).unsqueeze(0) + + >>> # prepare image for the model + >>> inputs = image_processor(images=image, return_tensors="pt", prompt_depth=prompt_depth) + + >>> with torch.no_grad(): + ... outputs = model(**inputs) + + >>> # interpolate to original size + >>> post_processed_output = image_processor.post_process_depth_estimation( + ... outputs, + ... target_sizes=[(image.height, image.width)], + ... ) + + >>> # visualize the prediction + >>> predicted_depth = post_processed_output[0]["predicted_depth"] + >>> depth = predicted_depth * 1000 + >>> depth = depth.detach().cpu().numpy() + >>> depth = Image.fromarray(depth.astype("uint16")) # mm + ```""" + loss = None + if labels is not None: + raise NotImplementedError("Training is not implemented yet") + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + + outputs = self.backbone.forward_with_filtered_kwargs( + pixel_values, output_hidden_states=output_hidden_states, output_attentions=output_attentions + ) + hidden_states = outputs.feature_maps + + _, _, height, width = pixel_values.shape + patch_size = self.config.patch_size + patch_height = height // patch_size + patch_width = width // patch_size + + + # normalize prompt depth + B = len(prompt_depth) + depth_min, depth_max = torch.min(prompt_depth.reshape(B, -1), dim=1).values, torch.max(prompt_depth.reshape(B, -1), dim=1).values + invalid_mask = (depth_max - depth_min) <= 0 + if invalid_mask.any(): + depth_max[invalid_mask] = depth_min[invalid_mask] + 1e-6 + depth_min, depth_max = depth_min.view(B, 1, 1, 1), depth_max.view(B, 1, 1, 1) + prompt_depth = (prompt_depth - depth_min) / (depth_max - depth_min) + # normalize done + + + hidden_states = self.neck(hidden_states, patch_height, patch_width, prompt_depth=prompt_depth) + + predicted_depth = self.head(hidden_states, patch_height, patch_width) + # denormalize predicted depth + predicted_depth = predicted_depth * (depth_max - depth_min) + depth_min + # denormalize done + + if not return_dict: + if output_hidden_states: + output = (predicted_depth,) + outputs[1:] + else: + output = (predicted_depth,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return DepthEstimatorOutput( + loss=loss, + predicted_depth=predicted_depth, + hidden_states=outputs.hidden_states if output_hidden_states else None, + attentions=outputs.attentions, + ) \ No newline at end of file From 7e6dcaabcc3ef59b63989922fd3765bb30d56426 Mon Sep 17 00:00:00 2001 From: haotongl Date: Mon, 23 Dec 2024 20:56:51 +0800 Subject: [PATCH 02/58] add prompt depth anything docs and imports --- docs/source/en/_toctree.yml | 2 + docs/source/en/index.md | 1 + .../en/model_doc/prompt_depth_anything.md | 99 ++++++ src/transformers/__init__.py | 14 + src/transformers/models/__init__.py | 1 + .../models/auto/configuration_auto.py | 2 + src/transformers/models/auto/modeling_auto.py | 1 + src/transformers/utils/dummy_pt_objects.py | 14 + .../models/prompt_depth_anything/__init__.py | 0 .../test_modeling_prompt_depth_anything.py | 289 ++++++++++++++++++ 10 files changed, 423 insertions(+) create mode 100644 docs/source/en/model_doc/prompt_depth_anything.md create mode 100644 tests/models/prompt_depth_anything/__init__.py create mode 100644 tests/models/prompt_depth_anything/test_modeling_prompt_depth_anything.py diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 18de03e1df80..a72ee3b17dde 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -550,6 +550,8 @@ title: PhoBERT - local: model_doc/plbart title: PLBart + - local: model_doc/prompt_depth_anything + title: Prompt Depth Anything - local: model_doc/prophetnet title: ProphetNet - local: model_doc/qdqbert diff --git a/docs/source/en/index.md b/docs/source/en/index.md index 967049d89cbe..a96d7eee4f2f 100644 --- a/docs/source/en/index.md +++ b/docs/source/en/index.md @@ -275,6 +275,7 @@ Flax), PyTorch, and/or TensorFlow. | [PLBart](model_doc/plbart) | ✅ | ❌ | ❌ | | [PoolFormer](model_doc/poolformer) | ✅ | ❌ | ❌ | | [Pop2Piano](model_doc/pop2piano) | ✅ | ❌ | ❌ | +| [PromptDepthAnything](model_doc/prompt_depth_anything) | ✅ | ❌ | ❌ | | [ProphetNet](model_doc/prophetnet) | ✅ | ❌ | ❌ | | [PVT](model_doc/pvt) | ✅ | ❌ | ❌ | | [PVTv2](model_doc/pvt_v2) | ✅ | ❌ | ❌ | diff --git a/docs/source/en/model_doc/prompt_depth_anything.md b/docs/source/en/model_doc/prompt_depth_anything.md new file mode 100644 index 000000000000..bc7b75133c8e --- /dev/null +++ b/docs/source/en/model_doc/prompt_depth_anything.md @@ -0,0 +1,99 @@ + + +# Prompt Depth Anything + +## Overview + +The Prompt Depth Anything model was introduced in [Prompting Depth Anything for 4K Resolution Accurate Metric Depth Estimation](https://promptda.github.io/) by Haotong Lin, Sida Peng, Jingxiao Chen, Songyou Peng, Jiaming Sun, Minghuan Liu, Hujun Bao, Jiashi Feng, Xiaowei Zhou, Bingyi Kang. + + + +The abstract from the paper is as follows: + +*Prompts play a critical role in unleashing the power of language and vision foundation models for specific tasks. For the first time, we introduce prompting into depth foundation models, creating a new paradigm for metric depth estimation termed Prompt Depth Anything. Specifically, we use a low-cost LiDAR as the prompt to guide the Depth Anything model for accurate metric depth output, achieving up to 4K resolution. Our approach centers on a concise prompt fusion design that integrates the LiDAR at multiple scales within the depth decoder. To address training challenges posed by limited datasets containing both LiDAR depth and precise GT depth, we propose a scalable data pipeline that includes synthetic data LiDAR simulation and real data pseudo GT depth generation. Our approach sets new state-of-the-arts on the ARKitScenes and ScanNet++ datasets and benefits downstream applications, including 3D reconstruction and generalized robotic grasping.* + + + + Prompt Depth Anything overview. Taken from the original paper. + + + +## Usage example + +The transformers library allows you to use the model with just a few lines of code: + +```python +>>> from transformers import AutoImageProcessor, AutoModelForDepthEstimation +>>> import torch +>>> import numpy as np +>>> from PIL import Image +>>> import requests + +>>> url = "https://github.com/DepthAnything/PromptDA/blob/main/assets/example_images/image.jpg?raw=true" +>>> image = Image.open(requests.get(url, stream=True).raw) + +>>> image_processor = AutoImageProcessor.from_pretrained("depth-anything/promptda_vits_hf") +>>> model = AutoModelForDepthEstimation.from_pretrained("depth-anything/promptda_vits_hf") + +>>> prompt_depth_url = "https://github.com/DepthAnything/PromptDA/blob/main/assets/example_images/arkit_depth.png?raw=true" +>>> prompt_depth = Image.open(requests.get(prompt_depth_url, stream=True).raw) +>>> prompt_depth = torch.tensor((np.asarray(prompt_depth) / 1000.0).astype(np.float32)) +>>> prompt_depth = prompt_depth.unsqueeze(0).unsqueeze(0) + +>>> # prepare image for the model +>>> inputs = image_processor(images=image, return_tensors="pt", prompt_depth=prompt_depth) + +>>> with torch.no_grad(): +... outputs = model(**inputs) + +>>> # interpolate to original size +>>> post_processed_output = image_processor.post_process_depth_estimation( +... outputs, +... target_sizes=[(image.height, image.width)], +... ) + +>>> # visualize the prediction +>>> predicted_depth = post_processed_output[0]["predicted_depth"] +>>> depth = predicted_depth * 1000 +>>> depth = depth.detach().cpu().numpy() +>>> depth = Image.fromarray(depth.astype("uint16")) # mm +``` + +## Resources + +A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with Prompt Depth Anything. + +- [Prompt Depth Anything Demo](https://huggingface.co/spaces/depth-anything/PromptDA) +- [Prompt Depth Anything Interactive Results](https://promptda.github.io/interactive.html) + +If you're interested in submitting a resource to be included here, please feel free to open a Pull Request and we'll review it! The resource should ideally demonstrate something new instead of duplicating an existing resource. + +## PromptDepthAnythingConfig + +[[autodoc]] PromptDepthAnythingConfig + +## PromptDepthAnythingForDepthEstimation + +[[autodoc]] PromptDepthAnythingForDepthEstimation + - forward \ No newline at end of file diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 5510ac6c8ad5..ba434ff3247c 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -693,6 +693,9 @@ "ProphetNetConfig", "ProphetNetTokenizer", ], + "models.prompt_depth_anything": [ + "PromptDepthAnythingConfig", + ], "models.pvt": ["PvtConfig"], "models.pvt_v2": ["PvtV2Config"], "models.qwen2": [ @@ -3191,6 +3194,12 @@ "ProphetNetPreTrainedModel", ] ) + _import_structure["models.prompt_depth_anything"].extend( + [ + "PromptDepthAnythingForDepthEstimation", + "PromptDepthAnythingPreTrainedModel", + ] + ) _import_structure["models.pvt"].extend( [ "PvtForImageClassification", @@ -5682,6 +5691,7 @@ from .models.pop2piano import ( Pop2PianoConfig, ) + from .models.prompt_depth_anything import PromptDepthAnythingConfig from .models.prophetnet import ( ProphetNetConfig, ProphetNetTokenizer, @@ -7819,6 +7829,10 @@ Pop2PianoForConditionalGeneration, Pop2PianoPreTrainedModel, ) + from .models.prompt_depth_anything import ( + PromptDepthAnythingForDepthEstimation, + PromptDepthAnythingPreTrainedModel, + ) from .models.prophetnet import ( ProphetNetDecoder, ProphetNetEncoder, diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py index 7fcaddde704c..fc59e75971ef 100644 --- a/src/transformers/models/__init__.py +++ b/src/transformers/models/__init__.py @@ -207,6 +207,7 @@ plbart, poolformer, pop2piano, + prompt_depth_anything, prophetnet, pvt, pvt_v2, diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index 69ce8efa10c7..e46fd0d18227 100644 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -227,6 +227,7 @@ ("plbart", "PLBartConfig"), ("poolformer", "PoolFormerConfig"), ("pop2piano", "Pop2PianoConfig"), + ("prompt_depth_anything", "PromptDepthAnythingConfig"), ("prophetnet", "ProphetNetConfig"), ("pvt", "PvtConfig"), ("pvt_v2", "PvtV2Config"), @@ -554,6 +555,7 @@ ("plbart", "PLBart"), ("poolformer", "PoolFormer"), ("pop2piano", "Pop2Piano"), + ("prompt_depth_anything", "PromptDepthAnything"), ("prophetnet", "ProphetNet"), ("pvt", "PVT"), ("pvt_v2", "PVTv2"), diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index e8a2dece4324..52cb7223923b 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -893,6 +893,7 @@ ("depth_anything", "DepthAnythingForDepthEstimation"), ("dpt", "DPTForDepthEstimation"), ("glpn", "GLPNForDepthEstimation"), + ("prompt_depth_anything", "PromptDepthAnythingForDepthEstimation"), ("zoedepth", "ZoeDepthForDepthEstimation"), ] ) diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index e3463461ea07..11f5aa1c86a2 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -7605,6 +7605,20 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) +class PromptDepthAnythingForDepthEstimation(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class PromptDepthAnythingPreTrainedModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + class ProphetNetDecoder(metaclass=DummyObject): _backends = ["torch"] diff --git a/tests/models/prompt_depth_anything/__init__.py b/tests/models/prompt_depth_anything/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/models/prompt_depth_anything/test_modeling_prompt_depth_anything.py b/tests/models/prompt_depth_anything/test_modeling_prompt_depth_anything.py new file mode 100644 index 000000000000..f43a078edd66 --- /dev/null +++ b/tests/models/prompt_depth_anything/test_modeling_prompt_depth_anything.py @@ -0,0 +1,289 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Testing suite for the PyTorch Prompt Depth Anything model.""" + +import unittest + +from transformers import PromptDepthAnythingConfig, Dinov2Config +from transformers.file_utils import is_torch_available, is_vision_available +from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_4 +from transformers.testing_utils import require_torch, require_vision, slow, torch_device + +from ...test_configuration_common import ConfigTester +from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor +from ...test_pipeline_mixin import PipelineTesterMixin + +import requests +import numpy as np + +if is_torch_available(): + import torch + + from transformers import PromptDepthAnythingForDepthEstimation + + +if is_vision_available(): + from PIL import Image + + from transformers import DPTImageProcessor + + +class PromptDepthAnythingModelTester: + def __init__( + self, + parent, + batch_size=2, + num_channels=3, + image_size=32, + patch_size=16, + use_labels=True, + num_labels=3, + is_training=True, + hidden_size=4, + num_hidden_layers=2, + num_attention_heads=2, + intermediate_size=8, + out_features=["stage1", "stage2"], + apply_layernorm=False, + reshape_hidden_states=False, + neck_hidden_sizes=[2, 2], + fusion_hidden_size=6, + ): + self.parent = parent + self.batch_size = batch_size + self.num_channels = num_channels + self.image_size = image_size + self.patch_size = patch_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.intermediate_size = intermediate_size + self.out_features = out_features + self.apply_layernorm = apply_layernorm + self.reshape_hidden_states = reshape_hidden_states + self.use_labels = use_labels + self.num_labels = num_labels + self.is_training = is_training + self.neck_hidden_sizes = neck_hidden_sizes + self.fusion_hidden_size = fusion_hidden_size + self.seq_length = (self.image_size // self.patch_size) ** 2 + 1 + + def prepare_config_and_inputs(self): + pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size]) + + labels = None + if self.use_labels: + labels = ids_tensor([self.batch_size, self.image_size, self.image_size], self.num_labels) + + config = self.get_config() + + return config, pixel_values, labels + + def get_config(self): + return PromptDepthAnythingConfig( + backbone_config=self.get_backbone_config(), + reassemble_hidden_size=self.hidden_size, + patch_size=self.patch_size, + neck_hidden_sizes=self.neck_hidden_sizes, + fusion_hidden_size=self.fusion_hidden_size, + ) + + def get_backbone_config(self): + return Dinov2Config( + image_size=self.image_size, + patch_size=self.patch_size, + num_channels=self.num_channels, + hidden_size=self.hidden_size, + num_hidden_layers=self.num_hidden_layers, + num_attention_heads=self.num_attention_heads, + intermediate_size=self.intermediate_size, + is_training=self.is_training, + out_features=self.out_features, + reshape_hidden_states=self.reshape_hidden_states, + ) + + def create_and_check_for_depth_estimation(self, config, pixel_values, labels): + config.num_labels = self.num_labels + model = PromptDepthAnythingForDepthEstimation(config) + model.to(torch_device) + model.eval() + result = model(pixel_values) + self.parent.assertEqual(result.predicted_depth.shape, (self.batch_size, self.image_size, self.image_size)) + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + config, pixel_values, labels = config_and_inputs + inputs_dict = {"pixel_values": pixel_values} + return config, inputs_dict + + +@require_torch +class PromptDepthAnythingModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): + """ + Here we also overwrite some of the tests of test_modeling_common.py, as Prompt Depth Anything does not use input_ids, inputs_embeds, + attention_mask and seq_length. + """ + + all_model_classes = (PromptDepthAnythingForDepthEstimation,) if is_torch_available() else () + pipeline_model_mapping = {"depth-estimation": PromptDepthAnythingForDepthEstimation} if is_torch_available() else {} + + test_pruning = False + test_resize_embeddings = False + test_head_masking = False + + def setUp(self): + self.model_tester = PromptDepthAnythingModelTester(self) + self.config_tester = ConfigTester( + self, + config_class=PromptDepthAnythingConfig, + has_text_modality=False, + hidden_size=37, + common_properties=["patch_size"], + ) + + def test_config(self): + self.config_tester.run_common_tests() + + @unittest.skip(reason="Prompt Depth Anything with AutoBackbone does not have a base model and hence no input_embeddings") + def test_inputs_embeds(self): + pass + + def test_for_depth_estimation(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_for_depth_estimation(*config_and_inputs) + + @unittest.skip(reason="Prompt Depth Anything does not support training yet") + def test_training(self): + pass + + @unittest.skip(reason="Prompt Depth Anything does not support training yet") + def test_training_gradient_checkpointing(self): + pass + + @unittest.skip(reason="Prompt Depth Anything with AutoBackbone does not have a base model and hence no input_embeddings") + def test_model_get_set_embeddings(self): + pass + + @unittest.skip(reason="Prompt Depth Anything with AutoBackbone does not have a base model") + def test_save_load_fast_init_from_base(self): + pass + + @unittest.skip(reason="Prompt Depth Anything with AutoBackbone does not have a base model") + def test_save_load_fast_init_to_base(self): + pass + + @unittest.skip( + reason="This architecture seems to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124" + ) + def test_training_gradient_checkpointing_use_reentrant(self): + pass + + @unittest.skip( + reason="This architecture seems to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124" + ) + def test_training_gradient_checkpointing_use_reentrant_false(self): + pass + + @slow + def test_model_from_pretrained(self): + model_name = "depth-anything/promptda_vits_hf" + model = PromptDepthAnythingForDepthEstimation.from_pretrained(model_name) + self.assertIsNotNone(model) + + def test_backbone_selection(self): + def _validate_backbone_init(): + for model_class in self.all_model_classes: + model = model_class(config) + model.to(torch_device) + model.eval() + + self.assertEqual(len(model.backbone.out_indices), 2) + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + config.backbone = "resnet18" + config.use_pretrained_backbone = True + config.use_timm_backbone = True + config.backbone_config = None + config.backbone_kwargs = {"out_indices": (-2, -1)} + _validate_backbone_init() + + config.backbone = "facebook/dinov2-small" + config.use_pretrained_backbone = True + config.use_timm_backbone = False + config.backbone_config = None + config.backbone_kwargs = {"out_indices": [-2, -1]} + _validate_backbone_init() + + +def prepare_img(): + url = "https://github.com/DepthAnything/PromptDA/blob/main/assets/example_images/image.jpg?raw=true" + image = Image.open(requests.get(url, stream=True).raw) + return image + +@require_torch +@require_vision +@slow +class PromptDepthAnythingModelIntegrationTest(unittest.TestCase): + def test_inference(self): + image_processor = DPTImageProcessor.from_pretrained("depth-anything/promptda_vits_hf") + model = PromptDepthAnythingForDepthEstimation.from_pretrained("depth-anything/promptda_vits_hf").to(torch_device) + + image = prepare_img() + prompt_depth_url = "https://github.com/DepthAnything/PromptDA/blob/main/assets/example_images/arkit_depth.png?raw=true" + prompt_depth = Image.open(requests.get(prompt_depth_url, stream=True).raw) + prompt_depth = torch.tensor((np.asarray(prompt_depth) / 1000.0).astype(np.float32)) + prompt_depth = prompt_depth.unsqueeze(0).unsqueeze(0) + inputs = image_processor(images=image, return_tensors="pt", prompt_depth=prompt_depth).to(torch_device) + + with torch.no_grad(): + outputs = model(**inputs) + predicted_depth = outputs.predicted_depth + + expected_shape = torch.Size([1, 1, 756, 1008]) + self.assertEqual(predicted_depth.shape, expected_shape) + + expected_slice = torch.tensor( + [[[3.0100, 3.0016, 3.0219], [3.0046, 3.0137, 3.0275], [3.0083, 3.0191, 3.0292]]] + ).to(torch_device) + + self.assertTrue(torch.allclose(predicted_depth[0, 0, :3, :3], expected_slice, atol=1e-3)) + + def test_export(self): + for strict in [True, False]: + with self.subTest(strict=strict): + if not is_torch_greater_or_equal_than_2_4: + self.skipTest(reason="This test requires torch >= 2.4 to run.") + model = ( + PromptDepthAnythingForDepthEstimation.from_pretrained("depth-anything/promptda_vits_hf") + .to(torch_device) + .eval() + ) + image_processor = DPTImageProcessor.from_pretrained("depth-anything/promptda_vits_hf") + image = prepare_img() + inputs = image_processor(images=image, return_tensors="pt").to(torch_device) + + exported_program = torch.export.export( + model, + args=(inputs["pixel_values"],), + strict=strict, + ) + with torch.no_grad(): + eager_outputs = model(**inputs) + exported_outputs = exported_program.module().forward(inputs["pixel_values"]) + self.assertEqual(eager_outputs.predicted_depth.shape, exported_outputs.predicted_depth.shape) + self.assertTrue( + torch.allclose(eager_outputs.predicted_depth, exported_outputs.predicted_depth, atol=1e-4) + ) From dfa7d67b0edb396cf49bea7b949567590ba3c424 Mon Sep 17 00:00:00 2001 From: haotongl Date: Mon, 23 Dec 2024 21:55:28 +0800 Subject: [PATCH 03/58] update code style according transformers doc --- .../convert_prompt_depth_anything_to_hf.py | 64 ++++++++++--------- .../modular_prompt_depth_anything.py | 34 ++++++---- .../test_modeling_prompt_depth_anything.py | 28 +++++--- 3 files changed, 76 insertions(+), 50 deletions(-) diff --git a/src/transformers/models/prompt_depth_anything/convert_prompt_depth_anything_to_hf.py b/src/transformers/models/prompt_depth_anything/convert_prompt_depth_anything_to_hf.py index 9d9f87e3a492..08f3d525695c 100644 --- a/src/transformers/models/prompt_depth_anything/convert_prompt_depth_anything_to_hf.py +++ b/src/transformers/models/prompt_depth_anything/convert_prompt_depth_anything_to_hf.py @@ -12,19 +12,24 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Convert Depth Anything checkpoints from the original repository. URL: -https://github.com/LiheYoung/Depth-Anything""" +"""Convert Prompt Depth Anything checkpoints from the original repository. URL: +https://github.com/DepthAnything/PromptDA""" import argparse from pathlib import Path +import numpy as np import requests import torch from huggingface_hub import hf_hub_download from PIL import Image -import numpy as np -from transformers import PromptDepthAnythingConfig, PromptDepthAnythingForDepthEstimation, Dinov2Config, DPTImageProcessor +from transformers import ( + Dinov2Config, + DPTImageProcessor, + PromptDepthAnythingConfig, + PromptDepthAnythingForDepthEstimation, +) from transformers.utils import logging @@ -33,22 +38,22 @@ def get_dpt_config(model_name): - if "small" in model_name or 'vits' in model_name: - out_indices = [3, 6, 9, 12] + if "small" in model_name or "vits" in model_name: + out_indices = [3, 6, 9, 12] backbone_config = Dinov2Config.from_pretrained( "facebook/dinov2-small", out_indices=out_indices, apply_layernorm=True, reshape_hidden_states=False ) fusion_hidden_size = 64 neck_hidden_sizes = [48, 96, 192, 384] - elif "base" in model_name or 'vitb' in model_name: + elif "base" in model_name or "vitb" in model_name: out_indices = [3, 6, 9, 12] backbone_config = Dinov2Config.from_pretrained( "facebook/dinov2-base", out_indices=out_indices, apply_layernorm=True, reshape_hidden_states=False ) fusion_hidden_size = 128 neck_hidden_sizes = [96, 192, 384, 768] - elif "large" in model_name or 'vitl' in model_name: - out_indices = [5, 12, 18, 24] + elif "large" in model_name or "vitl" in model_name: + out_indices = [5, 12, 18, 24] backbone_config = Dinov2Config.from_pretrained( "facebook/dinov2-large", out_indices=out_indices, apply_layernorm=True, reshape_hidden_states=False ) @@ -184,20 +189,10 @@ def prepare_img(): name_to_checkpoint = { - "depth-anything-small": "pytorch_model.bin", - "depth-anything-base": "pytorch_model.bin", - "depth-anything-large": "pytorch_model.bin", - "depth-anything-v2-small": "depth_anything_v2_vits.pth", - "depth-anything-v2-base": "depth_anything_v2_vitb.pth", - "depth-anything-v2-large": "depth_anything_v2_vitl.pth", - "depth-anything-v2-metric-indoor-small": "depth_anything_v2_metric_hypersim_vits.pth", - "depth-anything-v2-metric-indoor-base": "depth_anything_v2_metric_hypersim_vitb.pth", - "depth-anything-v2-metric-indoor-large": "depth_anything_v2_metric_hypersim_vitl.pth", - "depth-anything-v2-metric-outdoor-small": "depth_anything_v2_metric_vkitti_vits.pth", - "depth-anything-v2-metric-outdoor-base": "depth_anything_v2_metric_vkitti_vitb.pth", - "depth-anything-v2-metric-outdoor-large": "depth_anything_v2_metric_vkitti_vitl.pth", # v2-giant pending - "promptda_vits": "model.ckpt" + "promptda_vits": "model.ckpt", + "promptda_vits_transparent": "model.ckpt", + "promptda_vitl": "model.ckpt", } @@ -211,7 +206,9 @@ def convert_dpt_checkpoint(model_name, pytorch_dump_folder_path, push_to_hub, ve config = get_dpt_config(model_name) model_name_to_repo = { - "promptda_vits": "depth-anything/promptda_vits" + "promptda_vits": "depth-anything/promptda_vits", + "promptda_vits_transparent": "depth-anything/promptda_vits_transparent", + "promptda_vitl": "depth-anything/promptda_vitl", } # load original state_dict @@ -222,8 +219,8 @@ def convert_dpt_checkpoint(model_name, pytorch_dump_folder_path, push_to_hub, ve filename=f"{filename}", ) - state_dict = torch.load(filepath, map_location="cpu")['state_dict'] - state_dict = {key[9:]:state_dict[key] for key in state_dict} + state_dict = torch.load(filepath, map_location="cpu")["state_dict"] + state_dict = {key[9:]: state_dict[key] for key in state_dict} # rename keys rename_keys = create_rename_keys(config) for src, dest in rename_keys: @@ -251,8 +248,9 @@ def convert_dpt_checkpoint(model_name, pytorch_dump_folder_path, push_to_hub, ve pixel_values = processor(image, return_tensors="pt").pixel_values - - prompt_depth_url = "https://github.com/DepthAnything/PromptDA/blob/main/assets/example_images/arkit_depth.png?raw=true" + prompt_depth_url = ( + "https://github.com/DepthAnything/PromptDA/blob/main/assets/example_images/arkit_depth.png?raw=true" + ) prompt_depth = Image.open(requests.get(prompt_depth_url, stream=True).raw) prompt_depth = torch.tensor((np.asarray(prompt_depth) / 1000.0).astype(np.float32)) prompt_depth = prompt_depth.unsqueeze(0).unsqueeze(0) @@ -268,14 +266,22 @@ def convert_dpt_checkpoint(model_name, pytorch_dump_folder_path, push_to_hub, ve # assert logits if verify_logits: expected_shape = torch.Size([1, 1, 756, 1008]) - if model_name == 'promptda_vits': + if model_name == "promptda_vits": expected_slice = torch.tensor( [[3.0100, 3.0016, 3.0219], [3.0046, 3.0137, 3.0275], [3.0083, 3.0191, 3.0292]] ) + elif model_name == "promptda_vits_transparent": + expected_slice = torch.tensor( + [[3.0058, 3.0397, 3.0460], [3.0314, 3.0393, 3.0504], [3.0326, 3.0465, 3.0545]] + ) + elif model_name == "promptda_vitl": + expected_slice = torch.tensor( + [[3.1336, 3.1358, 3.1363], [3.1368, 3.1267, 3.1414], [3.1397, 3.1385, 3.1448]] + ) else: raise ValueError("Not supported") assert predicted_depth.shape == torch.Size(expected_shape) - assert torch.allclose(predicted_depth[0, 0, :3, :3], expected_slice, atol=1e-3) # 1mm tolerance + assert torch.allclose(predicted_depth[0, 0, :3, :3], expected_slice, atol=5e-3) # 5mm tolerance print("Looks ok!") if pytorch_dump_folder_path is not None: diff --git a/src/transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py b/src/transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py index b2f75ea14492..c570dd4303bc 100644 --- a/src/transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py +++ b/src/transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py @@ -1,8 +1,15 @@ from typing import List, Optional, Tuple, Union + import torch import torch.nn as nn + from transformers.models.depth_anything.configuration_depth_anything import DepthAnythingConfig -from transformers.models.depth_anything.modeling_depth_anything import DepthAnythingFeatureFusionLayer, DepthAnythingForDepthEstimation, DepthAnythingNeck +from transformers.models.depth_anything.modeling_depth_anything import ( + DepthAnythingFeatureFusionLayer, + DepthAnythingForDepthEstimation, + DepthAnythingNeck, +) + from ...file_utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, @@ -10,6 +17,7 @@ ) from ...modeling_outputs import DepthEstimatorOutput + _CONFIG_FOR_DOC = "PromptDepthAnythingConfig" PROMPT_DEPTH_ANYTHING_START_DOCSTRING = r""" @@ -44,7 +52,6 @@ class PromptDepthAnythingConfig(DepthAnythingConfig): class PromptDepthAnythingResidualLayer(nn.Module): - def __init__(self, config): super().__init__() self.convolution1 = nn.Conv2d( @@ -76,7 +83,6 @@ def __init__(self, config): bias=True, ) - def forward(self, prompt_depth: torch.Tensor) -> torch.Tensor: residual = prompt_depth residual = self.convolution1(residual) @@ -86,6 +92,7 @@ def forward(self, prompt_depth: torch.Tensor) -> torch.Tensor: residual = self.convolution3(residual) return residual + class PromptDepthAnythingFeatureFusionLayer(DepthAnythingFeatureFusionLayer): def __init__(self, config): super().__init__(config) @@ -103,10 +110,7 @@ def forward(self, hidden_state, residual=None, size=None, prompt_depth=None): if prompt_depth is not None: prompt_depth = nn.functional.interpolate( - prompt_depth, - hidden_state.shape[2:], - mode='bilinear', - align_corners=False + prompt_depth, hidden_state.shape[2:], mode="bilinear", align_corners=False ) res = self.residual_layer_depth(prompt_depth) hidden_state = hidden_state + res @@ -157,7 +161,9 @@ def __init__(self, config): super().__init__(config) self.fusion_stage = PromptDepthAnythingFeatureFusionStage(config) - def forward(self, hidden_states: List[torch.Tensor], patch_height=None, patch_width=None, prompt_depth=None) -> List[torch.Tensor]: + def forward( + self, hidden_states: List[torch.Tensor], patch_height=None, patch_width=None, prompt_depth=None + ) -> List[torch.Tensor]: """ Args: hidden_states (`List[torch.FloatTensor]`, each of shape `(batch_size, sequence_length, hidden_size)` or `(batch_size, hidden_size, height, width)`): @@ -179,6 +185,7 @@ def forward(self, hidden_states: List[torch.Tensor], patch_height=None, patch_wi return output + @add_start_docstrings( """ Prompt Depth Anything Model with a depth estimation head on top (consisting of 3 convolutional layers) e.g. for KITTI, NYUv2. @@ -241,7 +248,7 @@ def forward( >>> # visualize the prediction >>> predicted_depth = post_processed_output[0]["predicted_depth"] - >>> depth = predicted_depth * 1000 + >>> depth = predicted_depth * 1000. >>> depth = depth.detach().cpu().numpy() >>> depth = Image.fromarray(depth.astype("uint16")) # mm ```""" @@ -265,10 +272,12 @@ def forward( patch_height = height // patch_size patch_width = width // patch_size - # normalize prompt depth B = len(prompt_depth) - depth_min, depth_max = torch.min(prompt_depth.reshape(B, -1), dim=1).values, torch.max(prompt_depth.reshape(B, -1), dim=1).values + depth_min, depth_max = ( + torch.min(prompt_depth.reshape(B, -1), dim=1).values, + torch.max(prompt_depth.reshape(B, -1), dim=1).values, + ) invalid_mask = (depth_max - depth_min) <= 0 if invalid_mask.any(): depth_max[invalid_mask] = depth_min[invalid_mask] + 1e-6 @@ -276,7 +285,6 @@ def forward( prompt_depth = (prompt_depth - depth_min) / (depth_max - depth_min) # normalize done - hidden_states = self.neck(hidden_states, patch_height, patch_width, prompt_depth=prompt_depth) predicted_depth = self.head(hidden_states, patch_height, patch_width) @@ -296,4 +304,4 @@ def forward( predicted_depth=predicted_depth, hidden_states=outputs.hidden_states if output_hidden_states else None, attentions=outputs.attentions, - ) \ No newline at end of file + ) diff --git a/tests/models/prompt_depth_anything/test_modeling_prompt_depth_anything.py b/tests/models/prompt_depth_anything/test_modeling_prompt_depth_anything.py index f43a078edd66..2ce7db6b208a 100644 --- a/tests/models/prompt_depth_anything/test_modeling_prompt_depth_anything.py +++ b/tests/models/prompt_depth_anything/test_modeling_prompt_depth_anything.py @@ -16,7 +16,10 @@ import unittest -from transformers import PromptDepthAnythingConfig, Dinov2Config +import numpy as np +import requests + +from transformers import Dinov2Config, PromptDepthAnythingConfig from transformers.file_utils import is_torch_available, is_vision_available from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_4 from transformers.testing_utils import require_torch, require_vision, slow, torch_device @@ -25,8 +28,6 @@ from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor from ...test_pipeline_mixin import PipelineTesterMixin -import requests -import numpy as np if is_torch_available(): import torch @@ -137,7 +138,9 @@ class PromptDepthAnythingModelTest(ModelTesterMixin, PipelineTesterMixin, unitte """ all_model_classes = (PromptDepthAnythingForDepthEstimation,) if is_torch_available() else () - pipeline_model_mapping = {"depth-estimation": PromptDepthAnythingForDepthEstimation} if is_torch_available() else {} + pipeline_model_mapping = ( + {"depth-estimation": PromptDepthAnythingForDepthEstimation} if is_torch_available() else {} + ) test_pruning = False test_resize_embeddings = False @@ -156,7 +159,9 @@ def setUp(self): def test_config(self): self.config_tester.run_common_tests() - @unittest.skip(reason="Prompt Depth Anything with AutoBackbone does not have a base model and hence no input_embeddings") + @unittest.skip( + reason="Prompt Depth Anything with AutoBackbone does not have a base model and hence no input_embeddings" + ) def test_inputs_embeds(self): pass @@ -172,7 +177,9 @@ def test_training(self): def test_training_gradient_checkpointing(self): pass - @unittest.skip(reason="Prompt Depth Anything with AutoBackbone does not have a base model and hence no input_embeddings") + @unittest.skip( + reason="Prompt Depth Anything with AutoBackbone does not have a base model and hence no input_embeddings" + ) def test_model_get_set_embeddings(self): pass @@ -233,16 +240,21 @@ def prepare_img(): image = Image.open(requests.get(url, stream=True).raw) return image + @require_torch @require_vision @slow class PromptDepthAnythingModelIntegrationTest(unittest.TestCase): def test_inference(self): image_processor = DPTImageProcessor.from_pretrained("depth-anything/promptda_vits_hf") - model = PromptDepthAnythingForDepthEstimation.from_pretrained("depth-anything/promptda_vits_hf").to(torch_device) + model = PromptDepthAnythingForDepthEstimation.from_pretrained("depth-anything/promptda_vits_hf").to( + torch_device + ) image = prepare_img() - prompt_depth_url = "https://github.com/DepthAnything/PromptDA/blob/main/assets/example_images/arkit_depth.png?raw=true" + prompt_depth_url = ( + "https://github.com/DepthAnything/PromptDA/blob/main/assets/example_images/arkit_depth.png?raw=true" + ) prompt_depth = Image.open(requests.get(prompt_depth_url, stream=True).raw) prompt_depth = torch.tensor((np.asarray(prompt_depth) / 1000.0).astype(np.float32)) prompt_depth = prompt_depth.unsqueeze(0).unsqueeze(0) From 8509440bcf1a075659e7ecc48474687aa97be43f Mon Sep 17 00:00:00 2001 From: haotongl Date: Mon, 23 Dec 2024 22:01:33 +0800 Subject: [PATCH 04/58] update code style: import order issue is fixed by custom_init_isort --- src/transformers/__init__.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index ba434ff3247c..afbbbf90820e 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -689,13 +689,13 @@ "models.plbart": ["PLBartConfig"], "models.poolformer": ["PoolFormerConfig"], "models.pop2piano": ["Pop2PianoConfig"], + "models.prompt_depth_anything": [ + "PromptDepthAnythingConfig" + ], "models.prophetnet": [ "ProphetNetConfig", "ProphetNetTokenizer", ], - "models.prompt_depth_anything": [ - "PromptDepthAnythingConfig", - ], "models.pvt": ["PvtConfig"], "models.pvt_v2": ["PvtV2Config"], "models.qwen2": [ @@ -3184,6 +3184,12 @@ "Pop2PianoPreTrainedModel", ] ) + _import_structure["models.prompt_depth_anything"].extend( + [ + "PromptDepthAnythingForDepthEstimation", + "PromptDepthAnythingPreTrainedModel", + ] + ) _import_structure["models.prophetnet"].extend( [ "ProphetNetDecoder", @@ -3194,12 +3200,6 @@ "ProphetNetPreTrainedModel", ] ) - _import_structure["models.prompt_depth_anything"].extend( - [ - "PromptDepthAnythingForDepthEstimation", - "PromptDepthAnythingPreTrainedModel", - ] - ) _import_structure["models.pvt"].extend( [ "PvtForImageClassification", From 2fa72ef78a964960961cd26b753ae34d26a1b32f Mon Sep 17 00:00:00 2001 From: haotongl Date: Mon, 23 Dec 2024 23:03:45 +0800 Subject: [PATCH 05/58] fix depth shape from B,1,H,W to B,H,W which is as the same as Depth Anything --- .../en/model_doc/prompt_depth_anything.md | 4 +-- src/transformers/__init__.py | 4 +-- .../convert_prompt_depth_anything_to_hf.py | 4 +-- .../modeling_prompt_depth_anything.py | 7 ++-- .../modular_prompt_depth_anything.py | 5 +-- .../test_modeling_prompt_depth_anything.py | 34 ++++++++++--------- 6 files changed, 30 insertions(+), 28 deletions(-) diff --git a/docs/source/en/model_doc/prompt_depth_anything.md b/docs/source/en/model_doc/prompt_depth_anything.md index bc7b75133c8e..1a60fe6242ad 100644 --- a/docs/source/en/model_doc/prompt_depth_anything.md +++ b/docs/source/en/model_doc/prompt_depth_anything.md @@ -62,10 +62,10 @@ The transformers library allows you to use the model with just a few lines of co >>> prompt_depth = prompt_depth.unsqueeze(0).unsqueeze(0) >>> # prepare image for the model ->>> inputs = image_processor(images=image, return_tensors="pt", prompt_depth=prompt_depth) +>>> inputs = image_processor(images=image, return_tensors="pt") >>> with torch.no_grad(): -... outputs = model(**inputs) +... outputs = model(pixel_values=inputs.pixel_values, prompt_depth=prompt_depth) >>> # interpolate to original size >>> post_processed_output = image_processor.post_process_depth_estimation( diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index afbbbf90820e..1bd6db463b3f 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -689,9 +689,7 @@ "models.plbart": ["PLBartConfig"], "models.poolformer": ["PoolFormerConfig"], "models.pop2piano": ["Pop2PianoConfig"], - "models.prompt_depth_anything": [ - "PromptDepthAnythingConfig" - ], + "models.prompt_depth_anything": ["PromptDepthAnythingConfig"], "models.prophetnet": [ "ProphetNetConfig", "ProphetNetTokenizer", diff --git a/src/transformers/models/prompt_depth_anything/convert_prompt_depth_anything_to_hf.py b/src/transformers/models/prompt_depth_anything/convert_prompt_depth_anything_to_hf.py index 08f3d525695c..f300e7ed3a08 100644 --- a/src/transformers/models/prompt_depth_anything/convert_prompt_depth_anything_to_hf.py +++ b/src/transformers/models/prompt_depth_anything/convert_prompt_depth_anything_to_hf.py @@ -265,7 +265,7 @@ def convert_dpt_checkpoint(model_name, pytorch_dump_folder_path, push_to_hub, ve # assert logits if verify_logits: - expected_shape = torch.Size([1, 1, 756, 1008]) + expected_shape = torch.Size([1, 756, 1008]) if model_name == "promptda_vits": expected_slice = torch.tensor( [[3.0100, 3.0016, 3.0219], [3.0046, 3.0137, 3.0275], [3.0083, 3.0191, 3.0292]] @@ -281,7 +281,7 @@ def convert_dpt_checkpoint(model_name, pytorch_dump_folder_path, push_to_hub, ve else: raise ValueError("Not supported") assert predicted_depth.shape == torch.Size(expected_shape) - assert torch.allclose(predicted_depth[0, 0, :3, :3], expected_slice, atol=5e-3) # 5mm tolerance + assert torch.allclose(predicted_depth[0, :3, :3], expected_slice, atol=5e-3) # 5mm tolerance print("Looks ok!") if pytorch_dump_folder_path is not None: diff --git a/src/transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py b/src/transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py index d19b42a28372..0f0051b6d009 100644 --- a/src/transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py +++ b/src/transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py @@ -451,10 +451,10 @@ def forward( >>> prompt_depth = prompt_depth.unsqueeze(0).unsqueeze(0) >>> # prepare image for the model - >>> inputs = image_processor(images=image, return_tensors="pt", prompt_depth=prompt_depth) + >>> inputs = image_processor(images=image, return_tensors="pt") >>> with torch.no_grad(): - ... outputs = model(**inputs) + ... outputs = model(pixel_values=inputs.pixel_values, prompt_depth=prompt_depth) >>> # interpolate to original size >>> post_processed_output = image_processor.post_process_depth_estimation( @@ -464,7 +464,7 @@ def forward( >>> # visualize the prediction >>> predicted_depth = post_processed_output[0]["predicted_depth"] - >>> depth = predicted_depth * 1000 + >>> depth = predicted_depth * 1000. >>> depth = depth.detach().cpu().numpy() >>> depth = Image.fromarray(depth.astype("uint16")) # mm ```""" @@ -505,6 +505,7 @@ def forward( predicted_depth = self.head(hidden_states, patch_height, patch_width) # denormalize predicted depth + depth_min, depth_max = depth_min.squeeze(1), depth_max.squeeze(1) predicted_depth = predicted_depth * (depth_max - depth_min) + depth_min # denormalize done diff --git a/src/transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py b/src/transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py index c570dd4303bc..825a37ba4acd 100644 --- a/src/transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py +++ b/src/transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py @@ -235,10 +235,10 @@ def forward( >>> prompt_depth = prompt_depth.unsqueeze(0).unsqueeze(0) >>> # prepare image for the model - >>> inputs = image_processor(images=image, return_tensors="pt", prompt_depth=prompt_depth) + >>> inputs = image_processor(images=image, return_tensors="pt") >>> with torch.no_grad(): - ... outputs = model(**inputs) + ... outputs = model(pixel_values=inputs.pixel_values, prompt_depth=prompt_depth) >>> # interpolate to original size >>> post_processed_output = image_processor.post_process_depth_estimation( @@ -289,6 +289,7 @@ def forward( predicted_depth = self.head(hidden_states, patch_height, patch_width) # denormalize predicted depth + depth_min, depth_max = depth_min.squeeze(1), depth_max.squeeze(1) predicted_depth = predicted_depth * (depth_max - depth_min) + depth_min # denormalize done diff --git a/tests/models/prompt_depth_anything/test_modeling_prompt_depth_anything.py b/tests/models/prompt_depth_anything/test_modeling_prompt_depth_anything.py index 2ce7db6b208a..b1fb81c166d5 100644 --- a/tests/models/prompt_depth_anything/test_modeling_prompt_depth_anything.py +++ b/tests/models/prompt_depth_anything/test_modeling_prompt_depth_anything.py @@ -88,9 +88,11 @@ def prepare_config_and_inputs(self): if self.use_labels: labels = ids_tensor([self.batch_size, self.image_size, self.image_size], self.num_labels) + prompt_depth = floats_tensor([self.batch_size, 1, self.image_size // 4, self.image_size // 4]) + config = self.get_config() - return config, pixel_values, labels + return config, pixel_values, labels, prompt_depth def get_config(self): return PromptDepthAnythingConfig( @@ -115,18 +117,18 @@ def get_backbone_config(self): reshape_hidden_states=self.reshape_hidden_states, ) - def create_and_check_for_depth_estimation(self, config, pixel_values, labels): + def create_and_check_for_depth_estimation(self, config, pixel_values, labels, prompt_depth): config.num_labels = self.num_labels model = PromptDepthAnythingForDepthEstimation(config) model.to(torch_device) model.eval() - result = model(pixel_values) + result = model(pixel_values, prompt_depth=prompt_depth) self.parent.assertEqual(result.predicted_depth.shape, (self.batch_size, self.image_size, self.image_size)) def prepare_config_and_inputs_for_common(self): config_and_inputs = self.prepare_config_and_inputs() - config, pixel_values, labels = config_and_inputs - inputs_dict = {"pixel_values": pixel_values} + config, pixel_values, labels, prompt_depth = config_and_inputs + inputs_dict = {"pixel_values": pixel_values, "prompt_depth": prompt_depth} return config, inputs_dict @@ -220,12 +222,12 @@ def _validate_backbone_init(): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - config.backbone = "resnet18" - config.use_pretrained_backbone = True - config.use_timm_backbone = True - config.backbone_config = None - config.backbone_kwargs = {"out_indices": (-2, -1)} - _validate_backbone_init() + # config.backbone = "resnet18" + # config.use_pretrained_backbone = True + # config.use_timm_backbone = True + # config.backbone_config = None + # config.backbone_kwargs = {"out_indices": (-2, -1)} + # _validate_backbone_init() config.backbone = "facebook/dinov2-small" config.use_pretrained_backbone = True @@ -258,20 +260,20 @@ def test_inference(self): prompt_depth = Image.open(requests.get(prompt_depth_url, stream=True).raw) prompt_depth = torch.tensor((np.asarray(prompt_depth) / 1000.0).astype(np.float32)) prompt_depth = prompt_depth.unsqueeze(0).unsqueeze(0) - inputs = image_processor(images=image, return_tensors="pt", prompt_depth=prompt_depth).to(torch_device) + inputs = image_processor(images=image, return_tensors="pt").to(torch_device) with torch.no_grad(): - outputs = model(**inputs) + outputs = model(pixel_values=inputs.pixel_values, prompt_depth=prompt_depth) predicted_depth = outputs.predicted_depth - expected_shape = torch.Size([1, 1, 756, 1008]) + expected_shape = torch.Size([1, 756, 1008]) self.assertEqual(predicted_depth.shape, expected_shape) expected_slice = torch.tensor( - [[[3.0100, 3.0016, 3.0219], [3.0046, 3.0137, 3.0275], [3.0083, 3.0191, 3.0292]]] + [[3.0100, 3.0016, 3.0219], [3.0046, 3.0137, 3.0275], [3.0083, 3.0191, 3.0292]] ).to(torch_device) - self.assertTrue(torch.allclose(predicted_depth[0, 0, :3, :3], expected_slice, atol=1e-3)) + self.assertTrue(torch.allclose(predicted_depth[0, :3, :3], expected_slice, atol=1e-3)) def test_export(self): for strict in [True, False]: From d13a55f91d989a03523c7bf4c464f378ab94f01f Mon Sep 17 00:00:00 2001 From: haotongl Date: Tue, 24 Dec 2024 20:37:05 +0800 Subject: [PATCH 06/58] move prompt depth anything to vision models in _toctree.yml --- docs/source/en/_toctree.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index a72ee3b17dde..211193979db6 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -550,8 +550,6 @@ title: PhoBERT - local: model_doc/plbart title: PLBart - - local: model_doc/prompt_depth_anything - title: Prompt Depth Anything - local: model_doc/prophetnet title: ProphetNet - local: model_doc/qdqbert @@ -691,6 +689,8 @@ title: NAT - local: model_doc/poolformer title: PoolFormer + - local: model_doc/prompt_depth_anything + title: Prompt Depth Anything - local: model_doc/pvt title: Pyramid Vision Transformer (PVT) - local: model_doc/pvt_v2 From 6cd1bbf2078a29c60b2eedb47429e4d157d40b74 Mon Sep 17 00:00:00 2001 From: haotongl Date: Tue, 24 Dec 2024 20:39:57 +0800 Subject: [PATCH 07/58] update backbone test; there is no need for resnet18 backbone test --- .../test_modeling_prompt_depth_anything.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/tests/models/prompt_depth_anything/test_modeling_prompt_depth_anything.py b/tests/models/prompt_depth_anything/test_modeling_prompt_depth_anything.py index b1fb81c166d5..3bea42d04090 100644 --- a/tests/models/prompt_depth_anything/test_modeling_prompt_depth_anything.py +++ b/tests/models/prompt_depth_anything/test_modeling_prompt_depth_anything.py @@ -222,13 +222,6 @@ def _validate_backbone_init(): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - # config.backbone = "resnet18" - # config.use_pretrained_backbone = True - # config.use_timm_backbone = True - # config.backbone_config = None - # config.backbone_kwargs = {"out_indices": (-2, -1)} - # _validate_backbone_init() - config.backbone = "facebook/dinov2-small" config.use_pretrained_backbone = True config.use_timm_backbone = False From 76299f43c3d4256adee9dd920d91de3a76f2a874 Mon Sep 17 00:00:00 2001 From: haotongl Date: Tue, 24 Dec 2024 21:38:51 +0800 Subject: [PATCH 08/58] update init file & pass RUN_SLOW tests --- .../models/prompt_depth_anything/__init__.py | 37 ++------ .../configuration_prompt_depth_anything.py | 3 + .../modeling_prompt_depth_anything.py | 86 ++++++++++--------- .../modular_prompt_depth_anything.py | 69 +++++++++++---- 4 files changed, 108 insertions(+), 87 deletions(-) diff --git a/src/transformers/models/prompt_depth_anything/__init__.py b/src/transformers/models/prompt_depth_anything/__init__.py index 5674d28175c5..dbbfe38f8f84 100644 --- a/src/transformers/models/prompt_depth_anything/__init__.py +++ b/src/transformers/models/prompt_depth_anything/__init__.py @@ -13,40 +13,15 @@ # limitations under the License. from typing import TYPE_CHECKING -from ...file_utils import _LazyModule, is_torch_available -from ...utils import OptionalDependencyNotAvailable - - -_import_structure = {"configuration_prompt_depth_anything": ["PromptDepthAnythingConfig"]} - -try: - if not is_torch_available(): - raise OptionalDependencyNotAvailable() -except OptionalDependencyNotAvailable: - pass -else: - _import_structure["modeling_prompt_depth_anything"] = [ - "PromptDepthAnythingForDepthEstimation", - "PromptDepthAnythingPreTrainedModel", - ] +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure if TYPE_CHECKING: - from .configuration_prompt_depth_anything import PromptDepthAnythingConfig - - try: - if not is_torch_available(): - raise OptionalDependencyNotAvailable() - except OptionalDependencyNotAvailable: - pass - else: - from .modeling_prompt_depth_anything import ( - PromptDepthAnythingForDepthEstimation, - PromptDepthAnythingPreTrainedModel, - ) - - + from .configuration_prompt_depth_anything import * + from .modeling_prompt_depth_anything import * else: import sys - sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__) + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/src/transformers/models/prompt_depth_anything/configuration_prompt_depth_anything.py b/src/transformers/models/prompt_depth_anything/configuration_prompt_depth_anything.py index 15e25015d2a3..4852afb9c84f 100644 --- a/src/transformers/models/prompt_depth_anything/configuration_prompt_depth_anything.py +++ b/src/transformers/models/prompt_depth_anything/configuration_prompt_depth_anything.py @@ -154,3 +154,6 @@ def to_dict(self): output["model_type"] = self.__class__.model_type return output + + +__all__ = ["PromptDepthAnythingConfig"] diff --git a/src/transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py b/src/transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py index 0f0051b6d009..cf1c1fca04bb 100644 --- a/src/transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py +++ b/src/transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py @@ -179,6 +179,31 @@ def forward(self, hidden_states, size=None, prompt_depth=None): return fused_hidden_states +# Copied from transformers.models.dpt.modeling_dpt.DPTPreTrainedModel with DPT->PromptDepthAnything,dpt->prompt_depth_anything +class PromptDepthAnythingPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = PromptDepthAnythingConfig + base_model_prefix = "prompt_depth_anything" + main_input_name = "pixel_values" + supports_gradient_checkpointing = True + + def _init_weights(self, module): + """Initialize the weights""" + if isinstance(module, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + class PromptDepthAnythingReassembleLayer(nn.Module): def __init__(self, config, channels, factor): super().__init__() @@ -291,30 +316,6 @@ def forward( return output -class PromptDepthAnythingPreTrainedModel(PreTrainedModel): - """ - An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained - models. - """ - - config_class = PromptDepthAnythingConfig - base_model_prefix = "prompt_depth_anything" - main_input_name = "pixel_values" - supports_gradient_checkpointing = True - - def _init_weights(self, module): - """Initialize the weights""" - if isinstance(module, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)): - # Slightly different from the TF version which uses truncated_normal for initialization - # cf https://github.com/pytorch/pytorch/pull/5617 - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) - - class PromptDepthAnythingDepthEstimationHead(nn.Module): """ Output head consisting of 3 convolutional layers. It progressively halves the feature dimension and upsamples @@ -488,26 +489,28 @@ def forward( patch_height = height // patch_size patch_width = width // patch_size - # normalize prompt depth - B = len(prompt_depth) - depth_min, depth_max = ( - torch.min(prompt_depth.reshape(B, -1), dim=1).values, - torch.max(prompt_depth.reshape(B, -1), dim=1).values, - ) - invalid_mask = (depth_max - depth_min) <= 0 - if invalid_mask.any(): - depth_max[invalid_mask] = depth_min[invalid_mask] + 1e-6 - depth_min, depth_max = depth_min.view(B, 1, 1, 1), depth_max.view(B, 1, 1, 1) - prompt_depth = (prompt_depth - depth_min) / (depth_max - depth_min) - # normalize done + if prompt_depth is not None: + # normalize prompt depth + B = len(prompt_depth) + depth_min, depth_max = ( + torch.min(prompt_depth.reshape(B, -1), dim=1).values, + torch.max(prompt_depth.reshape(B, -1), dim=1).values, + ) + invalid_mask = (depth_max - depth_min) <= 0 + if invalid_mask.any(): + depth_max[invalid_mask] = depth_min[invalid_mask] + 1e-6 + depth_min, depth_max = depth_min.view(B, 1, 1, 1), depth_max.view(B, 1, 1, 1) + prompt_depth = (prompt_depth - depth_min) / (depth_max - depth_min) + # normalize done hidden_states = self.neck(hidden_states, patch_height, patch_width, prompt_depth=prompt_depth) predicted_depth = self.head(hidden_states, patch_height, patch_width) - # denormalize predicted depth - depth_min, depth_max = depth_min.squeeze(1), depth_max.squeeze(1) - predicted_depth = predicted_depth * (depth_max - depth_min) + depth_min - # denormalize done + if prompt_depth is not None: + # denormalize predicted depth + depth_min, depth_max = depth_min.squeeze(1), depth_max.squeeze(1) + predicted_depth = predicted_depth * (depth_max - depth_min) + depth_min + # denormalize done if not return_dict: if output_hidden_states: @@ -522,3 +525,6 @@ def forward( hidden_states=outputs.hidden_states if output_hidden_states else None, attentions=outputs.attentions, ) + + +__all__ = ["PromptDepthAnythingForDepthEstimation", "PromptDepthAnythingPreTrainedModel"] diff --git a/src/transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py b/src/transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py index 825a37ba4acd..847fa35075f0 100644 --- a/src/transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py +++ b/src/transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py @@ -16,6 +16,7 @@ replace_return_docstrings, ) from ...modeling_outputs import DepthEstimatorOutput +from ...modeling_utils import PreTrainedModel _CONFIG_FOR_DOC = "PromptDepthAnythingConfig" @@ -42,6 +43,8 @@ output_hidden_states (`bool`, *optional*): Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for more detail. + prompt_depth (`torch.FloatTensor` of shape `(batch_size, 1, height, width)`, *optional*): + Prompt depth. return_dict (`bool`, *optional*): Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple. """ @@ -156,6 +159,31 @@ def forward(self, hidden_states, size=None, prompt_depth=None): return fused_hidden_states +# Copied from transformers.models.dpt.modeling_dpt.DPTPreTrainedModel with DPT->PromptDepthAnything,dpt->prompt_depth_anything +class PromptDepthAnythingPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = PromptDepthAnythingConfig + base_model_prefix = "prompt_depth_anything" + main_input_name = "pixel_values" + supports_gradient_checkpointing = True + + def _init_weights(self, module): + """Initialize the weights""" + if isinstance(module, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + class PromptDepthAnythingNeck(DepthAnythingNeck): def __init__(self, config): super().__init__(config) @@ -272,26 +300,28 @@ def forward( patch_height = height // patch_size patch_width = width // patch_size - # normalize prompt depth - B = len(prompt_depth) - depth_min, depth_max = ( - torch.min(prompt_depth.reshape(B, -1), dim=1).values, - torch.max(prompt_depth.reshape(B, -1), dim=1).values, - ) - invalid_mask = (depth_max - depth_min) <= 0 - if invalid_mask.any(): - depth_max[invalid_mask] = depth_min[invalid_mask] + 1e-6 - depth_min, depth_max = depth_min.view(B, 1, 1, 1), depth_max.view(B, 1, 1, 1) - prompt_depth = (prompt_depth - depth_min) / (depth_max - depth_min) - # normalize done + if prompt_depth is not None: + # normalize prompt depth + B = len(prompt_depth) + depth_min, depth_max = ( + torch.min(prompt_depth.reshape(B, -1), dim=1).values, + torch.max(prompt_depth.reshape(B, -1), dim=1).values, + ) + invalid_mask = (depth_max - depth_min) <= 0 + if invalid_mask.any(): + depth_max[invalid_mask] = depth_min[invalid_mask] + 1e-6 + depth_min, depth_max = depth_min.view(B, 1, 1, 1), depth_max.view(B, 1, 1, 1) + prompt_depth = (prompt_depth - depth_min) / (depth_max - depth_min) + # normalize done hidden_states = self.neck(hidden_states, patch_height, patch_width, prompt_depth=prompt_depth) predicted_depth = self.head(hidden_states, patch_height, patch_width) - # denormalize predicted depth - depth_min, depth_max = depth_min.squeeze(1), depth_max.squeeze(1) - predicted_depth = predicted_depth * (depth_max - depth_min) + depth_min - # denormalize done + if prompt_depth is not None: + # denormalize predicted depth + depth_min, depth_max = depth_min.squeeze(1), depth_max.squeeze(1) + predicted_depth = predicted_depth * (depth_max - depth_min) + depth_min + # denormalize done if not return_dict: if output_hidden_states: @@ -306,3 +336,10 @@ def forward( hidden_states=outputs.hidden_states if output_hidden_states else None, attentions=outputs.attentions, ) + + +__all__ = [ + "PromptDepthAnythingConfig", + "PromptDepthAnythingForDepthEstimation", + "PromptDepthAnythingPreTrainedModel", +] From 2315dd154abb1f75b6f11c098cd1e30601331842 Mon Sep 17 00:00:00 2001 From: Haotong LIN Date: Wed, 25 Dec 2024 14:38:43 +0800 Subject: [PATCH 09/58] update len(prompt_depth) to prompt_depth.shape[0] Co-authored-by: Joshua Lochner --- .../prompt_depth_anything/modeling_prompt_depth_anything.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py b/src/transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py index cf1c1fca04bb..6bc38669304c 100644 --- a/src/transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py +++ b/src/transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py @@ -491,7 +491,7 @@ def forward( if prompt_depth is not None: # normalize prompt depth - B = len(prompt_depth) + B = prompt_depth.shape[0] depth_min, depth_max = ( torch.min(prompt_depth.reshape(B, -1), dim=1).values, torch.max(prompt_depth.reshape(B, -1), dim=1).values, From c423e91e058581f3500d0a1d61c28700ca9894ef Mon Sep 17 00:00:00 2001 From: haotongl Date: Wed, 25 Dec 2024 15:29:54 +0800 Subject: [PATCH 10/58] fix torch_int/model_doc --- .../en/model_doc/prompt_depth_anything.md | 6 - .../convert_prompt_depth_anything_to_hf.py | 2 +- .../modeling_prompt_depth_anything.py | 106 ++++++++++-------- .../modular_prompt_depth_anything.py | 70 +++++++++++- 4 files changed, 128 insertions(+), 56 deletions(-) diff --git a/docs/source/en/model_doc/prompt_depth_anything.md b/docs/source/en/model_doc/prompt_depth_anything.md index 1a60fe6242ad..202e9d1d7b10 100644 --- a/docs/source/en/model_doc/prompt_depth_anything.md +++ b/docs/source/en/model_doc/prompt_depth_anything.md @@ -20,12 +20,6 @@ rendered properly in your Markdown viewer. The Prompt Depth Anything model was introduced in [Prompting Depth Anything for 4K Resolution Accurate Metric Depth Estimation](https://promptda.github.io/) by Haotong Lin, Sida Peng, Jingxiao Chen, Songyou Peng, Jiaming Sun, Minghuan Liu, Hujun Bao, Jiashi Feng, Xiaowei Zhou, Bingyi Kang. - The abstract from the paper is as follows: diff --git a/src/transformers/models/prompt_depth_anything/convert_prompt_depth_anything_to_hf.py b/src/transformers/models/prompt_depth_anything/convert_prompt_depth_anything_to_hf.py index f300e7ed3a08..3594d6bc6e91 100644 --- a/src/transformers/models/prompt_depth_anything/convert_prompt_depth_anything_to_hf.py +++ b/src/transformers/models/prompt_depth_anything/convert_prompt_depth_anything_to_hf.py @@ -261,7 +261,7 @@ def convert_dpt_checkpoint(model_name, pytorch_dump_folder_path, push_to_hub, ve predicted_depth = outputs.predicted_depth print("Shape of predicted depth:", predicted_depth.shape) - print("First values:", predicted_depth[0, 0, :3, :3]) + print("First values:", predicted_depth[0, :3, :3]) # assert logits if verify_logits: diff --git a/src/transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py b/src/transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py index 6bc38669304c..291cde8ff8f5 100644 --- a/src/transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py +++ b/src/transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py @@ -9,6 +9,8 @@ import torch import torch.nn as nn +from transformers.utils.generic import torch_int + from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward, replace_return_docstrings from ...modeling_outputs import DepthEstimatorOutput from ...modeling_utils import PreTrainedModel @@ -179,6 +181,52 @@ def forward(self, hidden_states, size=None, prompt_depth=None): return fused_hidden_states +class PromptDepthAnythingDepthEstimationHead(nn.Module): + """ + Output head consisting of 3 convolutional layers. It progressively halves the feature dimension and upsamples + the predictions to the input resolution after the first convolutional layer (details can be found in the DPT paper's + supplementary material). The final activation function is either ReLU or Sigmoid, depending on the depth estimation + type (relative or metric). For metric depth estimation, the output is scaled by the maximum depth used during pretraining. + """ + + def __init__(self, config): + super().__init__() + + self.head_in_index = config.head_in_index + self.patch_size = config.patch_size + + features = config.fusion_hidden_size + self.conv1 = nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1) + self.conv2 = nn.Conv2d(features // 2, config.head_hidden_size, kernel_size=3, stride=1, padding=1) + self.activation1 = nn.ReLU() + self.conv3 = nn.Conv2d(config.head_hidden_size, 1, kernel_size=1, stride=1, padding=0) + if config.depth_estimation_type == "relative": + self.activation2 = nn.ReLU() + elif config.depth_estimation_type == "metric": + self.activation2 = nn.Sigmoid() + else: + raise ValueError(f"Unknown depth estimation type: {config.depth_estimation_type}") + self.max_depth = config.max_depth + + def forward(self, hidden_states: List[torch.Tensor], patch_height, patch_width) -> torch.Tensor: + hidden_states = hidden_states[self.head_in_index] + + predicted_depth = self.conv1(hidden_states) + predicted_depth = nn.functional.interpolate( + predicted_depth, + (torch_int(patch_height * self.patch_size), torch_int(patch_width * self.patch_size)), + mode="bilinear", + align_corners=True, + ) + predicted_depth = self.conv2(predicted_depth) + predicted_depth = self.activation1(predicted_depth) + predicted_depth = self.conv3(predicted_depth) + predicted_depth = self.activation2(predicted_depth) * self.max_depth + predicted_depth = predicted_depth.squeeze(dim=1) # shape (batch_size, height, width) + + return predicted_depth + + # Copied from transformers.models.dpt.modeling_dpt.DPTPreTrainedModel with DPT->PromptDepthAnything,dpt->prompt_depth_anything class PromptDepthAnythingPreTrainedModel(PreTrainedModel): """ @@ -218,6 +266,15 @@ def __init__(self, config, channels, factor): # so should downsample self.resize = nn.Conv2d(channels, channels, kernel_size=3, stride=int(1 / factor), padding=1) + # up/down sampling depending on factor + if factor > 1: + self.resize = nn.ConvTranspose2d(channels, channels, kernel_size=factor, stride=factor, padding=0) + elif factor == 1: + self.resize = nn.Identity() + elif factor < 1: + # so should downsample + self.resize = nn.Conv2d(channels, channels, kernel_size=3, stride=torch_int(1 / factor), padding=1) + def forward(self, hidden_state): hidden_state = self.projection(hidden_state) hidden_state = self.resize(hidden_state) @@ -236,7 +293,7 @@ class PromptDepthAnythingReassembleStage(nn.Module): 3. Resizing the spatial dimensions (height, width). Args: - config (`[PromptDepthAnythingConfig]`): + config (`[DepthAnythingConfig]`): Model configuration class defining the model architecture. """ @@ -283,7 +340,6 @@ class PromptDepthAnythingNeck(nn.Module): def __init__(self, config): super().__init__() self.config = config - self.reassemble_stage = PromptDepthAnythingReassembleStage(config) self.convs = nn.ModuleList() @@ -316,52 +372,6 @@ def forward( return output -class PromptDepthAnythingDepthEstimationHead(nn.Module): - """ - Output head consisting of 3 convolutional layers. It progressively halves the feature dimension and upsamples - the predictions to the input resolution after the first convolutional layer (details can be found in the DPT paper's - supplementary material). The final activation function is either ReLU or Sigmoid, depending on the depth estimation - type (relative or metric). For metric depth estimation, the output is scaled by the maximum depth used during pretraining. - """ - - def __init__(self, config): - super().__init__() - - self.head_in_index = config.head_in_index - self.patch_size = config.patch_size - - features = config.fusion_hidden_size - self.conv1 = nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1) - self.conv2 = nn.Conv2d(features // 2, config.head_hidden_size, kernel_size=3, stride=1, padding=1) - self.activation1 = nn.ReLU() - self.conv3 = nn.Conv2d(config.head_hidden_size, 1, kernel_size=1, stride=1, padding=0) - if config.depth_estimation_type == "relative": - self.activation2 = nn.ReLU() - elif config.depth_estimation_type == "metric": - self.activation2 = nn.Sigmoid() - else: - raise ValueError(f"Unknown depth estimation type: {config.depth_estimation_type}") - self.max_depth = config.max_depth - - def forward(self, hidden_states: List[torch.Tensor], patch_height, patch_width) -> torch.Tensor: - hidden_states = hidden_states[self.head_in_index] - - predicted_depth = self.conv1(hidden_states) - predicted_depth = nn.functional.interpolate( - predicted_depth, - (int(patch_height * self.patch_size), int(patch_width * self.patch_size)), - mode="bilinear", - align_corners=True, - ) - predicted_depth = self.conv2(predicted_depth) - predicted_depth = self.activation1(predicted_depth) - predicted_depth = self.conv3(predicted_depth) - predicted_depth = self.activation2(predicted_depth) * self.max_depth - predicted_depth = predicted_depth.squeeze(dim=1) # shape (batch_size, height, width) - - return predicted_depth - - PROMPT_DEPTH_ANYTHING_START_DOCSTRING = r""" This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and diff --git a/src/transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py b/src/transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py index 847fa35075f0..b433c538c015 100644 --- a/src/transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py +++ b/src/transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py @@ -5,10 +5,14 @@ from transformers.models.depth_anything.configuration_depth_anything import DepthAnythingConfig from transformers.models.depth_anything.modeling_depth_anything import ( + DepthAnythingDepthEstimationHead, DepthAnythingFeatureFusionLayer, DepthAnythingForDepthEstimation, DepthAnythingNeck, + DepthAnythingReassembleLayer, + DepthAnythingReassembleStage, ) +from transformers.utils.generic import torch_int from ...file_utils import ( add_start_docstrings, @@ -159,6 +163,29 @@ def forward(self, hidden_states, size=None, prompt_depth=None): return fused_hidden_states +class PromptDepthAnythingDepthEstimationHead(DepthAnythingDepthEstimationHead): + def __init__(self, config): + super().__init__(config) + + def forward(self, hidden_states: List[torch.Tensor], patch_height, patch_width) -> torch.Tensor: + hidden_states = hidden_states[self.head_in_index] + + predicted_depth = self.conv1(hidden_states) + predicted_depth = nn.functional.interpolate( + predicted_depth, + (torch_int(patch_height * self.patch_size), torch_int(patch_width * self.patch_size)), + mode="bilinear", + align_corners=True, + ) + predicted_depth = self.conv2(predicted_depth) + predicted_depth = self.activation1(predicted_depth) + predicted_depth = self.conv3(predicted_depth) + predicted_depth = self.activation2(predicted_depth) * self.max_depth + predicted_depth = predicted_depth.squeeze(dim=1) # shape (batch_size, height, width) + + return predicted_depth + + # Copied from transformers.models.dpt.modeling_dpt.DPTPreTrainedModel with DPT->PromptDepthAnything,dpt->prompt_depth_anything class PromptDepthAnythingPreTrainedModel(PreTrainedModel): """ @@ -184,9 +211,49 @@ def _init_weights(self, module): module.weight.data.fill_(1.0) +class PromptDepthAnythingReassembleLayer(DepthAnythingReassembleLayer): + def __init__(self, config, channels, factor): + super().__init__(config, channels, factor) + self.projection = nn.Conv2d(in_channels=config.reassemble_hidden_size, out_channels=channels, kernel_size=1) + + # up/down sampling depending on factor + if factor > 1: + self.resize = nn.ConvTranspose2d(channels, channels, kernel_size=factor, stride=factor, padding=0) + elif factor == 1: + self.resize = nn.Identity() + elif factor < 1: + # so should downsample + self.resize = nn.Conv2d(channels, channels, kernel_size=3, stride=torch_int(1 / factor), padding=1) + + +class PromptDepthAnythingReassembleStage(DepthAnythingReassembleStage): + """ + This class reassembles the hidden states of the backbone into image-like feature representations at various + resolutions. + + This happens in 3 stages: + 1. Take the patch embeddings and reshape them to image-like feature representations. + 2. Project the channel dimension of the hidden states according to `config.neck_hidden_sizes`. + 3. Resizing the spatial dimensions (height, width). + + Args: + config (`[DepthAnythingConfig]`): + Model configuration class defining the model architecture. + """ + + def __init__(self, config): + super().__init__() + + self.config = config + self.layers = nn.ModuleList() + for channels, factor in zip(config.neck_hidden_sizes, config.reassemble_factors): + self.layers.append(PromptDepthAnythingReassembleLayer(config, channels=channels, factor=factor)) + + class PromptDepthAnythingNeck(DepthAnythingNeck): def __init__(self, config): super().__init__(config) + self.reassemble_stage = PromptDepthAnythingReassembleStage(config) self.fusion_stage = PromptDepthAnythingFeatureFusionStage(config) def forward( @@ -224,6 +291,7 @@ class PromptDepthAnythingForDepthEstimation(DepthAnythingForDepthEstimation): def __init__(self, config): super().__init__(config) self.neck = PromptDepthAnythingNeck(config) + self.head = PromptDepthAnythingDepthEstimationHead(config) self.post_init() @add_start_docstrings_to_model_forward(PROMPT_DEPTH_ANYTHING_INPUTS_DOCSTRING) @@ -302,7 +370,7 @@ def forward( if prompt_depth is not None: # normalize prompt depth - B = len(prompt_depth) + B = prompt_depth.shape[0] depth_min, depth_max = ( torch.min(prompt_depth.reshape(B, -1), dim=1).values, torch.max(prompt_depth.reshape(B, -1), dim=1).values, From 739c07f37dfd19ee81e8e01708cb9c22af1b9acb Mon Sep 17 00:00:00 2001 From: haotongl Date: Wed, 25 Dec 2024 15:32:53 +0800 Subject: [PATCH 11/58] fix typo --- .../prompt_depth_anything/modeling_prompt_depth_anything.py | 2 +- .../prompt_depth_anything/modular_prompt_depth_anything.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py b/src/transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py index 291cde8ff8f5..fa85a3292377 100644 --- a/src/transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py +++ b/src/transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py @@ -293,7 +293,7 @@ class PromptDepthAnythingReassembleStage(nn.Module): 3. Resizing the spatial dimensions (height, width). Args: - config (`[DepthAnythingConfig]`): + config (`[PromptDepthAnythingConfig]`): Model configuration class defining the model architecture. """ diff --git a/src/transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py b/src/transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py index b433c538c015..aeac6bce4cba 100644 --- a/src/transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py +++ b/src/transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py @@ -237,7 +237,7 @@ class PromptDepthAnythingReassembleStage(DepthAnythingReassembleStage): 3. Resizing the spatial dimensions (height, width). Args: - config (`[DepthAnythingConfig]`): + config (`[PromptDepthAnythingConfig]`): Model configuration class defining the model architecture. """ From 5c046e825c8972223d3d6dcaf01febcbc76dd42d Mon Sep 17 00:00:00 2001 From: haotongl Date: Wed, 25 Dec 2024 19:20:27 +0800 Subject: [PATCH 12/58] update PromptDepthAnythingImageProcessor --- .../en/model_doc/prompt_depth_anything.md | 13 +- src/transformers/__init__.py | 2 + .../models/auto/image_processing_auto.py | 1 + .../models/prompt_depth_anything/__init__.py | 8 +- .../convert_prompt_depth_anything_to_hf.py | 21 +- .../image_processing_prompt_depth_anything.py | 534 ++++++++++++++++++ .../modeling_prompt_depth_anything.py | 6 +- .../modular_prompt_depth_anything.py | 6 +- .../utils/dummy_vision_objects.py | 7 + 9 files changed, 567 insertions(+), 31 deletions(-) create mode 100644 src/transformers/models/prompt_depth_anything/image_processing_prompt_depth_anything.py diff --git a/docs/source/en/model_doc/prompt_depth_anything.md b/docs/source/en/model_doc/prompt_depth_anything.md index 202e9d1d7b10..b66bdd91b1af 100644 --- a/docs/source/en/model_doc/prompt_depth_anything.md +++ b/docs/source/en/model_doc/prompt_depth_anything.md @@ -52,14 +52,12 @@ The transformers library allows you to use the model with just a few lines of co >>> prompt_depth_url = "https://github.com/DepthAnything/PromptDA/blob/main/assets/example_images/arkit_depth.png?raw=true" >>> prompt_depth = Image.open(requests.get(prompt_depth_url, stream=True).raw) ->>> prompt_depth = torch.tensor((np.asarray(prompt_depth) / 1000.0).astype(np.float32)) ->>> prompt_depth = prompt_depth.unsqueeze(0).unsqueeze(0) >>> # prepare image for the model ->>> inputs = image_processor(images=image, return_tensors="pt") +>>> inputs = image_processor(images=image, return_tensors="pt", prompt_depth=prompt_depth) >>> with torch.no_grad(): -... outputs = model(pixel_values=inputs.pixel_values, prompt_depth=prompt_depth) +... outputs = model(**inputs) >>> # interpolate to original size >>> post_processed_output = image_processor.post_process_depth_estimation( @@ -90,4 +88,9 @@ If you're interested in submitting a resource to be included here, please feel f ## PromptDepthAnythingForDepthEstimation [[autodoc]] PromptDepthAnythingForDepthEstimation - - forward \ No newline at end of file + - forward + +## PromptDepthAnythingImageProcessor + +[[autodoc]] PromptDepthAnythingImageProcessor + - preprocess \ No newline at end of file diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 1bd6db463b3f..e808276bce6a 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -1247,6 +1247,7 @@ _import_structure["models.pix2struct"].extend(["Pix2StructImageProcessor"]) _import_structure["models.pixtral"].append("PixtralImageProcessor") _import_structure["models.poolformer"].extend(["PoolFormerFeatureExtractor", "PoolFormerImageProcessor"]) + _import_structure["models.prompt_depth_anything"].extend(["PromptDepthAnythingImageProcessor"]) _import_structure["models.pvt"].extend(["PvtImageProcessor"]) _import_structure["models.qwen2_vl"].extend(["Qwen2VLImageProcessor"]) _import_structure["models.rt_detr"].extend(["RTDetrImageProcessor"]) @@ -6268,6 +6269,7 @@ PoolFormerFeatureExtractor, PoolFormerImageProcessor, ) + from .models.prompt_depth_anything import PromptDepthAnythingImageProcessor from .models.pvt import PvtImageProcessor from .models.qwen2_vl import Qwen2VLImageProcessor from .models.rt_detr import RTDetrImageProcessor diff --git a/src/transformers/models/auto/image_processing_auto.py b/src/transformers/models/auto/image_processing_auto.py index db25591eaa35..d08b22721ada 100644 --- a/src/transformers/models/auto/image_processing_auto.py +++ b/src/transformers/models/auto/image_processing_auto.py @@ -123,6 +123,7 @@ ("pix2struct", ("Pix2StructImageProcessor",)), ("pixtral", ("PixtralImageProcessor", "PixtralImageProcessorFast")), ("poolformer", ("PoolFormerImageProcessor",)), + ("prompt_depth_anything", ("PromptDepthAnythingImageProcessor",)), ("pvt", ("PvtImageProcessor",)), ("pvt_v2", ("PvtImageProcessor",)), ("qwen2_vl", ("Qwen2VLImageProcessor",)), diff --git a/src/transformers/models/prompt_depth_anything/__init__.py b/src/transformers/models/prompt_depth_anything/__init__.py index dbbfe38f8f84..3cb05f8e3788 100644 --- a/src/transformers/models/prompt_depth_anything/__init__.py +++ b/src/transformers/models/prompt_depth_anything/__init__.py @@ -18,8 +18,12 @@ if TYPE_CHECKING: - from .configuration_prompt_depth_anything import * - from .modeling_prompt_depth_anything import * + from .configuration_prompt_depth_anything import PromptDepthAnythingConfig + from .image_processing_prompt_depth_anything import PromptDepthAnythingImageProcessor + from .modeling_prompt_depth_anything import ( + PromptDepthAnythingForDepthEstimation, + PromptDepthAnythingPreTrainedModel, + ) else: import sys diff --git a/src/transformers/models/prompt_depth_anything/convert_prompt_depth_anything_to_hf.py b/src/transformers/models/prompt_depth_anything/convert_prompt_depth_anything_to_hf.py index 3594d6bc6e91..4ff9612750e0 100644 --- a/src/transformers/models/prompt_depth_anything/convert_prompt_depth_anything_to_hf.py +++ b/src/transformers/models/prompt_depth_anything/convert_prompt_depth_anything_to_hf.py @@ -18,7 +18,6 @@ import argparse from pathlib import Path -import numpy as np import requests import torch from huggingface_hub import hf_hub_download @@ -26,9 +25,9 @@ from transformers import ( Dinov2Config, - DPTImageProcessor, PromptDepthAnythingConfig, PromptDepthAnythingForDepthEstimation, + PromptDepthAnythingImageProcessor, ) from transformers.utils import logging @@ -181,15 +180,7 @@ def rename_key(dct, old, new): dct[new] = val -# We will verify our results on an image of cute cats -def prepare_img(): - url = "http://images.cocodataset.org/val2017/000000039769.jpg" - im = Image.open(requests.get(url, stream=True).raw) - return im - - name_to_checkpoint = { - # v2-giant pending "promptda_vits": "model.ckpt", "promptda_vits_transparent": "model.ckpt", "promptda_vitl": "model.ckpt", @@ -233,7 +224,7 @@ def convert_dpt_checkpoint(model_name, pytorch_dump_folder_path, push_to_hub, ve model.load_state_dict(state_dict, strict=False) model.eval() - processor = DPTImageProcessor( + processor = PromptDepthAnythingImageProcessor( do_resize=True, size=756, ensure_multiple_of=14, @@ -246,18 +237,16 @@ def convert_dpt_checkpoint(model_name, pytorch_dump_folder_path, push_to_hub, ve url = "https://github.com/DepthAnything/PromptDA/blob/main/assets/example_images/image.jpg?raw=true" image = Image.open(requests.get(url, stream=True).raw) - pixel_values = processor(image, return_tensors="pt").pixel_values - prompt_depth_url = ( "https://github.com/DepthAnything/PromptDA/blob/main/assets/example_images/arkit_depth.png?raw=true" ) prompt_depth = Image.open(requests.get(prompt_depth_url, stream=True).raw) - prompt_depth = torch.tensor((np.asarray(prompt_depth) / 1000.0).astype(np.float32)) - prompt_depth = prompt_depth.unsqueeze(0).unsqueeze(0) + + inputs = processor(image, return_tensors="pt", prompt_depth=prompt_depth) # Verify forward pass with torch.no_grad(): - outputs = model(pixel_values, prompt_depth=prompt_depth) + outputs = model(**inputs) predicted_depth = outputs.predicted_depth print("Shape of predicted depth:", predicted_depth.shape) diff --git a/src/transformers/models/prompt_depth_anything/image_processing_prompt_depth_anything.py b/src/transformers/models/prompt_depth_anything/image_processing_prompt_depth_anything.py new file mode 100644 index 000000000000..cd78edcd16c8 --- /dev/null +++ b/src/transformers/models/prompt_depth_anything/image_processing_prompt_depth_anything.py @@ -0,0 +1,534 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Image processor class for PromptDepthAnything.""" + +import math +from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple, Union + + +if TYPE_CHECKING: + from ...modeling_outputs import DepthEstimatorOutput + +import numpy as np + +from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict +from ...image_transforms import pad, resize, to_channel_dimension_format +from ...image_utils import ( + IMAGENET_STANDARD_MEAN, + IMAGENET_STANDARD_STD, + ChannelDimension, + ImageInput, + PILImageResampling, + get_image_size, + infer_channel_dimension_format, + is_scaled_image, + is_torch_available, + is_torch_tensor, + make_list_of_images, + to_numpy_array, + valid_images, + validate_preprocess_arguments, +) +from ...utils import ( + TensorType, + filter_out_non_signature_kwargs, + is_vision_available, + logging, + requires_backends, +) + + +if is_torch_available(): + import torch + +if is_vision_available(): + import PIL + + +logger = logging.get_logger(__name__) + + +def get_resize_output_image_size( + input_image: np.ndarray, + output_size: Union[int, Iterable[int]], + keep_aspect_ratio: bool, + multiple: int, + input_data_format: Optional[Union[str, ChannelDimension]] = None, +) -> Tuple[int, int]: + def constrain_to_multiple_of(val, multiple, min_val=0, max_val=None): + x = round(val / multiple) * multiple + + if max_val is not None and x > max_val: + x = math.floor(val / multiple) * multiple + + if x < min_val: + x = math.ceil(val / multiple) * multiple + + return x + + output_size = (output_size, output_size) if isinstance(output_size, int) else output_size + + input_height, input_width = get_image_size(input_image, input_data_format) + output_height, output_width = output_size + + # determine new height and width + scale_height = output_height / input_height + scale_width = output_width / input_width + + if keep_aspect_ratio: + # scale as little as possible + if abs(1 - scale_width) < abs(1 - scale_height): + # fit width + scale_height = scale_width + else: + # fit height + scale_width = scale_height + + new_height = constrain_to_multiple_of(scale_height * input_height, multiple=multiple) + new_width = constrain_to_multiple_of(scale_width * input_width, multiple=multiple) + + return (new_height, new_width) + + +class PromptDepthAnythingImageProcessor(BaseImageProcessor): + r""" + Constructs a PromptDepthAnything image processor. + + Args: + do_resize (`bool`, *optional*, defaults to `True`): + Whether to resize the image's (height, width) dimensions. Can be overidden by `do_resize` in `preprocess`. + size (`Dict[str, int]` *optional*, defaults to `{"height": 384, "width": 384}`): + Size of the image after resizing. Can be overidden by `size` in `preprocess`. + resample (`PILImageResampling`, *optional*, defaults to `Resampling.BICUBIC`): + Defines the resampling filter to use if resizing the image. Can be overidden by `resample` in `preprocess`. + keep_aspect_ratio (`bool`, *optional*, defaults to `False`): + If `True`, the image is resized to the largest possible size such that the aspect ratio is preserved. Can + be overidden by `keep_aspect_ratio` in `preprocess`. + ensure_multiple_of (`int`, *optional*, defaults to 1): + If `do_resize` is `True`, the image is resized to a size that is a multiple of this value. Can be overidden + by `ensure_multiple_of` in `preprocess`. + do_rescale (`bool`, *optional*, defaults to `True`): + Whether to rescale the image by the specified scale `rescale_factor`. Can be overidden by `do_rescale` in + `preprocess`. + rescale_factor (`int` or `float`, *optional*, defaults to `1/255`): + Scale factor to use if rescaling the image. Can be overidden by `rescale_factor` in `preprocess`. + do_normalize (`bool`, *optional*, defaults to `True`): + Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess` + method. + image_mean (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_MEAN`): + Mean to use if normalizing the image. This is a float or list of floats the length of the number of + channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method. + image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`): + Standard deviation to use if normalizing the image. This is a float or list of floats the length of the + number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method. + do_pad (`bool`, *optional*, defaults to `False`): + Whether to apply center padding. This was introduced in the DINOv2 paper, which uses the model in + combination with DPT. + size_divisor (`int`, *optional*): + If `do_pad` is `True`, pads the image dimensions to be divisible by this value. This was introduced in the + DINOv2 paper, which uses the model in combination with DPT. + prompt_scale_to_meter (`float`, *optional*, defaults to `0.001`): + Scale factor to convert the prompt depth to meters. + """ + + model_input_names = ["pixel_values", "prompt_depth"] + + def __init__( + self, + do_resize: bool = True, + size: Dict[str, int] = None, + resample: PILImageResampling = PILImageResampling.BICUBIC, + keep_aspect_ratio: bool = False, + ensure_multiple_of: int = 1, + do_rescale: bool = True, + rescale_factor: Union[int, float] = 1 / 255, + do_normalize: bool = True, + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + do_pad: bool = False, + size_divisor: int = None, + prompt_scale_to_meter: float = 0.001, # default unit is mm + **kwargs, + ) -> None: + super().__init__(**kwargs) + size = size if size is not None else {"height": 384, "width": 384} + size = get_size_dict(size) + self.do_resize = do_resize + self.size = size + self.keep_aspect_ratio = keep_aspect_ratio + self.ensure_multiple_of = ensure_multiple_of + self.resample = resample + self.do_rescale = do_rescale + self.rescale_factor = rescale_factor + self.do_normalize = do_normalize + self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN + self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD + self.do_pad = do_pad + self.size_divisor = size_divisor + self.prompt_scale_to_meter = prompt_scale_to_meter + + def resize( + self, + image: np.ndarray, + size: Dict[str, int], + keep_aspect_ratio: bool = False, + ensure_multiple_of: int = 1, + resample: PILImageResampling = PILImageResampling.BICUBIC, + data_format: Optional[Union[str, ChannelDimension]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + **kwargs, + ) -> np.ndarray: + """ + Resize an image to target size `(size["height"], size["width"])`. If `keep_aspect_ratio` is `True`, the image + is resized to the largest possible size such that the aspect ratio is preserved. If `ensure_multiple_of` is + set, the image is resized to a size that is a multiple of this value. + + Args: + image (`np.ndarray`): + Image to resize. + size (`Dict[str, int]`): + Target size of the output image. + keep_aspect_ratio (`bool`, *optional*, defaults to `False`): + If `True`, the image is resized to the largest possible size such that the aspect ratio is preserved. + ensure_multiple_of (`int`, *optional*, defaults to 1): + The image is resized to a size that is a multiple of this value. + resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`): + Defines the resampling filter to use if resizing the image. Otherwise, the image is resized to size + specified in `size`. + resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`): + Resampling filter to use when resiizing the image. + data_format (`str` or `ChannelDimension`, *optional*): + The channel dimension format of the image. If not provided, it will be the same as the input image. + input_data_format (`str` or `ChannelDimension`, *optional*): + The channel dimension format of the input image. If not provided, it will be inferred. + """ + size = get_size_dict(size) + if "height" not in size or "width" not in size: + raise ValueError(f"The size dictionary must contain the keys 'height' and 'width'. Got {size.keys()}") + + output_size = get_resize_output_image_size( + image, + output_size=(size["height"], size["width"]), + keep_aspect_ratio=keep_aspect_ratio, + multiple=ensure_multiple_of, + input_data_format=input_data_format, + ) + return resize( + image, + size=output_size, + resample=resample, + data_format=data_format, + input_data_format=input_data_format, + **kwargs, + ) + + def pad_image( + self, + image: np.array, + size_divisor: int, + data_format: Optional[Union[str, ChannelDimension]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + ): + """ + Center pad an image to be a multiple of `multiple`. + + Args: + image (`np.ndarray`): + Image to pad. + size_divisor (`int`): + The width and height of the image will be padded to a multiple of this number. + data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): + The channel dimension format for the output image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - Unset: Use the channel dimension format of the input image. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. If unset, the channel dimension format is inferred + from the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + """ + + def _get_pad(size, size_divisor): + new_size = math.ceil(size / size_divisor) * size_divisor + pad_size = new_size - size + pad_size_left = pad_size // 2 + pad_size_right = pad_size - pad_size_left + return pad_size_left, pad_size_right + + if input_data_format is None: + input_data_format = infer_channel_dimension_format(image) + + height, width = get_image_size(image, input_data_format) + + pad_size_left, pad_size_right = _get_pad(height, size_divisor) + pad_size_top, pad_size_bottom = _get_pad(width, size_divisor) + + return pad(image, ((pad_size_left, pad_size_right), (pad_size_top, pad_size_bottom)), data_format=data_format) + + @filter_out_non_signature_kwargs() + def preprocess( + self, + images: ImageInput, + do_resize: bool = None, + size: int = None, + keep_aspect_ratio: bool = None, + ensure_multiple_of: int = None, + resample: PILImageResampling = None, + do_rescale: bool = None, + rescale_factor: float = None, + do_normalize: bool = None, + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + do_pad: bool = None, + size_divisor: int = None, + return_tensors: Optional[Union[str, TensorType]] = None, + data_format: ChannelDimension = ChannelDimension.FIRST, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + prompt_depth: ImageInput = None, + ) -> PIL.Image.Image: + """ + Preprocess an image or batch of images. + + Args: + images (`ImageInput`): + Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If + passing in images with pixel values between 0 and 1, set `do_rescale=False`. + do_resize (`bool`, *optional*, defaults to `self.do_resize`): + Whether to resize the image. + size (`Dict[str, int]`, *optional*, defaults to `self.size`): + Size of the image after reszing. If `keep_aspect_ratio` is `True`, the image is resized to the largest + possible size such that the aspect ratio is preserved. If `ensure_multiple_of` is set, the image is + resized to a size that is a multiple of this value. + keep_aspect_ratio (`bool`, *optional*, defaults to `self.keep_aspect_ratio`): + Whether to keep the aspect ratio of the image. If False, the image will be resized to (size, size). If + True, the image will be resized to keep the aspect ratio and the size will be the maximum possible. + ensure_multiple_of (`int`, *optional*, defaults to `self.ensure_multiple_of`): + Ensure that the image size is a multiple of this value. + resample (`int`, *optional*, defaults to `self.resample`): + Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`, Only + has an effect if `do_resize` is set to `True`. + do_rescale (`bool`, *optional*, defaults to `self.do_rescale`): + Whether to rescale the image values between [0 - 1]. + rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`): + Rescale factor to rescale the image by if `do_rescale` is set to `True`. + do_normalize (`bool`, *optional*, defaults to `self.do_normalize`): + Whether to normalize the image. + image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`): + Image mean. + image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`): + Image standard deviation. + return_tensors (`str` or `TensorType`, *optional*): + The type of tensors to return. Can be one of: + - Unset: Return a list of `np.ndarray`. + - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. + - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. + - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. + - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. + data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): + The channel dimension format for the output image. Can be one of: + - `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `ChannelDimension.LAST`: image in (height, width, num_channels) format. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. If unset, the channel dimension format is inferred + from the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + prompt_depth (`ImageInput`, *optional*): + Prompt depth to preprocess. + """ + do_resize = do_resize if do_resize is not None else self.do_resize + size = size if size is not None else self.size + size = get_size_dict(size) + keep_aspect_ratio = keep_aspect_ratio if keep_aspect_ratio is not None else self.keep_aspect_ratio + ensure_multiple_of = ensure_multiple_of if ensure_multiple_of is not None else self.ensure_multiple_of + resample = resample if resample is not None else self.resample + do_rescale = do_rescale if do_rescale is not None else self.do_rescale + rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor + do_normalize = do_normalize if do_normalize is not None else self.do_normalize + image_mean = image_mean if image_mean is not None else self.image_mean + image_std = image_std if image_std is not None else self.image_std + do_pad = do_pad if do_pad is not None else self.do_pad + size_divisor = size_divisor if size_divisor is not None else self.size_divisor + + images = make_list_of_images(images) + + if not valid_images(images): + raise ValueError( + "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " + "torch.Tensor, tf.Tensor or jax.ndarray." + ) + validate_preprocess_arguments( + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + do_pad=do_pad, + size_divisibility=size_divisor, + do_resize=do_resize, + size=size, + resample=resample, + ) + # All transformations expect numpy arrays. + images = [to_numpy_array(image) for image in images] + + if is_scaled_image(images[0]) and do_rescale: + logger.warning_once( + "It looks like you are trying to rescale already rescaled images. If the input" + " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again." + ) + + if input_data_format is None: + # We assume that all images have the same channel dimension format. + input_data_format = infer_channel_dimension_format(images[0]) + + if do_resize: + images = [ + self.resize( + image=image, + size=size, + resample=resample, + keep_aspect_ratio=keep_aspect_ratio, + ensure_multiple_of=ensure_multiple_of, + input_data_format=input_data_format, + ) + for image in images + ] + + if do_rescale: + images = [ + self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format) + for image in images + ] + + if do_normalize: + images = [ + self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format) + for image in images + ] + + if do_pad: + images = [ + self.pad_image(image=image, size_divisor=size_divisor, input_data_format=input_data_format) + for image in images + ] + + images = [ + to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images + ] + + data = {"pixel_values": images} + if prompt_depth is not None: + # prompt_depth is a list of images with shape (height, width) + # we need to convert it to a list of images with shape (1, height, width) + prompt_depths = make_list_of_images(prompt_depth) + prompt_depths = [to_numpy_array(depth) for depth in prompt_depths] + prompt_depths = [depth * self.prompt_scale_to_meter for depth in prompt_depths] + prompt_depths = [prompt_depth[..., None].astype(np.float32) for prompt_depth in prompt_depths] + prompt_depths = [ + to_channel_dimension_format(depth, data_format, input_channel_dim=input_data_format) + for depth in prompt_depths + ] + data["prompt_depth"] = prompt_depths + return BatchFeature(data=data, tensor_type=return_tensors) + + # Copied from transformers.models.beit.image_processing_beit.BeitImageProcessor.post_process_semantic_segmentation with Beit->PromptDepthAnything + def post_process_semantic_segmentation(self, outputs, target_sizes: List[Tuple] = None): + """ + Converts the output of [`PromptDepthAnythingForSemanticSegmentation`] into semantic segmentation maps. Only supports PyTorch. + + Args: + outputs ([`PromptDepthAnythingForSemanticSegmentation`]): + Raw outputs of the model. + target_sizes (`List[Tuple]` of length `batch_size`, *optional*): + List of tuples corresponding to the requested final size (height, width) of each prediction. If unset, + predictions will not be resized. + + Returns: + semantic_segmentation: `List[torch.Tensor]` of length `batch_size`, where each item is a semantic + segmentation map of shape (height, width) corresponding to the target_sizes entry (if `target_sizes` is + specified). Each entry of each `torch.Tensor` correspond to a semantic class id. + """ + # TODO: add support for other frameworks + logits = outputs.logits + + # Resize logits and compute semantic segmentation maps + if target_sizes is not None: + if len(logits) != len(target_sizes): + raise ValueError( + "Make sure that you pass in as many target sizes as the batch dimension of the logits" + ) + + if is_torch_tensor(target_sizes): + target_sizes = target_sizes.numpy() + + semantic_segmentation = [] + + for idx in range(len(logits)): + resized_logits = torch.nn.functional.interpolate( + logits[idx].unsqueeze(dim=0), size=target_sizes[idx], mode="bilinear", align_corners=False + ) + semantic_map = resized_logits[0].argmax(dim=0) + semantic_segmentation.append(semantic_map) + else: + semantic_segmentation = logits.argmax(dim=1) + semantic_segmentation = [semantic_segmentation[i] for i in range(semantic_segmentation.shape[0])] + + return semantic_segmentation + + def post_process_depth_estimation( + self, + outputs: "DepthEstimatorOutput", + target_sizes: Optional[Union[TensorType, List[Tuple[int, int]], None]] = None, + ) -> List[Dict[str, TensorType]]: + """ + Converts the raw output of [`DepthEstimatorOutput`] into final depth predictions and depth PIL images. + Only supports PyTorch. + + Args: + outputs ([`DepthEstimatorOutput`]): + Raw outputs of the model. + target_sizes (`TensorType` or `List[Tuple[int, int]]`, *optional*): + Tensor of shape `(batch_size, 2)` or list of tuples (`Tuple[int, int]`) containing the target size + (height, width) of each image in the batch. If left to None, predictions will not be resized. + + Returns: + `List[Dict[str, TensorType]]`: A list of dictionaries of tensors representing the processed depth + predictions. + """ + requires_backends(self, "torch") + + predicted_depth = outputs.predicted_depth + + if (target_sizes is not None) and (len(predicted_depth) != len(target_sizes)): + raise ValueError( + "Make sure that you pass in as many target sizes as the batch dimension of the predicted depth" + ) + + results = [] + target_sizes = [None] * len(predicted_depth) if target_sizes is None else target_sizes + for depth, target_size in zip(predicted_depth, target_sizes): + if target_size is not None: + depth = torch.nn.functional.interpolate( + depth.unsqueeze(0).unsqueeze(1), size=target_size, mode="bicubic", align_corners=False + ).squeeze() + + results.append({"predicted_depth": depth}) + + return results + + +__all__ = ["PromptDepthAnythingImageProcessor"] diff --git a/src/transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py b/src/transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py index fa85a3292377..8432c8a4aa23 100644 --- a/src/transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py +++ b/src/transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py @@ -458,14 +458,12 @@ def forward( >>> prompt_depth_url = "https://github.com/DepthAnything/PromptDA/blob/main/assets/example_images/arkit_depth.png?raw=true" >>> prompt_depth = Image.open(requests.get(prompt_depth_url, stream=True).raw) - >>> prompt_depth = torch.tensor((np.asarray(prompt_depth) / 1000.0).astype(np.float32)) - >>> prompt_depth = prompt_depth.unsqueeze(0).unsqueeze(0) >>> # prepare image for the model - >>> inputs = image_processor(images=image, return_tensors="pt") + >>> inputs = image_processor(images=image, return_tensors="pt", prompt_depth=prompt_depth) >>> with torch.no_grad(): - ... outputs = model(pixel_values=inputs.pixel_values, prompt_depth=prompt_depth) + ... outputs = model(**inputs) >>> # interpolate to original size >>> post_processed_output = image_processor.post_process_depth_estimation( diff --git a/src/transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py b/src/transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py index aeac6bce4cba..54224d5ae568 100644 --- a/src/transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py +++ b/src/transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py @@ -327,14 +327,12 @@ def forward( >>> prompt_depth_url = "https://github.com/DepthAnything/PromptDA/blob/main/assets/example_images/arkit_depth.png?raw=true" >>> prompt_depth = Image.open(requests.get(prompt_depth_url, stream=True).raw) - >>> prompt_depth = torch.tensor((np.asarray(prompt_depth) / 1000.0).astype(np.float32)) - >>> prompt_depth = prompt_depth.unsqueeze(0).unsqueeze(0) >>> # prepare image for the model - >>> inputs = image_processor(images=image, return_tensors="pt") + >>> inputs = image_processor(images=image, return_tensors="pt", prompt_depth=prompt_depth) >>> with torch.no_grad(): - ... outputs = model(pixel_values=inputs.pixel_values, prompt_depth=prompt_depth) + ... outputs = model(**inputs) >>> # interpolate to original size >>> post_processed_output = image_processor.post_process_depth_estimation( diff --git a/src/transformers/utils/dummy_vision_objects.py b/src/transformers/utils/dummy_vision_objects.py index 3ebda4404aae..8d5dee958bd1 100644 --- a/src/transformers/utils/dummy_vision_objects.py +++ b/src/transformers/utils/dummy_vision_objects.py @@ -548,6 +548,13 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["vision"]) +class PromptDepthAnythingImageProcessor(metaclass=DummyObject): + _backends = ["vision"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["vision"]) + + class PvtImageProcessor(metaclass=DummyObject): _backends = ["vision"] From f3a8aa48d17d537afca4c84dcb9c31987b4bb0cd Mon Sep 17 00:00:00 2001 From: haotongl Date: Wed, 25 Dec 2024 19:37:39 +0800 Subject: [PATCH 13/58] fix typo --- .../image_processing_prompt_depth_anything.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/prompt_depth_anything/image_processing_prompt_depth_anything.py b/src/transformers/models/prompt_depth_anything/image_processing_prompt_depth_anything.py index cd78edcd16c8..7262e54ef2db 100644 --- a/src/transformers/models/prompt_depth_anything/image_processing_prompt_depth_anything.py +++ b/src/transformers/models/prompt_depth_anything/image_processing_prompt_depth_anything.py @@ -137,7 +137,7 @@ class PromptDepthAnythingImageProcessor(BaseImageProcessor): size_divisor (`int`, *optional*): If `do_pad` is `True`, pads the image dimensions to be divisible by this value. This was introduced in the DINOv2 paper, which uses the model in combination with DPT. - prompt_scale_to_meter (`float`, *optional*, defaults to `0.001`): + prompt_scale_to_meter (`float`, *optional*, defaults to 0.001): Scale factor to convert the prompt depth to meters. """ From c2647ca3e0f6a4357b74f9bd32b018355e1c84dc Mon Sep 17 00:00:00 2001 From: haotongl Date: Thu, 26 Dec 2024 12:09:44 +0800 Subject: [PATCH 14/58] fix typo for prompt depth anything doc --- docs/source/en/model_doc/prompt_depth_anything.md | 3 --- 1 file changed, 3 deletions(-) diff --git a/docs/source/en/model_doc/prompt_depth_anything.md b/docs/source/en/model_doc/prompt_depth_anything.md index b66bdd91b1af..72392ec6f32a 100644 --- a/docs/source/en/model_doc/prompt_depth_anything.md +++ b/docs/source/en/model_doc/prompt_depth_anything.md @@ -30,9 +30,6 @@ alt="drawing" width="600"/> Prompt Depth Anything overview. Taken from the original paper. - - ## Usage example The transformers library allows you to use the model with just a few lines of code: From ea67b900a343738158553e65d72d5359ee398e53 Mon Sep 17 00:00:00 2001 From: haotongl Date: Fri, 3 Jan 2025 00:44:09 +0800 Subject: [PATCH 15/58] update promptda overview image link of huggingface repo --- docs/source/en/model_doc/prompt_depth_anything.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/en/model_doc/prompt_depth_anything.md b/docs/source/en/model_doc/prompt_depth_anything.md index 72392ec6f32a..dc97eeb35cc6 100644 --- a/docs/source/en/model_doc/prompt_depth_anything.md +++ b/docs/source/en/model_doc/prompt_depth_anything.md @@ -25,7 +25,7 @@ The abstract from the paper is as follows: *Prompts play a critical role in unleashing the power of language and vision foundation models for specific tasks. For the first time, we introduce prompting into depth foundation models, creating a new paradigm for metric depth estimation termed Prompt Depth Anything. Specifically, we use a low-cost LiDAR as the prompt to guide the Depth Anything model for accurate metric depth output, achieving up to 4K resolution. Our approach centers on a concise prompt fusion design that integrates the LiDAR at multiple scales within the depth decoder. To address training challenges posed by limited datasets containing both LiDAR depth and precise GT depth, we propose a scalable data pipeline that includes synthetic data LiDAR simulation and real data pseudo GT depth generation. Our approach sets new state-of-the-arts on the ARKitScenes and ScanNet++ datasets and benefits downstream applications, including 3D reconstruction and generalized robotic grasping.* -drawing Prompt Depth Anything overview. Taken from the original paper. From b2379d613ccee47577df7688c5a267f6346cb7ab Mon Sep 17 00:00:00 2001 From: haotongl Date: Mon, 6 Jan 2025 17:14:26 +0800 Subject: [PATCH 16/58] fix some typos in promptda doc --- docs/source/en/model_doc/prompt_depth_anything.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/source/en/model_doc/prompt_depth_anything.md b/docs/source/en/model_doc/prompt_depth_anything.md index dc97eeb35cc6..6b2acfe7f42c 100644 --- a/docs/source/en/model_doc/prompt_depth_anything.md +++ b/docs/source/en/model_doc/prompt_depth_anything.md @@ -18,7 +18,7 @@ rendered properly in your Markdown viewer. ## Overview -The Prompt Depth Anything model was introduced in [Prompting Depth Anything for 4K Resolution Accurate Metric Depth Estimation](https://promptda.github.io/) by Haotong Lin, Sida Peng, Jingxiao Chen, Songyou Peng, Jiaming Sun, Minghuan Liu, Hujun Bao, Jiashi Feng, Xiaowei Zhou, Bingyi Kang. +The Prompt Depth Anything model was introduced in [Prompting Depth Anything for 4K Resolution Accurate Metric Depth Estimation](https://arxiv.org/abs/2412.14015) by Haotong Lin, Sida Peng, Jingxiao Chen, Songyou Peng, Jiaming Sun, Minghuan Liu, Hujun Bao, Jiashi Feng, Xiaowei Zhou, Bingyi Kang. The abstract from the paper is as follows: @@ -28,11 +28,11 @@ The abstract from the paper is as follows: drawing - Prompt Depth Anything overview. Taken from the original paper. + Prompt Depth Anything overview. Taken from the original paper. ## Usage example -The transformers library allows you to use the model with just a few lines of code: +The Transformers library allows you to use the model with just a few lines of code: ```python >>> from transformers import AutoImageProcessor, AutoModelForDepthEstimation From b9a44fb40a83ca82948d70f953275c9213d9d175 Mon Sep 17 00:00:00 2001 From: haotongl Date: Tue, 7 Jan 2025 20:01:19 +0800 Subject: [PATCH 17/58] Update image processing to include pad_image, prompt depth position, and related explanations for better clarity and functionality. --- .../image_processing_prompt_depth_anything.py | 54 +++---------------- 1 file changed, 7 insertions(+), 47 deletions(-) diff --git a/src/transformers/models/prompt_depth_anything/image_processing_prompt_depth_anything.py b/src/transformers/models/prompt_depth_anything/image_processing_prompt_depth_anything.py index 7262e54ef2db..5fe42c2bb5a4 100644 --- a/src/transformers/models/prompt_depth_anything/image_processing_prompt_depth_anything.py +++ b/src/transformers/models/prompt_depth_anything/image_processing_prompt_depth_anything.py @@ -275,12 +275,14 @@ def _get_pad(size, size_divisor): pad_size_left, pad_size_right = _get_pad(height, size_divisor) pad_size_top, pad_size_bottom = _get_pad(width, size_divisor) - return pad(image, ((pad_size_left, pad_size_right), (pad_size_top, pad_size_bottom)), data_format=data_format) + padded_image = pad(image, ((pad_size_left, pad_size_right), (pad_size_top, pad_size_bottom)), data_format=data_format) + return padded_image @filter_out_non_signature_kwargs() def preprocess( self, images: ImageInput, + prompt_depth: ImageInput = None, do_resize: bool = None, size: int = None, keep_aspect_ratio: bool = None, @@ -296,7 +298,6 @@ def preprocess( return_tensors: Optional[Union[str, TensorType]] = None, data_format: ChannelDimension = ChannelDimension.FIRST, input_data_format: Optional[Union[str, ChannelDimension]] = None, - prompt_depth: ImageInput = None, ) -> PIL.Image.Image: """ Preprocess an image or batch of images. @@ -305,6 +306,10 @@ def preprocess( images (`ImageInput`): Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If passing in images with pixel values between 0 and 1, set `do_rescale=False`. + prompt_depth (`ImageInput`, *optional*): + Prompt depth to preprocess, which can be sparse depth obtained from multi-view geometry or + low-resolution depth from a depth sensor. Generally has shape (height, width), where height + and width can be smaller than the images. do_resize (`bool`, *optional*, defaults to `self.do_resize`): Whether to resize the image. size (`Dict[str, int]`, *optional*, defaults to `self.size`): @@ -346,8 +351,6 @@ def preprocess( - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. - prompt_depth (`ImageInput`, *optional*): - Prompt depth to preprocess. """ do_resize = do_resize if do_resize is not None else self.do_resize size = size if size is not None else self.size @@ -445,49 +448,6 @@ def preprocess( data["prompt_depth"] = prompt_depths return BatchFeature(data=data, tensor_type=return_tensors) - # Copied from transformers.models.beit.image_processing_beit.BeitImageProcessor.post_process_semantic_segmentation with Beit->PromptDepthAnything - def post_process_semantic_segmentation(self, outputs, target_sizes: List[Tuple] = None): - """ - Converts the output of [`PromptDepthAnythingForSemanticSegmentation`] into semantic segmentation maps. Only supports PyTorch. - - Args: - outputs ([`PromptDepthAnythingForSemanticSegmentation`]): - Raw outputs of the model. - target_sizes (`List[Tuple]` of length `batch_size`, *optional*): - List of tuples corresponding to the requested final size (height, width) of each prediction. If unset, - predictions will not be resized. - - Returns: - semantic_segmentation: `List[torch.Tensor]` of length `batch_size`, where each item is a semantic - segmentation map of shape (height, width) corresponding to the target_sizes entry (if `target_sizes` is - specified). Each entry of each `torch.Tensor` correspond to a semantic class id. - """ - # TODO: add support for other frameworks - logits = outputs.logits - - # Resize logits and compute semantic segmentation maps - if target_sizes is not None: - if len(logits) != len(target_sizes): - raise ValueError( - "Make sure that you pass in as many target sizes as the batch dimension of the logits" - ) - - if is_torch_tensor(target_sizes): - target_sizes = target_sizes.numpy() - - semantic_segmentation = [] - - for idx in range(len(logits)): - resized_logits = torch.nn.functional.interpolate( - logits[idx].unsqueeze(dim=0), size=target_sizes[idx], mode="bilinear", align_corners=False - ) - semantic_map = resized_logits[0].argmax(dim=0) - semantic_segmentation.append(semantic_map) - else: - semantic_segmentation = logits.argmax(dim=1) - semantic_segmentation = [semantic_segmentation[i] for i in range(semantic_segmentation.shape[0])] - - return semantic_segmentation def post_process_depth_estimation( self, From dfee43faad6b5d47b96b89e3e42dc1cfd2739614 Mon Sep 17 00:00:00 2001 From: haotongl Date: Tue, 7 Jan 2025 20:05:34 +0800 Subject: [PATCH 18/58] add copy disclaimer for prompt depth anything image processing --- .../image_processing_prompt_depth_anything.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/prompt_depth_anything/image_processing_prompt_depth_anything.py b/src/transformers/models/prompt_depth_anything/image_processing_prompt_depth_anything.py index 5fe42c2bb5a4..e5ffdad3db04 100644 --- a/src/transformers/models/prompt_depth_anything/image_processing_prompt_depth_anything.py +++ b/src/transformers/models/prompt_depth_anything/image_processing_prompt_depth_anything.py @@ -448,7 +448,7 @@ def preprocess( data["prompt_depth"] = prompt_depths return BatchFeature(data=data, tensor_type=return_tensors) - + # Copied from transformers.models.dpt.image_processing_dpt.DPTImageProcessor.post_process_depth_estimation with DPT->PromptDepthAnything def post_process_depth_estimation( self, outputs: "DepthEstimatorOutput", From db9f301b4338225c29ebf59f4add902997b4eba9 Mon Sep 17 00:00:00 2001 From: haotongl Date: Tue, 7 Jan 2025 20:30:25 +0800 Subject: [PATCH 19/58] fix some format typos in image processing and conversion scripts --- .../convert_prompt_depth_anything_to_hf.py | 209 ++++++++++-------- .../image_processing_prompt_depth_anything.py | 9 +- 2 files changed, 119 insertions(+), 99 deletions(-) diff --git a/src/transformers/models/prompt_depth_anything/convert_prompt_depth_anything_to_hf.py b/src/transformers/models/prompt_depth_anything/convert_prompt_depth_anything_to_hf.py index 4ff9612750e0..1d80d48602d9 100644 --- a/src/transformers/models/prompt_depth_anything/convert_prompt_depth_anything_to_hf.py +++ b/src/transformers/models/prompt_depth_anything/convert_prompt_depth_anything_to_hf.py @@ -77,107 +77,117 @@ def get_dpt_config(model_name): return config -def create_rename_keys(config): - rename_keys = [] - - # fmt: off - # stem - rename_keys.append(("pretrained.cls_token", "backbone.embeddings.cls_token")) - rename_keys.append(("pretrained.mask_token", "backbone.embeddings.mask_token")) - rename_keys.append(("pretrained.pos_embed", "backbone.embeddings.position_embeddings")) - rename_keys.append(("pretrained.patch_embed.proj.weight", "backbone.embeddings.patch_embeddings.projection.weight")) - rename_keys.append(("pretrained.patch_embed.proj.bias", "backbone.embeddings.patch_embeddings.projection.bias")) - - # Transfomer encoder - for i in range(config.backbone_config.num_hidden_layers): - rename_keys.append((f"pretrained.blocks.{i}.ls1.gamma", f"backbone.encoder.layer.{i}.layer_scale1.lambda1")) - rename_keys.append((f"pretrained.blocks.{i}.ls2.gamma", f"backbone.encoder.layer.{i}.layer_scale2.lambda1")) - rename_keys.append((f"pretrained.blocks.{i}.norm1.weight", f"backbone.encoder.layer.{i}.norm1.weight")) - rename_keys.append((f"pretrained.blocks.{i}.norm1.bias", f"backbone.encoder.layer.{i}.norm1.bias")) - rename_keys.append((f"pretrained.blocks.{i}.norm2.weight", f"backbone.encoder.layer.{i}.norm2.weight")) - rename_keys.append((f"pretrained.blocks.{i}.norm2.bias", f"backbone.encoder.layer.{i}.norm2.bias")) - rename_keys.append((f"pretrained.blocks.{i}.mlp.fc1.weight", f"backbone.encoder.layer.{i}.mlp.fc1.weight")) - rename_keys.append((f"pretrained.blocks.{i}.mlp.fc1.bias", f"backbone.encoder.layer.{i}.mlp.fc1.bias")) - rename_keys.append((f"pretrained.blocks.{i}.mlp.fc2.weight", f"backbone.encoder.layer.{i}.mlp.fc2.weight")) - rename_keys.append((f"pretrained.blocks.{i}.mlp.fc2.bias", f"backbone.encoder.layer.{i}.mlp.fc2.bias")) - rename_keys.append((f"pretrained.blocks.{i}.attn.proj.weight", f"backbone.encoder.layer.{i}.attention.output.dense.weight")) - rename_keys.append((f"pretrained.blocks.{i}.attn.proj.bias", f"backbone.encoder.layer.{i}.attention.output.dense.bias")) - +KEY_MAPPING = { + # Stem + "pretrained.cls_token": "backbone.embeddings.cls_token", + "pretrained.mask_token": "backbone.embeddings.mask_token", + "pretrained.pos_embed": "backbone.embeddings.position_embeddings", + "pretrained.patch_embed.proj.weight": "backbone.embeddings.patch_embeddings.projection.weight", + "pretrained.patch_embed.proj.bias": "backbone.embeddings.patch_embeddings.projection.bias", + # Head + "pretrained.norm.weight": "backbone.layernorm.weight", + "pretrained.norm.bias": "backbone.layernorm.bias", # Head - rename_keys.append(("pretrained.norm.weight", "backbone.layernorm.weight")) - rename_keys.append(("pretrained.norm.bias", "backbone.layernorm.bias")) + "depth_head.scratch.output_conv1.weight": "head.conv1.weight", + "depth_head.scratch.output_conv1.bias": "head.conv1.bias", + "depth_head.scratch.output_conv2.0.weight": "head.conv2.weight", + "depth_head.scratch.output_conv2.0.bias": "head.conv2.bias", + "depth_head.scratch.output_conv2.2.weight": "head.conv3.weight", + "depth_head.scratch.output_conv2.2.bias": "head.conv3.bias", +} + + +def add_transformer_mappings(config): + # Transformer encoder mappings + for i in range(config.backbone_config.num_hidden_layers): + KEY_MAPPING.update( + { + f"pretrained.blocks.{i}.ls1.gamma": f"backbone.encoder.layer.{i}.layer_scale1.lambda1", + f"pretrained.blocks.{i}.ls2.gamma": f"backbone.encoder.layer.{i}.layer_scale2.lambda1", + f"pretrained.blocks.{i}.norm1.weight": f"backbone.encoder.layer.{i}.norm1.weight", + f"pretrained.blocks.{i}.norm1.bias": f"backbone.encoder.layer.{i}.norm1.bias", + f"pretrained.blocks.{i}.norm2.weight": f"backbone.encoder.layer.{i}.norm2.weight", + f"pretrained.blocks.{i}.norm2.bias": f"backbone.encoder.layer.{i}.norm2.bias", + f"pretrained.blocks.{i}.mlp.fc1.weight": f"backbone.encoder.layer.{i}.mlp.fc1.weight", + f"pretrained.blocks.{i}.mlp.fc1.bias": f"backbone.encoder.layer.{i}.mlp.fc1.bias", + f"pretrained.blocks.{i}.mlp.fc2.weight": f"backbone.encoder.layer.{i}.mlp.fc2.weight", + f"pretrained.blocks.{i}.mlp.fc2.bias": f"backbone.encoder.layer.{i}.mlp.fc2.bias", + f"pretrained.blocks.{i}.attn.proj.weight": f"backbone.encoder.layer.{i}.attention.output.dense.weight", + f"pretrained.blocks.{i}.attn.proj.bias": f"backbone.encoder.layer.{i}.attention.output.dense.bias", + f"pretrained.blocks.{i}.attn.qkv.weight": f"qkv_transform_{i}", + f"pretrained.blocks.{i}.attn.qkv.bias": f"qkv_transform_bias_{i}", + } + ) - # activation postprocessing (readout projections + resize blocks) - # Depth Anything does not use CLS token => readout_projects not required +def add_neck_mappings(): + # Neck mappings for i in range(4): - rename_keys.append((f"depth_head.projects.{i}.weight", f"neck.reassemble_stage.layers.{i}.projection.weight")) - rename_keys.append((f"depth_head.projects.{i}.bias", f"neck.reassemble_stage.layers.{i}.projection.bias")) + KEY_MAPPING.update( + { + f"depth_head.projects.{i}.weight": f"neck.reassemble_stage.layers.{i}.projection.weight", + f"depth_head.projects.{i}.bias": f"neck.reassemble_stage.layers.{i}.projection.bias", + f"depth_head.scratch.layer{i+1}_rn.weight": f"neck.convs.{i}.weight", + } + ) if i != 2: - rename_keys.append((f"depth_head.resize_layers.{i}.weight", f"neck.reassemble_stage.layers.{i}.resize.weight")) - rename_keys.append((f"depth_head.resize_layers.{i}.bias", f"neck.reassemble_stage.layers.{i}.resize.bias")) - - # refinenet (tricky here) - mapping = {1:3, 2:2, 3:1, 4:0} + KEY_MAPPING.update( + { + f"depth_head.resize_layers.{i}.weight": f"neck.reassemble_stage.layers.{i}.resize.weight", + f"depth_head.resize_layers.{i}.bias": f"neck.reassemble_stage.layers.{i}.resize.bias", + } + ) + # Refinenet mappings + mapping = {1: 3, 2: 2, 3: 1, 4: 0} for i in range(1, 5): j = mapping[i] - rename_keys.append((f"depth_head.scratch.refinenet{i}.out_conv.weight", f"neck.fusion_stage.layers.{j}.projection.weight")) - rename_keys.append((f"depth_head.scratch.refinenet{i}.out_conv.bias", f"neck.fusion_stage.layers.{j}.projection.bias")) - rename_keys.append((f"depth_head.scratch.refinenet{i}.resConfUnit1.conv1.weight", f"neck.fusion_stage.layers.{j}.residual_layer1.convolution1.weight")) - rename_keys.append((f"depth_head.scratch.refinenet{i}.resConfUnit1.conv1.bias", f"neck.fusion_stage.layers.{j}.residual_layer1.convolution1.bias")) - rename_keys.append((f"depth_head.scratch.refinenet{i}.resConfUnit1.conv2.weight", f"neck.fusion_stage.layers.{j}.residual_layer1.convolution2.weight")) - rename_keys.append((f"depth_head.scratch.refinenet{i}.resConfUnit1.conv2.bias", f"neck.fusion_stage.layers.{j}.residual_layer1.convolution2.bias")) - rename_keys.append((f"depth_head.scratch.refinenet{i}.resConfUnit2.conv1.weight", f"neck.fusion_stage.layers.{j}.residual_layer2.convolution1.weight")) - rename_keys.append((f"depth_head.scratch.refinenet{i}.resConfUnit2.conv1.bias", f"neck.fusion_stage.layers.{j}.residual_layer2.convolution1.bias")) - rename_keys.append((f"depth_head.scratch.refinenet{i}.resConfUnit2.conv2.weight", f"neck.fusion_stage.layers.{j}.residual_layer2.convolution2.weight")) - rename_keys.append((f"depth_head.scratch.refinenet{i}.resConfUnit2.conv2.bias", f"neck.fusion_stage.layers.{j}.residual_layer2.convolution2.bias")) - rename_keys.append((f"depth_head.scratch.refinenet{i}.resConfUnit_depth.0.weight", f"neck.fusion_stage.layers.{j}.residual_layer_depth.convolution1.weight")) - rename_keys.append((f"depth_head.scratch.refinenet{i}.resConfUnit_depth.0.bias", f"neck.fusion_stage.layers.{j}.residual_layer_depth.convolution1.bias")) - rename_keys.append((f"depth_head.scratch.refinenet{i}.resConfUnit_depth.2.weight", f"neck.fusion_stage.layers.{j}.residual_layer_depth.convolution2.weight")) - rename_keys.append((f"depth_head.scratch.refinenet{i}.resConfUnit_depth.2.bias", f"neck.fusion_stage.layers.{j}.residual_layer_depth.convolution2.bias")) - rename_keys.append((f"depth_head.scratch.refinenet{i}.resConfUnit_depth.4.weight", f"neck.fusion_stage.layers.{j}.residual_layer_depth.convolution3.weight")) - rename_keys.append((f"depth_head.scratch.refinenet{i}.resConfUnit_depth.4.bias", f"neck.fusion_stage.layers.{j}.residual_layer_depth.convolution3.bias")) - - # scratch convolutions - for i in range(4): - rename_keys.append((f"depth_head.scratch.layer{i+1}_rn.weight", f"neck.convs.{i}.weight")) - - # head - rename_keys.append(("depth_head.scratch.output_conv1.weight", "head.conv1.weight")) - rename_keys.append(("depth_head.scratch.output_conv1.bias", "head.conv1.bias")) - rename_keys.append(("depth_head.scratch.output_conv2.0.weight", "head.conv2.weight")) - rename_keys.append(("depth_head.scratch.output_conv2.0.bias", "head.conv2.bias")) - rename_keys.append(("depth_head.scratch.output_conv2.2.weight", "head.conv3.weight")) - rename_keys.append(("depth_head.scratch.output_conv2.2.bias", "head.conv3.bias")) + KEY_MAPPING.update( + { + f"depth_head.scratch.refinenet{i}.out_conv.weight": f"neck.fusion_stage.layers.{j}.projection.weight", + f"depth_head.scratch.refinenet{i}.out_conv.bias": f"neck.fusion_stage.layers.{j}.projection.bias", + f"depth_head.scratch.refinenet{i}.resConfUnit1.conv1.weight": f"neck.fusion_stage.layers.{j}.residual_layer1.convolution1.weight", + f"depth_head.scratch.refinenet{i}.resConfUnit1.conv1.bias": f"neck.fusion_stage.layers.{j}.residual_layer1.convolution1.bias", + f"depth_head.scratch.refinenet{i}.resConfUnit1.conv2.weight": f"neck.fusion_stage.layers.{j}.residual_layer1.convolution2.weight", + f"depth_head.scratch.refinenet{i}.resConfUnit1.conv2.bias": f"neck.fusion_stage.layers.{j}.residual_layer1.convolution2.bias", + f"depth_head.scratch.refinenet{i}.resConfUnit2.conv1.weight": f"neck.fusion_stage.layers.{j}.residual_layer2.convolution1.weight", + f"depth_head.scratch.refinenet{i}.resConfUnit2.conv1.bias": f"neck.fusion_stage.layers.{j}.residual_layer2.convolution1.bias", + f"depth_head.scratch.refinenet{i}.resConfUnit2.conv2.weight": f"neck.fusion_stage.layers.{j}.residual_layer2.convolution2.weight", + f"depth_head.scratch.refinenet{i}.resConfUnit2.conv2.bias": f"neck.fusion_stage.layers.{j}.residual_layer2.convolution2.bias", + f"depth_head.scratch.refinenet{i}.resConfUnit_depth.0.weight": f"neck.fusion_stage.layers.{j}.residual_layer_depth.convolution1.weight", + f"depth_head.scratch.refinenet{i}.resConfUnit_depth.0.bias": f"neck.fusion_stage.layers.{j}.residual_layer_depth.convolution1.bias", + f"depth_head.scratch.refinenet{i}.resConfUnit_depth.2.weight": f"neck.fusion_stage.layers.{j}.residual_layer_depth.convolution2.weight", + f"depth_head.scratch.refinenet{i}.resConfUnit_depth.2.bias": f"neck.fusion_stage.layers.{j}.residual_layer_depth.convolution2.bias", + f"depth_head.scratch.refinenet{i}.resConfUnit_depth.4.weight": f"neck.fusion_stage.layers.{j}.residual_layer_depth.convolution3.weight", + f"depth_head.scratch.refinenet{i}.resConfUnit_depth.4.bias": f"neck.fusion_stage.layers.{j}.residual_layer_depth.convolution3.bias", + } + ) - return rename_keys +def transform_qkv_weights(key, value, config): + if not key.startswith("qkv_transform"): + return value -# we split up the matrix of each encoder layer into queries, keys and values -def read_in_q_k_v(state_dict, config): + layer_idx = int(key.split("_")[-1]) hidden_size = config.backbone_config.hidden_size - for i in range(config.backbone_config.num_hidden_layers): - # read in weights + bias of input projection layer (in original implementation, this is a single matrix + bias) - in_proj_weight = state_dict.pop(f"pretrained.blocks.{i}.attn.qkv.weight") - in_proj_bias = state_dict.pop(f"pretrained.blocks.{i}.attn.qkv.bias") - # next, add query, keys and values (in that order) to the state dict - state_dict[f"backbone.encoder.layer.{i}.attention.attention.query.weight"] = in_proj_weight[:hidden_size, :] - state_dict[f"backbone.encoder.layer.{i}.attention.attention.query.bias"] = in_proj_bias[:hidden_size] - state_dict[f"backbone.encoder.layer.{i}.attention.attention.key.weight"] = in_proj_weight[ - hidden_size : hidden_size * 2, : - ] - state_dict[f"backbone.encoder.layer.{i}.attention.attention.key.bias"] = in_proj_bias[ - hidden_size : hidden_size * 2 - ] - state_dict[f"backbone.encoder.layer.{i}.attention.attention.value.weight"] = in_proj_weight[-hidden_size:, :] - state_dict[f"backbone.encoder.layer.{i}.attention.attention.value.bias"] = in_proj_bias[-hidden_size:] - - -def rename_key(dct, old, new): - val = dct.pop(old) - dct[new] = val + + if "bias" in key: + # Handle bias + return { + f"backbone.encoder.layer.{layer_idx}.attention.attention.query.bias": value[:hidden_size], + f"backbone.encoder.layer.{layer_idx}.attention.attention.key.bias": value[hidden_size : hidden_size * 2], + f"backbone.encoder.layer.{layer_idx}.attention.attention.value.bias": value[-hidden_size:], + } + else: + # Handle weights + return { + f"backbone.encoder.layer.{layer_idx}.attention.attention.query.weight": value[:hidden_size, :], + f"backbone.encoder.layer.{layer_idx}.attention.attention.key.weight": value[ + hidden_size : hidden_size * 2, : + ], + f"backbone.encoder.layer.{layer_idx}.attention.attention.value.weight": value[-hidden_size:, :], + } name_to_checkpoint = { @@ -196,6 +206,10 @@ def convert_dpt_checkpoint(model_name, pytorch_dump_folder_path, push_to_hub, ve # define DPT configuration config = get_dpt_config(model_name) + # Add dynamic key mappings + add_transformer_mappings(config) + add_neck_mappings() + model_name_to_repo = { "promptda_vits": "depth-anything/promptda_vits", "promptda_vits_transparent": "depth-anything/promptda_vits_transparent", @@ -212,16 +226,21 @@ def convert_dpt_checkpoint(model_name, pytorch_dump_folder_path, push_to_hub, ve state_dict = torch.load(filepath, map_location="cpu")["state_dict"] state_dict = {key[9:]: state_dict[key] for key in state_dict} - # rename keys - rename_keys = create_rename_keys(config) - for src, dest in rename_keys: - rename_key(state_dict, src, dest) - # read in qkv matrices - read_in_q_k_v(state_dict, config) + + # Convert state dict using mappings + new_state_dict = {} + for key, value in state_dict.items(): + if key in KEY_MAPPING: + new_key = KEY_MAPPING[key] + transformed_value = transform_qkv_weights(new_key, value, config) + if isinstance(transformed_value, dict): + new_state_dict.update(transformed_value) + else: + new_state_dict[new_key] = transformed_value # load HuggingFace model model = PromptDepthAnythingForDepthEstimation(config) - model.load_state_dict(state_dict, strict=False) + model.load_state_dict(new_state_dict, strict=False) model.eval() processor = PromptDepthAnythingImageProcessor( diff --git a/src/transformers/models/prompt_depth_anything/image_processing_prompt_depth_anything.py b/src/transformers/models/prompt_depth_anything/image_processing_prompt_depth_anything.py index e5ffdad3db04..bb76acddeaa2 100644 --- a/src/transformers/models/prompt_depth_anything/image_processing_prompt_depth_anything.py +++ b/src/transformers/models/prompt_depth_anything/image_processing_prompt_depth_anything.py @@ -33,7 +33,6 @@ infer_channel_dimension_format, is_scaled_image, is_torch_available, - is_torch_tensor, make_list_of_images, to_numpy_array, valid_images, @@ -275,7 +274,9 @@ def _get_pad(size, size_divisor): pad_size_left, pad_size_right = _get_pad(height, size_divisor) pad_size_top, pad_size_bottom = _get_pad(width, size_divisor) - padded_image = pad(image, ((pad_size_left, pad_size_right), (pad_size_top, pad_size_bottom)), data_format=data_format) + padded_image = pad( + image, ((pad_size_left, pad_size_right), (pad_size_top, pad_size_bottom)), data_format=data_format + ) return padded_image @filter_out_non_signature_kwargs() @@ -307,8 +308,8 @@ def preprocess( Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If passing in images with pixel values between 0 and 1, set `do_rescale=False`. prompt_depth (`ImageInput`, *optional*): - Prompt depth to preprocess, which can be sparse depth obtained from multi-view geometry or - low-resolution depth from a depth sensor. Generally has shape (height, width), where height + Prompt depth to preprocess, which can be sparse depth obtained from multi-view geometry or + low-resolution depth from a depth sensor. Generally has shape (height, width), where height and width can be smaller than the images. do_resize (`bool`, *optional*, defaults to `self.do_resize`): Whether to resize the image. From 8d0a435d67c9ac35f5b6e897aa6729c0bc3ab9a4 Mon Sep 17 00:00:00 2001 From: haotongl Date: Tue, 7 Jan 2025 20:31:35 +0800 Subject: [PATCH 20/58] fix nn.ReLU(False) to nn.ReLU() --- .../prompt_depth_anything/modular_prompt_depth_anything.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py b/src/transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py index 54224d5ae568..fa12c767c403 100644 --- a/src/transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py +++ b/src/transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py @@ -69,7 +69,7 @@ def __init__(self, config): padding=1, bias=True, ) - self.activation1 = nn.ReLU(False) + self.activation1 = nn.ReLU() self.convolution2 = nn.Conv2d( config.fusion_hidden_size, @@ -79,7 +79,7 @@ def __init__(self, config): padding=1, bias=True, ) - self.activation2 = nn.ReLU(False) + self.activation2 = nn.ReLU() self.convolution3 = nn.Conv2d( config.fusion_hidden_size, From 89956c4d05e6cda258c9935e307314b5d1708ba1 Mon Sep 17 00:00:00 2001 From: haotongl Date: Tue, 7 Jan 2025 20:36:11 +0800 Subject: [PATCH 21/58] rename residual layer as it's a sequential layer --- .../convert_prompt_depth_anything_to_hf.py | 12 ++++++------ .../modular_prompt_depth_anything.py | 6 +++--- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/src/transformers/models/prompt_depth_anything/convert_prompt_depth_anything_to_hf.py b/src/transformers/models/prompt_depth_anything/convert_prompt_depth_anything_to_hf.py index 1d80d48602d9..63abcb7ab1fb 100644 --- a/src/transformers/models/prompt_depth_anything/convert_prompt_depth_anything_to_hf.py +++ b/src/transformers/models/prompt_depth_anything/convert_prompt_depth_anything_to_hf.py @@ -155,12 +155,12 @@ def add_neck_mappings(): f"depth_head.scratch.refinenet{i}.resConfUnit2.conv1.bias": f"neck.fusion_stage.layers.{j}.residual_layer2.convolution1.bias", f"depth_head.scratch.refinenet{i}.resConfUnit2.conv2.weight": f"neck.fusion_stage.layers.{j}.residual_layer2.convolution2.weight", f"depth_head.scratch.refinenet{i}.resConfUnit2.conv2.bias": f"neck.fusion_stage.layers.{j}.residual_layer2.convolution2.bias", - f"depth_head.scratch.refinenet{i}.resConfUnit_depth.0.weight": f"neck.fusion_stage.layers.{j}.residual_layer_depth.convolution1.weight", - f"depth_head.scratch.refinenet{i}.resConfUnit_depth.0.bias": f"neck.fusion_stage.layers.{j}.residual_layer_depth.convolution1.bias", - f"depth_head.scratch.refinenet{i}.resConfUnit_depth.2.weight": f"neck.fusion_stage.layers.{j}.residual_layer_depth.convolution2.weight", - f"depth_head.scratch.refinenet{i}.resConfUnit_depth.2.bias": f"neck.fusion_stage.layers.{j}.residual_layer_depth.convolution2.bias", - f"depth_head.scratch.refinenet{i}.resConfUnit_depth.4.weight": f"neck.fusion_stage.layers.{j}.residual_layer_depth.convolution3.weight", - f"depth_head.scratch.refinenet{i}.resConfUnit_depth.4.bias": f"neck.fusion_stage.layers.{j}.residual_layer_depth.convolution3.bias", + f"depth_head.scratch.refinenet{i}.resConfUnit_depth.0.weight": f"neck.fusion_stage.layers.{j}.prompt_depth_layer.convolution1.weight", + f"depth_head.scratch.refinenet{i}.resConfUnit_depth.0.bias": f"neck.fusion_stage.layers.{j}.prompt_depth_layer.convolution1.bias", + f"depth_head.scratch.refinenet{i}.resConfUnit_depth.2.weight": f"neck.fusion_stage.layers.{j}.prompt_depth_layer.convolution2.weight", + f"depth_head.scratch.refinenet{i}.resConfUnit_depth.2.bias": f"neck.fusion_stage.layers.{j}.prompt_depth_layer.convolution2.bias", + f"depth_head.scratch.refinenet{i}.resConfUnit_depth.4.weight": f"neck.fusion_stage.layers.{j}.prompt_depth_layer.convolution3.weight", + f"depth_head.scratch.refinenet{i}.resConfUnit_depth.4.bias": f"neck.fusion_stage.layers.{j}.prompt_depth_layer.convolution3.bias", } ) diff --git a/src/transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py b/src/transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py index fa12c767c403..3e0351113db6 100644 --- a/src/transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py +++ b/src/transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py @@ -58,7 +58,7 @@ class PromptDepthAnythingConfig(DepthAnythingConfig): model_type = "prompt_depth_anything" -class PromptDepthAnythingResidualLayer(nn.Module): +class PromptDepthAnythingLayer(nn.Module): def __init__(self, config): super().__init__() self.convolution1 = nn.Conv2d( @@ -103,7 +103,7 @@ def forward(self, prompt_depth: torch.Tensor) -> torch.Tensor: class PromptDepthAnythingFeatureFusionLayer(DepthAnythingFeatureFusionLayer): def __init__(self, config): super().__init__(config) - self.residual_layer_depth = PromptDepthAnythingResidualLayer(config) + self.prompt_depth_layer = PromptDepthAnythingLayer(config) def forward(self, hidden_state, residual=None, size=None, prompt_depth=None): if residual is not None: @@ -119,7 +119,7 @@ def forward(self, hidden_state, residual=None, size=None, prompt_depth=None): prompt_depth = nn.functional.interpolate( prompt_depth, hidden_state.shape[2:], mode="bilinear", align_corners=False ) - res = self.residual_layer_depth(prompt_depth) + res = self.prompt_depth_layer(prompt_depth) hidden_state = hidden_state + res modifier = {"scale_factor": 2} if size is None else {"size": size} From c713a5edb224d1f3688abdb2a96132ae2363f135 Mon Sep 17 00:00:00 2001 From: haotongl Date: Tue, 7 Jan 2025 20:42:52 +0800 Subject: [PATCH 22/58] move size compute to a separate line/variable for easier debug in modular prompt depth anything --- .../prompt_depth_anything/modular_prompt_depth_anything.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py b/src/transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py index 3e0351113db6..2ebd65ab4b92 100644 --- a/src/transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py +++ b/src/transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py @@ -171,9 +171,11 @@ def forward(self, hidden_states: List[torch.Tensor], patch_height, patch_width) hidden_states = hidden_states[self.head_in_index] predicted_depth = self.conv1(hidden_states) + target_height = torch_int(patch_height * self.patch_size) + target_width = torch_int(patch_width * self.patch_size) predicted_depth = nn.functional.interpolate( predicted_depth, - (torch_int(patch_height * self.patch_size), torch_int(patch_width * self.patch_size)), + (target_height, target_width), mode="bilinear", align_corners=True, ) From 777c367360d10a07b2de21b789b1156e3327b6b7 Mon Sep 17 00:00:00 2001 From: haotongl Date: Tue, 7 Jan 2025 21:07:50 +0800 Subject: [PATCH 23/58] fix modular format for prompt depth anything --- .../modeling_prompt_depth_anything.py | 24 ++++---- .../modular_prompt_depth_anything.py | 58 +------------------ 2 files changed, 15 insertions(+), 67 deletions(-) diff --git a/src/transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py b/src/transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py index 8432c8a4aa23..0ea5dbc064ac 100644 --- a/src/transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py +++ b/src/transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py @@ -22,7 +22,7 @@ _CONFIG_FOR_DOC = "PromptDepthAnythingConfig" -class PromptDepthAnythingResidualLayer(nn.Module): +class PromptDepthAnythingLayer(nn.Module): def __init__(self, config): super().__init__() self.convolution1 = nn.Conv2d( @@ -33,7 +33,7 @@ def __init__(self, config): padding=1, bias=True, ) - self.activation1 = nn.ReLU(False) + self.activation1 = nn.ReLU() self.convolution2 = nn.Conv2d( config.fusion_hidden_size, @@ -43,7 +43,7 @@ def __init__(self, config): padding=1, bias=True, ) - self.activation2 = nn.ReLU(False) + self.activation2 = nn.ReLU() self.convolution3 = nn.Conv2d( config.fusion_hidden_size, @@ -121,7 +121,7 @@ def __init__(self, config): self.residual_layer1 = PromptDepthAnythingPreActResidualLayer(config) self.residual_layer2 = PromptDepthAnythingPreActResidualLayer(config) - self.residual_layer_depth = PromptDepthAnythingResidualLayer(config) + self.prompt_depth_layer = PromptDepthAnythingLayer(config) def forward(self, hidden_state, residual=None, size=None, prompt_depth=None): if residual is not None: @@ -137,7 +137,7 @@ def forward(self, hidden_state, residual=None, size=None, prompt_depth=None): prompt_depth = nn.functional.interpolate( prompt_depth, hidden_state.shape[2:], mode="bilinear", align_corners=False ) - res = self.residual_layer_depth(prompt_depth) + res = self.prompt_depth_layer(prompt_depth) hidden_state = hidden_state + res modifier = {"scale_factor": 2} if size is None else {"size": size} @@ -212,9 +212,11 @@ def forward(self, hidden_states: List[torch.Tensor], patch_height, patch_width) hidden_states = hidden_states[self.head_in_index] predicted_depth = self.conv1(hidden_states) + target_height = torch_int(patch_height * self.patch_size) + target_width = torch_int(patch_width * self.patch_size) predicted_depth = nn.functional.interpolate( predicted_depth, - (torch_int(patch_height * self.patch_size), torch_int(patch_width * self.patch_size)), + (target_height, target_width), mode="bilinear", align_corners=True, ) @@ -227,7 +229,6 @@ def forward(self, hidden_states: List[torch.Tensor], patch_height, patch_width) return predicted_depth -# Copied from transformers.models.dpt.modeling_dpt.DPTPreTrainedModel with DPT->PromptDepthAnything,dpt->prompt_depth_anything class PromptDepthAnythingPreTrainedModel(PreTrainedModel): """ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained @@ -340,16 +341,17 @@ class PromptDepthAnythingNeck(nn.Module): def __init__(self, config): super().__init__() self.config = config + self.reassemble_stage = PromptDepthAnythingReassembleStage(config) self.convs = nn.ModuleList() for channel in config.neck_hidden_sizes: self.convs.append(nn.Conv2d(channel, config.fusion_hidden_size, kernel_size=3, padding=1, bias=False)) + + # fusion self.fusion_stage = PromptDepthAnythingFeatureFusionStage(config) - def forward( - self, hidden_states: List[torch.Tensor], patch_height=None, patch_width=None, prompt_depth=None - ) -> List[torch.Tensor]: + def forward(self, hidden_states: List[torch.Tensor], patch_height=None, patch_width=None) -> List[torch.Tensor]: """ Args: hidden_states (`List[torch.FloatTensor]`, each of shape `(batch_size, sequence_length, hidden_size)` or `(batch_size, hidden_size, height, width)`): @@ -367,7 +369,7 @@ def forward( features = [self.convs[i](feature) for i, feature in enumerate(hidden_states)] # fusion blocks - output = self.fusion_stage(features, prompt_depth=prompt_depth) + output = self.fusion_stage(features) return output diff --git a/src/transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py b/src/transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py index 2ebd65ab4b92..fe6387edf2af 100644 --- a/src/transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py +++ b/src/transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py @@ -188,7 +188,6 @@ def forward(self, hidden_states: List[torch.Tensor], patch_height, patch_width) return predicted_depth -# Copied from transformers.models.dpt.modeling_dpt.DPTPreTrainedModel with DPT->PromptDepthAnything,dpt->prompt_depth_anything class PromptDepthAnythingPreTrainedModel(PreTrainedModel): """ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained @@ -229,58 +228,11 @@ def __init__(self, config, channels, factor): class PromptDepthAnythingReassembleStage(DepthAnythingReassembleStage): - """ - This class reassembles the hidden states of the backbone into image-like feature representations at various - resolutions. - - This happens in 3 stages: - 1. Take the patch embeddings and reshape them to image-like feature representations. - 2. Project the channel dimension of the hidden states according to `config.neck_hidden_sizes`. - 3. Resizing the spatial dimensions (height, width). - - Args: - config (`[PromptDepthAnythingConfig]`): - Model configuration class defining the model architecture. - """ - - def __init__(self, config): - super().__init__() - - self.config = config - self.layers = nn.ModuleList() - for channels, factor in zip(config.neck_hidden_sizes, config.reassemble_factors): - self.layers.append(PromptDepthAnythingReassembleLayer(config, channels=channels, factor=factor)) + pass class PromptDepthAnythingNeck(DepthAnythingNeck): - def __init__(self, config): - super().__init__(config) - self.reassemble_stage = PromptDepthAnythingReassembleStage(config) - self.fusion_stage = PromptDepthAnythingFeatureFusionStage(config) - - def forward( - self, hidden_states: List[torch.Tensor], patch_height=None, patch_width=None, prompt_depth=None - ) -> List[torch.Tensor]: - """ - Args: - hidden_states (`List[torch.FloatTensor]`, each of shape `(batch_size, sequence_length, hidden_size)` or `(batch_size, hidden_size, height, width)`): - List of hidden states from the backbone. - """ - if not isinstance(hidden_states, (tuple, list)): - raise TypeError("hidden_states should be a tuple or list of tensors") - - if len(hidden_states) != len(self.config.neck_hidden_sizes): - raise ValueError("The number of hidden states should be equal to the number of neck hidden sizes.") - - # postprocess hidden states - hidden_states = self.reassemble_stage(hidden_states, patch_height, patch_width) - - features = [self.convs[i](feature) for i, feature in enumerate(hidden_states)] - - # fusion blocks - output = self.fusion_stage(features, prompt_depth=prompt_depth) - - return output + pass @add_start_docstrings( @@ -290,12 +242,6 @@ def forward( PROMPT_DEPTH_ANYTHING_START_DOCSTRING, ) class PromptDepthAnythingForDepthEstimation(DepthAnythingForDepthEstimation): - def __init__(self, config): - super().__init__(config) - self.neck = PromptDepthAnythingNeck(config) - self.head = PromptDepthAnythingDepthEstimationHead(config) - self.post_init() - @add_start_docstrings_to_model_forward(PROMPT_DEPTH_ANYTHING_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=DepthEstimatorOutput, config_class=_CONFIG_FOR_DOC) def forward( From cc8f4acd0f97e8f949755ff81e9adb7b5f7ba17b Mon Sep 17 00:00:00 2001 From: haotongl Date: Wed, 8 Jan 2025 00:13:05 +0800 Subject: [PATCH 24/58] update modular prompt depth anything --- .../en/model_doc/prompt_depth_anything.md | 4 +- .../convert_prompt_depth_anything_to_hf.py | 20 +++++----- .../modeling_prompt_depth_anything.py | 10 +++-- .../modular_prompt_depth_anything.py | 40 +++++++++++++------ .../test_modeling_prompt_depth_anything.py | 16 ++++---- 5 files changed, 54 insertions(+), 36 deletions(-) diff --git a/docs/source/en/model_doc/prompt_depth_anything.md b/docs/source/en/model_doc/prompt_depth_anything.md index 6b2acfe7f42c..50d6f44399a6 100644 --- a/docs/source/en/model_doc/prompt_depth_anything.md +++ b/docs/source/en/model_doc/prompt_depth_anything.md @@ -44,8 +44,8 @@ The Transformers library allows you to use the model with just a few lines of co >>> url = "https://github.com/DepthAnything/PromptDA/blob/main/assets/example_images/image.jpg?raw=true" >>> image = Image.open(requests.get(url, stream=True).raw) ->>> image_processor = AutoImageProcessor.from_pretrained("depth-anything/promptda_vits_hf") ->>> model = AutoModelForDepthEstimation.from_pretrained("depth-anything/promptda_vits_hf") +>>> image_processor = AutoImageProcessor.from_pretrained("depth-anything/prompt-depth-anything-vits-hf") +>>> model = AutoModelForDepthEstimation.from_pretrained("depth-anything/prompt-depth-anything-vits-hf") >>> prompt_depth_url = "https://github.com/DepthAnything/PromptDA/blob/main/assets/example_images/arkit_depth.png?raw=true" >>> prompt_depth = Image.open(requests.get(prompt_depth_url, stream=True).raw) diff --git a/src/transformers/models/prompt_depth_anything/convert_prompt_depth_anything_to_hf.py b/src/transformers/models/prompt_depth_anything/convert_prompt_depth_anything_to_hf.py index 63abcb7ab1fb..54854fe0a7a2 100644 --- a/src/transformers/models/prompt_depth_anything/convert_prompt_depth_anything_to_hf.py +++ b/src/transformers/models/prompt_depth_anything/convert_prompt_depth_anything_to_hf.py @@ -191,9 +191,9 @@ def transform_qkv_weights(key, value, config): name_to_checkpoint = { - "promptda_vits": "model.ckpt", - "promptda_vits_transparent": "model.ckpt", - "promptda_vitl": "model.ckpt", + "prompt-depth-anything-vits": "model.ckpt", + "prompt-depth-anything-vits-transparent": "model.ckpt", + "prompt-depth-anything-vitl": "model.ckpt", } @@ -211,9 +211,9 @@ def convert_dpt_checkpoint(model_name, pytorch_dump_folder_path, push_to_hub, ve add_neck_mappings() model_name_to_repo = { - "promptda_vits": "depth-anything/promptda_vits", - "promptda_vits_transparent": "depth-anything/promptda_vits_transparent", - "promptda_vitl": "depth-anything/promptda_vitl", + "prompt-depth-anything-vits": "depth-anything/prompt-depth-anything-vits", + "prompt-depth-anything-vits-transparent": "depth-anything/prompt-depth-anything-vits-transparent", + "prompt-depth-anything-vitl": "depth-anything/prompt-depth-anything-vitl", } # load original state_dict @@ -274,15 +274,15 @@ def convert_dpt_checkpoint(model_name, pytorch_dump_folder_path, push_to_hub, ve # assert logits if verify_logits: expected_shape = torch.Size([1, 756, 1008]) - if model_name == "promptda_vits": + if model_name == "prompt-depth-anything-vits": expected_slice = torch.tensor( [[3.0100, 3.0016, 3.0219], [3.0046, 3.0137, 3.0275], [3.0083, 3.0191, 3.0292]] ) - elif model_name == "promptda_vits_transparent": + elif model_name == "prompt-depth-anything-vits-transparent": expected_slice = torch.tensor( [[3.0058, 3.0397, 3.0460], [3.0314, 3.0393, 3.0504], [3.0326, 3.0465, 3.0545]] ) - elif model_name == "promptda_vitl": + elif model_name == "prompt-depth-anything-vitl": expected_slice = torch.tensor( [[3.1336, 3.1358, 3.1363], [3.1368, 3.1267, 3.1414], [3.1397, 3.1385, 3.1448]] ) @@ -309,7 +309,7 @@ def convert_dpt_checkpoint(model_name, pytorch_dump_folder_path, push_to_hub, ve # Required parameters parser.add_argument( "--model_name", - default="promptda_vits", + default="prompt_depth_anything_vits", type=str, choices=name_to_checkpoint.keys(), help="Name of the model you'd like to convert.", diff --git a/src/transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py b/src/transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py index 0ea5dbc064ac..b976215a8d1a 100644 --- a/src/transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py +++ b/src/transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py @@ -351,7 +351,9 @@ def __init__(self, config): # fusion self.fusion_stage = PromptDepthAnythingFeatureFusionStage(config) - def forward(self, hidden_states: List[torch.Tensor], patch_height=None, patch_width=None) -> List[torch.Tensor]: + def forward( + self, hidden_states: List[torch.Tensor], patch_height=None, patch_width=None, prompt_depth=None + ) -> List[torch.Tensor]: """ Args: hidden_states (`List[torch.FloatTensor]`, each of shape `(batch_size, sequence_length, hidden_size)` or `(batch_size, hidden_size, height, width)`): @@ -369,7 +371,7 @@ def forward(self, hidden_states: List[torch.Tensor], patch_height=None, patch_wi features = [self.convs[i](feature) for i, feature in enumerate(hidden_states)] # fusion blocks - output = self.fusion_stage(features) + output = self.fusion_stage(features, prompt_depth=prompt_depth) return output @@ -455,8 +457,8 @@ def forward( >>> url = "https://github.com/DepthAnything/PromptDA/blob/main/assets/example_images/image.jpg?raw=true" >>> image = Image.open(requests.get(url, stream=True).raw) - >>> image_processor = AutoImageProcessor.from_pretrained("depth-anything/promptda_vits_hf") - >>> model = AutoModelForDepthEstimation.from_pretrained("depth-anything/promptda_vits_hf") + >>> image_processor = AutoImageProcessor.from_pretrained("depth-anything/prompt-depth-anything-vits-hf") + >>> model = AutoModelForDepthEstimation.from_pretrained("depth-anything/prompt-depth-anything-vits-hf") >>> prompt_depth_url = "https://github.com/DepthAnything/PromptDA/blob/main/assets/example_images/arkit_depth.png?raw=true" >>> prompt_depth = Image.open(requests.get(prompt_depth_url, stream=True).raw) diff --git a/src/transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py b/src/transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py index fe6387edf2af..1cd587770627 100644 --- a/src/transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py +++ b/src/transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py @@ -7,6 +7,7 @@ from transformers.models.depth_anything.modeling_depth_anything import ( DepthAnythingDepthEstimationHead, DepthAnythingFeatureFusionLayer, + DepthAnythingFeatureFusionStage, DepthAnythingForDepthEstimation, DepthAnythingNeck, DepthAnythingReassembleLayer, @@ -135,13 +136,7 @@ def forward(self, hidden_state, residual=None, size=None, prompt_depth=None): return hidden_state -class PromptDepthAnythingFeatureFusionStage(nn.Module): - def __init__(self, config): - super().__init__() - self.layers = nn.ModuleList() - for _ in range(len(config.neck_hidden_sizes)): - self.layers.append(PromptDepthAnythingFeatureFusionLayer(config)) - +class PromptDepthAnythingFeatureFusionStage(DepthAnythingFeatureFusionStage): def forward(self, hidden_states, size=None, prompt_depth=None): # reversing the hidden_states, we start from the last hidden_states = hidden_states[::-1] @@ -164,9 +159,6 @@ def forward(self, hidden_states, size=None, prompt_depth=None): class PromptDepthAnythingDepthEstimationHead(DepthAnythingDepthEstimationHead): - def __init__(self, config): - super().__init__(config) - def forward(self, hidden_states: List[torch.Tensor], patch_height, patch_width) -> torch.Tensor: hidden_states = hidden_states[self.head_in_index] @@ -232,7 +224,29 @@ class PromptDepthAnythingReassembleStage(DepthAnythingReassembleStage): class PromptDepthAnythingNeck(DepthAnythingNeck): - pass + def forward( + self, hidden_states: List[torch.Tensor], patch_height=None, patch_width=None, prompt_depth=None + ) -> List[torch.Tensor]: + """ + Args: + hidden_states (`List[torch.FloatTensor]`, each of shape `(batch_size, sequence_length, hidden_size)` or `(batch_size, hidden_size, height, width)`): + List of hidden states from the backbone. + """ + if not isinstance(hidden_states, (tuple, list)): + raise TypeError("hidden_states should be a tuple or list of tensors") + + if len(hidden_states) != len(self.config.neck_hidden_sizes): + raise ValueError("The number of hidden states should be equal to the number of neck hidden sizes.") + + # postprocess hidden states + hidden_states = self.reassemble_stage(hidden_states, patch_height, patch_width) + + features = [self.convs[i](feature) for i, feature in enumerate(hidden_states)] + + # fusion blocks + output = self.fusion_stage(features, prompt_depth=prompt_depth) + + return output @add_start_docstrings( @@ -270,8 +284,8 @@ def forward( >>> url = "https://github.com/DepthAnything/PromptDA/blob/main/assets/example_images/image.jpg?raw=true" >>> image = Image.open(requests.get(url, stream=True).raw) - >>> image_processor = AutoImageProcessor.from_pretrained("depth-anything/promptda_vits_hf") - >>> model = AutoModelForDepthEstimation.from_pretrained("depth-anything/promptda_vits_hf") + >>> image_processor = AutoImageProcessor.from_pretrained("depth-anything/prompt-depth-anything-vits-hf") + >>> model = AutoModelForDepthEstimation.from_pretrained("depth-anything/prompt-depth-anything-vits-hf") >>> prompt_depth_url = "https://github.com/DepthAnything/PromptDA/blob/main/assets/example_images/arkit_depth.png?raw=true" >>> prompt_depth = Image.open(requests.get(prompt_depth_url, stream=True).raw) diff --git a/tests/models/prompt_depth_anything/test_modeling_prompt_depth_anything.py b/tests/models/prompt_depth_anything/test_modeling_prompt_depth_anything.py index 3bea42d04090..5e12644e98f0 100644 --- a/tests/models/prompt_depth_anything/test_modeling_prompt_depth_anything.py +++ b/tests/models/prompt_depth_anything/test_modeling_prompt_depth_anything.py @@ -207,7 +207,7 @@ def test_training_gradient_checkpointing_use_reentrant_false(self): @slow def test_model_from_pretrained(self): - model_name = "depth-anything/promptda_vits_hf" + model_name = "depth-anything/prompt-depth-anything-vits-hf" model = PromptDepthAnythingForDepthEstimation.from_pretrained(model_name) self.assertIsNotNone(model) @@ -241,10 +241,10 @@ def prepare_img(): @slow class PromptDepthAnythingModelIntegrationTest(unittest.TestCase): def test_inference(self): - image_processor = DPTImageProcessor.from_pretrained("depth-anything/promptda_vits_hf") - model = PromptDepthAnythingForDepthEstimation.from_pretrained("depth-anything/promptda_vits_hf").to( - torch_device - ) + image_processor = DPTImageProcessor.from_pretrained("depth-anything/prompt-depth-anything-vits-hf") + model = PromptDepthAnythingForDepthEstimation.from_pretrained( + "depth-anything/prompt-depth-anything-vits-hf" + ).to(torch_device) image = prepare_img() prompt_depth_url = ( @@ -274,11 +274,13 @@ def test_export(self): if not is_torch_greater_or_equal_than_2_4: self.skipTest(reason="This test requires torch >= 2.4 to run.") model = ( - PromptDepthAnythingForDepthEstimation.from_pretrained("depth-anything/promptda_vits_hf") + PromptDepthAnythingForDepthEstimation.from_pretrained( + "depth-anything/prompt-depth-anything-vits-hf" + ) .to(torch_device) .eval() ) - image_processor = DPTImageProcessor.from_pretrained("depth-anything/promptda_vits_hf") + image_processor = DPTImageProcessor.from_pretrained("depth-anything/prompt-depth-anything-vits-hf") image = prepare_img() inputs = image_processor(images=image, return_tensors="pt").to(torch_device) From 0848054a43ba708727733b3d4286fb10ae85f712 Mon Sep 17 00:00:00 2001 From: linhaotong Date: Thu, 16 Jan 2025 11:54:47 +0800 Subject: [PATCH 25/58] fix scale to meter and some internal funcs warp --- .../image_processing_prompt_depth_anything.py | 142 ++++++++++-------- 1 file changed, 79 insertions(+), 63 deletions(-) diff --git a/src/transformers/models/prompt_depth_anything/image_processing_prompt_depth_anything.py b/src/transformers/models/prompt_depth_anything/image_processing_prompt_depth_anything.py index bb76acddeaa2..98bba22e35cf 100644 --- a/src/transformers/models/prompt_depth_anything/image_processing_prompt_depth_anything.py +++ b/src/transformers/models/prompt_depth_anything/image_processing_prompt_depth_anything.py @@ -57,24 +57,25 @@ logger = logging.get_logger(__name__) -def get_resize_output_image_size( +def _constrain_to_multiple_of(val, multiple, min_val=0, max_val=None): + x = round(val / multiple) * multiple + + if max_val is not None and x > max_val: + x = math.floor(val / multiple) * multiple + + if x < min_val: + x = math.ceil(val / multiple) * multiple + + return x + + +def _get_resize_output_image_size( input_image: np.ndarray, output_size: Union[int, Iterable[int]], keep_aspect_ratio: bool, multiple: int, input_data_format: Optional[Union[str, ChannelDimension]] = None, ) -> Tuple[int, int]: - def constrain_to_multiple_of(val, multiple, min_val=0, max_val=None): - x = round(val / multiple) * multiple - - if max_val is not None and x > max_val: - x = math.floor(val / multiple) * multiple - - if x < min_val: - x = math.ceil(val / multiple) * multiple - - return x - output_size = (output_size, output_size) if isinstance(output_size, int) else output_size input_height, input_width = get_image_size(input_image, input_data_format) @@ -93,8 +94,8 @@ def constrain_to_multiple_of(val, multiple, min_val=0, max_val=None): # fit height scale_width = scale_height - new_height = constrain_to_multiple_of(scale_height * input_height, multiple=multiple) - new_width = constrain_to_multiple_of(scale_width * input_width, multiple=multiple) + new_height = _constrain_to_multiple_of(scale_height * input_height, multiple=multiple) + new_width = _constrain_to_multiple_of(scale_width * input_width, multiple=multiple) return (new_height, new_width) @@ -158,7 +159,7 @@ def __init__( size_divisor: int = None, prompt_scale_to_meter: float = 0.001, # default unit is mm **kwargs, - ) -> None: + ): super().__init__(**kwargs) size = size if size is not None else {"height": 384, "width": 384} size = get_size_dict(size) @@ -215,7 +216,7 @@ def resize( if "height" not in size or "width" not in size: raise ValueError(f"The size dictionary must contain the keys 'height' and 'width'. Got {size.keys()}") - output_size = get_resize_output_image_size( + output_size = _get_resize_output_image_size( image, output_size=(size["height"], size["width"]), keep_aspect_ratio=keep_aspect_ratio, @@ -233,7 +234,7 @@ def resize( def pad_image( self, - image: np.array, + image: np.ndarray, size_divisor: int, data_format: Optional[Union[str, ChannelDimension]] = None, input_data_format: Optional[Union[str, ChannelDimension]] = None, @@ -283,23 +284,24 @@ def _get_pad(size, size_divisor): def preprocess( self, images: ImageInput, - prompt_depth: ImageInput = None, - do_resize: bool = None, - size: int = None, - keep_aspect_ratio: bool = None, - ensure_multiple_of: int = None, - resample: PILImageResampling = None, - do_rescale: bool = None, - rescale_factor: float = None, - do_normalize: bool = None, + prompt_depth: Optional[ImageInput] = None, + do_resize: Optional[bool] = None, + size: Optional[int] = None, + keep_aspect_ratio: Optional[bool] = None, + ensure_multiple_of: Optional[int] = None, + resample: Optional[PILImageResampling] = None, + do_rescale: Optional[bool] = None, + rescale_factor: Optional[float] = None, + do_normalize: Optional[bool] = None, image_mean: Optional[Union[float, List[float]]] = None, image_std: Optional[Union[float, List[float]]] = None, - do_pad: bool = None, - size_divisor: int = None, + do_pad: Optional[bool] = None, + size_divisor: Optional[int] = None, + prompt_scale_to_meter: Optional[float] = None, return_tensors: Optional[Union[str, TensorType]] = None, data_format: ChannelDimension = ChannelDimension.FIRST, input_data_format: Optional[Union[str, ChannelDimension]] = None, - ) -> PIL.Image.Image: + ) -> BatchFeature: """ Preprocess an image or batch of images. @@ -335,6 +337,8 @@ def preprocess( Image mean. image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`): Image standard deviation. + prompt_scale_to_meter (`float`, *optional*, defaults to `self.prompt_scale_to_meter`): + Scale factor to convert the prompt depth to meters. return_tensors (`str` or `TensorType`, *optional*): The type of tensors to return. Can be one of: - Unset: Return a list of `np.ndarray`. @@ -399,9 +403,10 @@ def preprocess( # We assume that all images have the same channel dimension format. input_data_format = infer_channel_dimension_format(images[0]) - if do_resize: - images = [ - self.resize( + preprocessed_images = [] + for image in images: + if do_resize: + image = self.resize( image=image, size=size, resample=resample, @@ -409,43 +414,54 @@ def preprocess( ensure_multiple_of=ensure_multiple_of, input_data_format=input_data_format, ) - for image in images - ] - - if do_rescale: - images = [ - self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format) - for image in images - ] - - if do_normalize: - images = [ - self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format) - for image in images - ] - - if do_pad: - images = [ - self.pad_image(image=image, size_divisor=size_divisor, input_data_format=input_data_format) - for image in images - ] - - images = [ - to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images - ] + + if do_rescale: + image = self.rescale( + image=image, + scale=rescale_factor, + input_data_format=input_data_format + ) + + if do_normalize: + image = self.normalize( + image=image, + mean=image_mean, + std=image_std, + input_data_format=input_data_format + ) + + if do_pad: + image = self.pad_image( + image=image, + size_divisor=size_divisor, + input_data_format=input_data_format + ) + + image = to_channel_dimension_format( + image, + data_format, + input_channel_dim=input_data_format + ) + preprocessed_images.append(image) + + images = preprocessed_images data = {"pixel_values": images} if prompt_depth is not None: # prompt_depth is a list of images with shape (height, width) # we need to convert it to a list of images with shape (1, height, width) prompt_depths = make_list_of_images(prompt_depth) - prompt_depths = [to_numpy_array(depth) for depth in prompt_depths] - prompt_depths = [depth * self.prompt_scale_to_meter for depth in prompt_depths] - prompt_depths = [prompt_depth[..., None].astype(np.float32) for prompt_depth in prompt_depths] - prompt_depths = [ - to_channel_dimension_format(depth, data_format, input_channel_dim=input_data_format) - for depth in prompt_depths - ] + if prompt_scale_to_meter is None: + prompt_scale_to_meter = self.prompt_scale_to_meter + processed_prompt_depths = [] + for depth in prompt_depths: + depth = to_numpy_array(depth) + depth = depth * prompt_scale_to_meter + depth = depth[..., None].astype(np.float32) + depth = to_channel_dimension_format(depth, data_format, input_channel_dim=input_data_format) + + processed_prompt_depths.append(depth) + prompt_depths = processed_prompt_depths data["prompt_depth"] = prompt_depths return BatchFeature(data=data, tensor_type=return_tensors) From 25e1144a905afb2edd4e26ca797e99258e94ebad Mon Sep 17 00:00:00 2001 From: linhaotong Date: Thu, 16 Jan 2025 14:03:28 +0800 Subject: [PATCH 26/58] fix code style in image_processing_prompt_depth_anything.py --- .../image_processing_prompt_depth_anything.py | 31 +++++-------------- 1 file changed, 8 insertions(+), 23 deletions(-) diff --git a/src/transformers/models/prompt_depth_anything/image_processing_prompt_depth_anything.py b/src/transformers/models/prompt_depth_anything/image_processing_prompt_depth_anything.py index 98bba22e35cf..e9888a3676db 100644 --- a/src/transformers/models/prompt_depth_anything/image_processing_prompt_depth_anything.py +++ b/src/transformers/models/prompt_depth_anything/image_processing_prompt_depth_anything.py @@ -51,7 +51,7 @@ import torch if is_vision_available(): - import PIL + pass logger = logging.get_logger(__name__) @@ -414,36 +414,21 @@ def preprocess( ensure_multiple_of=ensure_multiple_of, input_data_format=input_data_format, ) - + if do_rescale: - image = self.rescale( - image=image, - scale=rescale_factor, - input_data_format=input_data_format - ) + image = self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format) if do_normalize: image = self.normalize( - image=image, - mean=image_mean, - std=image_std, - input_data_format=input_data_format + image=image, mean=image_mean, std=image_std, input_data_format=input_data_format ) if do_pad: - image = self.pad_image( - image=image, - size_divisor=size_divisor, - input_data_format=input_data_format - ) + image = self.pad_image(image=image, size_divisor=size_divisor, input_data_format=input_data_format) - image = to_channel_dimension_format( - image, - data_format, - input_channel_dim=input_data_format - ) + image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) preprocessed_images.append(image) - + images = preprocessed_images data = {"pixel_values": images} @@ -459,7 +444,7 @@ def preprocess( depth = depth * prompt_scale_to_meter depth = depth[..., None].astype(np.float32) depth = to_channel_dimension_format(depth, data_format, input_channel_dim=input_data_format) - + processed_prompt_depths.append(depth) prompt_depths = processed_prompt_depths data["prompt_depth"] = prompt_depths From 3c8f6c04807453905a45fc746dcbb9e1731ce453 Mon Sep 17 00:00:00 2001 From: linhaotong Date: Thu, 16 Jan 2025 14:36:30 +0800 Subject: [PATCH 27/58] fix issues in image_processing_prompt_depth_anything.py --- .../image_processing_prompt_depth_anything.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/src/transformers/models/prompt_depth_anything/image_processing_prompt_depth_anything.py b/src/transformers/models/prompt_depth_anything/image_processing_prompt_depth_anything.py index e9888a3676db..ead82eb573a8 100644 --- a/src/transformers/models/prompt_depth_anything/image_processing_prompt_depth_anything.py +++ b/src/transformers/models/prompt_depth_anything/image_processing_prompt_depth_anything.py @@ -202,9 +202,6 @@ def resize( If `True`, the image is resized to the largest possible size such that the aspect ratio is preserved. ensure_multiple_of (`int`, *optional*, defaults to 1): The image is resized to a size that is a multiple of this value. - resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`): - Defines the resampling filter to use if resizing the image. Otherwise, the image is resized to size - specified in `size`. resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`): Resampling filter to use when resiizing the image. data_format (`str` or `ChannelDimension`, *optional*): @@ -312,7 +309,10 @@ def preprocess( prompt_depth (`ImageInput`, *optional*): Prompt depth to preprocess, which can be sparse depth obtained from multi-view geometry or low-resolution depth from a depth sensor. Generally has shape (height, width), where height - and width can be smaller than the images. + and width can be smaller than those of the images. It's optional and can be None, which means no prompt depth + is used. If it is None, the output depth will be a monocular relative depth. + It is recommended to provide a prompt_scale_to_meter value, which is the scale factor to convert the prompt depth + to meters. This is useful when the prompt depth is not in meters. do_resize (`bool`, *optional*, defaults to `self.do_resize`): Whether to resize the image. size (`Dict[str, int]`, *optional*, defaults to `self.size`): @@ -436,8 +436,16 @@ def preprocess( # prompt_depth is a list of images with shape (height, width) # we need to convert it to a list of images with shape (1, height, width) prompt_depths = make_list_of_images(prompt_depth) + + # Validate prompt_depths has same length as images + if len(prompt_depths) != len(images): + raise ValueError( + f"Number of prompt depth images ({len(prompt_depths)}) does not match number of input images ({len(images)})" + ) + if prompt_scale_to_meter is None: prompt_scale_to_meter = self.prompt_scale_to_meter + processed_prompt_depths = [] for depth in prompt_depths: depth = to_numpy_array(depth) From cf24f48e30403c461399b5d339d55150396fad01 Mon Sep 17 00:00:00 2001 From: linhaotong Date: Thu, 16 Jan 2025 15:18:17 +0800 Subject: [PATCH 28/58] fix issues in image_processing_prompt_depth_anything.py --- .../modeling_prompt_depth_anything.py | 30 +++--- .../modular_prompt_depth_anything.py | 91 ++++++++++--------- 2 files changed, 65 insertions(+), 56 deletions(-) diff --git a/src/transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py b/src/transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py index b976215a8d1a..a535e1661947 100644 --- a/src/transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py +++ b/src/transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py @@ -23,7 +23,7 @@ class PromptDepthAnythingLayer(nn.Module): - def __init__(self, config): + def __init__(self, config: PromptDepthAnythingConfig): super().__init__() self.convolution1 = nn.Conv2d( 1, @@ -55,13 +55,12 @@ def __init__(self, config): ) def forward(self, prompt_depth: torch.Tensor) -> torch.Tensor: - residual = prompt_depth - residual = self.convolution1(residual) - residual = self.activation1(residual) - residual = self.convolution2(residual) - residual = self.activation2(residual) - residual = self.convolution3(residual) - return residual + hidden_state = self.convolution1(prompt_depth) + hidden_state = self.activation1(hidden_state) + hidden_state = self.convolution2(hidden_state) + hidden_state = self.activation2(hidden_state) + hidden_state = self.convolution3(hidden_state) + return hidden_state class PromptDepthAnythingPreActResidualLayer(nn.Module): @@ -114,7 +113,7 @@ class PromptDepthAnythingFeatureFusionLayer(nn.Module): Model configuration class defining the model architecture. """ - def __init__(self, config): + def __init__(self, config: PromptDepthAnythingConfig): super().__init__() self.projection = nn.Conv2d(config.fusion_hidden_size, config.fusion_hidden_size, kernel_size=1, bias=True) @@ -127,7 +126,7 @@ def forward(self, hidden_state, residual=None, size=None, prompt_depth=None): if residual is not None: if hidden_state.shape != residual.shape: residual = nn.functional.interpolate( - residual, size=(hidden_state.shape[2], hidden_state.shape[3]), mode="bilinear", align_corners=False + residual, size=hidden_state.shape[2:], mode="bilinear", align_corners=False ) hidden_state = hidden_state + self.residual_layer1(residual) @@ -135,7 +134,7 @@ def forward(self, hidden_state, residual=None, size=None, prompt_depth=None): if prompt_depth is not None: prompt_depth = nn.functional.interpolate( - prompt_depth, hidden_state.shape[2:], mode="bilinear", align_corners=False + prompt_depth, size=hidden_state.shape[2:], mode="bilinear", align_corners=False ) res = self.prompt_depth_layer(prompt_depth) hidden_state = hidden_state + res @@ -224,7 +223,9 @@ def forward(self, hidden_states: List[torch.Tensor], patch_height, patch_width) predicted_depth = self.activation1(predicted_depth) predicted_depth = self.conv3(predicted_depth) predicted_depth = self.activation2(predicted_depth) * self.max_depth - predicted_depth = predicted_depth.squeeze(dim=1) # shape (batch_size, height, width) + # (batch_size, 1, height, width) -> (batch_size, height, width), which + # keeps the same behavior as Depth Anything v1 & v2 + predicted_depth = predicted_depth.squeeze(dim=1) return predicted_depth @@ -254,7 +255,7 @@ def _init_weights(self, module): class PromptDepthAnythingReassembleLayer(nn.Module): - def __init__(self, config, channels, factor): + def __init__(self, config: PromptDepthAnythingConfig, channels: int, factor: int): super().__init__() self.projection = nn.Conv2d(in_channels=config.reassemble_hidden_size, out_channels=channels, kernel_size=1) @@ -274,7 +275,8 @@ def __init__(self, config, channels, factor): self.resize = nn.Identity() elif factor < 1: # so should downsample - self.resize = nn.Conv2d(channels, channels, kernel_size=3, stride=torch_int(1 / factor), padding=1) + stride = torch_int(1 / factor) + self.resize = nn.Conv2d(channels, channels, kernel_size=3, stride=stride, padding=1) def forward(self, hidden_state): hidden_state = self.projection(hidden_state) diff --git a/src/transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py b/src/transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py index 1cd587770627..e81b2c87acb3 100644 --- a/src/transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py +++ b/src/transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py @@ -26,41 +26,13 @@ _CONFIG_FOR_DOC = "PromptDepthAnythingConfig" -PROMPT_DEPTH_ANYTHING_START_DOCSTRING = r""" - This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it - as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and - behavior. - - Parameters: - config ([`PromptDepthAnythingConfig`]): Model configuration class with all the parameters of the model. - Initializing with a config file does not load the weights associated with the model, only the - configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. -""" - -PROMPT_DEPTH_ANYTHING_INPUTS_DOCSTRING = r""" - Args: - pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): - Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See [`DPTImageProcessor.__call__`] - for details. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. - prompt_depth (`torch.FloatTensor` of shape `(batch_size, 1, height, width)`, *optional*): - Prompt depth. - return_dict (`bool`, *optional*): - Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple. -""" - class PromptDepthAnythingConfig(DepthAnythingConfig): model_type = "prompt_depth_anything" class PromptDepthAnythingLayer(nn.Module): - def __init__(self, config): + def __init__(self, config: PromptDepthAnythingConfig): super().__init__() self.convolution1 = nn.Conv2d( 1, @@ -92,17 +64,16 @@ def __init__(self, config): ) def forward(self, prompt_depth: torch.Tensor) -> torch.Tensor: - residual = prompt_depth - residual = self.convolution1(residual) - residual = self.activation1(residual) - residual = self.convolution2(residual) - residual = self.activation2(residual) - residual = self.convolution3(residual) - return residual + hidden_state = self.convolution1(prompt_depth) + hidden_state = self.activation1(hidden_state) + hidden_state = self.convolution2(hidden_state) + hidden_state = self.activation2(hidden_state) + hidden_state = self.convolution3(hidden_state) + return hidden_state class PromptDepthAnythingFeatureFusionLayer(DepthAnythingFeatureFusionLayer): - def __init__(self, config): + def __init__(self, config: PromptDepthAnythingConfig): super().__init__(config) self.prompt_depth_layer = PromptDepthAnythingLayer(config) @@ -110,7 +81,7 @@ def forward(self, hidden_state, residual=None, size=None, prompt_depth=None): if residual is not None: if hidden_state.shape != residual.shape: residual = nn.functional.interpolate( - residual, size=(hidden_state.shape[2], hidden_state.shape[3]), mode="bilinear", align_corners=False + residual, size=hidden_state.shape[2:], mode="bilinear", align_corners=False ) hidden_state = hidden_state + self.residual_layer1(residual) @@ -118,7 +89,7 @@ def forward(self, hidden_state, residual=None, size=None, prompt_depth=None): if prompt_depth is not None: prompt_depth = nn.functional.interpolate( - prompt_depth, hidden_state.shape[2:], mode="bilinear", align_corners=False + prompt_depth, size=hidden_state.shape[2:], mode="bilinear", align_corners=False ) res = self.prompt_depth_layer(prompt_depth) hidden_state = hidden_state + res @@ -175,11 +146,46 @@ def forward(self, hidden_states: List[torch.Tensor], patch_height, patch_width) predicted_depth = self.activation1(predicted_depth) predicted_depth = self.conv3(predicted_depth) predicted_depth = self.activation2(predicted_depth) * self.max_depth - predicted_depth = predicted_depth.squeeze(dim=1) # shape (batch_size, height, width) + # (batch_size, 1, height, width) -> (batch_size, height, width), which + # keeps the same behavior as Depth Anything v1 & v2 + predicted_depth = predicted_depth.squeeze(dim=1) return predicted_depth +PROMPT_DEPTH_ANYTHING_START_DOCSTRING = r""" + This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it + as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and + behavior. + + Parameters: + config ([`PromptDepthAnythingConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +PROMPT_DEPTH_ANYTHING_INPUTS_DOCSTRING = r""" + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See [`DPTImageProcessor.__call__`] + for details. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + prompt_depth (`torch.FloatTensor` of shape `(batch_size, 1, height, width)`, *optional*): + Prompt depth is the sparse or low-resolution depth obtained from multi-view geometry or a + low-resolution depth sensor. It generally has shape (height, width), where height + and width can be smaller than those of the images. It is optional and can be None, which means no prompt depth + will be used. If it is None, the output will be a monocular relative depth. + The values are recommended to be in meters, but this is not necessary. + return_dict (`bool`, *optional*): + Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple. +""" + + class PromptDepthAnythingPreTrainedModel(PreTrainedModel): """ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained @@ -205,7 +211,7 @@ def _init_weights(self, module): class PromptDepthAnythingReassembleLayer(DepthAnythingReassembleLayer): - def __init__(self, config, channels, factor): + def __init__(self, config: PromptDepthAnythingConfig, channels: int, factor: int): super().__init__(config, channels, factor) self.projection = nn.Conv2d(in_channels=config.reassemble_hidden_size, out_channels=channels, kernel_size=1) @@ -216,7 +222,8 @@ def __init__(self, config, channels, factor): self.resize = nn.Identity() elif factor < 1: # so should downsample - self.resize = nn.Conv2d(channels, channels, kernel_size=3, stride=torch_int(1 / factor), padding=1) + stride = torch_int(1 / factor) + self.resize = nn.Conv2d(channels, channels, kernel_size=3, stride=stride, padding=1) class PromptDepthAnythingReassembleStage(DepthAnythingReassembleStage): From fcd5107d7c86f0bf8b4f1eec53e338a5bb1dba75 Mon Sep 17 00:00:00 2001 From: linhaotong Date: Thu, 16 Jan 2025 17:15:19 +0800 Subject: [PATCH 29/58] fix issues in prompt depth anything --- .../en/model_doc/prompt_depth_anything.md | 1 + .../convert_prompt_depth_anything_to_hf.py | 142 +++++++----------- .../image_processing_prompt_depth_anything.py | 2 + .../modeling_prompt_depth_anything.py | 26 +--- .../modular_prompt_depth_anything.py | 28 ++-- 5 files changed, 74 insertions(+), 125 deletions(-) diff --git a/docs/source/en/model_doc/prompt_depth_anything.md b/docs/source/en/model_doc/prompt_depth_anything.md index 50d6f44399a6..d1fc02009044 100644 --- a/docs/source/en/model_doc/prompt_depth_anything.md +++ b/docs/source/en/model_doc/prompt_depth_anything.md @@ -49,6 +49,7 @@ The Transformers library allows you to use the model with just a few lines of co >>> prompt_depth_url = "https://github.com/DepthAnything/PromptDA/blob/main/assets/example_images/arkit_depth.png?raw=true" >>> prompt_depth = Image.open(requests.get(prompt_depth_url, stream=True).raw) +>>> # the prompt depth can be None, and the model will output a monocular relative depth. >>> # prepare image for the model >>> inputs = image_processor(images=image, return_tensors="pt", prompt_depth=prompt_depth) diff --git a/src/transformers/models/prompt_depth_anything/convert_prompt_depth_anything_to_hf.py b/src/transformers/models/prompt_depth_anything/convert_prompt_depth_anything_to_hf.py index 54854fe0a7a2..0d303898e5a5 100644 --- a/src/transformers/models/prompt_depth_anything/convert_prompt_depth_anything_to_hf.py +++ b/src/transformers/models/prompt_depth_anything/convert_prompt_depth_anything_to_hf.py @@ -77,94 +77,58 @@ def get_dpt_config(model_name): return config -KEY_MAPPING = { - # Stem - "pretrained.cls_token": "backbone.embeddings.cls_token", - "pretrained.mask_token": "backbone.embeddings.mask_token", - "pretrained.pos_embed": "backbone.embeddings.position_embeddings", - "pretrained.patch_embed.proj.weight": "backbone.embeddings.patch_embeddings.projection.weight", - "pretrained.patch_embed.proj.bias": "backbone.embeddings.patch_embeddings.projection.bias", - # Head - "pretrained.norm.weight": "backbone.layernorm.weight", - "pretrained.norm.bias": "backbone.layernorm.bias", - # Head - "depth_head.scratch.output_conv1.weight": "head.conv1.weight", - "depth_head.scratch.output_conv1.bias": "head.conv1.bias", - "depth_head.scratch.output_conv2.0.weight": "head.conv2.weight", - "depth_head.scratch.output_conv2.0.bias": "head.conv2.bias", - "depth_head.scratch.output_conv2.2.weight": "head.conv3.weight", - "depth_head.scratch.output_conv2.2.bias": "head.conv3.bias", +ORIGINAL_TO_CONVERTED_KEY_MAPPING = { + r"pretrained.cls_token": r"backbone.embeddings.cls_token", + r"pretrained.mask_token": r"backbone.embeddings.mask_token", + r"pretrained.pos_embed": r"backbone.embeddings.position_embeddings", + r"pretrained.patch_embed.proj.weight": r"backbone.embeddings.patch_embeddings.projection.weight", + r"pretrained.patch_embed.proj.bias": r"backbone.embeddings.patch_embeddings.projection.bias", + r"pretrained.norm.weight": r"backbone.layernorm.weight", + r"pretrained.norm.bias": r"backbone.layernorm.bias", + r"depth_head.scratch.output_conv1.weight": r"head.conv1.weight", + r"depth_head.scratch.output_conv1.bias": r"head.conv1.bias", + r"depth_head.scratch.output_conv2.0.weight": r"head.conv2.weight", + r"depth_head.scratch.output_conv2.0.bias": r"head.conv2.bias", + r"depth_head.scratch.output_conv2.2.weight": r"head.conv3.weight", + r"depth_head.scratch.output_conv2.2.bias": r"head.conv3.bias", + r"pretrained.blocks.(\d+).ls1.gamma": r"backbone.encoder.layer.\1.layer_scale1.lambda1", + r"pretrained.blocks.(\d+).ls2.gamma": r"backbone.encoder.layer.\1.layer_scale2.lambda1", + r"pretrained.blocks.(\d+).norm1.weight": r"backbone.encoder.layer.\1.norm1.weight", + r"pretrained.blocks.(\d+).norm1.bias": r"backbone.encoder.layer.\1.norm1.bias", + r"pretrained.blocks.(\d+).norm2.weight": r"backbone.encoder.layer.\1.norm2.weight", + r"pretrained.blocks.(\d+).norm2.bias": r"backbone.encoder.layer.\1.norm2.bias", + r"pretrained.blocks.(\d+).mlp.fc1.weight": r"backbone.encoder.layer.\1.mlp.fc1.weight", + r"pretrained.blocks.(\d+).mlp.fc1.bias": r"backbone.encoder.layer.\1.mlp.fc1.bias", + r"pretrained.blocks.(\d+).mlp.fc2.weight": r"backbone.encoder.layer.\1.mlp.fc2.weight", + r"pretrained.blocks.(\d+).mlp.fc2.bias": r"backbone.encoder.layer.\1.mlp.fc2.bias", + r"pretrained.blocks.(\d+).attn.proj.weight": r"backbone.encoder.layer.\1.attention.output.dense.weight", + r"pretrained.blocks.(\d+).attn.proj.bias": r"backbone.encoder.layer.\1.attention.output.dense.bias", + r"pretrained.blocks.(\d+).attn.qkv.weight": r"qkv_transform_\1", + r"pretrained.blocks.(\d+).attn.qkv.bias": r"qkv_transform_bias_\1", + r"depth_head.projects.(\d+).weight": r"neck.reassemble_stage.layers.\1.projection.weight", + r"depth_head.projects.(\d+).bias": r"neck.reassemble_stage.layers.\1.projection.bias", + r"depth_head.scratch.layer(\d+)_rn.weight": r"neck.convs.\0.weight", + r"depth_head.resize_layers.(\d+).weight": r"neck.reassemble_stage.layers.\1.resize.weight", + r"depth_head.resize_layers.(\d+).bias": r"neck.reassemble_stage.layers.\1.resize.bias", + r"depth_head.scratch.refinenet(\d+).out_conv.weight": r"neck.fusion_stage.layers.\0.projection.weight", + r"depth_head.scratch.refinenet(\d+).out_conv.bias": r"neck.fusion_stage.layers.\0.projection.bias", + r"depth_head.scratch.refinenet(\d+).resConfUnit1.conv1.weight": r"neck.fusion_stage.layers.\0.residual_layer1.convolution1.weight", + r"depth_head.scratch.refinenet(\d+).resConfUnit1.conv1.bias": r"neck.fusion_stage.layers.\0.residual_layer1.convolution1.bias", + r"depth_head.scratch.refinenet(\d+).resConfUnit1.conv2.weight": r"neck.fusion_stage.layers.\0.residual_layer1.convolution2.weight", + r"depth_head.scratch.refinenet(\d+).resConfUnit1.conv2.bias": r"neck.fusion_stage.layers.\0.residual_layer1.convolution2.bias", + r"depth_head.scratch.refinenet(\d+).resConfUnit2.conv1.weight": r"neck.fusion_stage.layers.\0.residual_layer2.convolution1.weight", + r"depth_head.scratch.refinenet(\d+).resConfUnit2.conv1.bias": r"neck.fusion_stage.layers.\0.residual_layer2.convolution1.bias", + r"depth_head.scratch.refinenet(\d+).resConfUnit2.conv2.weight": r"neck.fusion_stage.layers.\0.residual_layer2.convolution2.weight", + r"depth_head.scratch.refinenet(\d+).resConfUnit2.conv2.bias": r"neck.fusion_stage.layers.\0.residual_layer2.convolution2.bias", + r"depth_head.scratch.refinenet(\d+).resConfUnit_depth.0.weight": r"neck.fusion_stage.layers.\0.prompt_depth_layer.convolution1.weight", + r"depth_head.scratch.refinenet(\d+).resConfUnit_depth.0.bias": r"neck.fusion_stage.layers.\0.prompt_depth_layer.convolution1.bias", + r"depth_head.scratch.refinenet(\d+).resConfUnit_depth.2.weight": r"neck.fusion_stage.layers.\0.prompt_depth_layer.convolution2.weight", + r"depth_head.scratch.refinenet(\d+).resConfUnit_depth.2.bias": r"neck.fusion_stage.layers.\0.prompt_depth_layer.convolution2.bias", + r"depth_head.scratch.refinenet(\d+).resConfUnit_depth.4.weight": r"neck.fusion_stage.layers.\0.prompt_depth_layer.convolution3.weight", + r"depth_head.scratch.refinenet(\d+).resConfUnit_depth.4.bias": r"neck.fusion_stage.layers.\0.prompt_depth_layer.convolution3.bias", } -def add_transformer_mappings(config): - # Transformer encoder mappings - for i in range(config.backbone_config.num_hidden_layers): - KEY_MAPPING.update( - { - f"pretrained.blocks.{i}.ls1.gamma": f"backbone.encoder.layer.{i}.layer_scale1.lambda1", - f"pretrained.blocks.{i}.ls2.gamma": f"backbone.encoder.layer.{i}.layer_scale2.lambda1", - f"pretrained.blocks.{i}.norm1.weight": f"backbone.encoder.layer.{i}.norm1.weight", - f"pretrained.blocks.{i}.norm1.bias": f"backbone.encoder.layer.{i}.norm1.bias", - f"pretrained.blocks.{i}.norm2.weight": f"backbone.encoder.layer.{i}.norm2.weight", - f"pretrained.blocks.{i}.norm2.bias": f"backbone.encoder.layer.{i}.norm2.bias", - f"pretrained.blocks.{i}.mlp.fc1.weight": f"backbone.encoder.layer.{i}.mlp.fc1.weight", - f"pretrained.blocks.{i}.mlp.fc1.bias": f"backbone.encoder.layer.{i}.mlp.fc1.bias", - f"pretrained.blocks.{i}.mlp.fc2.weight": f"backbone.encoder.layer.{i}.mlp.fc2.weight", - f"pretrained.blocks.{i}.mlp.fc2.bias": f"backbone.encoder.layer.{i}.mlp.fc2.bias", - f"pretrained.blocks.{i}.attn.proj.weight": f"backbone.encoder.layer.{i}.attention.output.dense.weight", - f"pretrained.blocks.{i}.attn.proj.bias": f"backbone.encoder.layer.{i}.attention.output.dense.bias", - f"pretrained.blocks.{i}.attn.qkv.weight": f"qkv_transform_{i}", - f"pretrained.blocks.{i}.attn.qkv.bias": f"qkv_transform_bias_{i}", - } - ) - - -def add_neck_mappings(): - # Neck mappings - for i in range(4): - KEY_MAPPING.update( - { - f"depth_head.projects.{i}.weight": f"neck.reassemble_stage.layers.{i}.projection.weight", - f"depth_head.projects.{i}.bias": f"neck.reassemble_stage.layers.{i}.projection.bias", - f"depth_head.scratch.layer{i+1}_rn.weight": f"neck.convs.{i}.weight", - } - ) - - if i != 2: - KEY_MAPPING.update( - { - f"depth_head.resize_layers.{i}.weight": f"neck.reassemble_stage.layers.{i}.resize.weight", - f"depth_head.resize_layers.{i}.bias": f"neck.reassemble_stage.layers.{i}.resize.bias", - } - ) - - # Refinenet mappings - mapping = {1: 3, 2: 2, 3: 1, 4: 0} - for i in range(1, 5): - j = mapping[i] - KEY_MAPPING.update( - { - f"depth_head.scratch.refinenet{i}.out_conv.weight": f"neck.fusion_stage.layers.{j}.projection.weight", - f"depth_head.scratch.refinenet{i}.out_conv.bias": f"neck.fusion_stage.layers.{j}.projection.bias", - f"depth_head.scratch.refinenet{i}.resConfUnit1.conv1.weight": f"neck.fusion_stage.layers.{j}.residual_layer1.convolution1.weight", - f"depth_head.scratch.refinenet{i}.resConfUnit1.conv1.bias": f"neck.fusion_stage.layers.{j}.residual_layer1.convolution1.bias", - f"depth_head.scratch.refinenet{i}.resConfUnit1.conv2.weight": f"neck.fusion_stage.layers.{j}.residual_layer1.convolution2.weight", - f"depth_head.scratch.refinenet{i}.resConfUnit1.conv2.bias": f"neck.fusion_stage.layers.{j}.residual_layer1.convolution2.bias", - f"depth_head.scratch.refinenet{i}.resConfUnit2.conv1.weight": f"neck.fusion_stage.layers.{j}.residual_layer2.convolution1.weight", - f"depth_head.scratch.refinenet{i}.resConfUnit2.conv1.bias": f"neck.fusion_stage.layers.{j}.residual_layer2.convolution1.bias", - f"depth_head.scratch.refinenet{i}.resConfUnit2.conv2.weight": f"neck.fusion_stage.layers.{j}.residual_layer2.convolution2.weight", - f"depth_head.scratch.refinenet{i}.resConfUnit2.conv2.bias": f"neck.fusion_stage.layers.{j}.residual_layer2.convolution2.bias", - f"depth_head.scratch.refinenet{i}.resConfUnit_depth.0.weight": f"neck.fusion_stage.layers.{j}.prompt_depth_layer.convolution1.weight", - f"depth_head.scratch.refinenet{i}.resConfUnit_depth.0.bias": f"neck.fusion_stage.layers.{j}.prompt_depth_layer.convolution1.bias", - f"depth_head.scratch.refinenet{i}.resConfUnit_depth.2.weight": f"neck.fusion_stage.layers.{j}.prompt_depth_layer.convolution2.weight", - f"depth_head.scratch.refinenet{i}.resConfUnit_depth.2.bias": f"neck.fusion_stage.layers.{j}.prompt_depth_layer.convolution2.bias", - f"depth_head.scratch.refinenet{i}.resConfUnit_depth.4.weight": f"neck.fusion_stage.layers.{j}.prompt_depth_layer.convolution3.weight", - f"depth_head.scratch.refinenet{i}.resConfUnit_depth.4.bias": f"neck.fusion_stage.layers.{j}.prompt_depth_layer.convolution3.bias", - } - ) - - def transform_qkv_weights(key, value, config): if not key.startswith("qkv_transform"): return value @@ -206,10 +170,6 @@ def convert_dpt_checkpoint(model_name, pytorch_dump_folder_path, push_to_hub, ve # define DPT configuration config = get_dpt_config(model_name) - # Add dynamic key mappings - add_transformer_mappings(config) - add_neck_mappings() - model_name_to_repo = { "prompt-depth-anything-vits": "depth-anything/prompt-depth-anything-vits", "prompt-depth-anything-vits-transparent": "depth-anything/prompt-depth-anything-vits-transparent", @@ -230,8 +190,8 @@ def convert_dpt_checkpoint(model_name, pytorch_dump_folder_path, push_to_hub, ve # Convert state dict using mappings new_state_dict = {} for key, value in state_dict.items(): - if key in KEY_MAPPING: - new_key = KEY_MAPPING[key] + if key in ORIGINAL_TO_CONVERTED_KEY_MAPPING: + new_key = ORIGINAL_TO_CONVERTED_KEY_MAPPING[key] transformed_value = transform_qkv_weights(new_key, value, config) if isinstance(transformed_value, dict): new_state_dict.update(transformed_value) diff --git a/src/transformers/models/prompt_depth_anything/image_processing_prompt_depth_anything.py b/src/transformers/models/prompt_depth_anything/image_processing_prompt_depth_anything.py index ead82eb573a8..e92f70fdb6a7 100644 --- a/src/transformers/models/prompt_depth_anything/image_processing_prompt_depth_anything.py +++ b/src/transformers/models/prompt_depth_anything/image_processing_prompt_depth_anything.py @@ -450,6 +450,8 @@ def preprocess( for depth in prompt_depths: depth = to_numpy_array(depth) depth = depth * prompt_scale_to_meter + if depth.min() == depth.max(): + raise ValueError("Prompt depth is invalid, min and max are the same.") depth = depth[..., None].astype(np.float32) depth = to_channel_dimension_format(depth, data_format, input_channel_dim=input_data_format) diff --git a/src/transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py b/src/transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py index a535e1661947..2b4a6ea0ef7f 100644 --- a/src/transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py +++ b/src/transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py @@ -208,7 +208,7 @@ def __init__(self, config): self.max_depth = config.max_depth def forward(self, hidden_states: List[torch.Tensor], patch_height, patch_width) -> torch.Tensor: - hidden_states = hidden_states[self.head_in_index] + hidden_states = hidden_states[-1] predicted_depth = self.conv1(hidden_states) target_height = torch_int(patch_height * self.patch_size) @@ -222,7 +222,7 @@ def forward(self, hidden_states: List[torch.Tensor], patch_height, patch_width) predicted_depth = self.conv2(predicted_depth) predicted_depth = self.activation1(predicted_depth) predicted_depth = self.conv3(predicted_depth) - predicted_depth = self.activation2(predicted_depth) * self.max_depth + predicted_depth = self.activation2(predicted_depth) # (batch_size, 1, height, width) -> (batch_size, height, width), which # keeps the same behavior as Depth Anything v1 & v2 predicted_depth = predicted_depth.squeeze(dim=1) @@ -259,15 +259,6 @@ def __init__(self, config: PromptDepthAnythingConfig, channels: int, factor: int super().__init__() self.projection = nn.Conv2d(in_channels=config.reassemble_hidden_size, out_channels=channels, kernel_size=1) - # up/down sampling depending on factor - if factor > 1: - self.resize = nn.ConvTranspose2d(channels, channels, kernel_size=factor, stride=factor, padding=0) - elif factor == 1: - self.resize = nn.Identity() - elif factor < 1: - # so should downsample - self.resize = nn.Conv2d(channels, channels, kernel_size=3, stride=int(1 / factor), padding=1) - # up/down sampling depending on factor if factor > 1: self.resize = nn.ConvTranspose2d(channels, channels, kernel_size=factor, stride=factor, padding=0) @@ -505,15 +496,10 @@ def forward( if prompt_depth is not None: # normalize prompt depth - B = prompt_depth.shape[0] - depth_min, depth_max = ( - torch.min(prompt_depth.reshape(B, -1), dim=1).values, - torch.max(prompt_depth.reshape(B, -1), dim=1).values, - ) - invalid_mask = (depth_max - depth_min) <= 0 - if invalid_mask.any(): - depth_max[invalid_mask] = depth_min[invalid_mask] + 1e-6 - depth_min, depth_max = depth_min.view(B, 1, 1, 1), depth_max.view(B, 1, 1, 1) + batch_size = prompt_depth.shape[0] + depth_min = torch.min(prompt_depth.reshape(batch_size, -1), dim=1).values + depth_max = torch.max(prompt_depth.reshape(batch_size, -1), dim=1).values + depth_min, depth_max = depth_min.view(batch_size, 1, 1, 1), depth_max.view(batch_size, 1, 1, 1) prompt_depth = (prompt_depth - depth_min) / (depth_max - depth_min) # normalize done diff --git a/src/transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py b/src/transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py index e81b2c87acb3..9d14b5fc7e1f 100644 --- a/src/transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py +++ b/src/transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py @@ -10,7 +10,6 @@ DepthAnythingFeatureFusionStage, DepthAnythingForDepthEstimation, DepthAnythingNeck, - DepthAnythingReassembleLayer, DepthAnythingReassembleStage, ) from transformers.utils.generic import torch_int @@ -131,7 +130,7 @@ def forward(self, hidden_states, size=None, prompt_depth=None): class PromptDepthAnythingDepthEstimationHead(DepthAnythingDepthEstimationHead): def forward(self, hidden_states: List[torch.Tensor], patch_height, patch_width) -> torch.Tensor: - hidden_states = hidden_states[self.head_in_index] + hidden_states = hidden_states[-1] predicted_depth = self.conv1(hidden_states) target_height = torch_int(patch_height * self.patch_size) @@ -145,7 +144,7 @@ def forward(self, hidden_states: List[torch.Tensor], patch_height, patch_width) predicted_depth = self.conv2(predicted_depth) predicted_depth = self.activation1(predicted_depth) predicted_depth = self.conv3(predicted_depth) - predicted_depth = self.activation2(predicted_depth) * self.max_depth + predicted_depth = self.activation2(predicted_depth) # (batch_size, 1, height, width) -> (batch_size, height, width), which # keeps the same behavior as Depth Anything v1 & v2 predicted_depth = predicted_depth.squeeze(dim=1) @@ -210,9 +209,9 @@ def _init_weights(self, module): module.weight.data.fill_(1.0) -class PromptDepthAnythingReassembleLayer(DepthAnythingReassembleLayer): +class PromptDepthAnythingReassembleLayer(nn.Module): def __init__(self, config: PromptDepthAnythingConfig, channels: int, factor: int): - super().__init__(config, channels, factor) + super().__init__() self.projection = nn.Conv2d(in_channels=config.reassemble_hidden_size, out_channels=channels, kernel_size=1) # up/down sampling depending on factor @@ -225,6 +224,12 @@ def __init__(self, config: PromptDepthAnythingConfig, channels: int, factor: int stride = torch_int(1 / factor) self.resize = nn.Conv2d(channels, channels, kernel_size=3, stride=stride, padding=1) + def forward(self, hidden_state): + hidden_state = self.projection(hidden_state) + hidden_state = self.resize(hidden_state) + + return hidden_state + class PromptDepthAnythingReassembleStage(DepthAnythingReassembleStage): pass @@ -337,15 +342,10 @@ def forward( if prompt_depth is not None: # normalize prompt depth - B = prompt_depth.shape[0] - depth_min, depth_max = ( - torch.min(prompt_depth.reshape(B, -1), dim=1).values, - torch.max(prompt_depth.reshape(B, -1), dim=1).values, - ) - invalid_mask = (depth_max - depth_min) <= 0 - if invalid_mask.any(): - depth_max[invalid_mask] = depth_min[invalid_mask] + 1e-6 - depth_min, depth_max = depth_min.view(B, 1, 1, 1), depth_max.view(B, 1, 1, 1) + batch_size = prompt_depth.shape[0] + depth_min = torch.min(prompt_depth.reshape(batch_size, -1), dim=1).values + depth_max = torch.max(prompt_depth.reshape(batch_size, -1), dim=1).values + depth_min, depth_max = depth_min.view(batch_size, 1, 1, 1), depth_max.view(batch_size, 1, 1, 1) prompt_depth = (prompt_depth - depth_min) / (depth_max - depth_min) # normalize done From d9f6ecf2f6b3f167f4b3d4428fd5f8c21fd7e4b1 Mon Sep 17 00:00:00 2001 From: linhaotong Date: Fri, 17 Jan 2025 01:52:30 +0800 Subject: [PATCH 30/58] update converting script similar to mllamma --- .../convert_prompt_depth_anything_to_hf.py | 154 +++++++++--------- 1 file changed, 75 insertions(+), 79 deletions(-) diff --git a/src/transformers/models/prompt_depth_anything/convert_prompt_depth_anything_to_hf.py b/src/transformers/models/prompt_depth_anything/convert_prompt_depth_anything_to_hf.py index 0d303898e5a5..8dfeff03ad27 100644 --- a/src/transformers/models/prompt_depth_anything/convert_prompt_depth_anything_to_hf.py +++ b/src/transformers/models/prompt_depth_anything/convert_prompt_depth_anything_to_hf.py @@ -16,6 +16,7 @@ https://github.com/DepthAnything/PromptDA""" import argparse +import re from pathlib import Path import requests @@ -77,58 +78,6 @@ def get_dpt_config(model_name): return config -ORIGINAL_TO_CONVERTED_KEY_MAPPING = { - r"pretrained.cls_token": r"backbone.embeddings.cls_token", - r"pretrained.mask_token": r"backbone.embeddings.mask_token", - r"pretrained.pos_embed": r"backbone.embeddings.position_embeddings", - r"pretrained.patch_embed.proj.weight": r"backbone.embeddings.patch_embeddings.projection.weight", - r"pretrained.patch_embed.proj.bias": r"backbone.embeddings.patch_embeddings.projection.bias", - r"pretrained.norm.weight": r"backbone.layernorm.weight", - r"pretrained.norm.bias": r"backbone.layernorm.bias", - r"depth_head.scratch.output_conv1.weight": r"head.conv1.weight", - r"depth_head.scratch.output_conv1.bias": r"head.conv1.bias", - r"depth_head.scratch.output_conv2.0.weight": r"head.conv2.weight", - r"depth_head.scratch.output_conv2.0.bias": r"head.conv2.bias", - r"depth_head.scratch.output_conv2.2.weight": r"head.conv3.weight", - r"depth_head.scratch.output_conv2.2.bias": r"head.conv3.bias", - r"pretrained.blocks.(\d+).ls1.gamma": r"backbone.encoder.layer.\1.layer_scale1.lambda1", - r"pretrained.blocks.(\d+).ls2.gamma": r"backbone.encoder.layer.\1.layer_scale2.lambda1", - r"pretrained.blocks.(\d+).norm1.weight": r"backbone.encoder.layer.\1.norm1.weight", - r"pretrained.blocks.(\d+).norm1.bias": r"backbone.encoder.layer.\1.norm1.bias", - r"pretrained.blocks.(\d+).norm2.weight": r"backbone.encoder.layer.\1.norm2.weight", - r"pretrained.blocks.(\d+).norm2.bias": r"backbone.encoder.layer.\1.norm2.bias", - r"pretrained.blocks.(\d+).mlp.fc1.weight": r"backbone.encoder.layer.\1.mlp.fc1.weight", - r"pretrained.blocks.(\d+).mlp.fc1.bias": r"backbone.encoder.layer.\1.mlp.fc1.bias", - r"pretrained.blocks.(\d+).mlp.fc2.weight": r"backbone.encoder.layer.\1.mlp.fc2.weight", - r"pretrained.blocks.(\d+).mlp.fc2.bias": r"backbone.encoder.layer.\1.mlp.fc2.bias", - r"pretrained.blocks.(\d+).attn.proj.weight": r"backbone.encoder.layer.\1.attention.output.dense.weight", - r"pretrained.blocks.(\d+).attn.proj.bias": r"backbone.encoder.layer.\1.attention.output.dense.bias", - r"pretrained.blocks.(\d+).attn.qkv.weight": r"qkv_transform_\1", - r"pretrained.blocks.(\d+).attn.qkv.bias": r"qkv_transform_bias_\1", - r"depth_head.projects.(\d+).weight": r"neck.reassemble_stage.layers.\1.projection.weight", - r"depth_head.projects.(\d+).bias": r"neck.reassemble_stage.layers.\1.projection.bias", - r"depth_head.scratch.layer(\d+)_rn.weight": r"neck.convs.\0.weight", - r"depth_head.resize_layers.(\d+).weight": r"neck.reassemble_stage.layers.\1.resize.weight", - r"depth_head.resize_layers.(\d+).bias": r"neck.reassemble_stage.layers.\1.resize.bias", - r"depth_head.scratch.refinenet(\d+).out_conv.weight": r"neck.fusion_stage.layers.\0.projection.weight", - r"depth_head.scratch.refinenet(\d+).out_conv.bias": r"neck.fusion_stage.layers.\0.projection.bias", - r"depth_head.scratch.refinenet(\d+).resConfUnit1.conv1.weight": r"neck.fusion_stage.layers.\0.residual_layer1.convolution1.weight", - r"depth_head.scratch.refinenet(\d+).resConfUnit1.conv1.bias": r"neck.fusion_stage.layers.\0.residual_layer1.convolution1.bias", - r"depth_head.scratch.refinenet(\d+).resConfUnit1.conv2.weight": r"neck.fusion_stage.layers.\0.residual_layer1.convolution2.weight", - r"depth_head.scratch.refinenet(\d+).resConfUnit1.conv2.bias": r"neck.fusion_stage.layers.\0.residual_layer1.convolution2.bias", - r"depth_head.scratch.refinenet(\d+).resConfUnit2.conv1.weight": r"neck.fusion_stage.layers.\0.residual_layer2.convolution1.weight", - r"depth_head.scratch.refinenet(\d+).resConfUnit2.conv1.bias": r"neck.fusion_stage.layers.\0.residual_layer2.convolution1.bias", - r"depth_head.scratch.refinenet(\d+).resConfUnit2.conv2.weight": r"neck.fusion_stage.layers.\0.residual_layer2.convolution2.weight", - r"depth_head.scratch.refinenet(\d+).resConfUnit2.conv2.bias": r"neck.fusion_stage.layers.\0.residual_layer2.convolution2.bias", - r"depth_head.scratch.refinenet(\d+).resConfUnit_depth.0.weight": r"neck.fusion_stage.layers.\0.prompt_depth_layer.convolution1.weight", - r"depth_head.scratch.refinenet(\d+).resConfUnit_depth.0.bias": r"neck.fusion_stage.layers.\0.prompt_depth_layer.convolution1.bias", - r"depth_head.scratch.refinenet(\d+).resConfUnit_depth.2.weight": r"neck.fusion_stage.layers.\0.prompt_depth_layer.convolution2.weight", - r"depth_head.scratch.refinenet(\d+).resConfUnit_depth.2.bias": r"neck.fusion_stage.layers.\0.prompt_depth_layer.convolution2.bias", - r"depth_head.scratch.refinenet(\d+).resConfUnit_depth.4.weight": r"neck.fusion_stage.layers.\0.prompt_depth_layer.convolution3.weight", - r"depth_head.scratch.refinenet(\d+).resConfUnit_depth.4.bias": r"neck.fusion_stage.layers.\0.prompt_depth_layer.convolution3.bias", -} - - def transform_qkv_weights(key, value, config): if not key.startswith("qkv_transform"): return value @@ -136,31 +85,71 @@ def transform_qkv_weights(key, value, config): layer_idx = int(key.split("_")[-1]) hidden_size = config.backbone_config.hidden_size - if "bias" in key: - # Handle bias - return { - f"backbone.encoder.layer.{layer_idx}.attention.attention.query.bias": value[:hidden_size], - f"backbone.encoder.layer.{layer_idx}.attention.attention.key.bias": value[hidden_size : hidden_size * 2], - f"backbone.encoder.layer.{layer_idx}.attention.attention.value.bias": value[-hidden_size:], - } - else: - # Handle weights - return { - f"backbone.encoder.layer.{layer_idx}.attention.attention.query.weight": value[:hidden_size, :], - f"backbone.encoder.layer.{layer_idx}.attention.attention.key.weight": value[ - hidden_size : hidden_size * 2, : - ], - f"backbone.encoder.layer.{layer_idx}.attention.attention.value.weight": value[-hidden_size:, :], - } + suffix = "bias" if "bias" in key else "weight" + return { + f"backbone.encoder.layer.{layer_idx}.attention.attention.query.{suffix}": value[:hidden_size], + f"backbone.encoder.layer.{layer_idx}.attention.attention.key.{suffix}": value[hidden_size : hidden_size * 2], + f"backbone.encoder.layer.{layer_idx}.attention.attention.value.{suffix}": value[-hidden_size:], + } -name_to_checkpoint = { - "prompt-depth-anything-vits": "model.ckpt", - "prompt-depth-anything-vits-transparent": "model.ckpt", - "prompt-depth-anything-vitl": "model.ckpt", +ORIGINAL_TO_CONVERTED_KEY_MAPPING = { + # Stem + r"pretrained.cls_token": r"backbone.embeddings.cls_token", + r"pretrained.mask_token": r"backbone.embeddings.mask_token", + r"pretrained.pos_embed": r"backbone.embeddings.position_embeddings", + r"pretrained.patch_embed.proj.(weight|bias)": r"backbone.embeddings.patch_embeddings.projection.\1", + # Backbone + r"pretrained.norm.(weight|bias)": r"backbone.layernorm.\1", + # Transformer layers + r"pretrained.blocks.(\d+).ls1.gamma": r"backbone.encoder.layer.\1.layer_scale1.lambda1", + r"pretrained.blocks.(\d+).ls2.gamma": r"backbone.encoder.layer.\1.layer_scale2.lambda1", + r"pretrained.blocks.(\d+).norm1.(weight|bias)": r"backbone.encoder.layer.\1.norm1.\2", + r"pretrained.blocks.(\d+).norm2.(weight|bias)": r"backbone.encoder.layer.\1.norm2.\2", + r"pretrained.blocks.(\d+).mlp.fc1.(weight|bias)": r"backbone.encoder.layer.\1.mlp.fc1.\2", + r"pretrained.blocks.(\d+).mlp.fc2.(weight|bias)": r"backbone.encoder.layer.\1.mlp.fc2.\2", + r"pretrained.blocks.(\d+).attn.proj.(weight|bias)": r"backbone.encoder.layer.\1.attention.output.dense.\2", + r"pretrained.blocks.(\d+).attn.qkv.(weight|bias)": r"qkv_transform_\2_\1", + # Neck + r"depth_head.projects.(\d+).(weight|bias)": r"neck.reassemble_stage.layers.\1.projection.\2", + r"depth_head.scratch.layer(\d+)_rn.weight": lambda m: f"neck.convs.{int(m.group(1))-1}.weight", + r"depth_head.resize_layers.(\d+).(weight|bias)": r"neck.reassemble_stage.layers.\1.resize.\2", + # Refinenet (with reversed indices) + r"depth_head.scratch.refinenet(\d+).out_conv.(weight|bias)": lambda m: f"neck.fusion_stage.layers.{4-int(m.group(1))}.projection.{m.group(2)}", + r"depth_head.scratch.refinenet(\d+).resConfUnit1.conv1.(weight|bias)": lambda m: f"neck.fusion_stage.layers.{4-int(m.group(1))}.residual_layer1.convolution1.{m.group(2)}", + r"depth_head.scratch.refinenet(\d+).resConfUnit1.conv2.(weight|bias)": lambda m: f"neck.fusion_stage.layers.{4-int(m.group(1))}.residual_layer1.convolution2.{m.group(2)}", + r"depth_head.scratch.refinenet(\d+).resConfUnit2.conv1.(weight|bias)": lambda m: f"neck.fusion_stage.layers.{4-int(m.group(1))}.residual_layer2.convolution1.{m.group(2)}", + r"depth_head.scratch.refinenet(\d+).resConfUnit2.conv2.(weight|bias)": lambda m: f"neck.fusion_stage.layers.{4-int(m.group(1))}.residual_layer2.convolution2.{m.group(2)}", + r"depth_head.scratch.refinenet(\d+).resConfUnit_depth.0.(weight|bias)": lambda m: f"neck.fusion_stage.layers.{4-int(m.group(1))}.prompt_depth_layer.convolution1.{m.group(2)}", + r"depth_head.scratch.refinenet(\d+).resConfUnit_depth.2.(weight|bias)": lambda m: f"neck.fusion_stage.layers.{4-int(m.group(1))}.prompt_depth_layer.convolution2.{m.group(2)}", + r"depth_head.scratch.refinenet(\d+).resConfUnit_depth.4.(weight|bias)": lambda m: f"neck.fusion_stage.layers.{4-int(m.group(1))}.prompt_depth_layer.convolution3.{m.group(2)}", + # Head + r"depth_head.scratch.output_conv1.(weight|bias)": r"head.conv1.\1", + r"depth_head.scratch.output_conv2.0.(weight|bias)": r"head.conv2.\1", + r"depth_head.scratch.output_conv2.2.(weight|bias)": r"head.conv3.\1", } +def convert_old_keys_to_new_keys(state_dict_keys: dict = None): + """ + Convert old state dict keys to new keys using regex patterns. + """ + output_dict = {} + if state_dict_keys is not None: + for old_key in state_dict_keys: + new_key = old_key + for pattern, replacement in ORIGINAL_TO_CONVERTED_KEY_MAPPING.items(): + match = re.match(pattern, old_key) + if match: + if callable(replacement): + new_key = replacement(match) + else: + new_key = re.sub(pattern, replacement, old_key) + break + output_dict[old_key] = new_key + return output_dict + + @torch.no_grad() def convert_dpt_checkpoint(model_name, pytorch_dump_folder_path, push_to_hub, verify_logits): """ @@ -188,15 +177,15 @@ def convert_dpt_checkpoint(model_name, pytorch_dump_folder_path, push_to_hub, ve state_dict = {key[9:]: state_dict[key] for key in state_dict} # Convert state dict using mappings + key_mapping = convert_old_keys_to_new_keys(state_dict.keys()) new_state_dict = {} for key, value in state_dict.items(): - if key in ORIGINAL_TO_CONVERTED_KEY_MAPPING: - new_key = ORIGINAL_TO_CONVERTED_KEY_MAPPING[key] - transformed_value = transform_qkv_weights(new_key, value, config) - if isinstance(transformed_value, dict): - new_state_dict.update(transformed_value) - else: - new_state_dict[new_key] = transformed_value + new_key = key_mapping[key] + transformed_value = transform_qkv_weights(new_key, value, config) + if isinstance(transformed_value, dict): + new_state_dict.update(transformed_value) + else: + new_state_dict[new_key] = transformed_value # load HuggingFace model model = PromptDepthAnythingForDepthEstimation(config) @@ -264,6 +253,13 @@ def convert_dpt_checkpoint(model_name, pytorch_dump_folder_path, push_to_hub, ve processor.push_to_hub(repo_id=f"{model_name.title()}-hf") +name_to_checkpoint = { + "prompt-depth-anything-vits": "model.ckpt", + "prompt-depth-anything-vits-transparent": "model.ckpt", + "prompt-depth-anything-vitl": "model.ckpt", +} + + if __name__ == "__main__": parser = argparse.ArgumentParser() # Required parameters From 357cc122765c7ab4281a81150a40a7d5d723e533 Mon Sep 17 00:00:00 2001 From: linhaotong Date: Fri, 17 Jan 2025 02:10:51 +0800 Subject: [PATCH 31/58] update testing for modeling prompt depth anything --- .../test_modeling_prompt_depth_anything.py | 48 ++++++++++++++----- 1 file changed, 36 insertions(+), 12 deletions(-) diff --git a/tests/models/prompt_depth_anything/test_modeling_prompt_depth_anything.py b/tests/models/prompt_depth_anything/test_modeling_prompt_depth_anything.py index 5e12644e98f0..136058d5f483 100644 --- a/tests/models/prompt_depth_anything/test_modeling_prompt_depth_anything.py +++ b/tests/models/prompt_depth_anything/test_modeling_prompt_depth_anything.py @@ -16,7 +16,6 @@ import unittest -import numpy as np import requests from transformers import Dinov2Config, PromptDepthAnythingConfig @@ -38,7 +37,7 @@ if is_vision_available(): from PIL import Image - from transformers import DPTImageProcessor + from transformers import AutoImageProcessor class PromptDepthAnythingModelTester: @@ -236,27 +235,52 @@ def prepare_img(): return image +def prepare_prompt_depth(): + prompt_depth_url = ( + "https://github.com/DepthAnything/PromptDA/blob/main/assets/example_images/arkit_depth.png?raw=true" + ) + prompt_depth = Image.open(requests.get(prompt_depth_url, stream=True).raw) + return prompt_depth + + @require_torch @require_vision @slow class PromptDepthAnythingModelIntegrationTest(unittest.TestCase): - def test_inference(self): - image_processor = DPTImageProcessor.from_pretrained("depth-anything/prompt-depth-anything-vits-hf") + def test_inference_wo_prompt_depth(self): + image_processor = AutoImageProcessor.from_pretrained("depth-anything/prompt-depth-anything-vits-hf") model = PromptDepthAnythingForDepthEstimation.from_pretrained( "depth-anything/prompt-depth-anything-vits-hf" ).to(torch_device) image = prepare_img() - prompt_depth_url = ( - "https://github.com/DepthAnything/PromptDA/blob/main/assets/example_images/arkit_depth.png?raw=true" - ) - prompt_depth = Image.open(requests.get(prompt_depth_url, stream=True).raw) - prompt_depth = torch.tensor((np.asarray(prompt_depth) / 1000.0).astype(np.float32)) - prompt_depth = prompt_depth.unsqueeze(0).unsqueeze(0) inputs = image_processor(images=image, return_tensors="pt").to(torch_device) with torch.no_grad(): - outputs = model(pixel_values=inputs.pixel_values, prompt_depth=prompt_depth) + outputs = model(**inputs) + predicted_depth = outputs.predicted_depth + + expected_shape = torch.Size([1, 756, 1008]) + self.assertEqual(predicted_depth.shape, expected_shape) + + expected_slice = torch.tensor( + [[0.5029, 0.5120, 0.5176], [0.4998, 0.5147, 0.5197], [0.4973, 0.5201, 0.5241]] + ).to(torch_device) + + self.assertTrue(torch.allclose(predicted_depth[0, :3, :3], expected_slice, atol=1e-3)) + + def test_inference(self): + image_processor = AutoImageProcessor.from_pretrained("depth-anything/prompt-depth-anything-vits-hf") + model = PromptDepthAnythingForDepthEstimation.from_pretrained( + "depth-anything/prompt-depth-anything-vits-hf" + ).to(torch_device) + + image = prepare_img() + prompt_depth = prepare_prompt_depth() + inputs = image_processor(images=image, return_tensors="pt", prompt_depth=prompt_depth).to(torch_device) + + with torch.no_grad(): + outputs = model(**inputs) predicted_depth = outputs.predicted_depth expected_shape = torch.Size([1, 756, 1008]) @@ -280,7 +304,7 @@ def test_export(self): .to(torch_device) .eval() ) - image_processor = DPTImageProcessor.from_pretrained("depth-anything/prompt-depth-anything-vits-hf") + image_processor = AutoImageProcessor.from_pretrained("depth-anything/prompt-depth-anything-vits-hf") image = prepare_img() inputs = image_processor(images=image, return_tensors="pt").to(torch_device) From f79f91205e4dcab4deff5650d8b098bec65ba338 Mon Sep 17 00:00:00 2001 From: linhaotong Date: Fri, 17 Jan 2025 02:34:52 +0800 Subject: [PATCH 32/58] update testing for image_processing_prompt_depth_anything --- .../image_processing_prompt_depth_anything.py | 9 +- ..._image_processing_prompt_depth_anything.py | 139 ++++++++++++++++++ .../test_modeling_prompt_depth_anything.py | 5 +- 3 files changed, 149 insertions(+), 4 deletions(-) create mode 100644 tests/models/prompt_depth_anything/test_image_processing_prompt_depth_anything.py diff --git a/src/transformers/models/prompt_depth_anything/image_processing_prompt_depth_anything.py b/src/transformers/models/prompt_depth_anything/image_processing_prompt_depth_anything.py index e92f70fdb6a7..8040800f75ee 100644 --- a/src/transformers/models/prompt_depth_anything/image_processing_prompt_depth_anything.py +++ b/src/transformers/models/prompt_depth_anything/image_processing_prompt_depth_anything.py @@ -435,7 +435,7 @@ def preprocess( if prompt_depth is not None: # prompt_depth is a list of images with shape (height, width) # we need to convert it to a list of images with shape (1, height, width) - prompt_depths = make_list_of_images(prompt_depth) + prompt_depths = make_list_of_images(prompt_depth, expected_ndims=2) # Validate prompt_depths has same length as images if len(prompt_depths) != len(images): @@ -451,7 +451,12 @@ def preprocess( depth = to_numpy_array(depth) depth = depth * prompt_scale_to_meter if depth.min() == depth.max(): - raise ValueError("Prompt depth is invalid, min and max are the same.") + # Prompt depth is invalid, min and max are the same. + # We can simply randomly select one pixel and set it to a small value. + EPS = 1e-6 + random_x = np.random.randint(0, depth.shape[0]) + random_y = np.random.randint(0, depth.shape[1]) + depth[random_x, random_y] = depth[0, 0] + EPS depth = depth[..., None].astype(np.float32) depth = to_channel_dimension_format(depth, data_format, input_channel_dim=input_data_format) diff --git a/tests/models/prompt_depth_anything/test_image_processing_prompt_depth_anything.py b/tests/models/prompt_depth_anything/test_image_processing_prompt_depth_anything.py new file mode 100644 index 000000000000..7becbe5dfa50 --- /dev/null +++ b/tests/models/prompt_depth_anything/test_image_processing_prompt_depth_anything.py @@ -0,0 +1,139 @@ +# coding=utf-8 +# Copyright 2024 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import unittest + +import numpy as np + +from transformers.file_utils import is_vision_available +from transformers.testing_utils import require_torch, require_vision + +from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs + + +if is_vision_available(): + from transformers import PromptDepthAnythingImageProcessor + + +class PromptDepthAnythingImageProcessingTester(unittest.TestCase): + def __init__( + self, + parent, + batch_size=7, + num_channels=3, + image_size=18, + min_resolution=30, + max_resolution=400, + do_resize=True, + size=None, + do_normalize=True, + image_mean=[0.5, 0.5, 0.5], + image_std=[0.5, 0.5, 0.5], + ): + super().__init__() + size = size if size is not None else {"height": 18, "width": 18} + self.parent = parent + self.batch_size = batch_size + self.num_channels = num_channels + self.image_size = image_size + self.min_resolution = min_resolution + self.max_resolution = max_resolution + self.do_resize = do_resize + self.size = size + self.do_normalize = do_normalize + self.image_mean = image_mean + self.image_std = image_std + + def prepare_image_processor_dict(self): + return { + "image_mean": self.image_mean, + "image_std": self.image_std, + "do_normalize": self.do_normalize, + "do_resize": self.do_resize, + "size": self.size, + } + + def expected_output_image_shape(self, images): + return self.num_channels, self.size["height"], self.size["width"] + + def prepare_image_inputs(self, equal_resolution=False, numpify=False, torchify=False): + return prepare_image_inputs( + batch_size=self.batch_size, + num_channels=self.num_channels, + min_resolution=self.min_resolution, + max_resolution=self.max_resolution, + equal_resolution=equal_resolution, + numpify=numpify, + torchify=torchify, + ) + + +@require_torch +@require_vision +class PromptDepthAnythingImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase): + image_processing_class = PromptDepthAnythingImageProcessor if is_vision_available() else None + + def setUp(self): + super().setUp() + self.image_processor_tester = PromptDepthAnythingImageProcessingTester(self) + + @property + def image_processor_dict(self): + return self.image_processor_tester.prepare_image_processor_dict() + + def test_image_processor_properties(self): + image_processing = self.image_processing_class(**self.image_processor_dict) + self.assertTrue(hasattr(image_processing, "image_mean")) + self.assertTrue(hasattr(image_processing, "image_std")) + self.assertTrue(hasattr(image_processing, "do_normalize")) + self.assertTrue(hasattr(image_processing, "do_resize")) + self.assertTrue(hasattr(image_processing, "size")) + self.assertTrue(hasattr(image_processing, "do_rescale")) + self.assertTrue(hasattr(image_processing, "rescale_factor")) + self.assertTrue(hasattr(image_processing, "do_pad")) + self.assertTrue(hasattr(image_processing, "size_divisor")) + self.assertTrue(hasattr(image_processing, "prompt_scale_to_meter")) + + def test_image_processor_from_dict_with_kwargs(self): + image_processor = self.image_processing_class.from_dict(self.image_processor_dict) + self.assertEqual(image_processor.size, {"height": 18, "width": 18}) + + image_processor = self.image_processing_class.from_dict(self.image_processor_dict, size=42) + self.assertEqual(image_processor.size, {"height": 42, "width": 42}) + + def test_keep_aspect_ratio(self): + size = {"height": 512, "width": 512} + image_processor = PromptDepthAnythingImageProcessor(size=size, keep_aspect_ratio=True, ensure_multiple_of=32) + + image = np.zeros((489, 640, 3)) + + pixel_values = image_processor(image, return_tensors="pt").pixel_values + + self.assertEqual(list(pixel_values.shape), [1, 3, 512, 672]) + + def test_prompt_depth_processing(self): + size = {"height": 756, "width": 756} + image_processor = PromptDepthAnythingImageProcessor(size=size, keep_aspect_ratio=True, ensure_multiple_of=32) + + image = np.zeros((756, 1008, 3)) + prompt_depth = np.random.random((192, 256)) + + outputs = image_processor(image, prompt_depth=prompt_depth, return_tensors="pt") + pixel_values = outputs.pixel_values + prompt_depth_values = outputs.prompt_depth + + self.assertEqual(list(pixel_values.shape), [1, 3, 768, 1024]) + self.assertEqual(list(prompt_depth_values.shape), [1, 1, 192, 256]) diff --git a/tests/models/prompt_depth_anything/test_modeling_prompt_depth_anything.py b/tests/models/prompt_depth_anything/test_modeling_prompt_depth_anything.py index 136058d5f483..3e95670fc460 100644 --- a/tests/models/prompt_depth_anything/test_modeling_prompt_depth_anything.py +++ b/tests/models/prompt_depth_anything/test_modeling_prompt_depth_anything.py @@ -306,11 +306,12 @@ def test_export(self): ) image_processor = AutoImageProcessor.from_pretrained("depth-anything/prompt-depth-anything-vits-hf") image = prepare_img() - inputs = image_processor(images=image, return_tensors="pt").to(torch_device) + prompt_depth = prepare_prompt_depth() + inputs = image_processor(images=image, prompt_depth=prompt_depth, return_tensors="pt").to(torch_device) exported_program = torch.export.export( model, - args=(inputs["pixel_values"],), + args=(inputs["pixel_values"], inputs["prompt_depth"]), strict=strict, ) with torch.no_grad(): From 2aa3363b7eef98e9c07851f54aa8a2fb76ebbcb2 Mon Sep 17 00:00:00 2001 From: linhaotong Date: Fri, 17 Jan 2025 02:44:06 +0800 Subject: [PATCH 33/58] fix assertion in image_processing_prompt_depth_anything --- .../image_processing_prompt_depth_anything.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/transformers/models/prompt_depth_anything/image_processing_prompt_depth_anything.py b/src/transformers/models/prompt_depth_anything/image_processing_prompt_depth_anything.py index 8040800f75ee..75de664cc46a 100644 --- a/src/transformers/models/prompt_depth_anything/image_processing_prompt_depth_anything.py +++ b/src/transformers/models/prompt_depth_anything/image_processing_prompt_depth_anything.py @@ -436,6 +436,7 @@ def preprocess( # prompt_depth is a list of images with shape (height, width) # we need to convert it to a list of images with shape (1, height, width) prompt_depths = make_list_of_images(prompt_depth, expected_ndims=2) + assert len(prompt_depths) == len(images) # Validate prompt_depths has same length as images if len(prompt_depths) != len(images): From 17bd1688efb5a3c084473f6a60cfa94a19d3ff82 Mon Sep 17 00:00:00 2001 From: Haotong LIN Date: Wed, 22 Jan 2025 23:56:42 +0800 Subject: [PATCH 34/58] Update src/transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py Co-authored-by: Pavel Iakubovskii --- .../prompt_depth_anything/modular_prompt_depth_anything.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py b/src/transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py index 9d14b5fc7e1f..f86a50821807 100644 --- a/src/transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py +++ b/src/transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py @@ -237,7 +237,11 @@ class PromptDepthAnythingReassembleStage(DepthAnythingReassembleStage): class PromptDepthAnythingNeck(DepthAnythingNeck): def forward( - self, hidden_states: List[torch.Tensor], patch_height=None, patch_width=None, prompt_depth=None + self, + hidden_states: List[torch.Tensor], + patch_height: Optional[int] = None, + patch_width: Optional[int] = None, + prompt_depth: Optional[torch.Tensor] = None, ) -> List[torch.Tensor]: """ Args: From 1d7a6d05a294675744cc2fd71084dfea70cd7902 Mon Sep 17 00:00:00 2001 From: Haotong LIN Date: Wed, 22 Jan 2025 23:56:55 +0800 Subject: [PATCH 35/58] Update src/transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py Co-authored-by: Pavel Iakubovskii --- .../prompt_depth_anything/modular_prompt_depth_anything.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py b/src/transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py index f86a50821807..54356cf6d20d 100644 --- a/src/transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py +++ b/src/transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py @@ -129,7 +129,7 @@ def forward(self, hidden_states, size=None, prompt_depth=None): class PromptDepthAnythingDepthEstimationHead(DepthAnythingDepthEstimationHead): - def forward(self, hidden_states: List[torch.Tensor], patch_height, patch_width) -> torch.Tensor: + def forward(self, hidden_states: List[torch.Tensor], patch_height: int, patch_width: int) -> torch.Tensor: hidden_states = hidden_states[-1] predicted_depth = self.conv1(hidden_states) From ab381cacebdfc1ae6ddedd96a847f7df56941878 Mon Sep 17 00:00:00 2001 From: Haotong LIN Date: Wed, 22 Jan 2025 23:57:17 +0800 Subject: [PATCH 36/58] Update src/transformers/models/prompt_depth_anything/image_processing_prompt_depth_anything.py Co-authored-by: Pavel Iakubovskii --- .../image_processing_prompt_depth_anything.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/src/transformers/models/prompt_depth_anything/image_processing_prompt_depth_anything.py b/src/transformers/models/prompt_depth_anything/image_processing_prompt_depth_anything.py index 75de664cc46a..2efc1070e9d8 100644 --- a/src/transformers/models/prompt_depth_anything/image_processing_prompt_depth_anything.py +++ b/src/transformers/models/prompt_depth_anything/image_processing_prompt_depth_anything.py @@ -453,11 +453,8 @@ def preprocess( depth = depth * prompt_scale_to_meter if depth.min() == depth.max(): # Prompt depth is invalid, min and max are the same. - # We can simply randomly select one pixel and set it to a small value. - EPS = 1e-6 - random_x = np.random.randint(0, depth.shape[0]) - random_y = np.random.randint(0, depth.shape[1]) - depth[random_x, random_y] = depth[0, 0] + EPS + # We can simply select one pixel and set it to a small value. + depth[0, 0] = depth[0, 0] + 1e-6 depth = depth[..., None].astype(np.float32) depth = to_channel_dimension_format(depth, data_format, input_channel_dim=input_data_format) From a509ad1575d645d76da8cd43b487ee4905b5ad3e Mon Sep 17 00:00:00 2001 From: Haotong LIN Date: Wed, 22 Jan 2025 23:57:56 +0800 Subject: [PATCH 37/58] Update src/transformers/models/prompt_depth_anything/image_processing_prompt_depth_anything.py Co-authored-by: Pavel Iakubovskii --- .../image_processing_prompt_depth_anything.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/transformers/models/prompt_depth_anything/image_processing_prompt_depth_anything.py b/src/transformers/models/prompt_depth_anything/image_processing_prompt_depth_anything.py index 2efc1070e9d8..422015f21771 100644 --- a/src/transformers/models/prompt_depth_anything/image_processing_prompt_depth_anything.py +++ b/src/transformers/models/prompt_depth_anything/image_processing_prompt_depth_anything.py @@ -436,7 +436,6 @@ def preprocess( # prompt_depth is a list of images with shape (height, width) # we need to convert it to a list of images with shape (1, height, width) prompt_depths = make_list_of_images(prompt_depth, expected_ndims=2) - assert len(prompt_depths) == len(images) # Validate prompt_depths has same length as images if len(prompt_depths) != len(images): From 188b88d5bc9a8d52bbbd6d34037a61082bee1571 Mon Sep 17 00:00:00 2001 From: Haotong LIN Date: Wed, 22 Jan 2025 23:58:11 +0800 Subject: [PATCH 38/58] Update src/transformers/models/prompt_depth_anything/image_processing_prompt_depth_anything.py Co-authored-by: Pavel Iakubovskii --- .../image_processing_prompt_depth_anything.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/prompt_depth_anything/image_processing_prompt_depth_anything.py b/src/transformers/models/prompt_depth_anything/image_processing_prompt_depth_anything.py index 422015f21771..acf908792bcd 100644 --- a/src/transformers/models/prompt_depth_anything/image_processing_prompt_depth_anything.py +++ b/src/transformers/models/prompt_depth_anything/image_processing_prompt_depth_anything.py @@ -316,7 +316,7 @@ def preprocess( do_resize (`bool`, *optional*, defaults to `self.do_resize`): Whether to resize the image. size (`Dict[str, int]`, *optional*, defaults to `self.size`): - Size of the image after reszing. If `keep_aspect_ratio` is `True`, the image is resized to the largest + Size of the image after resizing. If `keep_aspect_ratio` is `True`, the image is resized to the largest possible size such that the aspect ratio is preserved. If `ensure_multiple_of` is set, the image is resized to a size that is a multiple of this value. keep_aspect_ratio (`bool`, *optional*, defaults to `self.keep_aspect_ratio`): From 9d48d97c223d409b5ddfc06fd2b440187fcc0dd9 Mon Sep 17 00:00:00 2001 From: Haotong LIN Date: Wed, 22 Jan 2025 23:58:31 +0800 Subject: [PATCH 39/58] Update docs/source/en/model_doc/prompt_depth_anything.md Co-authored-by: Pavel Iakubovskii --- docs/source/en/model_doc/prompt_depth_anything.md | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/docs/source/en/model_doc/prompt_depth_anything.md b/docs/source/en/model_doc/prompt_depth_anything.md index d1fc02009044..8cd663591274 100644 --- a/docs/source/en/model_doc/prompt_depth_anything.md +++ b/docs/source/en/model_doc/prompt_depth_anything.md @@ -35,11 +35,12 @@ alt="drawing" width="600"/> The Transformers library allows you to use the model with just a few lines of code: ```python ->>> from transformers import AutoImageProcessor, AutoModelForDepthEstimation >>> import torch +>>> import requests >>> import numpy as np + >>> from PIL import Image ->>> import requests +>>> from transformers import AutoImageProcessor, AutoModelForDepthEstimation >>> url = "https://github.com/DepthAnything/PromptDA/blob/main/assets/example_images/image.jpg?raw=true" >>> image = Image.open(requests.get(url, stream=True).raw) From c033e6cf2b9f8cb72c9da3e01d3ce2f821bac07e Mon Sep 17 00:00:00 2001 From: Haotong LIN Date: Thu, 23 Jan 2025 01:27:23 +0800 Subject: [PATCH 40/58] Update docs/source/en/model_doc/prompt_depth_anything.md Co-authored-by: Pavel Iakubovskii --- docs/source/en/model_doc/prompt_depth_anything.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/docs/source/en/model_doc/prompt_depth_anything.md b/docs/source/en/model_doc/prompt_depth_anything.md index 8cd663591274..f2efb8aa4811 100644 --- a/docs/source/en/model_doc/prompt_depth_anything.md +++ b/docs/source/en/model_doc/prompt_depth_anything.md @@ -92,4 +92,5 @@ If you're interested in submitting a resource to be included here, please feel f ## PromptDepthAnythingImageProcessor [[autodoc]] PromptDepthAnythingImageProcessor - - preprocess \ No newline at end of file + - preprocess + - post_process_depth_estimation \ No newline at end of file From c2693f8ff4adefb398092f4d318ea465b4d9dab2 Mon Sep 17 00:00:00 2001 From: linhaotong Date: Mon, 27 Jan 2025 00:01:47 +0800 Subject: [PATCH 41/58] update some testing --- .../modeling_prompt_depth_anything.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py b/src/transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py index 2b4a6ea0ef7f..033edb2d1c1c 100644 --- a/src/transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py +++ b/src/transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py @@ -207,7 +207,7 @@ def __init__(self, config): raise ValueError(f"Unknown depth estimation type: {config.depth_estimation_type}") self.max_depth = config.max_depth - def forward(self, hidden_states: List[torch.Tensor], patch_height, patch_width) -> torch.Tensor: + def forward(self, hidden_states: List[torch.Tensor], patch_height: int, patch_width: int) -> torch.Tensor: hidden_states = hidden_states[-1] predicted_depth = self.conv1(hidden_states) @@ -345,7 +345,11 @@ def __init__(self, config): self.fusion_stage = PromptDepthAnythingFeatureFusionStage(config) def forward( - self, hidden_states: List[torch.Tensor], patch_height=None, patch_width=None, prompt_depth=None + self, + hidden_states: List[torch.Tensor], + patch_height: Optional[int] = None, + patch_width: Optional[int] = None, + prompt_depth: Optional[torch.Tensor] = None, ) -> List[torch.Tensor]: """ Args: From d957f56424f2012bcc9ef543562cc26607f780fe Mon Sep 17 00:00:00 2001 From: linhaotong Date: Mon, 27 Jan 2025 00:39:00 +0800 Subject: [PATCH 42/58] fix testing --- .../prompt_depth_anything/test_modeling_prompt_depth_anything.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/models/prompt_depth_anything/test_modeling_prompt_depth_anything.py b/tests/models/prompt_depth_anything/test_modeling_prompt_depth_anything.py index 3e95670fc460..aff7f4b1196f 100644 --- a/tests/models/prompt_depth_anything/test_modeling_prompt_depth_anything.py +++ b/tests/models/prompt_depth_anything/test_modeling_prompt_depth_anything.py @@ -293,6 +293,7 @@ def test_inference(self): self.assertTrue(torch.allclose(predicted_depth[0, :3, :3], expected_slice, atol=1e-3)) def test_export(self): + for strict in [True, False]: with self.subTest(strict=strict): if not is_torch_greater_or_equal_than_2_4: From b34e35a1cb0e53a8eefee085bff9f7935d7f74a0 Mon Sep 17 00:00:00 2001 From: linhaotong Date: Mon, 27 Jan 2025 01:42:30 +0800 Subject: [PATCH 43/58] fix --- .../prompt_depth_anything/test_modeling_prompt_depth_anything.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/models/prompt_depth_anything/test_modeling_prompt_depth_anything.py b/tests/models/prompt_depth_anything/test_modeling_prompt_depth_anything.py index aff7f4b1196f..3e95670fc460 100644 --- a/tests/models/prompt_depth_anything/test_modeling_prompt_depth_anything.py +++ b/tests/models/prompt_depth_anything/test_modeling_prompt_depth_anything.py @@ -293,7 +293,6 @@ def test_inference(self): self.assertTrue(torch.allclose(predicted_depth[0, :3, :3], expected_slice, atol=1e-3)) def test_export(self): - for strict in [True, False]: with self.subTest(strict=strict): if not is_torch_greater_or_equal_than_2_4: From 27e2e494c5a770b1b4056a2d7bfb586c944d1b97 Mon Sep 17 00:00:00 2001 From: linhaotong Date: Mon, 3 Feb 2025 23:14:40 +0800 Subject: [PATCH 44/58] add return doc for forward of prompt depth anything --- .../prompt_depth_anything/modeling_prompt_depth_anything.py | 1 + .../prompt_depth_anything/modular_prompt_depth_anything.py | 1 + 2 files changed, 2 insertions(+) diff --git a/src/transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py b/src/transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py index 033edb2d1c1c..3d67fcffd7fc 100644 --- a/src/transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py +++ b/src/transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py @@ -441,6 +441,7 @@ def forward( Ground truth depth estimation maps for computing the loss. Returns: + DepthEstimatorOutput: A DepthEstimatorOutput containing the depth prediction. Examples: diff --git a/src/transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py b/src/transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py index 54356cf6d20d..85574838f354 100644 --- a/src/transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py +++ b/src/transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py @@ -288,6 +288,7 @@ def forward( Ground truth depth estimation maps for computing the loss. Returns: + DepthEstimatorOutput: A DepthEstimatorOutput containing the depth prediction. Examples: ```python From cc9c0b759a422b785091331a1a1fc7fddc671b8e Mon Sep 17 00:00:00 2001 From: Haotong LIN Date: Mon, 3 Feb 2025 23:26:30 +0800 Subject: [PATCH 45/58] Update src/transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py Co-authored-by: Pavel Iakubovskii --- .../prompt_depth_anything/modular_prompt_depth_anything.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py b/src/transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py index 85574838f354..97d0f0e8a5f7 100644 --- a/src/transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py +++ b/src/transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py @@ -277,10 +277,10 @@ class PromptDepthAnythingForDepthEstimation(DepthAnythingForDepthEstimation): def forward( self, pixel_values: torch.FloatTensor, + prompt_depth: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - prompt_depth: Optional[torch.FloatTensor] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple[torch.Tensor], DepthEstimatorOutput]: r""" From e8ac52681389bf0f2c0393df2e2c4db01746360a Mon Sep 17 00:00:00 2001 From: Haotong LIN Date: Mon, 3 Feb 2025 23:26:55 +0800 Subject: [PATCH 46/58] Update tests/models/prompt_depth_anything/test_modeling_prompt_depth_anything.py Co-authored-by: Pavel Iakubovskii --- .../test_modeling_prompt_depth_anything.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/models/prompt_depth_anything/test_modeling_prompt_depth_anything.py b/tests/models/prompt_depth_anything/test_modeling_prompt_depth_anything.py index 3e95670fc460..c1e9111cadda 100644 --- a/tests/models/prompt_depth_anything/test_modeling_prompt_depth_anything.py +++ b/tests/models/prompt_depth_anything/test_modeling_prompt_depth_anything.py @@ -316,7 +316,7 @@ def test_export(self): ) with torch.no_grad(): eager_outputs = model(**inputs) - exported_outputs = exported_program.module().forward(inputs["pixel_values"]) + exported_outputs = exported_program.module().forward(inputs["pixel_values"], inputs["prompt_depth"]) self.assertEqual(eager_outputs.predicted_depth.shape, exported_outputs.predicted_depth.shape) self.assertTrue( torch.allclose(eager_outputs.predicted_depth, exported_outputs.predicted_depth, atol=1e-4) From 0ffa387060ffb3eafd004e7d4118380e727bd6ab Mon Sep 17 00:00:00 2001 From: linhaotong Date: Mon, 3 Feb 2025 23:28:51 +0800 Subject: [PATCH 47/58] fix prompt depth order --- .../prompt_depth_anything/modeling_prompt_depth_anything.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py b/src/transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py index 3d67fcffd7fc..33612400f299 100644 --- a/src/transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py +++ b/src/transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py @@ -424,10 +424,10 @@ def __init__(self, config): def forward( self, pixel_values: torch.FloatTensor, + prompt_depth: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - prompt_depth: Optional[torch.FloatTensor] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple[torch.Tensor], DepthEstimatorOutput]: r""" From 51acb23db2fc28f2e95799c8197659cf7b6bc395 Mon Sep 17 00:00:00 2001 From: linhaotong Date: Mon, 3 Feb 2025 23:32:27 +0800 Subject: [PATCH 48/58] fix format for testing prompt depth anything --- .../test_modeling_prompt_depth_anything.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/models/prompt_depth_anything/test_modeling_prompt_depth_anything.py b/tests/models/prompt_depth_anything/test_modeling_prompt_depth_anything.py index c1e9111cadda..77cb96ccea02 100644 --- a/tests/models/prompt_depth_anything/test_modeling_prompt_depth_anything.py +++ b/tests/models/prompt_depth_anything/test_modeling_prompt_depth_anything.py @@ -316,7 +316,9 @@ def test_export(self): ) with torch.no_grad(): eager_outputs = model(**inputs) - exported_outputs = exported_program.module().forward(inputs["pixel_values"], inputs["prompt_depth"]) + exported_outputs = exported_program.module().forward( + inputs["pixel_values"], inputs["prompt_depth"] + ) self.assertEqual(eager_outputs.predicted_depth.shape, exported_outputs.predicted_depth.shape) self.assertTrue( torch.allclose(eager_outputs.predicted_depth, exported_outputs.predicted_depth, atol=1e-4) From 149225b8844b58a5c184d5b1b76261304c680623 Mon Sep 17 00:00:00 2001 From: linhaotong Date: Tue, 4 Feb 2025 00:02:53 +0800 Subject: [PATCH 49/58] fix minor issues in prompt depth anything doc --- docs/source/en/model_doc/prompt_depth_anything.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/en/model_doc/prompt_depth_anything.md b/docs/source/en/model_doc/prompt_depth_anything.md index f2efb8aa4811..910298fa8c71 100644 --- a/docs/source/en/model_doc/prompt_depth_anything.md +++ b/docs/source/en/model_doc/prompt_depth_anything.md @@ -78,7 +78,7 @@ A list of official Hugging Face and community (indicated by 🌎) resources to h - [Prompt Depth Anything Demo](https://huggingface.co/spaces/depth-anything/PromptDA) - [Prompt Depth Anything Interactive Results](https://promptda.github.io/interactive.html) -If you're interested in submitting a resource to be included here, please feel free to open a Pull Request and we'll review it! The resource should ideally demonstrate something new instead of duplicating an existing resource. +If you are interested in submitting a resource to be included here, please feel free to open a Pull Request and we'll review it! The resource should ideally demonstrate something new instead of duplicating an existing resource. ## PromptDepthAnythingConfig From f0942579d4a35feb38efeb59191977584f9ddd4a Mon Sep 17 00:00:00 2001 From: linhaotong Date: Tue, 4 Feb 2025 00:20:27 +0800 Subject: [PATCH 50/58] fix format for modular prompt depth anything --- .../modeling_prompt_depth_anything.py | 16 +++++++++------- .../modular_prompt_depth_anything.py | 16 +++++++++------- 2 files changed, 18 insertions(+), 14 deletions(-) diff --git a/src/transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py b/src/transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py index 33612400f299..c8458cf4f69c 100644 --- a/src/transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py +++ b/src/transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py @@ -522,14 +522,16 @@ def forward( output = (predicted_depth,) + outputs[1:] else: output = (predicted_depth,) + outputs[2:] - return ((loss,) + output) if loss is not None else output + ret = ((loss,) + output) if loss is not None else output + else: + ret = DepthEstimatorOutput( + loss=loss, + predicted_depth=predicted_depth, + hidden_states=outputs.hidden_states if output_hidden_states else None, + attentions=outputs.attentions, + ) - return DepthEstimatorOutput( - loss=loss, - predicted_depth=predicted_depth, - hidden_states=outputs.hidden_states if output_hidden_states else None, - attentions=outputs.attentions, - ) + return ret __all__ = ["PromptDepthAnythingForDepthEstimation", "PromptDepthAnythingPreTrainedModel"] diff --git a/src/transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py b/src/transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py index 97d0f0e8a5f7..eacbf399adb7 100644 --- a/src/transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py +++ b/src/transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py @@ -368,14 +368,16 @@ def forward( output = (predicted_depth,) + outputs[1:] else: output = (predicted_depth,) + outputs[2:] - return ((loss,) + output) if loss is not None else output + ret = ((loss,) + output) if loss is not None else output + else: + ret = DepthEstimatorOutput( + loss=loss, + predicted_depth=predicted_depth, + hidden_states=outputs.hidden_states if output_hidden_states else None, + attentions=outputs.attentions, + ) - return DepthEstimatorOutput( - loss=loss, - predicted_depth=predicted_depth, - hidden_states=outputs.hidden_states if output_hidden_states else None, - attentions=outputs.attentions, - ) + return ret __all__ = [ From a6089ff7db8dfa6262db5d424ebe8ad5549e52cc Mon Sep 17 00:00:00 2001 From: linhaotong Date: Tue, 4 Feb 2025 00:53:08 +0800 Subject: [PATCH 51/58] revert format for modular prompt depth anything --- .../modeling_prompt_depth_anything.py | 17 +++++++---------- .../modular_prompt_depth_anything.py | 18 ++++++++---------- 2 files changed, 15 insertions(+), 20 deletions(-) diff --git a/src/transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py b/src/transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py index c8458cf4f69c..5d9c3b313dc3 100644 --- a/src/transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py +++ b/src/transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py @@ -18,7 +18,6 @@ from .configuration_prompt_depth_anything import PromptDepthAnythingConfig -# General docstring _CONFIG_FOR_DOC = "PromptDepthAnythingConfig" @@ -522,16 +521,14 @@ def forward( output = (predicted_depth,) + outputs[1:] else: output = (predicted_depth,) + outputs[2:] - ret = ((loss,) + output) if loss is not None else output - else: - ret = DepthEstimatorOutput( - loss=loss, - predicted_depth=predicted_depth, - hidden_states=outputs.hidden_states if output_hidden_states else None, - attentions=outputs.attentions, - ) + return ((loss,) + output) if loss is not None else output - return ret + return DepthEstimatorOutput( + loss=loss, + predicted_depth=predicted_depth, + hidden_states=outputs.hidden_states if output_hidden_states else None, + attentions=outputs.attentions, + ) __all__ = ["PromptDepthAnythingForDepthEstimation", "PromptDepthAnythingPreTrainedModel"] diff --git a/src/transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py b/src/transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py index eacbf399adb7..211079e7cceb 100644 --- a/src/transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py +++ b/src/transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py @@ -368,16 +368,14 @@ def forward( output = (predicted_depth,) + outputs[1:] else: output = (predicted_depth,) + outputs[2:] - ret = ((loss,) + output) if loss is not None else output - else: - ret = DepthEstimatorOutput( - loss=loss, - predicted_depth=predicted_depth, - hidden_states=outputs.hidden_states if output_hidden_states else None, - attentions=outputs.attentions, - ) - - return ret + return ((loss,) + output) if loss is not None else output + + return DepthEstimatorOutput( + loss=loss, + predicted_depth=predicted_depth, + hidden_states=outputs.hidden_states if output_hidden_states else None, + attentions=outputs.attentions, + ) __all__ = [ From 2caf27ca6ee5e40f94f69a4fcddb432f225db3a8 Mon Sep 17 00:00:00 2001 From: linhaotong Date: Tue, 4 Feb 2025 00:58:18 +0800 Subject: [PATCH 52/58] revert format for modular prompt depth anything --- .../prompt_depth_anything/modeling_prompt_depth_anything.py | 1 + .../prompt_depth_anything/modular_prompt_depth_anything.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py b/src/transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py index 5d9c3b313dc3..33612400f299 100644 --- a/src/transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py +++ b/src/transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py @@ -18,6 +18,7 @@ from .configuration_prompt_depth_anything import PromptDepthAnythingConfig +# General docstring _CONFIG_FOR_DOC = "PromptDepthAnythingConfig" diff --git a/src/transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py b/src/transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py index 211079e7cceb..97d0f0e8a5f7 100644 --- a/src/transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py +++ b/src/transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py @@ -369,7 +369,7 @@ def forward( else: output = (predicted_depth,) + outputs[2:] return ((loss,) + output) if loss is not None else output - + return DepthEstimatorOutput( loss=loss, predicted_depth=predicted_depth, From fca03fb0cd97d462780429065015a754a052a93e Mon Sep 17 00:00:00 2001 From: linhaotong Date: Tue, 4 Feb 2025 01:21:04 +0800 Subject: [PATCH 53/58] update format for modular prompt depth anything --- .../prompt_depth_anything/modeling_prompt_depth_anything.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py b/src/transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py index 33612400f299..5d9c3b313dc3 100644 --- a/src/transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py +++ b/src/transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py @@ -18,7 +18,6 @@ from .configuration_prompt_depth_anything import PromptDepthAnythingConfig -# General docstring _CONFIG_FOR_DOC = "PromptDepthAnythingConfig" From f16f9b486311701a9efad86cca9d222094b0948c Mon Sep 17 00:00:00 2001 From: linhaotong Date: Tue, 4 Feb 2025 02:03:52 +0800 Subject: [PATCH 54/58] fix parallel testing errors --- .../prompt_depth_anything/modeling_prompt_depth_anything.py | 3 ++- .../prompt_depth_anything/modular_prompt_depth_anything.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py b/src/transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py index 5d9c3b313dc3..5430a08bcd65 100644 --- a/src/transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py +++ b/src/transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py @@ -512,7 +512,8 @@ def forward( predicted_depth = self.head(hidden_states, patch_height, patch_width) if prompt_depth is not None: # denormalize predicted depth - depth_min, depth_max = depth_min.squeeze(1), depth_max.squeeze(1) + depth_min = depth_min.squeeze(1).to(predicted_depth.device) + depth_max = depth_max.squeeze(1).to(predicted_depth.device) predicted_depth = predicted_depth * (depth_max - depth_min) + depth_min # denormalize done diff --git a/src/transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py b/src/transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py index 97d0f0e8a5f7..6e37f4ed8a33 100644 --- a/src/transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py +++ b/src/transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py @@ -359,7 +359,8 @@ def forward( predicted_depth = self.head(hidden_states, patch_height, patch_width) if prompt_depth is not None: # denormalize predicted depth - depth_min, depth_max = depth_min.squeeze(1), depth_max.squeeze(1) + depth_min = depth_min.squeeze(1).to(predicted_depth.device) + depth_max = depth_max.squeeze(1).to(predicted_depth.device) predicted_depth = predicted_depth * (depth_max - depth_min) + depth_min # denormalize done From 47178a6d4cef37abc35681949ef3534ff36a7656 Mon Sep 17 00:00:00 2001 From: linhaotong Date: Tue, 4 Feb 2025 16:50:07 +0800 Subject: [PATCH 55/58] fix doc for prompt depth anything --- .../modeling_prompt_depth_anything.py | 7 ------- .../prompt_depth_anything/modular_prompt_depth_anything.py | 7 ------- 2 files changed, 14 deletions(-) diff --git a/src/transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py b/src/transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py index 5430a08bcd65..e0cff66521ff 100644 --- a/src/transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py +++ b/src/transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py @@ -435,13 +435,6 @@ def forward( Returns: - Examples: - labels (`torch.LongTensor` of shape `(batch_size, height, width)`, *optional*): - Ground truth depth estimation maps for computing the loss. - - Returns: - DepthEstimatorOutput: A DepthEstimatorOutput containing the depth prediction. - Examples: ```python diff --git a/src/transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py b/src/transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py index 6e37f4ed8a33..da4a05735ad9 100644 --- a/src/transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py +++ b/src/transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py @@ -284,13 +284,6 @@ def forward( return_dict: Optional[bool] = None, ) -> Union[Tuple[torch.Tensor], DepthEstimatorOutput]: r""" - labels (`torch.LongTensor` of shape `(batch_size, height, width)`, *optional*): - Ground truth depth estimation maps for computing the loss. - - Returns: - DepthEstimatorOutput: A DepthEstimatorOutput containing the depth prediction. - - Examples: ```python >>> from transformers import AutoImageProcessor, AutoModelForDepthEstimation >>> import torch From fe7ec93c20cb2f439e86f671b2db19f61a0a09d2 Mon Sep 17 00:00:00 2001 From: qubvel Date: Thu, 20 Mar 2025 15:53:10 +0000 Subject: [PATCH 56/58] Add header --- .../modular_prompt_depth_anything.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/src/transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py b/src/transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py index da4a05735ad9..ad9b254a8a26 100644 --- a/src/transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py +++ b/src/transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py @@ -1,3 +1,15 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from typing import List, Optional, Tuple, Union import torch From 95abf2604655e8f300c5ff491fe6149e7629b516 Mon Sep 17 00:00:00 2001 From: qubvel Date: Thu, 20 Mar 2025 15:54:04 +0000 Subject: [PATCH 57/58] Fix imports --- .../image_processing_prompt_depth_anything.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/transformers/models/prompt_depth_anything/image_processing_prompt_depth_anything.py b/src/transformers/models/prompt_depth_anything/image_processing_prompt_depth_anything.py index acf908792bcd..b4fdea0d4c11 100644 --- a/src/transformers/models/prompt_depth_anything/image_processing_prompt_depth_anything.py +++ b/src/transformers/models/prompt_depth_anything/image_processing_prompt_depth_anything.py @@ -41,7 +41,6 @@ from ...utils import ( TensorType, filter_out_non_signature_kwargs, - is_vision_available, logging, requires_backends, ) @@ -50,9 +49,6 @@ if is_torch_available(): import torch -if is_vision_available(): - pass - logger = logging.get_logger(__name__) From e7330fcdc1cd68ca7e0ceb235233f6b54b43419d Mon Sep 17 00:00:00 2001 From: qubvel Date: Thu, 20 Mar 2025 15:56:14 +0000 Subject: [PATCH 58/58] Licence header --- .../configuration_prompt_depth_anything.py | 12 ++++++++++++ .../modeling_prompt_depth_anything.py | 18 ++++++++++++++++++ 2 files changed, 30 insertions(+) diff --git a/src/transformers/models/prompt_depth_anything/configuration_prompt_depth_anything.py b/src/transformers/models/prompt_depth_anything/configuration_prompt_depth_anything.py index 4852afb9c84f..cf213133c142 100644 --- a/src/transformers/models/prompt_depth_anything/configuration_prompt_depth_anything.py +++ b/src/transformers/models/prompt_depth_anything/configuration_prompt_depth_anything.py @@ -4,6 +4,18 @@ # the file from the modular. If any change should be done, please apply the change to the # modular_prompt_depth_anything.py file directly. One of our CI enforces this. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# Copyright 2025 The HuggingFace Team. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import copy diff --git a/src/transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py b/src/transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py index e0cff66521ff..7653a72f0695 100644 --- a/src/transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py +++ b/src/transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py @@ -4,6 +4,18 @@ # the file from the modular. If any change should be done, please apply the change to the # modular_prompt_depth_anything.py file directly. One of our CI enforces this. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# Copyright 2025 The HuggingFace Team. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from typing import List, Optional, Tuple, Union import torch @@ -394,6 +406,12 @@ def forward( output_hidden_states (`bool`, *optional*): Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for more detail. + prompt_depth (`torch.FloatTensor` of shape `(batch_size, 1, height, width)`, *optional*): + Prompt depth is the sparse or low-resolution depth obtained from multi-view geometry or a + low-resolution depth sensor. It generally has shape (height, width), where height + and width can be smaller than those of the images. It is optional and can be None, which means no prompt depth + will be used. If it is None, the output will be a monocular relative depth. + The values are recommended to be in meters, but this is not necessary. return_dict (`bool`, *optional*): Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple. """