Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
170 commits
Select commit Hold shift + click to select a range
2986dc2
implement config and model building blocks
geetu040 Nov 3, 2024
1728a2f
refactor model architechture
geetu040 Nov 9, 2024
11ce50c
update model outputs
geetu040 Nov 12, 2024
27e9593
update init param to include use_fov_model
geetu040 Nov 16, 2024
e74a7f5
update param name in config
geetu040 Nov 16, 2024
8c2460b
fix hidden_states and attentions outputs for fov
geetu040 Nov 16, 2024
55f6ed3
sort config
geetu040 Nov 16, 2024
b25dffb
complete minor todos
geetu040 Nov 16, 2024
c225deb
update patching
geetu040 Nov 16, 2024
176932d
update config for encoder
geetu040 Nov 16, 2024
dcec522
fix config
geetu040 Nov 16, 2024
0384d2f
use correct defaults in config
geetu040 Nov 16, 2024
85e4f86
update merge for compatibility with different image size
geetu040 Nov 17, 2024
00e4aa3
restructure encoder for custom configuration
geetu040 Nov 21, 2024
6be242c
make fov model compatible with custom config
geetu040 Nov 21, 2024
0189108
replace word "decoder" with "fusion"
geetu040 Nov 21, 2024
7614e1a
weight conversion script
geetu040 Nov 24, 2024
7d323ce
fix fov squeeze
geetu040 Nov 25, 2024
6aaa59e
update conversion script (without test)
geetu040 Nov 25, 2024
263b773
upload ruff image processing
geetu040 Nov 25, 2024
17e5487
create fast image processing
geetu040 Nov 26, 2024
a8dd704
use torch interpolation for image processing
geetu040 Nov 26, 2024
261bbaf
complete post_process_depth_estimation
geetu040 Nov 26, 2024
a4b3556
config: fix imports and sort args
geetu040 Nov 26, 2024
f13c632
apply inference in weight conversion
geetu040 Nov 26, 2024
387ddd8
use mllama script instead for weight conversion
geetu040 Nov 27, 2024
9b67f9d
clean weight conversion script
geetu040 Nov 27, 2024
617c872
add depth-pro status in other files
geetu040 Nov 27, 2024
6e1c512
fill docstring in config
geetu040 Nov 27, 2024
12ee607
formatting
geetu040 Nov 27, 2024
d0a8733
more formatting
geetu040 Nov 27, 2024
e6b385a
formatting with ruff
geetu040 Nov 27, 2024
267e50f
formatting with style
geetu040 Nov 27, 2024
a1ec997
fix copied classes
geetu040 Nov 27, 2024
3c656f2
add examples; update weight convert script
geetu040 Nov 27, 2024
f6f6d3d
fix using check_table.py and isort
geetu040 Nov 29, 2024
b4575d0
fix config docstring
geetu040 Nov 29, 2024
c8d8a9e
add depth pro to sdpa docs
geetu040 Nov 29, 2024
77873de
undo unintentional changes in configuration_gemma.py
geetu040 Nov 29, 2024
5f2378d
minor fixes
geetu040 Nov 30, 2024
d51d0b1
test image processing
geetu040 Nov 30, 2024
082b055
fixes and tests
geetu040 Dec 2, 2024
16a3917
more fixes
geetu040 Dec 2, 2024
2408ec5
use output states from image_encoder instead
geetu040 Dec 3, 2024
be0c2a3
Revert "use output states from image_encoder instead"
geetu040 Dec 4, 2024
efed39f
make embeddings dynamic
geetu040 Dec 4, 2024
c3b14fb
reshape output hidden states and attentions as part of computation graph
geetu040 Dec 4, 2024
7cf2485
fix ruff formating
geetu040 Dec 4, 2024
0aa451d
fix docstring failure
geetu040 Dec 4, 2024
160afbf
use num_fov_head_layers in tests
geetu040 Dec 4, 2024
9d2be26
update doc
geetu040 Dec 4, 2024
e208459
check consistency with config
geetu040 Dec 4, 2024
0415722
ruff formatting
geetu040 Dec 4, 2024
402eedf
merge branch main
geetu040 Dec 4, 2024
f4e7404
update test case
geetu040 Dec 5, 2024
2c1cc10
fix ruff formatting
geetu040 Dec 5, 2024
4d94396
Merge branch 'main' into depth-pro
geetu040 Dec 5, 2024
871b80d
add tests for fov
geetu040 Dec 6, 2024
0ff0655
use interpolation in postprocess
geetu040 Dec 6, 2024
befa6cd
run and fix slow tests locally
geetu040 Dec 6, 2024
db16fe6
Merge branch 'main' into depth-pro
geetu040 Dec 6, 2024
99ac5e8
use scaled_images_features for image and fov encoder
geetu040 Dec 12, 2024
ebb62dd
return fused_hidden_states in fusion stage
geetu040 Dec 12, 2024
46c88e8
fix example
geetu040 Dec 12, 2024
2431358
fix ruff
geetu040 Dec 12, 2024
fd38841
Merge branch 'main' into depth-pro
geetu040 Dec 12, 2024
d9d3a49
fix copyright license for all files
geetu040 Dec 21, 2024
8f4c61f
add __all__ for each file
geetu040 Dec 21, 2024
8960535
minor fixes
geetu040 Dec 21, 2024
1ac1b84
return list in post_process_depth_estimation
geetu040 Dec 21, 2024
27bff69
minor fixes
geetu040 Dec 21, 2024
a69b5af
fix "ruff check"
geetu040 Dec 21, 2024
365a71d
update upsample and projection
geetu040 Dec 21, 2024
c009468
major changes: (image size and merge optimization)
geetu040 Dec 24, 2024
7bed369
Merge branch 'main' into depth-pro
geetu040 Dec 24, 2024
1563f06
fix push_to_hub option in weights conversion
geetu040 Dec 24, 2024
e194ae4
remove image_size in weights conversion
geetu040 Dec 24, 2024
a4889f2
major changes in the architecture
geetu040 Jan 14, 2025
be5087b
Merge branch "main"
geetu040 Jan 14, 2025
9e09a6f
placeholder for unused config attributes
geetu040 Jan 14, 2025
bf159b2
improve docs amid review
geetu040 Jan 14, 2025
fb41687
minor change in docs
geetu040 Jan 14, 2025
7fbb53e
further optimize merge
geetu040 Jan 15, 2025
558836c
fix formatting
geetu040 Jan 15, 2025
ed77f78
remove unused patch/batch convertion functions
geetu040 Jan 24, 2025
5bc4b31
use original F.interpolate
geetu040 Jan 24, 2025
628ff09
improve function naming
geetu040 Jan 24, 2025
e2996b6
minor chages
geetu040 Jan 24, 2025
8cb5c7a
rearchitect upsample block for improved modularity
geetu040 Jan 24, 2025
1ba3a4a
update upsample keys in weight conversion
geetu040 Jan 24, 2025
83706b8
improve padding in merge_patches
geetu040 Jan 25, 2025
004cdc2
use double-loop for merge
geetu040 Jan 25, 2025
922b3de
update comments
geetu040 Jan 25, 2025
0f01b08
create feature_extractor, reduce some forward code
geetu040 Jan 25, 2025
4d871a7
introduce config.use_mask_token in dinov2
geetu040 Jan 26, 2025
85f7e3a
minor fixes
geetu040 Jan 26, 2025
c0127d7
minor fixes for onnx
geetu040 Jan 26, 2025
1898459
update __init__ to latest format
geetu040 Jan 26, 2025
bcf1bf3
remove DepthProConfig.to_dict()
geetu040 Jan 26, 2025
09bffc3
major changes in backbone
geetu040 Jan 26, 2025
0936897
Merge branch 'main' into depth-pro
geetu040 Jan 26, 2025
c26dc99
update config in weight conversion
geetu040 Jan 26, 2025
5fb0bb7
formatting
geetu040 Jan 26, 2025
d741890
converted model is fp32
geetu040 Jan 26, 2025
03f137d
improve naming and docs for feature_extractor->reconstruct_feature_maps
geetu040 Jan 28, 2025
2b8ee8f
minor fixes; amid review
geetu040 Jan 28, 2025
774617a
create intermediate vars in func call
geetu040 Jan 28, 2025
b6d15ff
use torch.testing.assert_close
geetu040 Jan 28, 2025
425d63e
use ModuleList instead of Sequential and ModuleDict
geetu040 Jan 28, 2025
f415ee6
update docs
geetu040 Jan 28, 2025
2777305
Merge branch 'main' into depth-pro
geetu040 Jan 28, 2025
1a2dd3a
include fov in integraiton tests
geetu040 Jan 30, 2025
4cfebae
update docs
geetu040 Jan 30, 2025
9062767
improve initialization of convolution layers
geetu040 Jan 30, 2025
fcba6bd
fix unused fov keys
geetu040 Jan 30, 2025
56cd570
update tests
geetu040 Jan 30, 2025
e64d39a
Merge branch 'main' into depth-pro
geetu040 Jan 30, 2025
26b1391
ruff format
geetu040 Jan 30, 2025
8914549
Merge branch 'main' into depth-pro
geetu040 Jan 31, 2025
01247f8
fix test, amid kaimming initialization
geetu040 Jan 31, 2025
0b7e77f
add depthpro to toctree
geetu040 Jan 31, 2025
20b277d
add residual layer to _no_split_modules
geetu040 Jan 31, 2025
ff0e408
architecture rework
geetu040 Feb 1, 2025
1522c53
Update src/transformers/models/depth_pro/image_processing_depth_pro.py
geetu040 Feb 1, 2025
131817a
Update src/transformers/models/depth_pro/image_processing_depth_pro_f…
geetu040 Feb 1, 2025
72a1f0c
update docs
geetu040 Feb 1, 2025
aed7e3d
improve merge_patches
geetu040 Feb 1, 2025
405bee3
use flatten with fov_output
geetu040 Feb 1, 2025
a8528da
ruff formatting
geetu040 Feb 1, 2025
aed655c
Merge branch 'main' into depth-pro
geetu040 Feb 1, 2025
31383e1
update resources section in docs
geetu040 Feb 3, 2025
641cb84
fix typo "final_kernal_size"
geetu040 Feb 3, 2025
6af8a11
fix output typehint for DepthProDepthEstimator
geetu040 Feb 3, 2025
abd5307
residual operation in 2 steps
geetu040 Feb 3, 2025
8dc2751
use image_size instead of global patch_size in interpolation
geetu040 Feb 3, 2025
2f88694
replace all Sequential with ModuleList
geetu040 Feb 3, 2025
208ee26
update fov
geetu040 Feb 3, 2025
bc63511
update heads
geetu040 Feb 3, 2025
e33a531
fix and update conversion script for heads
geetu040 Feb 3, 2025
8c0e81a
ruff formatting
geetu040 Feb 3, 2025
524dda6
remove float32 conversion
geetu040 Feb 3, 2025
029dd9d
Merge branch 'main' into depth-pro
geetu040 Feb 3, 2025
a87d26a
use "Fov" instead of "FOV" in class names
geetu040 Feb 4, 2025
5fccbff
use "Fov" instead of "FOV" in config docs
geetu040 Feb 4, 2025
24f1413
remove prune_heads
geetu040 Feb 4, 2025
a3dab18
update fusion stage
geetu040 Feb 4, 2025
48eb534
use device in examples
geetu040 Feb 4, 2025
39ea929
Merge branch 'main' into depth-pro
geetu040 Feb 4, 2025
26db9ec
Merge branch 'main' into depth-pro
geetu040 Feb 5, 2025
ba37c91
update processor
geetu040 Feb 5, 2025
949ecb9
ruff fixes
geetu040 Feb 5, 2025
0e2861d
add do_rescale in image_processor_dict
geetu040 Feb 5, 2025
a6efedb
skip test: test_fast_is_faster_than_slow
geetu040 Feb 5, 2025
4d8f927
ruff formatting
geetu040 Feb 5, 2025
dd8de27
DepthProImageProcessorFast in other files
geetu040 Feb 5, 2025
75215ed
Merge branch 'main' into depth-pro
geetu040 Feb 5, 2025
ffb3a82
Merge branch 'main' into depth-pro
geetu040 Feb 5, 2025
5caa0bd
revert antialias removal
geetu040 Feb 5, 2025
3ae1134
add antialias in BaseImageProcessorFast
geetu040 Feb 5, 2025
8372ad9
Revert "revert antialias removal"
geetu040 Feb 5, 2025
666f3b7
Revert "add antialias in BaseImageProcessorFast"
geetu040 Feb 5, 2025
41180e3
update processor for grouping and antialias
geetu040 Feb 5, 2025
1265b12
try test_fast_is_faster_than_slow without "skip" or "flanky"
geetu040 Feb 5, 2025
86c4604
Merge branch 'main' into depth-pro
geetu040 Feb 5, 2025
4dc850f
update checkpoint
geetu040 Feb 5, 2025
b7f32b9
Merge branch 'main' into depth-pro
geetu040 Feb 5, 2025
592648c
update checkpoint
geetu040 Feb 6, 2025
162f141
use @is_flanky for processor test
geetu040 Feb 6, 2025
3a62d63
Merge branch 'main' into depth-pro
geetu040 Feb 6, 2025
4b76239
update checkpoint to "apple/DepthPro-hf"
geetu040 Feb 7, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/en/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -653,6 +653,8 @@
title: Depth Anything
- local: model_doc/depth_anything_v2
title: Depth Anything V2
- local: model_doc/depth_pro
title: DepthPro
- local: model_doc/deta
title: DETA
- local: model_doc/detr
Expand Down
1 change: 1 addition & 0 deletions docs/source/en/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ Flax), PyTorch, and/or TensorFlow.
| [DeiT](model_doc/deit) | ✅ | ✅ | ❌ |
| [DePlot](model_doc/deplot) | ✅ | ❌ | ❌ |
| [Depth Anything](model_doc/depth_anything) | ✅ | ❌ | ❌ |
| [DepthPro](model_doc/depth_pro) | ✅ | ❌ | ❌ |
| [DETA](model_doc/deta) | ✅ | ❌ | ❌ |
| [DETR](model_doc/detr) | ✅ | ❌ | ❌ |
| [DialoGPT](model_doc/dialogpt) | ✅ | ✅ | ✅ |
Expand Down
183 changes: 183 additions & 0 deletions docs/source/en/model_doc/depth_pro.md
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

