Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
53 commits
Select commit Hold shift + click to select a range
6b5740e
Add CLIP model
Jun 8, 2023
b476263
Add bpe vocabulary dict
Jun 8, 2023
e7a9a7a
Add prompt templates for language models
Jun 8, 2023
3e56ded
Add backbone for CAT-Seg
Jun 8, 2023
eb34f1e
Add CAT-Seg decoder head
Jun 8, 2023
b0d9e45
Add CAT-Seg aggregator (neck)
Jun 8, 2023
40e675d
Support CLIP image encoder finetune
Jun 9, 2023
40cd281
Add CAT-Seg r101 training config
Jun 9, 2023
9849bb4
Fix coco-stuff164k typos
Jun 9, 2023
e2d278a
Fix yapf format
Jun 9, 2023
5089e34
Refactor CAT-Seg configs
Jun 9, 2023
0e1fec8
Fix feature extractor input transform
Jun 9, 2023
38f8258
Refactor aggregator & update config
Jun 9, 2023
7488480
Support slide inference
Jun 10, 2023
04a5270
Add README and model index
Jun 11, 2023
fa31087
Update configs & support vitg and vith
Jun 11, 2023
9c2a9d9
Enhance CLIP weights huggingface downloading
Jun 11, 2023
068a23d
Fix descriptions of classes
Jun 11, 2023
ecbcf5e
Fix docstring converge error
Jun 11, 2023
9d7fc3a
Update optional dependencies
Jun 11, 2023
659af17
Fix open_clip dependency
Jun 11, 2023
8925512
Update ViTH results
Jun 11, 2023
db2614a
Add regex dependency
Jun 11, 2023
0950ed9
Support ade20k and pascal-context-59
Jun 11, 2023
145f935
Update reproduction results
Jun 12, 2023
86d70d7
Fix redundant kwargs
Jun 13, 2023
f85fdac
Enhance open_clip weights loading
Jun 13, 2023
1db3679
Add Unit Tests
Jun 13, 2023
8740a56
Fix cat-seg head unit test
Jun 13, 2023
876906d
Reduce the test batch size
Jun 13, 2023
a113e9f
Fix over memory test error
Jun 13, 2023
bd3fe9a
Add unit test pseudo data
SheffieldCao Jun 13, 2023
3e3f8dd
Fix unit test configs
Jun 13, 2023
717d9ba
Merge branch 'support-cat-seg' of https://github.com/SheffieldCao/mms…
Jun 13, 2023
3e7dd26
Skip unit test for lower version torch
Jun 13, 2023
0d8a0a5
Skip unit test on windows due to limited memory
Jun 13, 2023
32985b2
Skip unit tests on cpu
Jun 14, 2023
f2afafd
Enhance backbone class embedding
Jun 14, 2023
d0aa6de
Skip unit test on cpu
Jun 14, 2023
a359a2c
Add type hints and reference
Jul 17, 2023
4bdade3
Delete relative position embedding
Jul 17, 2023
9e97075
Sync with dev1.x
SheffieldCao Jul 18, 2023
572752f
Fix inference pooling size
SheffieldCao Jul 18, 2023
d3824c1
Resolve conflicts with branch dev-1.x
SheffieldCao Aug 2, 2023
654eff1
Merge branch 'dev-1.x' into support-cat-seg
SheffieldCao Aug 2, 2023
796e4c0
Move to Project support
SheffieldCao Aug 8, 2023
1f1c66e
Rebase the configs
SheffieldCao Aug 8, 2023
e7a75f3
--refactor=support build
xiexinch Aug 9, 2023
9b32cb7
--fix=fix linear attn
xiexinch Aug 9, 2023
854d47f
--other=remove config
xiexinch Aug 9, 2023
4141077
--other=update model link
xiexinch Aug 9, 2023
1ed9947
--other=restore mmseg package
xiexinch Aug 9, 2023
8ca9250
fix
xiexinch Aug 9, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 92 additions & 0 deletions projects/CAT-Seg/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
# CAT-Seg

