diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml
index 7557f295d968..2fd44c1ae5d1 100644
--- a/docs/source/en/_toctree.yml
+++ b/docs/source/en/_toctree.yml
@@ -653,6 +653,8 @@
title: Depth Anything
- local: model_doc/depth_anything_v2
title: Depth Anything V2
+ - local: model_doc/depth_pro
+ title: DepthPro
- local: model_doc/deta
title: DETA
- local: model_doc/detr
diff --git a/docs/source/en/index.md b/docs/source/en/index.md
index 9c3c5c76954d..3ac18788c8fe 100644
--- a/docs/source/en/index.md
+++ b/docs/source/en/index.md
@@ -123,6 +123,7 @@ Flax), PyTorch, and/or TensorFlow.
| [DeiT](model_doc/deit) | ✅ | ✅ | ❌ |
| [DePlot](model_doc/deplot) | ✅ | ❌ | ❌ |
| [Depth Anything](model_doc/depth_anything) | ✅ | ❌ | ❌ |
+| [DepthPro](model_doc/depth_pro) | ✅ | ❌ | ❌ |
| [DETA](model_doc/deta) | ✅ | ❌ | ❌ |
| [DETR](model_doc/detr) | ✅ | ❌ | ❌ |
| [DialoGPT](model_doc/dialogpt) | ✅ | ✅ | ✅ |
diff --git a/docs/source/en/model_doc/depth_pro.md b/docs/source/en/model_doc/depth_pro.md
new file mode 100644
index 000000000000..2447b7d93dd5
--- /dev/null
+++ b/docs/source/en/model_doc/depth_pro.md
@@ -0,0 +1,183 @@
+
+
+# DepthPro
+
+## Overview
+
+The DepthPro model was proposed in [Depth Pro: Sharp Monocular Metric Depth in Less Than a Second](https://arxiv.org/abs/2410.02073) by Aleksei Bochkovskii, Amaël Delaunoy, Hugo Germain, Marcel Santos, Yichao Zhou, Stephan R. Richter, Vladlen Koltun.
+
+DepthPro is a foundation model for zero-shot metric monocular depth estimation, designed to generate high-resolution depth maps with remarkable sharpness and fine-grained details. It employs a multi-scale Vision Transformer (ViT)-based architecture, where images are downsampled, divided into patches, and processed using a shared Dinov2 encoder. The extracted patch-level features are merged, upsampled, and refined using a DPT-like fusion stage, enabling precise depth estimation.
+
+The abstract from the paper is the following:
+
+*We present a foundation model for zero-shot metric monocular depth estimation. Our model, Depth Pro, synthesizes high-resolution depth maps with unparalleled sharpness and high-frequency details. The predictions are metric, with absolute scale, without relying on the availability of metadata such as camera intrinsics. And the model is fast, producing a 2.25-megapixel depth map in 0.3 seconds on a standard GPU. These characteristics are enabled by a number of technical contributions, including an efficient multi-scale vision transformer for dense prediction, a training protocol that combines real and synthetic datasets to achieve high metric accuracy alongside fine boundary tracing, dedicated evaluation metrics for boundary accuracy in estimated depth maps, and state-of-the-art focal length estimation from a single image. Extensive experiments analyze specific design choices and demonstrate that Depth Pro outperforms prior work along multiple dimensions.*
+
+
+
+ DepthPro Outputs. Taken from the official code.
+
+This model was contributed by [geetu040](https://github.com/geetu040). The original code can be found [here](https://github.com/apple/ml-depth-pro).
+
+## Usage Tips
+
+The DepthPro model processes an input image by first downsampling it at multiple scales and splitting each scaled version into patches. These patches are then encoded using a shared Vision Transformer (ViT)-based Dinov2 patch encoder, while the full image is processed by a separate image encoder. The extracted patch features are merged into feature maps, upsampled, and fused using a DPT-like decoder to generate the final depth estimation. If enabled, an additional Field of View (FOV) encoder processes the image for estimating the camera's field of view, aiding in depth accuracy.
+
+```py
+>>> import requests
+>>> from PIL import Image
+>>> import torch
+>>> from transformers import DepthProImageProcessorFast, DepthProForDepthEstimation
+
+>>> device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
+>>> url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
+>>> image = Image.open(requests.get(url, stream=True).raw)
+
+>>> image_processor = DepthProImageProcessorFast.from_pretrained("apple/DepthPro-hf")
+>>> model = DepthProForDepthEstimation.from_pretrained("apple/DepthPro-hf").to(device)
+
+>>> inputs = image_processor(images=image, return_tensors="pt").to(device)
+
+>>> with torch.no_grad():
+... outputs = model(**inputs)
+
+>>> post_processed_output = image_processor.post_process_depth_estimation(
+... outputs, target_sizes=[(image.height, image.width)],
+... )
+
+>>> field_of_view = post_processed_output[0]["field_of_view"]
+>>> focal_length = post_processed_output[0]["focal_length"]
+>>> depth = post_processed_output[0]["predicted_depth"]
+>>> depth = (depth - depth.min()) / depth.max()
+>>> depth = depth * 255.
+>>> depth = depth.detach().cpu().numpy()
+>>> depth = Image.fromarray(depth.astype("uint8"))
+```
+
+### Architecture and Configuration
+
+
+
+ DepthPro architecture. Taken from the original paper.
+
+The `DepthProForDepthEstimation` model uses a `DepthProEncoder`, for encoding the input image and a `FeatureFusionStage` for fusing the output features from encoder.
+
+The `DepthProEncoder` further uses two encoders:
+- `patch_encoder`
+ - Input image is scaled with multiple ratios, as specified in the `scaled_images_ratios` configuration.
+ - Each scaled image is split into smaller **patches** of size `patch_size` with overlapping areas determined by `scaled_images_overlap_ratios`.
+ - These patches are processed by the **`patch_encoder`**
+- `image_encoder`
+ - Input image is also rescaled to `patch_size` and processed by the **`image_encoder`**
+
+Both these encoders can be configured via `patch_model_config` and `image_model_config` respectively, both of which are seperate `Dinov2Model` by default.
+
+Outputs from both encoders (`last_hidden_state`) and selected intermediate states (`hidden_states`) from **`patch_encoder`** are fused by a `DPT`-based `FeatureFusionStage` for depth estimation.
+
+### Field-of-View (FOV) Prediction
+
+The network is supplemented with a focal length estimation head. A small convolutional head ingests frozen features from the depth estimation network and task-specific features from a separate ViT image encoder to predict the horizontal angular field-of-view.
+
+The `use_fov_model` parameter in `DepthProConfig` controls whether **FOV prediction** is enabled. By default, it is set to `False` to conserve memory and computation. When enabled, the **FOV encoder** is instantiated based on the `fov_model_config` parameter, which defaults to a `Dinov2Model`. The `use_fov_model` parameter can also be passed when initializing the `DepthProForDepthEstimation` model.
+
+The pretrained model at checkpoint `apple/DepthPro-hf` uses the FOV encoder. To use the pretrained-model without FOV encoder, set `use_fov_model=False` when loading the model, which saves computation.
+```py
+>>> from transformers import DepthProForDepthEstimation
+>>> model = DepthProForDepthEstimation.from_pretrained("apple/DepthPro-hf", use_fov_model=False)
+```
+
+To instantiate a new model with FOV encoder, set `use_fov_model=True` in the config.
+```py
+>>> from transformers import DepthProConfig, DepthProForDepthEstimation
+>>> config = DepthProConfig(use_fov_model=True)
+>>> model = DepthProForDepthEstimation(config)
+```
+
+Or set `use_fov_model=True` when initializing the model, which overrides the value in config.
+```py
+>>> from transformers import DepthProConfig, DepthProForDepthEstimation
+>>> config = DepthProConfig()
+>>> model = DepthProForDepthEstimation(config, use_fov_model=True)
+```
+
+### Using Scaled Dot Product Attention (SDPA)
+
+PyTorch includes a native scaled dot-product attention (SDPA) operator as part of `torch.nn.functional`. This function
+encompasses several implementations that can be applied depending on the inputs and the hardware in use. See the
+[official documentation](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html)
+or the [GPU Inference](https://huggingface.co/docs/transformers/main/en/perf_infer_gpu_one#pytorch-scaled-dot-product-attention)
+page for more information.
+
+SDPA is used by default for `torch>=2.1.1` when an implementation is available, but you may also set
+`attn_implementation="sdpa"` in `from_pretrained()` to explicitly request SDPA to be used.
+
+```py
+from transformers import DepthProForDepthEstimation
+model = DepthProForDepthEstimation.from_pretrained("apple/DepthPro-hf", attn_implementation="sdpa", torch_dtype=torch.float16)
+```
+
+For the best speedups, we recommend loading the model in half-precision (e.g. `torch.float16` or `torch.bfloat16`).
+
+On a local benchmark (A100-40GB, PyTorch 2.3.0, OS Ubuntu 22.04) with `float32` and `google/vit-base-patch16-224` model, we saw the following speedups during inference.
+
+| Batch size | Average inference time (ms), eager mode | Average inference time (ms), sdpa model | Speed up, Sdpa / Eager (x) |
+|--------------|-------------------------------------------|-------------------------------------------|------------------------------|
+| 1 | 7 | 6 | 1.17 |
+| 2 | 8 | 6 | 1.33 |
+| 4 | 8 | 6 | 1.33 |
+| 8 | 8 | 6 | 1.33 |
+
+## Resources
+
+A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with DepthPro:
+
+- Research Paper: [Depth Pro: Sharp Monocular Metric Depth in Less Than a Second](https://arxiv.org/pdf/2410.02073)
+- Official Implementation: [apple/ml-depth-pro](https://github.com/apple/ml-depth-pro)
+- DepthPro Inference Notebook: [DepthPro Inference](https://github.com/qubvel/transformers-notebooks/blob/main/notebooks/DepthPro_inference.ipynb)
+- DepthPro for Super Resolution and Image Segmentation
+ - Read blog on Medium: [Depth Pro: Beyond Depth](https://medium.com/@raoarmaghanshakir040/depth-pro-beyond-depth-9d822fc557ba)
+ - Code on Github: [geetu040/depthpro-beyond-depth](https://github.com/geetu040/depthpro-beyond-depth)
+
+If you're interested in submitting a resource to be included here, please feel free to open a Pull Request and we'll review it! The resource should ideally demonstrate something new instead of duplicating an existing resource.
+
+## DepthProConfig
+
+[[autodoc]] DepthProConfig
+
+## DepthProImageProcessor
+
+[[autodoc]] DepthProImageProcessor
+ - preprocess
+ - post_process_depth_estimation
+
+## DepthProImageProcessorFast
+
+[[autodoc]] DepthProImageProcessorFast
+ - preprocess
+ - post_process_depth_estimation
+
+## DepthProModel
+
+[[autodoc]] DepthProModel
+ - forward
+
+## DepthProForDepthEstimation
+
+[[autodoc]] DepthProForDepthEstimation
+ - forward
diff --git a/docs/source/en/perf_infer_gpu_one.md b/docs/source/en/perf_infer_gpu_one.md
index 5ffdefec9e84..06143c757000 100644
--- a/docs/source/en/perf_infer_gpu_one.md
+++ b/docs/source/en/perf_infer_gpu_one.md
@@ -244,6 +244,7 @@ For now, Transformers supports SDPA inference and training for the following arc
* [data2vec_vision](https://huggingface.co/docs/transformers/main/en/model_doc/data2vec#transformers.Data2VecVisionModel)
* [Dbrx](https://huggingface.co/docs/transformers/model_doc/dbrx#transformers.DbrxModel)
* [DeiT](https://huggingface.co/docs/transformers/model_doc/deit#transformers.DeiTModel)
+* [DepthPro](https://huggingface.co/docs/transformers/model_doc/depth_pro#transformers.DepthProModel)
* [DiffLlama](https://huggingface.co/docs/transformers/model_doc/diffllama#transformers.DiffLlamaModel)
* [Dinov2](https://huggingface.co/docs/transformers/en/model_doc/dinov2)
* [Dinov2_with_registers](https://huggingface.co/docs/transformers/en/model_doc/dinov2)
diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py
index ec6805f504b0..d9db0a0fd6e3 100755
--- a/src/transformers/__init__.py
+++ b/src/transformers/__init__.py
@@ -400,6 +400,7 @@
"models.deprecated.vit_hybrid": ["ViTHybridConfig"],
"models.deprecated.xlm_prophetnet": ["XLMProphetNetConfig"],
"models.depth_anything": ["DepthAnythingConfig"],
+ "models.depth_pro": ["DepthProConfig"],
"models.detr": ["DetrConfig"],
"models.dialogpt": [],
"models.diffllama": ["DiffLlamaConfig"],
@@ -1236,6 +1237,7 @@
_import_structure["models.deprecated.efficientformer"].append("EfficientFormerImageProcessor")
_import_structure["models.deprecated.tvlt"].append("TvltImageProcessor")
_import_structure["models.deprecated.vit_hybrid"].extend(["ViTHybridImageProcessor"])
+ _import_structure["models.depth_pro"].extend(["DepthProImageProcessor", "DepthProImageProcessorFast"])
_import_structure["models.detr"].extend(["DetrFeatureExtractor", "DetrImageProcessor"])
_import_structure["models.donut"].extend(["DonutFeatureExtractor", "DonutImageProcessor"])
_import_structure["models.dpt"].extend(["DPTFeatureExtractor", "DPTImageProcessor"])
@@ -1313,6 +1315,7 @@
_import_structure["models.convnext"].append("ConvNextImageProcessorFast")
_import_structure["models.deformable_detr"].append("DeformableDetrImageProcessorFast")
_import_structure["models.deit"].append("DeiTImageProcessorFast")
+ _import_structure["models.depth_pro"].append("DepthProImageProcessorFast")
_import_structure["models.detr"].append("DetrImageProcessorFast")
_import_structure["models.llava"].append("LlavaImageProcessorFast")
_import_structure["models.llava_next"].append("LlavaNextImageProcessorFast")
@@ -2180,6 +2183,13 @@
"DepthAnythingPreTrainedModel",
]
)
+ _import_structure["models.depth_pro"].extend(
+ [
+ "DepthProForDepthEstimation",
+ "DepthProModel",
+ "DepthProPreTrainedModel",
+ ]
+ )
_import_structure["models.detr"].extend(
[
"DetrForObjectDetection",
@@ -5494,6 +5504,7 @@
XLMProphetNetConfig,
)
from .models.depth_anything import DepthAnythingConfig
+ from .models.depth_pro import DepthProConfig
from .models.detr import DetrConfig
from .models.diffllama import DiffLlamaConfig
from .models.dinat import DinatConfig
@@ -6362,6 +6373,7 @@
from .models.deprecated.efficientformer import EfficientFormerImageProcessor
from .models.deprecated.tvlt import TvltImageProcessor
from .models.deprecated.vit_hybrid import ViTHybridImageProcessor
+ from .models.depth_pro import DepthProImageProcessor, DepthProImageProcessorFast
from .models.detr import DetrFeatureExtractor, DetrImageProcessor
from .models.donut import DonutFeatureExtractor, DonutImageProcessor
from .models.dpt import DPTFeatureExtractor, DPTImageProcessor
@@ -6455,6 +6467,7 @@
from .models.convnext import ConvNextImageProcessorFast
from .models.deformable_detr import DeformableDetrImageProcessorFast
from .models.deit import DeiTImageProcessorFast
+ from .models.depth_pro import DepthProImageProcessorFast
from .models.detr import DetrImageProcessorFast
from .models.llava import LlavaImageProcessorFast
from .models.llava_next import LlavaNextImageProcessorFast
@@ -7173,6 +7186,11 @@
DepthAnythingForDepthEstimation,
DepthAnythingPreTrainedModel,
)
+ from .models.depth_pro import (
+ DepthProForDepthEstimation,
+ DepthProModel,
+ DepthProPreTrainedModel,
+ )
from .models.detr import (
DetrForObjectDetection,
DetrForSegmentation,
diff --git a/src/transformers/image_processing_utils_fast.py b/src/transformers/image_processing_utils_fast.py
index cb7d1c46aa79..d21d35212144 100644
--- a/src/transformers/image_processing_utils_fast.py
+++ b/src/transformers/image_processing_utils_fast.py
@@ -283,6 +283,7 @@ def resize(
image: "torch.Tensor",
size: SizeDict,
interpolation: "F.InterpolationMode" = None,
+ antialias: bool = True,
**kwargs,
) -> "torch.Tensor":
"""
@@ -324,7 +325,7 @@ def resize(
"Size must contain 'height' and 'width' keys, or 'max_height' and 'max_width', or 'shortest_edge' key. Got"
f" {size}."
)
- return F.resize(image, new_size, interpolation=interpolation)
+ return F.resize(image, new_size, interpolation=interpolation, antialias=antialias)
def rescale(
self,
diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py
index 1667edbf37ab..97b58a238b89 100644
--- a/src/transformers/models/__init__.py
+++ b/src/transformers/models/__init__.py
@@ -74,6 +74,7 @@
deit,
deprecated,
depth_anything,
+ depth_pro,
detr,
dialogpt,
diffllama,
diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py
index bd6fcb4a9d82..1f7751b31dc0 100644
--- a/src/transformers/models/auto/configuration_auto.py
+++ b/src/transformers/models/auto/configuration_auto.py
@@ -91,6 +91,7 @@
("deformable_detr", "DeformableDetrConfig"),
("deit", "DeiTConfig"),
("depth_anything", "DepthAnythingConfig"),
+ ("depth_pro", "DepthProConfig"),
("deta", "DetaConfig"),
("detr", "DetrConfig"),
("diffllama", "DiffLlamaConfig"),
@@ -414,6 +415,7 @@
("deplot", "DePlot"),
("depth_anything", "Depth Anything"),
("depth_anything_v2", "Depth Anything V2"),
+ ("depth_pro", "DepthPro"),
("deta", "DETA"),
("detr", "DETR"),
("dialogpt", "DialoGPT"),
diff --git a/src/transformers/models/auto/image_processing_auto.py b/src/transformers/models/auto/image_processing_auto.py
index 23f2021532a8..724137bd62cd 100644
--- a/src/transformers/models/auto/image_processing_auto.py
+++ b/src/transformers/models/auto/image_processing_auto.py
@@ -74,6 +74,7 @@
("deformable_detr", ("DeformableDetrImageProcessor", "DeformableDetrImageProcessorFast")),
("deit", ("DeiTImageProcessor", "DeiTImageProcessorFast")),
("depth_anything", ("DPTImageProcessor",)),
+ ("depth_pro", ("DepthProImageProcessor", "DepthProImageProcessorFast")),
("deta", ("DetaImageProcessor",)),
("detr", ("DetrImageProcessor", "DetrImageProcessorFast")),
("dinat", ("ViTImageProcessor", "ViTImageProcessorFast")),
diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py
index 87e2dab68708..c50232d85e68 100644
--- a/src/transformers/models/auto/modeling_auto.py
+++ b/src/transformers/models/auto/modeling_auto.py
@@ -89,6 +89,7 @@
("decision_transformer", "DecisionTransformerModel"),
("deformable_detr", "DeformableDetrModel"),
("deit", "DeiTModel"),
+ ("depth_pro", "DepthProModel"),
("deta", "DetaModel"),
("detr", "DetrModel"),
("diffllama", "DiffLlamaModel"),
@@ -597,6 +598,7 @@
("data2vec-vision", "Data2VecVisionModel"),
("deformable_detr", "DeformableDetrModel"),
("deit", "DeiTModel"),
+ ("depth_pro", "DepthProModel"),
("deta", "DetaModel"),
("detr", "DetrModel"),
("dinat", "DinatModel"),
@@ -916,6 +918,7 @@
[
# Model for depth estimation mapping
("depth_anything", "DepthAnythingForDepthEstimation"),
+ ("depth_pro", "DepthProForDepthEstimation"),
("dpt", "DPTForDepthEstimation"),
("glpn", "GLPNForDepthEstimation"),
("zoedepth", "ZoeDepthForDepthEstimation"),
diff --git a/src/transformers/models/depth_pro/__init__.py b/src/transformers/models/depth_pro/__init__.py
new file mode 100644
index 000000000000..5968aae67b52
--- /dev/null
+++ b/src/transformers/models/depth_pro/__init__.py
@@ -0,0 +1,29 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .configuration_depth_pro import *
+ from .image_processing_depth_pro import *
+ from .image_processing_depth_pro_fast import *
+ from .modeling_depth_pro import *
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/src/transformers/models/depth_pro/configuration_depth_pro.py b/src/transformers/models/depth_pro/configuration_depth_pro.py
new file mode 100644
index 000000000000..36de741b704a
--- /dev/null
+++ b/src/transformers/models/depth_pro/configuration_depth_pro.py
@@ -0,0 +1,205 @@
+# coding=utf-8
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""DepthPro model configuration"""
+
+from copy import deepcopy
+
+from ...configuration_utils import PretrainedConfig
+from ...utils import logging
+from ..auto.configuration_auto import CONFIG_MAPPING, AutoConfig
+
+
+logger = logging.get_logger(__name__)
+
+
+class DepthProConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`DepthProModel`]. It is used to instantiate a
+ DepthPro model according to the specified arguments, defining the model architecture. Instantiating a configuration
+ with the defaults will yield a similar configuration to that of the DepthPro
+ [apple/DepthPro](https://huggingface.co/apple/DepthPro) architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ fusion_hidden_size (`int`, *optional*, defaults to 256):
+ The number of channels before fusion.
+ patch_size (`int`, *optional*, defaults to 384):
+ The size (resolution) of each patch. This is also the image_size for backbone model.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ intermediate_hook_ids (`List[int]`, *optional*, defaults to `[11, 5]`):
+ Indices of the intermediate hidden states from the patch encoder to use for fusion.
+ intermediate_feature_dims (`List[int]`, *optional*, defaults to `[256, 256]`):
+ Hidden state dimensions during upsampling for each intermediate hidden state in `intermediate_hook_ids`.
+ scaled_images_ratios (`List[float]`, *optional*, defaults to `[0.25, 0.5, 1]`):
+ Ratios of scaled images to be used by the patch encoder.
+ scaled_images_overlap_ratios (`List[float]`, *optional*, defaults to `[0.0, 0.5, 0.25]`):
+ Overlap ratios between patches for each scaled image in `scaled_images_ratios`.
+ scaled_images_feature_dims (`List[int]`, *optional*, defaults to `[1024, 1024, 512]`):
+ Hidden state dimensions during upsampling for each scaled image in `scaled_images_ratios`.
+ merge_padding_value (`int`, *optional*, defaults to 3):
+ When merging smaller patches back to the image size, overlapping sections of this size are removed.
+ use_batch_norm_in_fusion_residual (`bool`, *optional*, defaults to `False`):
+ Whether to use batch normalization in the pre-activate residual units of the fusion blocks.
+ use_bias_in_fusion_residual (`bool`, *optional*, defaults to `True`):
+ Whether to use bias in the pre-activate residual units of the fusion blocks.
+ use_fov_model (`bool`, *optional*, defaults to `False`):
+ Whether to use `DepthProFovModel` to generate the field of view.
+ num_fov_head_layers (`int`, *optional*, defaults to 2):
+ Number of convolution layers in the head of `DepthProFovModel`.
+ image_model_config (`Union[Dict[str, Any], PretrainedConfig]`, *optional*):
+ The configuration of the image encoder model, which is loaded using the [`AutoModel`] API.
+ By default, Dinov2 model is used as backbone.
+ patch_model_config (`Union[Dict[str, Any], PretrainedConfig]`, *optional*):
+ The configuration of the patch encoder model, which is loaded using the [`AutoModel`] API.
+ By default, Dinov2 model is used as backbone.
+ fov_model_config (`Union[Dict[str, Any], PretrainedConfig]`, *optional*):
+ The configuration of the fov encoder model, which is loaded using the [`AutoModel`] API.
+ By default, Dinov2 model is used as backbone.
+
+ Example:
+
+ ```python
+ >>> from transformers import DepthProConfig, DepthProModel
+
+ >>> # Initializing a DepthPro apple/DepthPro style configuration
+ >>> configuration = DepthProConfig()
+
+ >>> # Initializing a model (with random weights) from the apple/DepthPro style configuration
+ >>> model = DepthProModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "depth_pro"
+ sub_configs = {"image_model_config": AutoConfig, "patch_model_config": AutoConfig, "fov_model_config": AutoConfig}
+
+ def __init__(
+ self,
+ fusion_hidden_size=256,
+ patch_size=384,
+ initializer_range=0.02,
+ intermediate_hook_ids=[11, 5],
+ intermediate_feature_dims=[256, 256],
+ scaled_images_ratios=[0.25, 0.5, 1],
+ scaled_images_overlap_ratios=[0.0, 0.5, 0.25],
+ scaled_images_feature_dims=[1024, 1024, 512],
+ merge_padding_value=3,
+ use_batch_norm_in_fusion_residual=False,
+ use_bias_in_fusion_residual=True,
+ use_fov_model=False,
+ num_fov_head_layers=2,
+ image_model_config=None,
+ patch_model_config=None,
+ fov_model_config=None,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+
+ # scaled_images_ratios is sorted
+ if scaled_images_ratios != sorted(scaled_images_ratios):
+ raise ValueError(
+ f"Values in scaled_images_ratios={scaled_images_ratios} " "should be sorted from low to high"
+ )
+
+ # scaled_images_ratios, scaled_images_overlap_ratios, scaled_images_feature_dims should be consistent
+ if not (len(scaled_images_ratios) == len(scaled_images_overlap_ratios) == len(scaled_images_feature_dims)):
+ raise ValueError(
+ f"len(scaled_images_ratios)={len(scaled_images_ratios)} and "
+ f"len(scaled_images_overlap_ratios)={len(scaled_images_overlap_ratios)} and "
+ f"len(scaled_images_feature_dims)={len(scaled_images_feature_dims)}, "
+ f"should match in config."
+ )
+
+ # intermediate_hook_ids, intermediate_feature_dims should be consistent
+ if not (len(intermediate_hook_ids) == len(intermediate_feature_dims)):
+ raise ValueError(
+ f"len(intermediate_hook_ids)={len(intermediate_hook_ids)} and "
+ f"len(intermediate_feature_dims)={len(intermediate_feature_dims)}, "
+ f"should match in config."
+ )
+
+ # fusion_hidden_size should be consistent with num_fov_head_layers
+ if fusion_hidden_size // 2**num_fov_head_layers == 0:
+ raise ValueError(
+ f"fusion_hidden_size={fusion_hidden_size} should be consistent with num_fov_head_layers={num_fov_head_layers} "
+ "i.e fusion_hidden_size // 2**num_fov_head_layers > 0"
+ )
+
+ self.fusion_hidden_size = fusion_hidden_size
+ self.patch_size = patch_size
+ self.initializer_range = initializer_range
+ self.use_batch_norm_in_fusion_residual = use_batch_norm_in_fusion_residual
+ self.use_bias_in_fusion_residual = use_bias_in_fusion_residual
+ self.use_fov_model = use_fov_model
+ self.num_fov_head_layers = num_fov_head_layers
+ self.intermediate_hook_ids = intermediate_hook_ids
+ self.intermediate_feature_dims = intermediate_feature_dims
+ self.scaled_images_ratios = scaled_images_ratios
+ self.scaled_images_overlap_ratios = scaled_images_overlap_ratios
+ self.scaled_images_feature_dims = scaled_images_feature_dims
+ self.merge_padding_value = merge_padding_value
+ self.image_model_config = image_model_config
+ self.patch_model_config = patch_model_config
+ self.fov_model_config = fov_model_config
+
+ for sub_config_key in self.sub_configs.keys():
+ sub_config = getattr(self, sub_config_key)
+
+ if sub_config is None:
+ sub_config = CONFIG_MAPPING["dinov2"](image_size=patch_size)
+ logger.info(
+ f"`{sub_config_key}` is `None`. Initializing `{sub_config_key}` with the `Dinov2Config` "
+ f"with default values except `{sub_config_key}.image_size` is set to `config.patch_size`."
+ )
+ elif isinstance(sub_config, dict):
+ sub_config = deepcopy(sub_config)
+ if "model_type" not in sub_config:
+ raise KeyError(
+ f"The `model_type` key is missing in the `{sub_config_key}` dictionary. Please provide the model type."
+ )
+ elif sub_config["model_type"] not in CONFIG_MAPPING:
+ raise ValueError(
+ f"The model type `{sub_config['model_type']}` in `{sub_config_key}` is not supported. Please provide a valid model type."
+ )
+ image_size = sub_config.get("image_size")
+ if image_size != patch_size:
+ logger.info(
+ f"The `image_size` in `{sub_config_key}` is set to `{image_size}`, "
+ f"but it does not match the required `patch_size` of `{patch_size}`. "
+ f"Updating `image_size` to `{patch_size}` for consistency. "
+ f"Ensure that `image_size` aligns with `patch_size` in the configuration."
+ )
+ sub_config.update({"image_size": patch_size})
+ sub_config = CONFIG_MAPPING[sub_config["model_type"]](**sub_config)
+ elif isinstance(sub_config, PretrainedConfig):
+ sub_config = sub_config
+ image_size = getattr(sub_config, "image_size", None)
+ if image_size != patch_size:
+ raise ValueError(
+ f"`config.{sub_config_key}.image_size={image_size}` should match `config.patch_size={patch_size}`."
+ )
+ else:
+ raise TypeError(
+ f"Invalid type for `sub_config`. Expected `PretrainedConfig`, `dict`, or `None`, but got {type(sub_config)}."
+ )
+
+ setattr(self, sub_config_key, sub_config)
+
+
+__all__ = ["DepthProConfig"]
diff --git a/src/transformers/models/depth_pro/convert_depth_pro_weights_to_hf.py b/src/transformers/models/depth_pro/convert_depth_pro_weights_to_hf.py
new file mode 100644
index 000000000000..b24c6a5174f0
--- /dev/null
+++ b/src/transformers/models/depth_pro/convert_depth_pro_weights_to_hf.py
@@ -0,0 +1,254 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import argparse
+import gc
+import os
+
+import regex as re
+import torch
+from huggingface_hub import hf_hub_download
+
+from transformers import (
+ DepthProConfig,
+ DepthProForDepthEstimation,
+ DepthProImageProcessorFast,
+)
+
+
+# fmt: off
+ORIGINAL_TO_CONVERTED_KEY_MAPPING = {
+
+ # encoder
+ r"encoder.(patch|image)_encoder.cls_token": r"depth_pro.encoder.\1_encoder.model.embeddings.cls_token",
+ r"encoder.(patch|image)_encoder.pos_embed": r"depth_pro.encoder.\1_encoder.model.embeddings.position_embeddings",
+ r"encoder.(patch|image)_encoder.patch_embed.proj.(weight|bias)": r"depth_pro.encoder.\1_encoder.model.embeddings.patch_embeddings.projection.\2",
+ r"encoder.(patch|image)_encoder.blocks.(\d+).norm(\d+).(weight|bias)": r"depth_pro.encoder.\1_encoder.model.encoder.layer.\2.norm\3.\4",
+ r"encoder.(patch|image)_encoder.blocks.(\d+).attn.qkv.(weight|bias)": r"depth_pro.encoder.\1_encoder.model.encoder.layer.\2.attention.attention.(query|key|value).\3",
+ r"encoder.(patch|image)_encoder.blocks.(\d+).attn.proj.(weight|bias)": r"depth_pro.encoder.\1_encoder.model.encoder.layer.\2.attention.output.dense.\3",
+ r"encoder.(patch|image)_encoder.blocks.(\d+).ls(\d+).gamma": r"depth_pro.encoder.\1_encoder.model.encoder.layer.\2.layer_scale\3.lambda1",
+ r"encoder.(patch|image)_encoder.blocks.(\d+).mlp.fc(\d+).(weight|bias)": r"depth_pro.encoder.\1_encoder.model.encoder.layer.\2.mlp.fc\3.\4",
+ r"encoder.(patch|image)_encoder.norm.(weight|bias)": r"depth_pro.encoder.\1_encoder.model.layernorm.\2",
+ r"encoder.fuse_lowres.(weight|bias)": r"depth_pro.neck.fuse_image_with_low_res.\1",
+
+ # fov
+ r"fov.encoder.0.cls_token": r"fov_model.fov_encoder.model.embeddings.cls_token",
+ r"fov.encoder.0.pos_embed": r"fov_model.fov_encoder.model.embeddings.position_embeddings",
+ r"fov.encoder.0.patch_embed.proj.(weight|bias)": r"fov_model.fov_encoder.model.embeddings.patch_embeddings.projection.\1",
+ r"fov.encoder.0.blocks.(\d+).norm(\d+).(weight|bias)": r"fov_model.fov_encoder.model.encoder.layer.\1.norm\2.\3",
+ r"fov.encoder.0.blocks.(\d+).attn.qkv.(weight|bias)": r"fov_model.fov_encoder.model.encoder.layer.\1.attention.attention.(query|key|value).\2",
+ r"fov.encoder.0.blocks.(\d+).attn.proj.(weight|bias)": r"fov_model.fov_encoder.model.encoder.layer.\1.attention.output.dense.\2",
+ r"fov.encoder.0.blocks.(\d+).ls(\d+).gamma": r"fov_model.fov_encoder.model.encoder.layer.\1.layer_scale\2.lambda1",
+ r"fov.encoder.0.blocks.(\d+).mlp.fc(\d+).(weight|bias)": r"fov_model.fov_encoder.model.encoder.layer.\1.mlp.fc\2.\3",
+ r"fov.encoder.0.norm.(weight|bias)": r"fov_model.fov_encoder.model.layernorm.\1",
+ r"fov.downsample.0.(weight|bias)": r"fov_model.conv.\1",
+ r"fov.encoder.1.(weight|bias)": r"fov_model.fov_encoder.neck.\1",
+ r"fov.head.(\d+).(weight|bias)": r"fov_model.head.layers.\1.\2",
+
+ # head
+ r"head.(\d+).(weight|bias)": r"head.layers.\1.\2",
+
+ # upsamples
+ r"encoder.upsample_lowres.(weight|bias)": r"depth_pro.neck.feature_upsample.image_block.layers.0.\1",
+ r"encoder.upsample_latent(\d+).(\d+).(weight|bias)": lambda match: (
+ f"depth_pro.neck.feature_upsample.intermediate.{1-int(match.group(1))}.layers.{match.group(2)}.{match.group(3)}"
+ ),
+ r"encoder.upsample(\d+).(\d+).(weight|bias)": lambda match: (
+ f"depth_pro.neck.feature_upsample.scaled_images.{2-int(match.group(1))}.layers.{match.group(2)}.{match.group(3)}"
+ ),
+
+ # projections between encoder and fusion
+ r"decoder.convs.(\d+).weight": lambda match: (
+ f"depth_pro.neck.feature_projection.projections.{4-int(match.group(1))}.weight"
+ ),
+
+ # fusion stage
+ r"decoder.fusions.([1234]).resnet(\d+).residual.(\d+).(weight|bias)": lambda match: (
+ f"fusion_stage.intermediate.{4-int(match.group(1))}.residual_layer{match.group(2)}.convolution{(int(match.group(3))+1)//2}.{match.group(4)}"
+ ),
+ r"decoder.fusions.0.resnet(\d+).residual.(\d+).(weight|bias)": lambda match: (
+ f"fusion_stage.final.residual_layer{match.group(1)}.convolution{(int(match.group(2))+1)//2}.{match.group(3)}"
+ ),
+ r"decoder.fusions.([1234]).out_conv.(weight|bias)": lambda match: (
+ f"fusion_stage.intermediate.{4-int(match.group(1))}.projection.{match.group(2)}"
+ ),
+ r"decoder.fusions.0.out_conv.(weight|bias)": lambda match: (
+ f"fusion_stage.final.projection.{match.group(1)}"
+ ),
+ r"decoder.fusions.(\d+).deconv.(weight|bias)": lambda match: (
+ f"fusion_stage.intermediate.{4-int(match.group(1))}.deconv.{match.group(2)}"
+ ),
+}
+# fmt: on
+
+
+def convert_old_keys_to_new_keys(state_dict_keys: dict = None):
+ output_dict = {}
+ if state_dict_keys is not None:
+ old_text = "\n".join(state_dict_keys)
+ new_text = old_text
+ for pattern, replacement in ORIGINAL_TO_CONVERTED_KEY_MAPPING.items():
+ if replacement is None:
+ new_text = re.sub(pattern, "", new_text) # an empty line
+ continue
+ new_text = re.sub(pattern, replacement, new_text)
+ output_dict = dict(zip(old_text.split("\n"), new_text.split("\n")))
+ return output_dict
+
+
+def get_qkv_state_dict(key, parameter):
+ """
+ new key which looks like this
+ xxxx.(q|k|v).xxx (m, n)
+
+ is converted to
+ xxxx.q.xxxx (m//3, n)
+ xxxx.k.xxxx (m//3, n)
+ xxxx.v.xxxx (m//3, n)
+ """
+ qkv_state_dict = {}
+ placeholder = re.search(r"(\(.*?\))", key).group(1) # finds "(query|key|value)"
+ replacements_keys = placeholder[1:-1].split("|") # creates ['query', 'key', 'value']
+ replacements_vals = torch.split(
+ parameter, split_size_or_sections=parameter.size(0) // len(replacements_keys), dim=0
+ )
+ for replacement_key, replacement_val in zip(replacements_keys, replacements_vals):
+ qkv_state_dict[key.replace(placeholder, replacement_key)] = replacement_val
+ return qkv_state_dict
+
+
+def write_model(
+ hf_repo_id: str,
+ output_dir: str,
+ safe_serialization: bool = True,
+):
+ os.makedirs(output_dir, exist_ok=True)
+
+ # ------------------------------------------------------------
+ # Create and save config
+ # ------------------------------------------------------------
+
+ # create config
+ backbone_config = {
+ "model_type": "dinov2",
+ "num_hidden_layers": 24,
+ "patch_size": 16,
+ "hidden_size": 1024,
+ "num_attention_heads": 16,
+ "image_size": 384,
+ "use_mask_token": False,
+ }
+ config = DepthProConfig(
+ # original implementation uses same config for all 3 models
+ image_model_config=backbone_config,
+ patch_model_config=backbone_config,
+ fov_model_config=backbone_config,
+ use_fov_model=True,
+ )
+
+ # save config
+ config.save_pretrained(output_dir)
+ print("Model config saved successfully...")
+
+ # ------------------------------------------------------------
+ # Convert weights
+ # ------------------------------------------------------------
+
+ # download and load state_dict from hf repo
+ file_path = hf_hub_download(hf_repo_id, "depth_pro.pt")
+ loaded = torch.load(file_path, weights_only=True)
+
+ print("Converting model...")
+ all_keys = list(loaded.keys())
+ new_keys = convert_old_keys_to_new_keys(all_keys)
+
+ state_dict = {}
+ for key in all_keys:
+ new_key = new_keys[key]
+ current_parameter = loaded.pop(key)
+
+ if "qkv" in key:
+ qkv_state_dict = get_qkv_state_dict(new_key, current_parameter)
+ state_dict.update(qkv_state_dict)
+ else:
+ state_dict[new_key] = current_parameter
+
+ print("Loading the checkpoint in a DepthPro model.")
+ model = DepthProForDepthEstimation(config)
+ model.load_state_dict(state_dict, strict=True, assign=True)
+ print("Checkpoint loaded successfully.")
+
+ print("Saving the model.")
+ model.save_pretrained(output_dir, safe_serialization=safe_serialization)
+ del state_dict, model
+
+ # Safety check: reload the converted model
+ gc.collect()
+ print("Reloading the model to check if it's saved correctly.")
+ model = DepthProForDepthEstimation.from_pretrained(output_dir, device_map="auto")
+ print("Model reloaded successfully.")
+ return model
+
+
+def write_image_processor(output_dir: str):
+ image_processor = DepthProImageProcessorFast()
+ image_processor.save_pretrained(output_dir)
+ return image_processor
+
+
+def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--hf_repo_id",
+ default="apple/DepthPro",
+ help="Location of official weights from apple on HF",
+ )
+ parser.add_argument(
+ "--output_dir",
+ default="apple_DepthPro",
+ help="Location to write the converted model and processor",
+ )
+ parser.add_argument(
+ "--safe_serialization", default=True, type=bool, help="Whether or not to save using `safetensors`."
+ )
+ parser.add_argument(
+ "--push_to_hub",
+ action=argparse.BooleanOptionalAction,
+ help="Whether or not to push the converted model to the huggingface hub.",
+ )
+ parser.add_argument(
+ "--hub_repo_id",
+ default="apple/DepthPro-hf",
+ help="Huggingface hub repo to write the converted model and processor",
+ )
+ args = parser.parse_args()
+
+ model = write_model(
+ hf_repo_id=args.hf_repo_id,
+ output_dir=args.output_dir,
+ safe_serialization=args.safe_serialization,
+ )
+
+ image_processor = write_image_processor(
+ output_dir=args.output_dir,
+ )
+
+ if args.push_to_hub:
+ print("Pushing to hub...")
+ model.push_to_hub(args.hub_repo_id)
+ image_processor.push_to_hub(args.hub_repo_id)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/src/transformers/models/depth_pro/image_processing_depth_pro.py b/src/transformers/models/depth_pro/image_processing_depth_pro.py
new file mode 100644
index 000000000000..5871e0f764cd
--- /dev/null
+++ b/src/transformers/models/depth_pro/image_processing_depth_pro.py
@@ -0,0 +1,389 @@
+# coding=utf-8
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Image processor class for DepthPro."""
+
+from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
+
+import numpy as np
+
+
+if TYPE_CHECKING:
+ from .modeling_depth_pro import DepthProDepthEstimatorOutput
+
+from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
+from ...image_transforms import to_channel_dimension_format
+from ...image_utils import (
+ IMAGENET_STANDARD_MEAN,
+ IMAGENET_STANDARD_STD,
+ ChannelDimension,
+ ImageInput,
+ PILImageResampling,
+ infer_channel_dimension_format,
+ is_scaled_image,
+ is_torch_available,
+ make_list_of_images,
+ pil_torch_interpolation_mapping,
+ to_numpy_array,
+ valid_images,
+)
+from ...utils import (
+ TensorType,
+ filter_out_non_signature_kwargs,
+ logging,
+ requires_backends,
+)
+
+
+if is_torch_available():
+ import torch
+
+
+logger = logging.get_logger(__name__)
+
+
+class DepthProImageProcessor(BaseImageProcessor):
+ r"""
+ Constructs a DepthPro image processor.
+
+ Args:
+ do_resize (`bool`, *optional*, defaults to `True`):
+ Whether to resize the image's (height, width) dimensions to the specified `(size["height"],
+ size["width"])`. Can be overridden by the `do_resize` parameter in the `preprocess` method.
+ size (`dict`, *optional*, defaults to `{"height": 1536, "width": 1536}`):
+ Size of the output image after resizing. Can be overridden by the `size` parameter in the `preprocess`
+ method.
+ resample (`PILImageResampling`, *optional*, defaults to `Resampling.BILINEAR`):
+ Resampling filter to use if resizing the image. Can be overridden by the `resample` parameter in the
+ `preprocess` method.
+ do_rescale (`bool`, *optional*, defaults to `True`):
+ Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the `do_rescale`
+ parameter in the `preprocess` method.
+ rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
+ Scale factor to use if rescaling the image. Can be overridden by the `rescale_factor` parameter in the
+ `preprocess` method.
+ do_normalize (`bool`, *optional*, defaults to `True`):
+ Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess`
+ method.
+ image_mean (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_MEAN`):
+ Mean to use if normalizing the image. This is a float or list of floats the length of the number of
+ channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method.
+ image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`):
+ Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
+ number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
+ """
+
+ model_input_names = ["pixel_values"]
+
+ def __init__(
+ self,
+ do_resize: bool = True,
+ size: Optional[Dict[str, int]] = None,
+ resample: PILImageResampling = PILImageResampling.BILINEAR,
+ do_rescale: bool = True,
+ rescale_factor: Union[int, float] = 1 / 255,
+ do_normalize: bool = True,
+ image_mean: Optional[Union[float, List[float]]] = None,
+ image_std: Optional[Union[float, List[float]]] = None,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ size = size if size is not None else {"height": 1536, "width": 1536}
+ size = get_size_dict(size)
+ self.do_resize = do_resize
+ self.do_rescale = do_rescale
+ self.do_normalize = do_normalize
+ self.size = size
+ self.resample = resample
+ self.rescale_factor = rescale_factor
+ self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN
+ self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD
+
+ def resize(
+ self,
+ image: np.ndarray,
+ size: Dict[str, int],
+ resample: PILImageResampling = PILImageResampling.BILINEAR,
+ data_format: Optional[Union[str, ChannelDimension]] = None,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ **kwargs,
+ ) -> np.ndarray:
+ """
+ Resize an image to `(size["height"], size["width"])`.
+
+ Args:
+ image (`np.ndarray`):
+ Image to resize.
+ size (`Dict[str, int]`):
+ Dictionary in the format `{"height": int, "width": int}` specifying the size of the output image.
+ resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BILINEAR`):
+ `PILImageResampling` filter to use when resizing the image e.g. `PILImageResampling.BILINEAR`.
+ data_format (`ChannelDimension` or `str`, *optional*):
+ The channel dimension format for the output image. If unset, the channel dimension format of the input
+ image is used. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
+ input_data_format (`ChannelDimension` or `str`, *optional*):
+ The channel dimension format for the input image. If unset, the channel dimension format is inferred
+ from the input image. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
+
+ Returns:
+ `np.ndarray`: The resized images.
+ """
+ requires_backends(self, "torch")
+
+ size = get_size_dict(size)
+ if "height" not in size or "width" not in size:
+ raise ValueError(f"The `size` dictionary must contain the keys `height` and `width`. Got {size.keys()}")
+ output_size = (size["height"], size["width"])
+
+ # we use torch interpolation instead of image.resize because DepthProImageProcessor
+ # rescales, then normalizes, which may cause some values to become negative, before resizing the image.
+ # image.resize expects all values to be in range [0, 1] or [0, 255] and throws an exception otherwise,
+ # however pytorch interpolation works with negative values.
+ # relevant issue here: https://github.com/huggingface/transformers/issues/34920
+ # input should be (B, C, H, W)
+ image_tensor = torch.from_numpy(image).unsqueeze(0)
+ resized_image = torch.nn.functional.interpolate(
+ input=image_tensor,
+ size=output_size,
+ mode=pil_torch_interpolation_mapping[resample].value,
+ )
+ resized_image = resized_image.squeeze(0).numpy()
+ return resized_image
+
+ def _validate_input_arguments(
+ self,
+ do_resize: bool,
+ size: Dict[str, int],
+ resample: PILImageResampling,
+ do_rescale: bool,
+ rescale_factor: float,
+ do_normalize: bool,
+ image_mean: Union[float, List[float]],
+ image_std: Union[float, List[float]],
+ data_format: Union[str, ChannelDimension],
+ ):
+ if do_resize and None in (size, resample):
+ raise ValueError("Size and resample must be specified if do_resize is True.")
+
+ if do_rescale and rescale_factor is None:
+ raise ValueError("Rescale factor must be specified if do_rescale is True.")
+
+ if do_normalize and None in (image_mean, image_std):
+ raise ValueError("Image mean and standard deviation must be specified if do_normalize is True.")
+
+ @filter_out_non_signature_kwargs()
+ def preprocess(
+ self,
+ images: ImageInput,
+ do_resize: Optional[bool] = None,
+ size: Optional[Dict[str, int]] = None,
+ resample: Optional[PILImageResampling] = None,
+ do_rescale: Optional[bool] = None,
+ rescale_factor: Optional[float] = None,
+ do_normalize: Optional[bool] = None,
+ image_mean: Optional[Union[float, List[float]]] = None,
+ image_std: Optional[Union[float, List[float]]] = None,
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ data_format: Union[str, ChannelDimension] = ChannelDimension.FIRST,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ ):
+ """
+ Preprocess an image or batch of images.
+
+ Args:
+ images (`ImageInput`):
+ Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
+ passing in images with pixel values between 0 and 1, set `do_rescale=False`.
+ do_resize (`bool`, *optional*, defaults to `self.do_resize`):
+ Whether to resize the image.
+ size (`Dict[str, int]`, *optional*, defaults to `self.size`):
+ Dictionary in the format `{"height": h, "width": w}` specifying the size of the output image after
+ resizing.
+ resample (`PILImageResampling` filter, *optional*, defaults to `self.resample`):
+ `PILImageResampling` filter to use if resizing the image e.g. `PILImageResampling.BILINEAR`. Only has
+ an effect if `do_resize` is set to `True`.
+ do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
+ Whether to rescale the image values between [0 - 1].
+ rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
+ Rescale factor to rescale the image by if `do_rescale` is set to `True`.
+ do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
+ Whether to normalize the image.
+ image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
+ Image mean to use if `do_normalize` is set to `True`.
+ image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
+ Image standard deviation to use if `do_normalize` is set to `True`.
+ return_tensors (`str` or `TensorType`, *optional*):
+ The type of tensors to return. Can be one of:
+ - Unset: Return a list of `np.ndarray`.
+ - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
+ - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
+ - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
+ - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
+ data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
+ The channel dimension format for the output image. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ - Unset: Use the channel dimension format of the input image.
+ input_data_format (`ChannelDimension` or `str`, *optional*):
+ The channel dimension format for the input image. If unset, the channel dimension format is inferred
+ from the input image. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
+ """
+ do_resize = do_resize if do_resize is not None else self.do_resize
+ do_rescale = do_rescale if do_rescale is not None else self.do_rescale
+ do_normalize = do_normalize if do_normalize is not None else self.do_normalize
+ resample = resample if resample is not None else self.resample
+ rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
+ image_mean = image_mean if image_mean is not None else self.image_mean
+ image_std = image_std if image_std is not None else self.image_std
+
+ size = size if size is not None else self.size
+
+ images = make_list_of_images(images)
+
+ if not valid_images(images):
+ raise ValueError(
+ "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
+ "torch.Tensor, tf.Tensor or jax.ndarray."
+ )
+ self._validate_input_arguments(
+ do_resize=do_resize,
+ size=size,
+ resample=resample,
+ do_rescale=do_rescale,
+ rescale_factor=rescale_factor,
+ do_normalize=do_normalize,
+ image_mean=image_mean,
+ image_std=image_std,
+ data_format=data_format,
+ )
+
+ # All transformations expect numpy arrays.
+ images = [to_numpy_array(image) for image in images]
+
+ if is_scaled_image(images[0]) and do_rescale:
+ logger.warning_once(
+ "It looks like you are trying to rescale already rescaled images. If the input"
+ " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
+ )
+
+ if input_data_format is None:
+ # We assume that all images have the same channel dimension format.
+ input_data_format = infer_channel_dimension_format(images[0])
+
+ all_images = []
+ for image in images:
+ if do_rescale:
+ image = self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
+
+ if do_normalize:
+ image = self.normalize(
+ image=image, mean=image_mean, std=image_std, input_data_format=input_data_format
+ )
+
+ # depth-pro rescales and normalizes the image before resizing it
+ # uses torch interpolation which requires ChannelDimension.FIRST
+ if do_resize:
+ image = to_channel_dimension_format(image, ChannelDimension.FIRST, input_channel_dim=input_data_format)
+ image = self.resize(image=image, size=size, resample=resample)
+ image = to_channel_dimension_format(image, data_format, input_channel_dim=ChannelDimension.FIRST)
+ else:
+ image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
+
+ all_images.append(image)
+
+ data = {"pixel_values": all_images}
+ return BatchFeature(data=data, tensor_type=return_tensors)
+
+ def post_process_depth_estimation(
+ self,
+ outputs: "DepthProDepthEstimatorOutput",
+ target_sizes: Optional[Union[TensorType, List[Tuple[int, int]], None]] = None,
+ ) -> Dict[str, List[TensorType]]:
+ """
+ Post-processes the raw depth predictions from the model to generate
+ final depth predictions which is caliberated using the field of view if provided
+ and resized to specified target sizes if provided.
+
+ Args:
+ outputs ([`DepthProDepthEstimatorOutput`]):
+ Raw outputs of the model.
+ target_sizes (`Optional[Union[TensorType, List[Tuple[int, int]], None]]`, *optional*, defaults to `None`):
+ Target sizes to resize the depth predictions. Can be a tensor of shape `(batch_size, 2)`
+ or a list of tuples `(height, width)` for each image in the batch. If `None`, no resizing
+ is performed.
+
+ Returns:
+ `List[Dict[str, TensorType]]`: A list of dictionaries of tensors representing the processed depth
+ predictions, and field of view (degrees) and focal length (pixels) if `field_of_view` is given in `outputs`.
+
+ Raises:
+ `ValueError`:
+ If the lengths of `predicted_depths`, `fovs`, or `target_sizes` are mismatched.
+ """
+ requires_backends(self, "torch")
+
+ predicted_depth = outputs.predicted_depth
+ fov = outputs.field_of_view
+
+ batch_size = len(predicted_depth)
+
+ if target_sizes is not None and batch_size != len(target_sizes):
+ raise ValueError(
+ "Make sure that you pass in as many fov values as the batch dimension of the predicted depth"
+ )
+
+ results = []
+ fov = [None] * batch_size if fov is None else fov
+ target_sizes = [None] * batch_size if target_sizes is None else target_sizes
+ for depth, fov_value, target_size in zip(predicted_depth, fov, target_sizes):
+ focal_length = None
+ if target_size is not None:
+ # scale image w.r.t fov
+ if fov_value is not None:
+ width = target_size[1]
+ focal_length = 0.5 * width / torch.tan(0.5 * torch.deg2rad(fov_value))
+ depth = depth * width / focal_length
+
+ # interpolate
+ depth = torch.nn.functional.interpolate(
+ # input should be (B, C, H, W)
+ input=depth.unsqueeze(0).unsqueeze(1),
+ size=target_size,
+ mode=pil_torch_interpolation_mapping[self.resample].value,
+ ).squeeze()
+
+ # inverse the depth
+ depth = 1.0 / torch.clamp(depth, min=1e-4, max=1e4)
+
+ results.append(
+ {
+ "predicted_depth": depth,
+ "field_of_view": fov_value,
+ "focal_length": focal_length,
+ }
+ )
+
+ return results
+
+
+__all__ = ["DepthProImageProcessor"]
diff --git a/src/transformers/models/depth_pro/image_processing_depth_pro_fast.py b/src/transformers/models/depth_pro/image_processing_depth_pro_fast.py
new file mode 100644
index 000000000000..43a23bf10b5e
--- /dev/null
+++ b/src/transformers/models/depth_pro/image_processing_depth_pro_fast.py
@@ -0,0 +1,187 @@
+# coding=utf-8
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Fast Image processor class for DepthPro."""
+
+from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
+
+from ...image_processing_base import BatchFeature
+from ...image_processing_utils_fast import (
+ BASE_IMAGE_PROCESSOR_FAST_DOCSTRING,
+ BaseImageProcessorFast,
+ group_images_by_shape,
+ reorder_images,
+)
+from ...image_utils import (
+ IMAGENET_STANDARD_MEAN,
+ IMAGENET_STANDARD_STD,
+ PILImageResampling,
+ SizeDict,
+)
+from ...utils import (
+ TensorType,
+ add_start_docstrings,
+ is_torch_available,
+ is_torchvision_available,
+ is_torchvision_v2_available,
+ logging,
+ requires_backends,
+)
+
+
+if TYPE_CHECKING:
+ from .modeling_depth_pro import DepthProDepthEstimatorOutput
+
+logger = logging.get_logger(__name__)
+
+
+if is_torch_available():
+ import torch
+
+
+if is_torchvision_available():
+ from ...image_utils import pil_torch_interpolation_mapping
+
+ if is_torchvision_v2_available():
+ from torchvision.transforms.v2 import functional as F
+ else:
+ from torchvision.transforms import functional as F
+
+
+@add_start_docstrings(
+ "Constructs a fast DepthPro image processor.",
+ BASE_IMAGE_PROCESSOR_FAST_DOCSTRING,
+)
+class DepthProImageProcessorFast(BaseImageProcessorFast):
+ resample = PILImageResampling.BILINEAR
+ image_mean = IMAGENET_STANDARD_MEAN
+ image_std = IMAGENET_STANDARD_STD
+ size = {"height": 1536, "width": 1536}
+ do_resize = True
+ do_rescale = True
+ do_normalize = True
+
+ # DepthPro resizes image after rescaling and normalizing,
+ # which makes it different from BaseImageProcessorFast._preprocess
+ def _preprocess(
+ self,
+ images: List["torch.Tensor"],
+ do_resize: bool,
+ size: SizeDict,
+ interpolation: Optional["F.InterpolationMode"],
+ do_center_crop: bool,
+ crop_size: SizeDict,
+ do_rescale: bool,
+ rescale_factor: float,
+ do_normalize: bool,
+ image_mean: Optional[Union[float, List[float]]],
+ image_std: Optional[Union[float, List[float]]],
+ return_tensors: Optional[Union[str, TensorType]],
+ ) -> BatchFeature:
+ # Group images by size for batched scaling
+ grouped_images, grouped_images_index = group_images_by_shape(images)
+ processed_images_grouped = {}
+ for shape, stacked_images in grouped_images.items():
+ # Fused rescale and normalize
+ stacked_images = self.rescale_and_normalize(
+ stacked_images, do_rescale, rescale_factor, do_normalize, image_mean, image_std
+ )
+ if do_resize:
+ stacked_images = self.resize(
+ image=stacked_images,
+ size=size,
+ interpolation=interpolation,
+ antialias=False,
+ )
+ processed_images_grouped[shape] = stacked_images
+
+ processed_images = reorder_images(processed_images_grouped, grouped_images_index)
+ processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images
+
+ return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors)
+
+ # Copied from transformers.models.depth_pro.image_processing_depth_pro.DepthProImageProcessor.post_process_depth_estimation
+ def post_process_depth_estimation(
+ self,
+ outputs: "DepthProDepthEstimatorOutput",
+ target_sizes: Optional[Union[TensorType, List[Tuple[int, int]], None]] = None,
+ ) -> Dict[str, List[TensorType]]:
+ """
+ Post-processes the raw depth predictions from the model to generate
+ final depth predictions which is caliberated using the field of view if provided
+ and resized to specified target sizes if provided.
+
+ Args:
+ outputs ([`DepthProDepthEstimatorOutput`]):
+ Raw outputs of the model.
+ target_sizes (`Optional[Union[TensorType, List[Tuple[int, int]], None]]`, *optional*, defaults to `None`):
+ Target sizes to resize the depth predictions. Can be a tensor of shape `(batch_size, 2)`
+ or a list of tuples `(height, width)` for each image in the batch. If `None`, no resizing
+ is performed.
+
+ Returns:
+ `List[Dict[str, TensorType]]`: A list of dictionaries of tensors representing the processed depth
+ predictions, and field of view (degrees) and focal length (pixels) if `field_of_view` is given in `outputs`.
+
+ Raises:
+ `ValueError`:
+ If the lengths of `predicted_depths`, `fovs`, or `target_sizes` are mismatched.
+ """
+ requires_backends(self, "torch")
+
+ predicted_depth = outputs.predicted_depth
+ fov = outputs.field_of_view
+
+ batch_size = len(predicted_depth)
+
+ if target_sizes is not None and batch_size != len(target_sizes):
+ raise ValueError(
+ "Make sure that you pass in as many fov values as the batch dimension of the predicted depth"
+ )
+
+ results = []
+ fov = [None] * batch_size if fov is None else fov
+ target_sizes = [None] * batch_size if target_sizes is None else target_sizes
+ for depth, fov_value, target_size in zip(predicted_depth, fov, target_sizes):
+ focal_length = None
+ if target_size is not None:
+ # scale image w.r.t fov
+ if fov_value is not None:
+ width = target_size[1]
+ focal_length = 0.5 * width / torch.tan(0.5 * torch.deg2rad(fov_value))
+ depth = depth * width / focal_length
+
+ # interpolate
+ depth = torch.nn.functional.interpolate(
+ # input should be (B, C, H, W)
+ input=depth.unsqueeze(0).unsqueeze(1),
+ size=target_size,
+ mode=pil_torch_interpolation_mapping[self.resample].value,
+ ).squeeze()
+
+ # inverse the depth
+ depth = 1.0 / torch.clamp(depth, min=1e-4, max=1e4)
+
+ results.append(
+ {
+ "predicted_depth": depth,
+ "field_of_view": fov_value,
+ "focal_length": focal_length,
+ }
+ )
+
+ return results
+
+
+__all__ = ["DepthProImageProcessorFast"]
diff --git a/src/transformers/models/depth_pro/modeling_depth_pro.py b/src/transformers/models/depth_pro/modeling_depth_pro.py
new file mode 100644
index 000000000000..67715723d133
--- /dev/null
+++ b/src/transformers/models/depth_pro/modeling_depth_pro.py
@@ -0,0 +1,1218 @@
+# coding=utf-8
+# Copyright 2024 The Apple Research Team Authors and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch DepthPro model."""
+
+import math
+from dataclasses import dataclass
+from typing import List, Optional, Tuple, Union
+
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+from ...modeling_outputs import BaseModelOutput
+from ...modeling_utils import PreTrainedModel
+from ...utils import (
+ ModelOutput,
+ add_start_docstrings,
+ add_start_docstrings_to_model_forward,
+ logging,
+ replace_return_docstrings,
+ torch_int,
+)
+from ..auto import AutoModel
+from .configuration_depth_pro import DepthProConfig
+
+
+logger = logging.get_logger(__name__)
+
+
+@dataclass
+class DepthProOutput(ModelOutput):
+ """
+ Base class for DepthPro's outputs.
+
+ Args:
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, n_patches_per_batch, sequence_length, hidden_size)`):
+ Sequence of hidden-states at the output of the last layer of the model.
+ features (`Union[torch.FloatTensor, List[torch.FloatTensor]]`, *optional*):
+ Features from encoders. Can be a single feature or a list of features.
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+ one for the output of each layer) of shape `(batch_size, n_patches_per_batch, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer and the optional initial embedding outputs.
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, n_patches_per_batch, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ """
+
+ last_hidden_state: torch.FloatTensor = None
+ features: Union[torch.FloatTensor, List[torch.FloatTensor]] = None
+ hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
+ attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
+
+
+@dataclass
+class DepthProDepthEstimatorOutput(ModelOutput):
+ """
+ Base class for DepthProForDepthEstimation's output.
+
+ Args:
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
+ Classification (or regression if config.num_labels==1) loss.
+ predicted_depth (`torch.FloatTensor` of shape `(batch_size, height, width)`):
+ Predicted depth for each pixel.
+ field_of_view (`torch.FloatTensor` of shape `(batch_size,)`, *optional*, returned when `use_fov_model` is provided):
+ Field of View Scaler.
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+ one for the output of each layer) of shape `(batch_size, n_patches_per_batch, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer and the optional initial embedding outputs.
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, n_patches_per_batch, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ """
+
+ loss: Optional[torch.FloatTensor] = None
+ predicted_depth: torch.FloatTensor = None
+ field_of_view: Optional[torch.FloatTensor] = None
+ hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
+ attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
+
+
+def split_to_patches(pixel_values: torch.Tensor, patch_size: int, overlap_ratio: float) -> torch.Tensor:
+ """Creates Patches from Batch."""
+ batch_size, num_channels, height, width = pixel_values.shape
+
+ if height == width == patch_size:
+ # create patches only if scaled image is not already equal to patch size
+ return pixel_values
+
+ stride = torch_int(patch_size * (1 - overlap_ratio))
+
+ patches = F.unfold(pixel_values, kernel_size=(patch_size, patch_size), stride=(stride, stride))
+ patches = patches.permute(2, 0, 1)
+ patches = patches.reshape(-1, num_channels, patch_size, patch_size)
+
+ return patches
+
+
+def reshape_features(hidden_states: torch.Tensor) -> torch.Tensor:
+ """Discard class token and reshape 1D feature map to a 2D grid."""
+ n_samples, seq_len, hidden_size = hidden_states.shape
+ size = torch_int(seq_len**0.5)
+
+ hidden_states = hidden_states[:, -(size**2) :, :] # remove special tokens if there are any
+ hidden_states = hidden_states.reshape(n_samples, size, size, hidden_size)
+ hidden_states = hidden_states.permute(0, 3, 1, 2)
+
+ return hidden_states
+
+
+def merge_patches(patches: torch.Tensor, batch_size: int, padding: int) -> torch.Tensor:
+ """Merges smaller patches into image-like feature map."""
+ n_patches, hidden_size, out_size, out_size = patches.shape
+ n_patches_per_batch = n_patches // batch_size
+ sqrt_n_patches_per_batch = torch_int(n_patches_per_batch**0.5)
+ new_out_size = sqrt_n_patches_per_batch * out_size
+
+ if n_patches == batch_size:
+ # merge only if the patches were created from scaled image
+ # patches are not created when scaled image size is equal to patch size
+ return patches
+
+ if n_patches_per_batch < 4:
+ # for each batch, atleast 4 small patches are required to
+ # recreate a large square patch from merging them and later padding is applied
+ # 3 x (8x8) patches becomes 1 x ( 8x8 ) patch (extra patch ignored, no padding)
+ # 4 x (8x8) patches becomes 1 x (16x16) patch (padding later)
+ # 5 x (8x8) patches becomes 1 x (16x16) patch (extra patch ignored, padding later)
+ # 9 x (8x8) patches becomes 1 x (24x24) patch (padding later)
+ # thus the following code only rearranges the patches and removes extra ones
+ padding = 0
+
+ # make sure padding is not large enough to remove more than half of the patch
+ padding = min(out_size // 4, padding)
+
+ if padding == 0:
+ # faster when no padding is required
+ merged = patches.reshape(n_patches_per_batch, batch_size, hidden_size, out_size, out_size)
+ merged = merged.permute(1, 2, 0, 3, 4)
+ merged = merged[:, :, : sqrt_n_patches_per_batch**2, :, :]
+ merged = merged.reshape(
+ batch_size, hidden_size, sqrt_n_patches_per_batch, sqrt_n_patches_per_batch, out_size, out_size
+ )
+ merged = merged.permute(0, 1, 2, 4, 3, 5)
+ merged = merged.reshape(batch_size, hidden_size, new_out_size, new_out_size)
+ else:
+ # padding example:
+ # let out_size = 8, new_out_size = 32, padding = 2
+ # each patch is separated by "|"
+ # and padding is applied to the merging edges of each patch
+ # 00 01 02 03 04 05 06 07 | 08 09 10 11 12 13 14 15 | 16 17 18 19 20 21 22 23 | 24 25 26 27 28 29 30 31
+ # 00 01 02 03 04 05 -- -- | -- -- 10 11 12 13 -- -- | -- -- 18 19 20 21 -- -- | -- -- 26 27 28 29 30 31
+ i = 0
+ boxes = []
+ for h in range(sqrt_n_patches_per_batch):
+ boxes_in_row = []
+ for w in range(sqrt_n_patches_per_batch):
+ box = patches[batch_size * i : batch_size * (i + 1)]
+
+ # collect paddings
+ paddings = [0, 0, 0, 0]
+ if h != 0:
+ # remove pad from height if box is not at top border
+ paddings[0] = padding
+ if w != 0:
+ # remove pad from width if box is not at left border
+ paddings[2] = padding
+ if h != sqrt_n_patches_per_batch - 1:
+ # remove pad from height if box is not at bottom border
+ paddings[1] = padding
+ if w != sqrt_n_patches_per_batch - 1:
+ # remove pad from width if box is not at right border
+ paddings[3] = padding
+
+ # remove paddings
+ _, _, box_h, box_w = box.shape
+ pad_top, pad_bottom, pad_left, pad_right = paddings
+ box = box[:, :, pad_top : box_h - pad_bottom, pad_left : box_w - pad_right]
+
+ boxes_in_row.append(box)
+ i += 1
+ boxes_in_row = torch.cat(boxes_in_row, dim=-1)
+ boxes.append(boxes_in_row)
+ merged = torch.cat(boxes, dim=-2)
+
+ return merged
+
+
+def reconstruct_feature_maps(
+ hidden_state: torch.Tensor, batch_size: int, padding: int, output_size: Tuple[float, float]
+) -> torch.Tensor:
+ """
+ Reconstructs feature maps from the hidden state produced by any of the encoder. Converts the hidden state of shape
+ `(n_patches_per_batch * batch_size, seq_len, hidden_size)` to feature maps of shape
+ `(batch_size, hidden_size, output_size[0], output_size[1])`.
+
+ Args:
+ hidden_state (torch.Tensor): Input tensor of shape `(n_patches_per_batch * batch_size, seq_len, hidden_size)`
+ representing the encoded patches.
+ batch_size (int): The number of samples in a batch.
+ padding (int): The amount of padding to be removed when merging patches.
+ output_size (Tuple[float, float]): The desired output size for the feature maps, specified as `(height, width)`.
+
+ Returns:
+ torch.Tensor: Reconstructed feature maps of shape `(batch_size, hidden_size, output_size[0], output_size[1])`.
+ """
+ # reshape back to image like
+ features = reshape_features(hidden_state)
+
+ # merge all patches in a batch to create one large patch per batch
+ features = merge_patches(
+ features,
+ batch_size=batch_size,
+ padding=padding,
+ )
+
+ # interpolate patches to base size
+ features = F.interpolate(
+ features,
+ size=output_size,
+ mode="bilinear",
+ align_corners=False,
+ )
+
+ return features
+
+
+class DepthProPatchEncoder(nn.Module):
+ def __init__(self, config: DepthProConfig):
+ super().__init__()
+ self.config = config
+
+ self.intermediate_hook_ids = config.intermediate_hook_ids
+ self.intermediate_feature_dims = config.intermediate_feature_dims
+ self.scaled_images_ratios = config.scaled_images_ratios
+ self.scaled_images_overlap_ratios = config.scaled_images_overlap_ratios
+ self.scaled_images_feature_dims = config.scaled_images_feature_dims
+ self.merge_padding_value = config.merge_padding_value
+
+ self.n_scaled_images = len(config.scaled_images_ratios)
+ self.n_intermediate_hooks = len(config.intermediate_hook_ids)
+ self.out_size = config.image_model_config.image_size // config.image_model_config.patch_size
+
+ self.model = AutoModel.from_config(config.patch_model_config)
+
+ def forward(
+ self,
+ pixel_values: torch.Tensor,
+ head_mask: Optional[torch.Tensor] = None,
+ ) -> List[torch.Tensor]:
+ batch_size, num_channels, height, width = pixel_values.shape
+
+ if min(self.scaled_images_ratios) * min(height, width) < self.config.patch_size:
+ raise ValueError(
+ f"Image size {height}x{width} is too small to be scaled "
+ f"with scaled_images_ratios={self.scaled_images_ratios} "
+ f"when patch_size={self.config.patch_size}."
+ )
+
+ # STEP 1: create 3-level image
+
+ scaled_images = []
+ for ratio in self.scaled_images_ratios:
+ scaled_images.append(
+ F.interpolate(
+ pixel_values,
+ scale_factor=ratio,
+ mode="bilinear",
+ align_corners=False,
+ )
+ )
+
+ # STEP 2: create patches
+
+ for i in range(self.n_scaled_images):
+ scaled_images[i] = split_to_patches(
+ scaled_images[i],
+ patch_size=self.config.patch_size,
+ overlap_ratio=self.scaled_images_overlap_ratios[i],
+ )
+ n_patches_per_scaled_image = [len(i) for i in scaled_images]
+ patches = torch.cat(scaled_images[::-1], dim=0) # -1 as patch encoder expects high res patches first
+
+ # STEP 3: apply patch encoder
+
+ encodings = self.model(
+ # each patch is processed as a separate batch
+ patches,
+ head_mask=head_mask,
+ # required for intermediate features
+ output_hidden_states=self.n_intermediate_hooks > 0,
+ )
+
+ scaled_images_last_hidden_state = torch.split_with_sizes(encodings[0], n_patches_per_scaled_image[::-1])
+ # -1 (reverse list) as patch encoder returns high res patches first, we need low res first
+ scaled_images_last_hidden_state = scaled_images_last_hidden_state[::-1]
+
+ # calculate base height and width
+ # base height and width are the dimensions of the lowest resolution features
+ exponent_value = torch_int(math.log2(width / self.out_size))
+ base_height = height // 2**exponent_value
+ base_width = width // 2**exponent_value
+
+ # STEP 4: get patch features (high_res, med_res, low_res) - (3-5) in diagram
+
+ scaled_images_features = []
+ for i in range(self.n_scaled_images):
+ hidden_state = scaled_images_last_hidden_state[i]
+ batch_size = batch_size
+ padding = torch_int(self.merge_padding_value * (1 / self.scaled_images_ratios[i]))
+ output_height = base_height * 2**i
+ output_width = base_width * 2**i
+ features = reconstruct_feature_maps(
+ hidden_state,
+ batch_size=batch_size,
+ padding=padding,
+ output_size=(output_height, output_width),
+ )
+ scaled_images_features.append(features)
+
+ # STEP 5: get intermediate features - (1-2) in diagram
+
+ intermediate_features = []
+ for i in range(self.n_intermediate_hooks):
+ # +1 to correct index position as hidden_states contain embedding output as well
+ hidden_state = encodings[2][self.intermediate_hook_ids[i] + 1]
+ padding = torch_int(self.merge_padding_value * (1 / self.scaled_images_ratios[-1]))
+ output_height = base_height * 2 ** (self.n_scaled_images - 1)
+ output_width = base_width * 2 ** (self.n_scaled_images - 1)
+ features = reconstruct_feature_maps(
+ hidden_state,
+ batch_size=batch_size,
+ padding=padding,
+ output_size=(output_height, output_width),
+ )
+ intermediate_features.append(features)
+
+ # STEP 7: combine all features
+ features = [*scaled_images_features, *intermediate_features]
+
+ return features
+
+
+class DepthProImageEncoder(nn.Module):
+ def __init__(self, config: DepthProConfig):
+ super().__init__()
+ self.config = config
+ self.out_size = config.image_model_config.image_size // config.image_model_config.patch_size
+
+ self.model = AutoModel.from_config(config.image_model_config)
+
+ def forward(
+ self,
+ pixel_values: torch.Tensor,
+ head_mask: Optional[torch.Tensor] = None,
+ output_attentions: bool = False,
+ output_hidden_states: bool = False,
+ return_dict: bool = True,
+ ) -> Union[tuple, DepthProOutput]:
+ batch_size, num_channels, height, width = pixel_values.shape
+
+ # scale the image for image_encoder
+ size = self.config.image_model_config.image_size
+ pixel_values = F.interpolate(
+ pixel_values,
+ size=(size, size),
+ mode="bilinear",
+ align_corners=False,
+ )
+ encodings = self.model(
+ pixel_values=pixel_values,
+ head_mask=head_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ )
+
+ # calculate base height and width
+ # base height and width are the dimensions of the lowest resolution features
+ exponent_value = torch_int(math.log2(width / self.out_size))
+ base_height = height // 2**exponent_value
+ base_width = width // 2**exponent_value
+
+ features = reconstruct_feature_maps(
+ encodings[0],
+ batch_size=batch_size,
+ padding=0,
+ output_size=(base_height, base_width),
+ )
+
+ if not return_dict:
+ return (encodings[0], features) + encodings[2:] # ignore last_hidden_state and poooler output
+
+ return DepthProOutput(
+ last_hidden_state=encodings.last_hidden_state,
+ features=features,
+ hidden_states=encodings.hidden_states,
+ attentions=encodings.attentions,
+ )
+
+
+class DepthProEncoder(nn.Module):
+ def __init__(self, config: DepthProConfig):
+ super().__init__()
+ self.config = config
+ self.intermediate_hook_ids = config.intermediate_hook_ids
+ self.intermediate_feature_dims = config.intermediate_feature_dims
+ self.scaled_images_ratios = config.scaled_images_ratios
+ self.scaled_images_overlap_ratios = config.scaled_images_overlap_ratios
+ self.scaled_images_feature_dims = config.scaled_images_feature_dims
+ self.merge_padding_value = config.merge_padding_value
+
+ self.n_scaled_images = len(self.scaled_images_ratios)
+ self.n_intermediate_hooks = len(self.intermediate_hook_ids)
+
+ self.patch_encoder = DepthProPatchEncoder(config)
+ self.image_encoder = DepthProImageEncoder(config)
+
+ def forward(
+ self,
+ pixel_values: torch.Tensor,
+ head_mask: Optional[torch.Tensor] = None,
+ output_attentions: bool = False,
+ output_hidden_states: bool = False,
+ return_dict: bool = True,
+ ) -> Union[tuple, DepthProOutput]:
+ batch_size, num_channels, height, width = pixel_values.shape
+
+ patch_features = self.patch_encoder(
+ pixel_values,
+ head_mask=head_mask,
+ )
+ image_encodings = self.image_encoder(
+ pixel_values,
+ head_mask=head_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ image_features = image_encodings[1] # index 1 contains features
+
+ features = [image_features, *patch_features]
+
+ if not return_dict:
+ return (image_encodings[0], features) + image_encodings[2:]
+
+ return DepthProOutput(
+ last_hidden_state=image_encodings.last_hidden_state,
+ features=features,
+ hidden_states=image_encodings.hidden_states,
+ attentions=image_encodings.attentions,
+ )
+
+
+class DepthProFeatureUpsampleBlock(nn.Module):
+ def __init__(
+ self,
+ config: DepthProConfig,
+ input_dims: int,
+ intermediate_dims: int,
+ output_dims: int,
+ n_upsample_layers: int,
+ use_proj: bool = True,
+ bias: bool = False,
+ ):
+ super().__init__()
+ self.config = config
+ self.layers = nn.ModuleList()
+
+ # create first projection layer
+ if use_proj:
+ proj = nn.Conv2d(
+ in_channels=input_dims,
+ out_channels=intermediate_dims,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ bias=bias,
+ )
+ self.layers.append(proj)
+
+ # create following upsample layers
+ for i in range(n_upsample_layers):
+ in_channels = intermediate_dims if i == 0 else output_dims
+ layer = nn.ConvTranspose2d(
+ in_channels=in_channels,
+ out_channels=output_dims,
+ kernel_size=2,
+ stride=2,
+ padding=0,
+ bias=bias,
+ )
+ self.layers.append(layer)
+
+ def forward(self, features: torch.Tensor) -> torch.Tensor:
+ for layer in self.layers:
+ features = layer(features)
+ return features
+
+
+class DepthProFeatureUpsample(nn.Module):
+ def __init__(self, config: DepthProConfig):
+ super().__init__()
+ self.config = config
+ self.n_scaled_images = len(self.config.scaled_images_ratios)
+ self.n_intermediate_hooks = len(self.config.intermediate_hook_ids)
+
+ # for image_features
+ self.image_block = DepthProFeatureUpsampleBlock(
+ config=config,
+ input_dims=config.image_model_config.hidden_size,
+ intermediate_dims=config.image_model_config.hidden_size,
+ output_dims=config.scaled_images_feature_dims[0],
+ n_upsample_layers=1,
+ use_proj=False,
+ bias=True,
+ )
+
+ # for scaled_images_features
+ self.scaled_images = nn.ModuleList()
+ for i, feature_dims in enumerate(config.scaled_images_feature_dims):
+ block = DepthProFeatureUpsampleBlock(
+ config=config,
+ input_dims=config.patch_model_config.hidden_size,
+ intermediate_dims=feature_dims,
+ output_dims=feature_dims,
+ n_upsample_layers=1,
+ )
+ self.scaled_images.append(block)
+
+ # for intermediate_features
+ self.intermediate = nn.ModuleList()
+ for i, feature_dims in enumerate(config.intermediate_feature_dims):
+ intermediate_dims = config.fusion_hidden_size if i == 0 else feature_dims
+ block = DepthProFeatureUpsampleBlock(
+ config=config,
+ input_dims=config.patch_model_config.hidden_size,
+ intermediate_dims=intermediate_dims,
+ output_dims=feature_dims,
+ n_upsample_layers=2 + i,
+ )
+ self.intermediate.append(block)
+
+ def forward(self, features: List[torch.Tensor]) -> List[torch.Tensor]:
+ features[0] = self.image_block(features[0])
+
+ for i in range(self.n_scaled_images):
+ features[i + 1] = self.scaled_images[i](features[i + 1])
+
+ for i in range(self.n_intermediate_hooks):
+ features[self.n_scaled_images + i + 1] = self.intermediate[i](features[self.n_scaled_images + i + 1])
+
+ return features
+
+
+class DepthProFeatureProjection(nn.Module):
+ def __init__(self, config: DepthProConfig):
+ super().__init__()
+ self.config = config
+
+ combined_feature_dims = config.scaled_images_feature_dims + config.intermediate_feature_dims
+ self.projections = nn.ModuleList()
+ for i, in_channels in enumerate(combined_feature_dims):
+ if i == len(combined_feature_dims) - 1 and in_channels == config.fusion_hidden_size:
+ # projection for last layer can be ignored if input and output channels already match
+ self.projections.append(nn.Identity())
+ else:
+ self.projections.append(
+ nn.Conv2d(
+ in_channels=in_channels,
+ out_channels=config.fusion_hidden_size,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ bias=False,
+ )
+ )
+
+ def forward(self, features: List[torch.Tensor]) -> List[torch.Tensor]:
+ projected_features = []
+ for i, projection in enumerate(self.projections):
+ upsampled_feature = projection(features[i])
+ projected_features.append(upsampled_feature)
+ return projected_features
+
+
+class DepthProNeck(nn.Module):
+ def __init__(self, config: DepthProConfig):
+ super().__init__()
+ self.config = config
+
+ self.feature_upsample = DepthProFeatureUpsample(config)
+ self.fuse_image_with_low_res = nn.Conv2d(
+ in_channels=config.scaled_images_feature_dims[0] * 2,
+ out_channels=config.scaled_images_feature_dims[0],
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ bias=True,
+ )
+ self.feature_projection = DepthProFeatureProjection(config)
+
+ def forward(self, features: List[torch.Tensor]) -> List[torch.Tensor]:
+ features = self.feature_upsample(features)
+ # global features = low res features + image features
+ global_features = torch.cat((features[1], features[0]), dim=1)
+ global_features = self.fuse_image_with_low_res(global_features)
+ features = [global_features, *features[2:]]
+ features = self.feature_projection(features)
+ return features
+
+
+# General docstring
+_CONFIG_FOR_DOC = "DepthProConfig"
+
+
+DEPTH_PRO_START_DOCSTRING = r"""
+ This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
+ as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
+ behavior.
+
+ Parameters:
+ config ([`DepthProConfig`]): Model configuration class with all the parameters of the model.
+ Initializing with a config file does not load the weights associated with the model, only the
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+DEPTH_PRO_INPUTS_DOCSTRING = r"""
+ Args:
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
+ Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See [`DPTImageProcessor.__call__`]
+ for details.
+
+ head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple.
+"""
+
+DEPTH_PRO_FOR_DEPTH_ESTIMATION_START_DOCSTRING = r"""
+ This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
+ as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
+ behavior.
+
+ Parameters:
+ config ([`DepthProConfig`]): Model configuration class with all the parameters of the model.
+ Initializing with a config file does not load the weights associated with the model, only the
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+ use_fov_model (`bool`, *optional*, defaults to `True`):
+ Whether to use `DepthProFovModel` to generate the field of view.
+"""
+
+
+class DepthProPreTrainedModel(PreTrainedModel):
+ """
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+ models.
+ """
+
+ config_class = DepthProConfig
+ base_model_prefix = "depth_pro"
+ main_input_name = "pixel_values"
+ supports_gradient_checkpointing = True
+ _supports_sdpa = True
+ _no_split_modules = ["DepthProPreActResidualLayer"]
+ _keys_to_ignore_on_load_unexpected = ["fov_model.*"]
+
+ def _init_weights(self, module):
+ """Initialize the weights"""
+ if isinstance(module, nn.Linear):
+ # Slightly different from the TF version which uses truncated_normal for initialization
+ # cf https://github.com/pytorch/pytorch/pull/5617
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.LayerNorm):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+ elif isinstance(module, (nn.Conv2d, nn.ConvTranspose2d)):
+ nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu")
+ if module.bias is not None:
+ module.bias.data.zero_()
+
+
+@add_start_docstrings(
+ "The bare DepthPro Model transformer outputting raw hidden-states without any specific head on top.",
+ DEPTH_PRO_START_DOCSTRING,
+)
+class DepthProModel(DepthProPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+ self.config = config
+ self.encoder = DepthProEncoder(config)
+ self.neck = DepthProNeck(config)
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.encoder.image_encoder.model.get_input_embeddings()
+
+ @add_start_docstrings_to_model_forward(DEPTH_PRO_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=BaseModelOutput, config_class=_CONFIG_FOR_DOC)
+ def forward(
+ self,
+ pixel_values: torch.FloatTensor,
+ head_mask: Optional[torch.FloatTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, DepthProOutput]:
+ r"""
+ Returns:
+
+ Examples:
+
+ ```python
+ >>> import torch
+ >>> from PIL import Image
+ >>> import requests
+ >>> from transformers import AutoProcessor, DepthProModel
+
+ >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
+ >>> image = Image.open(requests.get(url, stream=True).raw)
+
+ >>> checkpoint = "apple/DepthPro-hf"
+ >>> processor = AutoProcessor.from_pretrained(checkpoint)
+ >>> model = DepthProModel.from_pretrained(checkpoint)
+
+ >>> # prepare image for the model
+ >>> inputs = processor(images=image, return_tensors="pt")
+
+ >>> with torch.no_grad():
+ ... output = model(**inputs)
+
+ >>> output.last_hidden_state.shape
+ torch.Size([1, 35, 577, 1024])
+ ```"""
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ encodings = self.encoder(
+ pixel_values,
+ head_mask=head_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ features = encodings[1] # index 1 contains features
+ features = self.neck(features)
+
+ if not return_dict:
+ return (encodings[0], features) + encodings[2:]
+
+ return DepthProOutput(
+ last_hidden_state=encodings.last_hidden_state,
+ features=features,
+ hidden_states=encodings.hidden_states,
+ attentions=encodings.attentions,
+ )
+
+
+# Copied from transformers.models.dpt.modeling_dpt.DPTPreActResidualLayer DPT->DepthPro
+class DepthProPreActResidualLayer(nn.Module):
+ """
+ ResidualConvUnit, pre-activate residual unit.
+
+ Args:
+ config (`[DepthProConfig]`):
+ Model configuration class defining the model architecture.
+ """
+
+ def __init__(self, config):
+ super().__init__()
+
+ self.use_batch_norm = config.use_batch_norm_in_fusion_residual
+ use_bias_in_fusion_residual = (
+ config.use_bias_in_fusion_residual
+ if config.use_bias_in_fusion_residual is not None
+ else not self.use_batch_norm
+ )
+
+ self.activation1 = nn.ReLU()
+ self.convolution1 = nn.Conv2d(
+ config.fusion_hidden_size,
+ config.fusion_hidden_size,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ bias=use_bias_in_fusion_residual,
+ )
+
+ self.activation2 = nn.ReLU()
+ self.convolution2 = nn.Conv2d(
+ config.fusion_hidden_size,
+ config.fusion_hidden_size,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ bias=use_bias_in_fusion_residual,
+ )
+
+ if self.use_batch_norm:
+ self.batch_norm1 = nn.BatchNorm2d(config.fusion_hidden_size)
+ self.batch_norm2 = nn.BatchNorm2d(config.fusion_hidden_size)
+
+ def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
+ residual = hidden_state
+ hidden_state = self.activation1(hidden_state)
+
+ hidden_state = self.convolution1(hidden_state)
+
+ if self.use_batch_norm:
+ hidden_state = self.batch_norm1(hidden_state)
+
+ hidden_state = self.activation2(hidden_state)
+ hidden_state = self.convolution2(hidden_state)
+
+ if self.use_batch_norm:
+ hidden_state = self.batch_norm2(hidden_state)
+
+ return hidden_state + residual
+
+
+# Modified from transformers.models.dpt.modeling_dpt.DPTFeatureFusionLayer
+# except it uses deconv and skip_add and needs no interpolation
+class DepthProFeatureFusionLayer(nn.Module):
+ def __init__(self, config: DepthProConfig, use_deconv: bool = True):
+ super().__init__()
+ self.config = config
+ self.use_deconv = use_deconv
+
+ self.residual_layer1 = DepthProPreActResidualLayer(config)
+ self.residual_layer2 = DepthProPreActResidualLayer(config)
+
+ if self.use_deconv:
+ self.deconv = nn.ConvTranspose2d(
+ in_channels=config.fusion_hidden_size,
+ out_channels=config.fusion_hidden_size,
+ kernel_size=2,
+ stride=2,
+ padding=0,
+ bias=False,
+ )
+
+ self.projection = nn.Conv2d(config.fusion_hidden_size, config.fusion_hidden_size, kernel_size=1, bias=True)
+
+ def forward(self, hidden_state: torch.Tensor, residual: Optional[torch.Tensor] = None) -> torch.Tensor:
+ if residual is not None:
+ residual = self.residual_layer1(residual)
+ hidden_state = hidden_state + residual
+
+ hidden_state = self.residual_layer2(hidden_state)
+ if self.use_deconv:
+ hidden_state = self.deconv(hidden_state)
+ hidden_state = self.projection(hidden_state)
+
+ return hidden_state
+
+
+# Modified from transformers.models.dpt.modeling_dpt.DPTFeatureFusionStage with DPT->DepthPro
+# with deconv and reversed layers
+class DepthProFeatureFusionStage(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+
+ self.num_layers = len(config.intermediate_hook_ids) + len(config.scaled_images_ratios)
+ self.intermediate = nn.ModuleList()
+ for _ in range(self.num_layers - 1):
+ self.intermediate.append(DepthProFeatureFusionLayer(config))
+
+ # final layer doesnot require deconvolution
+ self.final = DepthProFeatureFusionLayer(config, use_deconv=False)
+
+ def forward(self, hidden_states: List[torch.Tensor]) -> List[torch.Tensor]:
+ if self.num_layers != len(hidden_states):
+ raise ValueError(
+ f"num_layers={self.num_layers} in DepthProFeatureFusionStage"
+ f"doesnot match len(hidden_states)={len(hidden_states)}"
+ )
+
+ fused_hidden_states = []
+ fused_hidden_state = None
+ for hidden_state, layer in zip(hidden_states[:-1], self.intermediate):
+ if fused_hidden_state is None:
+ # first layer only uses the last hidden_state
+ fused_hidden_state = layer(hidden_state)
+ else:
+ fused_hidden_state = layer(fused_hidden_state, hidden_state)
+ fused_hidden_states.append(fused_hidden_state)
+
+ hidden_state = hidden_states[-1]
+ fused_hidden_state = self.final(fused_hidden_state, hidden_state)
+ fused_hidden_states.append(fused_hidden_state)
+
+ return fused_hidden_states
+
+
+class DepthProFovEncoder(nn.Module):
+ def __init__(self, config: DepthProConfig):
+ super().__init__()
+ self.config = config
+ self.out_size = config.image_model_config.image_size // config.image_model_config.patch_size
+
+ self.model = AutoModel.from_config(config.fov_model_config)
+ self.neck = nn.Linear(config.fov_model_config.hidden_size, config.fusion_hidden_size // 2)
+
+ def forward(
+ self,
+ pixel_values: torch.Tensor,
+ head_mask: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ batch_size, num_channels, height, width = pixel_values.shape
+
+ # scale the image for fov_encoder
+ size = self.config.fov_model_config.image_size
+ pixel_values = F.interpolate(
+ pixel_values,
+ size=(size, size),
+ mode="bilinear",
+ align_corners=False,
+ )
+ encodings = self.model(
+ pixel_values=pixel_values,
+ head_mask=head_mask,
+ )
+ hidden_state = encodings[0]
+ hidden_state = self.neck(hidden_state)
+
+ # calculate base height and width
+ # base height and width are the dimensions of the lowest resolution features
+ exponent_value = torch_int(math.log2(width / self.out_size))
+ base_height = height // 2**exponent_value
+ base_width = width // 2**exponent_value
+
+ features = reconstruct_feature_maps(
+ hidden_state,
+ batch_size=batch_size,
+ padding=0,
+ output_size=(base_height, base_width),
+ )
+
+ return features
+
+
+class DepthProFovHead(nn.Module):
+ def __init__(self, config: DepthProConfig):
+ super().__init__()
+ self.config = config
+ self.fusion_hidden_size = config.fusion_hidden_size
+ self.out_size = config.image_model_config.image_size // config.image_model_config.patch_size
+
+ # create initial head layers
+ self.layers = nn.ModuleList()
+ for i in range(config.num_fov_head_layers):
+ self.layers.append(
+ nn.Conv2d(
+ math.ceil(self.fusion_hidden_size / 2 ** (i + 1)),
+ math.ceil(self.fusion_hidden_size / 2 ** (i + 2)),
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ )
+ )
+ self.layers.append(nn.ReLU(True))
+ # calculate expected shapes to finally generate a scalar output from final head layer
+ final_in_channels = math.ceil(self.fusion_hidden_size / 2 ** (config.num_fov_head_layers + 1))
+ final_kernel_size = torch_int((self.out_size - 1) / 2**config.num_fov_head_layers + 1)
+ self.layers.append(
+ nn.Conv2d(
+ in_channels=final_in_channels, out_channels=1, kernel_size=final_kernel_size, stride=1, padding=0
+ )
+ )
+
+ def forward(self, features: torch.Tensor) -> torch.Tensor:
+ features = F.interpolate(
+ features,
+ size=(self.out_size, self.out_size),
+ mode="bilinear",
+ align_corners=False,
+ )
+ for layer in self.layers:
+ features = layer(features)
+ return features
+
+
+class DepthProFovModel(nn.Module):
+ def __init__(self, config: DepthProConfig):
+ super().__init__()
+ self.config = config
+ self.fusion_hidden_size = config.fusion_hidden_size
+
+ self.fov_encoder = DepthProFovEncoder(config)
+ self.conv = nn.Conv2d(
+ self.fusion_hidden_size, self.fusion_hidden_size // 2, kernel_size=3, stride=2, padding=1
+ )
+ self.activation = nn.ReLU(inplace=True)
+ self.head = DepthProFovHead(config)
+
+ def forward(
+ self,
+ pixel_values: torch.Tensor,
+ global_features: torch.Tensor,
+ head_mask: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ fov_features = self.fov_encoder(pixel_values, head_mask)
+
+ global_features = self.conv(global_features)
+ global_features = self.activation(global_features)
+
+ fov_features = fov_features + global_features
+ fov_output = self.head(fov_features)
+ fov_output = fov_output.flatten()
+
+ return fov_output
+
+
+class DepthProDepthEstimationHead(nn.Module):
+ """
+ The DepthProDepthEstimationHead module serves as the output head for depth estimation tasks.
+ This module comprises a sequence of convolutional and transposed convolutional layers
+ that process the feature map from the fusion to produce a single-channel depth map.
+ Key operations include dimensionality reduction and upsampling to match the input resolution.
+ """
+
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+
+ features = config.fusion_hidden_size
+ self.layers = nn.ModuleList(
+ [
+ nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1),
+ nn.ConvTranspose2d(
+ in_channels=features // 2,
+ out_channels=features // 2,
+ kernel_size=2,
+ stride=2,
+ padding=0,
+ bias=True,
+ ),
+ nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1),
+ nn.ReLU(True),
+ nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
+ nn.ReLU(),
+ ]
+ )
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ for layer in self.layers:
+ hidden_states = layer(hidden_states)
+
+ predicted_depth = hidden_states.squeeze(dim=1)
+ return predicted_depth
+
+
+@add_start_docstrings(
+ """
+ DepthPro Model with a depth estimation head on top (consisting of 3 convolutional layers).
+ """,
+ DEPTH_PRO_FOR_DEPTH_ESTIMATION_START_DOCSTRING,
+)
+class DepthProForDepthEstimation(DepthProPreTrainedModel):
+ def __init__(self, config, use_fov_model=None):
+ super().__init__(config)
+ self.config = config
+ self.use_fov_model = use_fov_model if use_fov_model is not None else self.config.use_fov_model
+
+ # dinov2 (vit) like encoders
+ self.depth_pro = DepthProModel(config)
+
+ # dpt (vit) like fusion stage
+ self.fusion_stage = DepthProFeatureFusionStage(config)
+
+ # depth estimation head
+ self.head = DepthProDepthEstimationHead(config)
+
+ # dinov2 (vit) like encoder
+ self.fov_model = DepthProFovModel(config) if self.use_fov_model else None
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @add_start_docstrings_to_model_forward(DEPTH_PRO_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=DepthProDepthEstimatorOutput, config_class=_CONFIG_FOR_DOC)
+ def forward(
+ self,
+ pixel_values: torch.FloatTensor,
+ head_mask: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple[torch.Tensor], DepthProDepthEstimatorOutput]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, height, width)`, *optional*):
+ Ground truth depth estimation maps for computing the loss.
+
+ Returns:
+
+ Examples:
+
+ ```python
+ >>> from transformers import AutoImageProcessor, DepthProForDepthEstimation
+ >>> import torch
+ >>> from PIL import Image
+ >>> import requests
+
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+ >>> image = Image.open(requests.get(url, stream=True).raw)
+
+ >>> checkpoint = "apple/DepthPro-hf"
+ >>> processor = AutoImageProcessor.from_pretrained(checkpoint)
+ >>> model = DepthProForDepthEstimation.from_pretrained(checkpoint)
+
+ >>> device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+ >>> model.to(device)
+
+ >>> # prepare image for the model
+ >>> inputs = processor(images=image, return_tensors="pt").to(device)
+
+ >>> with torch.no_grad():
+ ... outputs = model(**inputs)
+
+ >>> # interpolate to original size
+ >>> post_processed_output = processor.post_process_depth_estimation(
+ ... outputs, target_sizes=[(image.height, image.width)],
+ ... )
+
+ >>> # get the field of view (fov) predictions
+ >>> field_of_view = post_processed_output[0]["field_of_view"]
+ >>> focal_length = post_processed_output[0]["focal_length"]
+
+ >>> # visualize the prediction
+ >>> predicted_depth = post_processed_output[0]["predicted_depth"]
+ >>> depth = predicted_depth * 255 / predicted_depth.max()
+ >>> depth = depth.detach().cpu().numpy()
+ >>> depth = Image.fromarray(depth.astype("uint8"))
+ ```"""
+ loss = None
+ if labels is not None:
+ raise NotImplementedError("Training is not implemented yet")
+
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+
+ depth_pro_outputs = self.depth_pro(
+ pixel_values=pixel_values,
+ head_mask=head_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=True,
+ )
+ features = depth_pro_outputs.features
+ fused_hidden_states = self.fusion_stage(features)
+ predicted_depth = self.head(fused_hidden_states[-1])
+
+ if self.use_fov_model:
+ # frozen features from encoder are used
+ features_for_fov = features[0].detach()
+ fov = self.fov_model(
+ pixel_values=pixel_values,
+ global_features=features_for_fov,
+ head_mask=head_mask,
+ )
+ else:
+ fov = None
+
+ if not return_dict:
+ outputs = [loss, predicted_depth, fov, depth_pro_outputs.hidden_states, depth_pro_outputs.attentions]
+ return tuple(v for v in outputs if v is not None)
+
+ return DepthProDepthEstimatorOutput(
+ loss=loss,
+ predicted_depth=predicted_depth,
+ field_of_view=fov,
+ hidden_states=depth_pro_outputs.hidden_states,
+ attentions=depth_pro_outputs.attentions,
+ )
+
+
+__all__ = ["DepthProPreTrainedModel", "DepthProModel", "DepthProForDepthEstimation"]
diff --git a/src/transformers/models/dinov2/configuration_dinov2.py b/src/transformers/models/dinov2/configuration_dinov2.py
index dfc339f49da7..f4b29273a509 100644
--- a/src/transformers/models/dinov2/configuration_dinov2.py
+++ b/src/transformers/models/dinov2/configuration_dinov2.py
@@ -88,6 +88,8 @@ class Dinov2Config(BackboneConfigMixin, PretrainedConfig):
Whether to reshape the feature maps to 4D tensors of shape `(batch_size, hidden_size, height, width)` in
case the model is used as backbone. If `False`, the feature maps will be 3D tensors of shape `(batch_size,
seq_len, hidden_size)`.
+ use_mask_token (`bool`, *optional*, defaults to `True`):
+ Whether to use mask_token in embeddings.
Example:
@@ -128,6 +130,7 @@ def __init__(
out_indices=None,
apply_layernorm=True,
reshape_hidden_states=True,
+ use_mask_token=True,
**kwargs,
):
super().__init__(**kwargs)
@@ -154,6 +157,7 @@ def __init__(
)
self.apply_layernorm = apply_layernorm
self.reshape_hidden_states = reshape_hidden_states
+ self.use_mask_token = use_mask_token
class Dinov2OnnxConfig(OnnxConfig):
diff --git a/src/transformers/models/dinov2/modeling_dinov2.py b/src/transformers/models/dinov2/modeling_dinov2.py
index 71e0029d22d8..33ec1c054990 100644
--- a/src/transformers/models/dinov2/modeling_dinov2.py
+++ b/src/transformers/models/dinov2/modeling_dinov2.py
@@ -67,12 +67,14 @@ def __init__(self, config: Dinov2Config) -> None:
super().__init__()
self.cls_token = nn.Parameter(torch.randn(1, 1, config.hidden_size))
- self.mask_token = nn.Parameter(torch.zeros(1, config.hidden_size))
+ if config.use_mask_token:
+ self.mask_token = nn.Parameter(torch.zeros(1, config.hidden_size))
self.patch_embeddings = Dinov2PatchEmbeddings(config)
num_patches = self.patch_embeddings.num_patches
self.position_embeddings = nn.Parameter(torch.randn(1, num_patches + 1, config.hidden_size))
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.patch_size = config.patch_size
+ self.use_mask_token = config.use_mask_token
self.config = config
def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
@@ -120,7 +122,7 @@ def forward(self, pixel_values: torch.Tensor, bool_masked_pos: Optional[torch.Te
target_dtype = self.patch_embeddings.projection.weight.dtype
embeddings = self.patch_embeddings(pixel_values.to(dtype=target_dtype))
- if bool_masked_pos is not None:
+ if bool_masked_pos is not None and self.use_mask_token:
embeddings = torch.where(
bool_masked_pos.unsqueeze(-1), self.mask_token.to(embeddings.dtype).unsqueeze(0), embeddings
)
diff --git a/src/transformers/models/dinov2/modeling_flax_dinov2.py b/src/transformers/models/dinov2/modeling_flax_dinov2.py
index 82d1bf95fa40..cf2a6e04c4ea 100644
--- a/src/transformers/models/dinov2/modeling_flax_dinov2.py
+++ b/src/transformers/models/dinov2/modeling_flax_dinov2.py
@@ -136,11 +136,12 @@ def setup(self):
jax.nn.initializers.variance_scaling(self.config.initializer_range**2, "fan_in", "truncated_normal"),
(1, 1, self.config.hidden_size),
)
- self.mask_token = self.param(
- "mask_token",
- jax.nn.initializers.variance_scaling(self.config.initializer_range**2, "fan_in", "truncated_normal"),
- (1, self.config.hidden_size),
- )
+ if self.config.use_mask_token:
+ self.mask_token = self.param(
+ "mask_token",
+ jax.nn.initializers.variance_scaling(self.config.initializer_range**2, "fan_in", "truncated_normal"),
+ (1, self.config.hidden_size),
+ )
self.patch_embeddings = FlaxDinov2PatchEmbeddings(self.config, dtype=self.dtype)
num_patches = self.patch_embeddings.num_patches
self.position_embeddings = self.param(
diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py
index 349ad988df40..8f4f8dfb9767 100644
--- a/src/transformers/utils/dummy_pt_objects.py
+++ b/src/transformers/utils/dummy_pt_objects.py
@@ -3551,6 +3551,27 @@ def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
+class DepthProForDepthEstimation(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
+class DepthProModel(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
+class DepthProPreTrainedModel(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
class DetrForObjectDetection(metaclass=DummyObject):
_backends = ["torch"]
diff --git a/src/transformers/utils/dummy_torchvision_objects.py b/src/transformers/utils/dummy_torchvision_objects.py
index f1b75efc2071..87b60fbc0463 100644
--- a/src/transformers/utils/dummy_torchvision_objects.py
+++ b/src/transformers/utils/dummy_torchvision_objects.py
@@ -44,6 +44,13 @@ def __init__(self, *args, **kwargs):
requires_backends(self, ["torchvision"])
+class DepthProImageProcessorFast(metaclass=DummyObject):
+ _backends = ["torchvision"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torchvision"])
+
+
class DetrImageProcessorFast(metaclass=DummyObject):
_backends = ["torchvision"]
diff --git a/src/transformers/utils/dummy_vision_objects.py b/src/transformers/utils/dummy_vision_objects.py
index f9802ba42bbf..aeccf53742ae 100644
--- a/src/transformers/utils/dummy_vision_objects.py
+++ b/src/transformers/utils/dummy_vision_objects.py
@@ -184,6 +184,20 @@ def __init__(self, *args, **kwargs):
requires_backends(self, ["vision"])
+class DepthProImageProcessor(metaclass=DummyObject):
+ _backends = ["vision"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["vision"])
+
+
+class DepthProImageProcessorFast(metaclass=DummyObject):
+ _backends = ["vision"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["vision"])
+
+
class DetrFeatureExtractor(metaclass=DummyObject):
_backends = ["vision"]
diff --git a/tests/models/depth_pro/__init__.py b/tests/models/depth_pro/__init__.py
new file mode 100644
index 000000000000..e69de29bb2d1
diff --git a/tests/models/depth_pro/test_image_processing_depth_pro.py b/tests/models/depth_pro/test_image_processing_depth_pro.py
new file mode 100644
index 000000000000..b30931a86cdb
--- /dev/null
+++ b/tests/models/depth_pro/test_image_processing_depth_pro.py
@@ -0,0 +1,124 @@
+# coding=utf-8
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import unittest
+
+from transformers.testing_utils import is_flaky, require_torch, require_vision
+from transformers.utils import is_torchvision_available, is_vision_available
+
+from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs
+
+
+if is_vision_available():
+ from transformers import DepthProImageProcessor
+
+ if is_torchvision_available():
+ from transformers import DepthProImageProcessorFast
+
+
+class DepthProImageProcessingTester(unittest.TestCase):
+ def __init__(
+ self,
+ parent,
+ batch_size=7,
+ num_channels=3,
+ image_size=18,
+ min_resolution=30,
+ max_resolution=400,
+ do_resize=True,
+ size=None,
+ do_rescale=True,
+ do_normalize=True,
+ image_mean=[0.5, 0.5, 0.5],
+ image_std=[0.5, 0.5, 0.5],
+ ):
+ super().__init__()
+ size = size if size is not None else {"height": 18, "width": 18}
+ self.parent = parent
+ self.batch_size = batch_size
+ self.num_channels = num_channels
+ self.image_size = image_size
+ self.min_resolution = min_resolution
+ self.max_resolution = max_resolution
+ self.do_resize = do_resize
+ self.size = size
+ self.do_rescale = do_rescale
+ self.do_normalize = do_normalize
+ self.image_mean = image_mean
+ self.image_std = image_std
+
+ def prepare_image_processor_dict(self):
+ return {
+ "image_mean": self.image_mean,
+ "image_std": self.image_std,
+ "do_rescale": self.do_rescale,
+ "do_normalize": self.do_normalize,
+ "do_resize": self.do_resize,
+ "size": self.size,
+ }
+
+ def expected_output_image_shape(self, images):
+ return self.num_channels, self.size["height"], self.size["width"]
+
+ def prepare_image_inputs(self, equal_resolution=False, numpify=False, torchify=False):
+ return prepare_image_inputs(
+ batch_size=self.batch_size,
+ num_channels=self.num_channels,
+ min_resolution=self.min_resolution,
+ max_resolution=self.max_resolution,
+ equal_resolution=equal_resolution,
+ numpify=numpify,
+ torchify=torchify,
+ )
+
+
+@require_torch
+@require_vision
+class DepthProImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
+ image_processing_class = DepthProImageProcessor if is_vision_available() else None
+ fast_image_processing_class = DepthProImageProcessorFast if is_torchvision_available() else None
+
+ def setUp(self):
+ super().setUp()
+ self.image_processor_tester = DepthProImageProcessingTester(self)
+
+ @property
+ def image_processor_dict(self):
+ return self.image_processor_tester.prepare_image_processor_dict()
+
+ def test_image_processor_properties(self):
+ image_processing = self.image_processing_class(**self.image_processor_dict)
+ self.assertTrue(hasattr(image_processing, "image_mean"))
+ self.assertTrue(hasattr(image_processing, "image_std"))
+ self.assertTrue(hasattr(image_processing, "do_normalize"))
+ self.assertTrue(hasattr(image_processing, "do_resize"))
+ self.assertTrue(hasattr(image_processing, "size"))
+ self.assertTrue(hasattr(image_processing, "do_rescale"))
+ self.assertTrue(hasattr(image_processing, "rescale_factor"))
+ self.assertTrue(hasattr(image_processing, "resample"))
+
+ def test_image_processor_from_dict_with_kwargs(self):
+ image_processor = self.image_processing_class.from_dict(self.image_processor_dict)
+ self.assertEqual(image_processor.size, {"height": 18, "width": 18})
+
+ image_processor = self.image_processing_class.from_dict(self.image_processor_dict, size=42)
+ self.assertEqual(image_processor.size, {"height": 42, "width": 42})
+
+ @is_flaky(
+ description="fast and slow, both processors use torch implementation, see: https://github.com/huggingface/transformers/issues/34920",
+ )
+ def test_fast_is_faster_than_slow(self):
+ super().test_fast_is_faster_than_slow()
diff --git a/tests/models/depth_pro/test_modeling_depth_pro.py b/tests/models/depth_pro/test_modeling_depth_pro.py
new file mode 100644
index 000000000000..44529270fd94
--- /dev/null
+++ b/tests/models/depth_pro/test_modeling_depth_pro.py
@@ -0,0 +1,398 @@
+# coding=utf-8
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Testing suite for the PyTorch DepthPro model."""
+
+import unittest
+
+from transformers import DepthProConfig
+from transformers.file_utils import is_torch_available, is_vision_available
+from transformers.testing_utils import require_torch, require_vision, slow, torch_device
+
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor, ids_tensor
+from ...test_pipeline_mixin import PipelineTesterMixin
+
+
+if is_torch_available():
+ import torch
+ from torch import nn
+
+ from transformers import DepthProForDepthEstimation, DepthProModel
+ from transformers.models.auto.modeling_auto import MODEL_MAPPING_NAMES
+
+
+if is_vision_available():
+ from PIL import Image
+
+ from transformers import DepthProImageProcessor
+
+
+class DepthProModelTester:
+ def __init__(
+ self,
+ parent,
+ batch_size=8,
+ image_size=64,
+ patch_size=16,
+ num_channels=3,
+ is_training=True,
+ use_labels=True,
+ fusion_hidden_size=16,
+ intermediate_hook_ids=[1, 0],
+ intermediate_feature_dims=[10, 8],
+ scaled_images_ratios=[0.5, 1.0],
+ scaled_images_overlap_ratios=[0.0, 0.2],
+ scaled_images_feature_dims=[12, 12],
+ initializer_range=0.02,
+ use_fov_model=False,
+ image_model_config={
+ "model_type": "dinov2",
+ "num_hidden_layers": 2,
+ "hidden_size": 16,
+ "num_attention_heads": 1,
+ "patch_size": 4,
+ },
+ patch_model_config={
+ "model_type": "vit",
+ "num_hidden_layers": 2,
+ "hidden_size": 24,
+ "num_attention_heads": 2,
+ "patch_size": 6,
+ },
+ fov_model_config={
+ "model_type": "vit",
+ "num_hidden_layers": 2,
+ "hidden_size": 32,
+ "num_attention_heads": 4,
+ "patch_size": 8,
+ },
+ num_labels=3,
+ ):
+ self.parent = parent
+ self.batch_size = batch_size
+ self.image_size = image_size
+ self.patch_size = patch_size
+ self.num_channels = num_channels
+ self.is_training = is_training
+ self.use_labels = use_labels
+ self.fusion_hidden_size = fusion_hidden_size
+ self.intermediate_hook_ids = intermediate_hook_ids
+ self.intermediate_feature_dims = intermediate_feature_dims
+ self.scaled_images_ratios = scaled_images_ratios
+ self.scaled_images_overlap_ratios = scaled_images_overlap_ratios
+ self.scaled_images_feature_dims = scaled_images_feature_dims
+ self.initializer_range = initializer_range
+ self.use_fov_model = use_fov_model
+ self.image_model_config = image_model_config
+ self.patch_model_config = patch_model_config
+ self.fov_model_config = fov_model_config
+ self.num_labels = num_labels
+
+ self.hidden_size = image_model_config["hidden_size"]
+ self.num_hidden_layers = image_model_config["num_hidden_layers"]
+ self.num_attention_heads = image_model_config["num_attention_heads"]
+
+ # may be different for a backbone other than dinov2
+ self.out_size = patch_size // image_model_config["patch_size"]
+ self.seq_length = self.out_size**2 + 1 # we add 1 for the [CLS] token
+
+ n_fusion_blocks = len(intermediate_hook_ids) + len(scaled_images_ratios)
+ self.expected_depth_size = 2 ** (n_fusion_blocks + 1) * self.out_size
+
+ def prepare_config_and_inputs(self):
+ pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
+
+ labels = None
+ if self.use_labels:
+ labels = ids_tensor([self.batch_size, self.image_size, self.image_size], self.num_labels)
+
+ config = self.get_config()
+
+ return config, pixel_values, labels
+
+ def get_config(self):
+ return DepthProConfig(
+ patch_size=self.patch_size,
+ fusion_hidden_size=self.fusion_hidden_size,
+ intermediate_hook_ids=self.intermediate_hook_ids,
+ intermediate_feature_dims=self.intermediate_feature_dims,
+ scaled_images_ratios=self.scaled_images_ratios,
+ scaled_images_overlap_ratios=self.scaled_images_overlap_ratios,
+ scaled_images_feature_dims=self.scaled_images_feature_dims,
+ initializer_range=self.initializer_range,
+ image_model_config=self.image_model_config,
+ patch_model_config=self.patch_model_config,
+ fov_model_config=self.fov_model_config,
+ use_fov_model=self.use_fov_model,
+ )
+
+ def create_and_check_model(self, config, pixel_values, labels):
+ model = DepthProModel(config=config)
+ model.to(torch_device)
+ model.eval()
+ result = model(pixel_values)
+ self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
+
+ def create_and_check_for_depth_estimation(self, config, pixel_values, labels):
+ config.num_labels = self.num_labels
+ model = DepthProForDepthEstimation(config)
+ model.to(torch_device)
+ model.eval()
+ result = model(pixel_values)
+ self.parent.assertEqual(
+ result.predicted_depth.shape, (self.batch_size, self.expected_depth_size, self.expected_depth_size)
+ )
+
+ def create_and_check_for_fov(self, config, pixel_values, labels):
+ model = DepthProForDepthEstimation(config, use_fov_model=True)
+ model.to(torch_device)
+ model.eval()
+
+ # check if the fov_model (DinoV2-based encoder) is created
+ self.parent.assertIsNotNone(model.fov_model)
+
+ batched_pixel_values = pixel_values
+ row_pixel_values = pixel_values[:1]
+
+ with torch.no_grad():
+ model_batched_output_fov = model(batched_pixel_values).field_of_view
+ model_row_output_fov = model(row_pixel_values).field_of_view
+
+ # check if fov is returned
+ self.parent.assertIsNotNone(model_batched_output_fov)
+ self.parent.assertIsNotNone(model_row_output_fov)
+
+ # check output shape consistency for fov
+ self.parent.assertEqual(model_batched_output_fov.shape, (self.batch_size,))
+
+ # check equivalence between batched and single row outputs for fov
+ diff = torch.max(torch.abs(model_row_output_fov - model_batched_output_fov[:1]))
+ model_name = model.__class__.__name__
+ self.parent.assertTrue(
+ diff <= 1e-03,
+ msg=(f"Batched and Single row outputs are not equal in {model_name} for fov. " f"Difference={diff}."),
+ )
+
+ def prepare_config_and_inputs_for_common(self):
+ config_and_inputs = self.prepare_config_and_inputs()
+ config, pixel_values, labels = config_and_inputs
+ inputs_dict = {"pixel_values": pixel_values}
+ return config, inputs_dict
+
+
+@require_torch
+class DepthProModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
+ """
+ Here we also overwrite some of the tests of test_modeling_common.py, as DepthPro does not use input_ids, inputs_embeds,
+ attention_mask and seq_length.
+ """
+
+ all_model_classes = (DepthProModel, DepthProForDepthEstimation) if is_torch_available() else ()
+ pipeline_model_mapping = (
+ {
+ "depth-estimation": DepthProForDepthEstimation,
+ "image-feature-extraction": DepthProModel,
+ }
+ if is_torch_available()
+ else {}
+ )
+
+ test_pruning = False
+ test_resize_embeddings = False
+ test_head_masking = False
+
+ def setUp(self):
+ self.model_tester = DepthProModelTester(self)
+ self.config_tester = ConfigTester(self, config_class=DepthProConfig, has_text_modality=False, hidden_size=37)
+
+ def test_config(self):
+ self.config_tester.run_common_tests()
+
+ @unittest.skip(reason="DepthPro does not use inputs_embeds")
+ def test_inputs_embeds(self):
+ pass
+
+ def test_model_get_set_embeddings(self):
+ config, _ = self.model_tester.prepare_config_and_inputs_for_common()
+
+ for model_class in self.all_model_classes:
+ model = model_class(config)
+ self.assertIsInstance(model.get_input_embeddings(), (nn.Module))
+ x = model.get_output_embeddings()
+ self.assertTrue(x is None or isinstance(x, nn.Linear))
+
+ def test_model(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_model(*config_and_inputs)
+
+ def test_for_depth_estimation(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_for_depth_estimation(*config_and_inputs)
+
+ def test_for_fov(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_for_fov(*config_and_inputs)
+
+ def test_training(self):
+ for model_class in self.all_model_classes:
+ if model_class.__name__ == "DepthProForDepthEstimation":
+ continue
+
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ config.return_dict = True
+
+ if model_class.__name__ in MODEL_MAPPING_NAMES.values():
+ continue
+
+ model = model_class(config)
+ model.to(torch_device)
+ model.train()
+ inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
+ loss = model(**inputs).loss
+ loss.backward()
+
+ def test_training_gradient_checkpointing(self):
+ for model_class in self.all_model_classes:
+ if model_class.__name__ == "DepthProForDepthEstimation":
+ continue
+
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ config.use_cache = False
+ config.return_dict = True
+
+ if model_class.__name__ in MODEL_MAPPING_NAMES.values() or not model_class.supports_gradient_checkpointing:
+ continue
+ model = model_class(config)
+ model.to(torch_device)
+ model.gradient_checkpointing_enable()
+ model.train()
+ inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
+ loss = model(**inputs).loss
+ loss.backward()
+
+ @unittest.skip(
+ reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
+ )
+ def test_training_gradient_checkpointing_use_reentrant(self):
+ pass
+
+ @unittest.skip(
+ reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
+ )
+ def test_training_gradient_checkpointing_use_reentrant_false(self):
+ pass
+
+ def test_initialization(self):
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+
+ configs_no_init = _config_zero_init(config)
+ for model_class in self.all_model_classes:
+ model = model_class(config=configs_no_init)
+ for name, param in model.named_parameters():
+ non_uniform_init_parms = [
+ # these encoders are vision transformers
+ # any layer outside these encoders is either Conv2d or ConvTranspose2d
+ # which use kaiming initialization
+ "patch_encoder",
+ "image_encoder",
+ "fov_model.encoder",
+ ]
+ if param.requires_grad:
+ if any(x in name for x in non_uniform_init_parms):
+ self.assertIn(
+ ((param.data.mean() * 1e9).round() / 1e9).item(),
+ [0.0, 1.0],
+ msg=f"Parameter {name} of model {model_class} seems not properly initialized",
+ )
+ else:
+ self.assertTrue(
+ -1.0 <= ((param.data.mean() * 1e9).round() / 1e9).item() <= 1.0,
+ msg=f"Parameter {name} of model {model_class} seems not properly initialized",
+ )
+
+ # this started when switched from normal initialization to kaiming_normal intialization
+ # maybe because the magnitude of offset values from ViT-encoders increases when followed by many convolution layers
+ def test_batching_equivalence(self, atol=1e-4, rtol=1e-4):
+ super().test_batching_equivalence(atol=atol, rtol=rtol)
+
+ @slow
+ def test_model_from_pretrained(self):
+ model_path = "apple/DepthPro-hf"
+ model = DepthProModel.from_pretrained(model_path)
+ self.assertIsNotNone(model)
+
+
+# We will verify our results on an image of cute cats
+def prepare_img():
+ image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
+ return image
+
+
+@require_torch
+@require_vision
+@slow
+class DepthProModelIntegrationTest(unittest.TestCase):
+ def test_inference_depth_estimation(self):
+ model_path = "apple/DepthPro-hf"
+ image_processor = DepthProImageProcessor.from_pretrained(model_path)
+ model = DepthProForDepthEstimation.from_pretrained(model_path).to(torch_device)
+ config = model.config
+
+ image = prepare_img()
+ inputs = image_processor(images=image, return_tensors="pt").to(torch_device)
+
+ # forward pass
+ with torch.no_grad():
+ outputs = model(**inputs)
+
+ # verify the predicted depth
+ n_fusion_blocks = len(config.intermediate_hook_ids) + len(config.scaled_images_ratios)
+ out_size = config.image_model_config.image_size // config.image_model_config.patch_size
+ expected_depth_size = 2 ** (n_fusion_blocks + 1) * out_size
+
+ expected_shape = torch.Size((1, expected_depth_size, expected_depth_size))
+ self.assertEqual(outputs.predicted_depth.shape, expected_shape)
+
+ expected_slice = torch.tensor(
+ [[1.0582, 1.1225, 1.1335], [1.1154, 1.1398, 1.1486], [1.1434, 1.1500, 1.1643]]
+ ).to(torch_device)
+ torch.testing.assert_close(outputs.predicted_depth[0, :3, :3], expected_slice, atol=1e-4, rtol=1e-4)
+
+ # verify the predicted fov
+ expected_shape = torch.Size((1,))
+ self.assertEqual(outputs.field_of_view.shape, expected_shape)
+
+ expected_slice = torch.tensor([47.2459]).to(torch_device)
+ torch.testing.assert_close(outputs.field_of_view, expected_slice, atol=1e-4, rtol=1e-4)
+
+ def test_post_processing_depth_estimation(self):
+ model_path = "apple/DepthPro-hf"
+ image_processor = DepthProImageProcessor.from_pretrained(model_path)
+ model = DepthProForDepthEstimation.from_pretrained(model_path)
+
+ image = prepare_img()
+ inputs = image_processor(images=image, return_tensors="pt")
+
+ # forward pass
+ with torch.no_grad():
+ outputs = model(**inputs)
+
+ outputs = image_processor.post_process_depth_estimation(
+ outputs,
+ target_sizes=[[image.height, image.width]],
+ )
+ predicted_depth = outputs[0]["predicted_depth"]
+ expected_shape = torch.Size((image.height, image.width))
+ self.assertTrue(predicted_depth.shape == expected_shape)