super detailed, amazing work !

Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
<!--Copyright 2024 The HuggingFace Team. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.

⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.

-->

# DepthPro

## Overview

The DepthPro model was proposed in [Depth Pro: Sharp Monocular Metric Depth in Less Than a Second](https://arxiv.org/abs/2410.02073) by Aleksei Bochkovskii, Amaël Delaunoy, Hugo Germain, Marcel Santos, Yichao Zhou, Stephan R. Richter, Vladlen Koltun.

DepthPro is a foundation model for zero-shot metric monocular depth estimation, designed to generate high-resolution depth maps with remarkable sharpness and fine-grained details. It employs a multi-scale Vision Transformer (ViT)-based architecture, where images are downsampled, divided into patches, and processed using a shared Dinov2 encoder. The extracted patch-level features are merged, upsampled, and refined using a DPT-like fusion stage, enabling precise depth estimation.

The abstract from the paper is the following:

*We present a foundation model for zero-shot metric monocular depth estimation. Our model, Depth Pro, synthesizes high-resolution depth maps with unparalleled sharpness and high-frequency details. The predictions are metric, with absolute scale, without relying on the availability of metadata such as camera intrinsics. And the model is fast, producing a 2.25-megapixel depth map in 0.3 seconds on a standard GPU. These characteristics are enabled by a number of technical contributions, including an efficient multi-scale vision transformer for dense prediction, a training protocol that combines real and synthetic datasets to achieve high metric accuracy alongside fine boundary tracing, dedicated evaluation metrics for boundary accuracy in estimated depth maps, and state-of-the-art focal length estimation from a single image. Extensive experiments analyze specific design choices and demonstrate that Depth Pro outperforms prior work along multiple dimensions.*

<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/model_doc/depth_pro_teaser.png"
alt="drawing" width="600"/>

<small> DepthPro Outputs. Taken from the <a href="https://github.com/apple/ml-depth-pro" target="_blank">official code</a>. </small>

This model was contributed by [geetu040](https://github.com/geetu040). The original code can be found [here](https://github.com/apple/ml-depth-pro).

## Usage Tips

The DepthPro model processes an input image by first downsampling it at multiple scales and splitting each scaled version into patches. These patches are then encoded using a shared Vision Transformer (ViT)-based Dinov2 patch encoder, while the full image is processed by a separate image encoder. The extracted patch features are merged into feature maps, upsampled, and fused using a DPT-like decoder to generate the final depth estimation. If enabled, an additional Field of View (FOV) encoder processes the image for estimating the camera's field of view, aiding in depth accuracy.

```py
>>> import requests
>>> from PIL import Image
>>> import torch
>>> from transformers import DepthProImageProcessorFast, DepthProForDepthEstimation

>>> device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

>>> url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
>>> image = Image.open(requests.get(url, stream=True).raw)

>>> image_processor = DepthProImageProcessorFast.from_pretrained("apple/DepthPro-hf")
>>> model = DepthProForDepthEstimation.from_pretrained("apple/DepthPro-hf").to(device)

>>> inputs = image_processor(images=image, return_tensors="pt").to(device)

>>> with torch.no_grad():
... outputs = model(**inputs)

>>> post_processed_output = image_processor.post_process_depth_estimation(
... outputs, target_sizes=[(image.height, image.width)],
... )

>>> field_of_view = post_processed_output[0]["field_of_view"]
>>> focal_length = post_processed_output[0]["focal_length"]
>>> depth = post_processed_output[0]["predicted_depth"]
>>> depth = (depth - depth.min()) / depth.max()
>>> depth = depth * 255.
>>> depth = depth.detach().cpu().numpy()
>>> depth = Image.fromarray(depth.astype("uint8"))
```

### Architecture and Configuration

<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/model_doc/depth_pro_architecture.png"
alt="drawing" width="600"/>

<small> DepthPro architecture. Taken from the <a href="https://arxiv.org/abs/2410.02073" target="_blank">original paper</a>. </small>

The `DepthProForDepthEstimation` model uses a `DepthProEncoder`, for encoding the input image and a `FeatureFusionStage` for fusing the output features from encoder.

The `DepthProEncoder` further uses two encoders:
- `patch_encoder`
- Input image is scaled with multiple ratios, as specified in the `scaled_images_ratios` configuration.
- Each scaled image is split into smaller **patches** of size `patch_size` with overlapping areas determined by `scaled_images_overlap_ratios`.
- These patches are processed by the **`patch_encoder`**
- `image_encoder`
- Input image is also rescaled to `patch_size` and processed by the **`image_encoder`**

Both these encoders can be configured via `patch_model_config` and `image_model_config` respectively, both of which are seperate `Dinov2Model` by default.

Outputs from both encoders (`last_hidden_state`) and selected intermediate states (`hidden_states`) from **`patch_encoder`** are fused by a `DPT`-based `FeatureFusionStage` for depth estimation.

### Field-of-View (FOV) Prediction

The network is supplemented with a focal length estimation head. A small convolutional head ingests frozen features from the depth estimation network and task-specific features from a separate ViT image encoder to predict the horizontal angular field-of-view.

The `use_fov_model` parameter in `DepthProConfig` controls whether **FOV prediction** is enabled. By default, it is set to `False` to conserve memory and computation. When enabled, the **FOV encoder** is instantiated based on the `fov_model_config` parameter, which defaults to a `Dinov2Model`. The `use_fov_model` parameter can also be passed when initializing the `DepthProForDepthEstimation` model.

The pretrained model at checkpoint `apple/DepthPro-hf` uses the FOV encoder. To use the pretrained-model without FOV encoder, set `use_fov_model=False` when loading the model, which saves computation.
```py
>>> from transformers import DepthProForDepthEstimation
>>> model = DepthProForDepthEstimation.from_pretrained("apple/DepthPro-hf", use_fov_model=False)
```

To instantiate a new model with FOV encoder, set `use_fov_model=True` in the config.
```py
>>> from transformers import DepthProConfig, DepthProForDepthEstimation
>>> config = DepthProConfig(use_fov_model=True)
>>> model = DepthProForDepthEstimation(config)
```

Or set `use_fov_model=True` when initializing the model, which overrides the value in config.
```py
>>> from transformers import DepthProConfig, DepthProForDepthEstimation
>>> config = DepthProConfig()
>>> model = DepthProForDepthEstimation(config, use_fov_model=True)
```

### Using Scaled Dot Product Attention (SDPA)

PyTorch includes a native scaled dot-product attention (SDPA) operator as part of `torch.nn.functional`. This function
encompasses several implementations that can be applied depending on the inputs and the hardware in use. See the
[official documentation](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html)
or the [GPU Inference](https://huggingface.co/docs/transformers/main/en/perf_infer_gpu_one#pytorch-scaled-dot-product-attention)
page for more information.

SDPA is used by default for `torch>=2.1.1` when an implementation is available, but you may also set
`attn_implementation="sdpa"` in `from_pretrained()` to explicitly request SDPA to be used.

```py
from transformers import DepthProForDepthEstimation
model = DepthProForDepthEstimation.from_pretrained("apple/DepthPro-hf", attn_implementation="sdpa", torch_dtype=torch.float16)
```

For the best speedups, we recommend loading the model in half-precision (e.g. `torch.float16` or `torch.bfloat16`).

On a local benchmark (A100-40GB, PyTorch 2.3.0, OS Ubuntu 22.04) with `float32` and `google/vit-base-patch16-224` model, we saw the following speedups during inference.

| Batch size | Average inference time (ms), eager mode | Average inference time (ms), sdpa model | Speed up, Sdpa / Eager (x) |
|--------------|-------------------------------------------|-------------------------------------------|------------------------------|
| 1 | 7 | 6 | 1.17 |
| 2 | 8 | 6 | 1.33 |
| 4 | 8 | 6 | 1.33 |
| 8 | 8 | 6 | 1.33 |

## Resources

A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with DepthPro:

- Research Paper: [Depth Pro: Sharp Monocular Metric Depth in Less Than a Second](https://arxiv.org/pdf/2410.02073)
- Official Implementation: [apple/ml-depth-pro](https://github.com/apple/ml-depth-pro)
- DepthPro Inference Notebook: [DepthPro Inference](https://github.com/qubvel/transformers-notebooks/blob/main/notebooks/DepthPro_inference.ipynb)
- DepthPro for Super Resolution and Image Segmentation
- Read blog on Medium: [Depth Pro: Beyond Depth](https://medium.com/@raoarmaghanshakir040/depth-pro-beyond-depth-9d822fc557ba)
- Code on Github: [geetu040/depthpro-beyond-depth](https://github.com/geetu040/depthpro-beyond-depth)

If you're interested in submitting a resource to be included here, please feel free to open a Pull Request and we'll review it! The resource should ideally demonstrate something new instead of duplicating an existing resource.

## DepthProConfig

[[autodoc]] DepthProConfig

## DepthProImageProcessor

[[autodoc]] DepthProImageProcessor
- preprocess
- post_process_depth_estimation

## DepthProImageProcessorFast

[[autodoc]] DepthProImageProcessorFast
- preprocess
- post_process_depth_estimation

## DepthProModel

[[autodoc]] DepthProModel
- forward

## DepthProForDepthEstimation

[[autodoc]] DepthProForDepthEstimation
- forward
1 change: 1 addition & 0 deletions docs/source/en/perf_infer_gpu_one.md
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,7 @@ For now, Transformers supports SDPA inference and training for the following arc
* [data2vec_vision](https://huggingface.co/docs/transformers/main/en/model_doc/data2vec#transformers.Data2VecVisionModel)
* [Dbrx](https://huggingface.co/docs/transformers/model_doc/dbrx#transformers.DbrxModel)
* [DeiT](https://huggingface.co/docs/transformers/model_doc/deit#transformers.DeiTModel)
* [DepthPro](https://huggingface.co/docs/transformers/model_doc/depth_pro#transformers.DepthProModel)
* [DiffLlama](https://huggingface.co/docs/transformers/model_doc/diffllama#transformers.DiffLlamaModel)
* [Dinov2](https://huggingface.co/docs/transformers/en/model_doc/dinov2)
* [Dinov2_with_registers](https://huggingface.co/docs/transformers/en/model_doc/dinov2)
Expand Down
18 changes: 18 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,6 +400,7 @@
"models.deprecated.vit_hybrid": ["ViTHybridConfig"],
"models.deprecated.xlm_prophetnet": ["XLMProphetNetConfig"],
"models.depth_anything": ["DepthAnythingConfig"],
"models.depth_pro": ["DepthProConfig"],
"models.detr": ["DetrConfig"],
"models.dialogpt": [],
"models.diffllama": ["DiffLlamaConfig"],
Expand Down Expand Up @@ -1236,6 +1237,7 @@
_import_structure["models.deprecated.efficientformer"].append("EfficientFormerImageProcessor")
_import_structure["models.deprecated.tvlt"].append("TvltImageProcessor")
_import_structure["models.deprecated.vit_hybrid"].extend(["ViTHybridImageProcessor"])
_import_structure["models.depth_pro"].extend(["DepthProImageProcessor", "DepthProImageProcessorFast"])
_import_structure["models.detr"].extend(["DetrFeatureExtractor", "DetrImageProcessor"])
_import_structure["models.donut"].extend(["DonutFeatureExtractor", "DonutImageProcessor"])
_import_structure["models.dpt"].extend(["DPTFeatureExtractor", "DPTImageProcessor"])
Expand Down Expand Up @@ -1313,6 +1315,7 @@
_import_structure["models.convnext"].append("ConvNextImageProcessorFast")
_import_structure["models.deformable_detr"].append("DeformableDetrImageProcessorFast")
_import_structure["models.deit"].append("DeiTImageProcessorFast")
_import_structure["models.depth_pro"].append("DepthProImageProcessorFast")
_import_structure["models.detr"].append("DetrImageProcessorFast")
_import_structure["models.llava"].append("LlavaImageProcessorFast")
_import_structure["models.llava_next"].append("LlavaNextImageProcessorFast")
Expand Down Expand Up @@ -2180,6 +2183,13 @@
"DepthAnythingPreTrainedModel",
]
)
_import_structure["models.depth_pro"].extend(
[
"DepthProForDepthEstimation",
"DepthProModel",
"DepthProPreTrainedModel",
]
)
_import_structure["models.detr"].extend(
[
"DetrForObjectDetection",
Expand Down Expand Up @@ -5494,6 +5504,7 @@
XLMProphetNetConfig,
)
from .models.depth_anything import DepthAnythingConfig
from .models.depth_pro import DepthProConfig
from .models.detr import DetrConfig
from .models.diffllama import DiffLlamaConfig
from .models.dinat import DinatConfig
Expand Down Expand Up @@ -6362,6 +6373,7 @@
from .models.deprecated.efficientformer import EfficientFormerImageProcessor
from .models.deprecated.tvlt import TvltImageProcessor
from .models.deprecated.vit_hybrid import ViTHybridImageProcessor
from .models.depth_pro import DepthProImageProcessor, DepthProImageProcessorFast
from .models.detr import DetrFeatureExtractor, DetrImageProcessor
from .models.donut import DonutFeatureExtractor, DonutImageProcessor
from .models.dpt import DPTFeatureExtractor, DPTImageProcessor
Expand Down Expand Up @@ -6455,6 +6467,7 @@
from .models.convnext import ConvNextImageProcessorFast
from .models.deformable_detr import DeformableDetrImageProcessorFast
from .models.deit import DeiTImageProcessorFast
from .models.depth_pro import DepthProImageProcessorFast
from .models.detr import DetrImageProcessorFast
from .models.llava import LlavaImageProcessorFast
from .models.llava_next import LlavaNextImageProcessorFast
Expand Down Expand Up @@ -7173,6 +7186,11 @@
DepthAnythingForDepthEstimation,
DepthAnythingPreTrainedModel,
)
from .models.depth_pro import (
DepthProForDepthEstimation,
DepthProModel,
DepthProPreTrainedModel,
)
from .models.detr import (
DetrForObjectDetection,
DetrForSegmentation,
Expand Down
3 changes: 2 additions & 1 deletion src/transformers/image_processing_utils_fast.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,7 @@ def resize(
image: "torch.Tensor",
size: SizeDict,
interpolation: "F.InterpolationMode" = None,
antialias: bool = True,
**kwargs,
) -> "torch.Tensor":
"""
Expand Down Expand Up @@ -324,7 +325,7 @@ def resize(
"Size must contain 'height' and 'width' keys, or 'max_height' and 'max_width', or 'shortest_edge' key. Got"
f" {size}."
)
return F.resize(image, new_size, interpolation=interpolation)
return F.resize(image, new_size, interpolation=interpolation, antialias=antialias)

def rescale(
self,
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@
deit,
deprecated,
depth_anything,
depth_pro,
detr,
dialogpt,
diffllama,
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/models/auto/configuration_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@
("deformable_detr", "DeformableDetrConfig"),
("deit", "DeiTConfig"),
("depth_anything", "DepthAnythingConfig"),
("depth_pro", "DepthProConfig"),
("deta", "DetaConfig"),
("detr", "DetrConfig"),
("diffllama", "DiffLlamaConfig"),
Expand Down Expand Up @@ -414,6 +415,7 @@
("deplot", "DePlot"),
("depth_anything", "Depth Anything"),
("depth_anything_v2", "Depth Anything V2"),
("depth_pro", "DepthPro"),
("deta", "DETA"),
("detr", "DETR"),
("dialogpt", "DialoGPT"),
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/auto/image_processing_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@
("deformable_detr", ("DeformableDetrImageProcessor", "DeformableDetrImageProcessorFast")),
("deit", ("DeiTImageProcessor", "DeiTImageProcessorFast")),
("depth_anything", ("DPTImageProcessor",)),
("depth_pro", ("DepthProImageProcessor", "DepthProImageProcessorFast")),
("deta", ("DetaImageProcessor",)),
("detr", ("DetrImageProcessor", "DetrImageProcessorFast")),
("dinat", ("ViTImageProcessor", "ViTImageProcessorFast")),
Expand Down
3 changes: 3 additions & 0 deletions src/transformers/models/auto/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@
("decision_transformer", "DecisionTransformerModel"),
("deformable_detr", "DeformableDetrModel"),
("deit", "DeiTModel"),
("depth_pro", "DepthProModel"),
("deta", "DetaModel"),
("detr", "DetrModel"),
("diffllama", "DiffLlamaModel"),
Expand Down Expand Up @@ -597,6 +598,7 @@
("data2vec-vision", "Data2VecVisionModel"),
("deformable_detr", "DeformableDetrModel"),
("deit", "DeiTModel"),
("depth_pro", "DepthProModel"),
("deta", "DetaModel"),
("detr", "DetrModel"),
("dinat", "DinatModel"),
Expand Down Expand Up @@ -916,6 +918,7 @@
[
# Model for depth estimation mapping
("depth_anything", "DepthAnythingForDepthEstimation"),
("depth_pro", "DepthProForDepthEstimation"),
("dpt", "DPTForDepthEstimation"),
("glpn", "GLPNForDepthEstimation"),
("zoedepth", "ZoeDepthForDepthEstimation"),
Expand Down
Loading