> [CAT-Seg: Cost Aggregation for Open-Vocabulary Semantic Segmentation](https://arxiv.org/abs/2303.11797)

## Introduction

<!-- [ALGORITHM] -->

<a href="https://github.com/KU-CVLAB/CAT-Seg">Official Repo</a>

<a href="https://github.com/SheffieldCao/mmsegmentation/blob/support-cat-seg/mmseg/models/necks/cat_aggregator.py">Code Snippet</a>

## Abstract

<!-- [ABSTRACT] -->

Existing works on open-vocabulary semantic segmentation have utilized large-scale vision-language models, such as CLIP, to leverage their exceptional open-vocabulary recognition capabilities. However, the problem of transferring these capabilities learned from image-level supervision to the pixel-level task of segmentation and addressing arbitrary unseen categories at inference makes this task challenging. To address these issues, we aim to attentively relate objects within an image to given categories by leveraging relational information among class categories and visual semantics through aggregation, while also adapting the CLIP representations to the pixel-level task. However, we observe that direct optimization of the CLIP embeddings can harm its open-vocabulary capabilities. In this regard, we propose an alternative approach to optimize the imagetext similarity map, i.e. the cost map, using a novel cost aggregation-based method. Our framework, namely CATSeg, achieves state-of-the-art performance across all benchmarks. We provide extensive ablation studies to validate our choices. [Project page](https://ku-cvlab.github.io/CAT-Seg).

<!-- [IMAGE] -->

<div align=center >
<img alt="CAT-Seg" src="https://github.com/open-mmlab/mmsegmentation/assets/49406546/d54674bb-52ae-4a20-a168-e25d041111e8"/>
CAT-Seg model structure
</div>

## Usage

CAT-Seg model training needs pretrained `CLIP` model. We have implemented `ViT-B` and `ViT-L` based `CLIP` model. To further use `ViT-bigG` or `ViT-H` ones, you need additional dependencies. Please install [open_clip](https://github.com/mlfoundations/open_clip) first. The pretrained `CLIP` model state dicts are loaded from [Huggingface-OpenCLIP](https://huggingface.co/models?library=open_clip). **If you come up with `ConnectionError` when downloading CLIP weights**, you can manually download them from the given repo and use `custom_clip_weights=/path/to/you/folder` of backbone in config file. Related tools are as shown in [requirements/optional.txt](requirements/optional.txt):

```shell
pip install ftfy==6.0.1
pip install huggingface-hub
pip install regex
```

In addition to the necessary [data preparation](https://github.com/open-mmlab/mmsegmentation/blob/main/docs/en/user_guides/2_dataset_prepare.md), you also need class texts for clip text encoder. Please download the class text json file first [cls_texts](https://github.com/open-mmlab/mmsegmentation/files/11714914/cls_texts.zip) and arrange the folder as follows:

```none
mmsegmentation
├── mmseg
├── tools
├── configs
├── data
│ ├── VOCdevkit
│ │ ├── VOC2012
│ │ ├── VOC2010
│ │ ├── VOCaug
│ ├── ade
│ ├── coco_stuff164k
│ ├── coco.json
│ ├── pc59.json
│ ├── pc459.json
│ ├── ade150.json
│ ├── ade847.json
│ ├── voc20b.json
│ ├── voc20.json
```

```shell
# setup PYTHONPATH
export PYTHONPATH=`pwd`:$PYTHONPATH
# run evaluation
mim test mmsegmentation ${CONFIG} --checkpoint ${CHECKPOINT} --launcher pytorch --gpus=8
```

## Results and models

### ADE20K-150-ZeroShot

| Method | Backbone | Crop Size | Lr schd | Mem (GB) | Inf time (fps) | Device | mIoU | mIoU(ms+flip) | config | download |
| ------- | ------------- | --------- | ------- | -------: | -------------- | ------- | ---- | ------------: | ------------------------------------------------------------------------------------------: | --------------------------------------------------------------------------------------------------------------------------------------------- |
| CAT-Seg | R-101 & ViT-B | 384x384 | 80000 | - | - | RTX3090 | 27.2 | - | [config](./configs/cat_seg/catseg_vitb-r101_4xb1-warmcoslr2e-4-adamw-80k_ade20k-384x384.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/cat_seg/catseg_vitb-r101_4xb1-warmcoslr2e-4-adamw-80k_ade20k-384x384-54194d72.pth) |

Note:

- All experiments of CAT-Seg are implemented with 4 RTX3090 GPUs, except the last one with pretrained ViT-bigG CLIP model (GPU Memory insufficient, you may need A100).
- Due to the feature size bottleneck of the CLIP image encoder, the inference and testing can only be done under `slide` mode, the inference time is longer since the test size is much more bigger that training size of `(384, 384)`.
- The ResNet backbones utilized in CAT-Seg models are standard `ResNet` rather than `ResNetV1c`.
- The zero-shot segmentation results on PASCAL VOC and ADE20K are from the original paper. Our results are coming soon. We appreatiate your contribution!
- In additional to zero-shot segmentation performance results, we also provided the evaluation results on the `val2017` set of **COCO-stuff164k** for reference, which is the training dataset of CAT-Seg. The testing was done **without TTA**.
- The number behind the dataset name is the category number for segmentation evaluation (except training data **COCO-stuff 164k**). **PASCAL VOC-20b** defines the "background" as classes present in **PASCAL-Context-59** but not in **PASCAL VOC-20**.

## Citation

```bibtex
@inproceedings{cheng2021mask2former,
title={CAT-Seg: Cost Aggregation for Open-Vocabulary Semantic Segmentation},
author={Seokju Cho and Heeseong Shin and Sunghwan Hong and Seungjun An and Seungjun Lee and Anurag Arnab and Paul Hongsuck Seo and Seungryong Kim},
journal={CVPR},
year={2023}
}
```
2 changes: 2 additions & 0 deletions projects/CAT-Seg/cat_seg/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .models import * # noqa: F401,F403
from .utils import * # noqa: F401,F403
10 changes: 10 additions & 0 deletions projects/CAT-Seg/cat_seg/models/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .cat_aggregator import (AggregatorLayer, CATSegAggregator,
ClassAggregateLayer, SpatialAggregateLayer)
from .cat_head import CATSegHead
from .clip_ovseg import CLIPOVCATSeg

__all__ = [
'AggregatorLayer', 'CATSegAggregator', 'ClassAggregateLayer',
'SpatialAggregateLayer', 'CATSegHead', 'CLIPOVCATSeg'
]
Loading