Frederick W. B. Li Jianguo Li Xiaohui Liang
This repository contains the implementation of MAGR, a novel approach designed for Continual Action Quality Assessment (CAQA). MAGR leverages Manifold Projector (MP) and Intra-Inter-Joint Graph Regularization (IIJ-GR) to address the challenges of feature deviation and regressor confusion across incremental sessions. The method aims to adapt to real-world complexities while safeguarding user privacy.
- torch==1.12.0
- torchvision==0.13.0
- torchvideotransforms
- tqdm
- numpy
- scipy
- quadprog
pip install -r requirements.txtTo get started with the experiments, follow the steps below to prepare the datasets:
- Download the MTL-AQA dataset from the MTL-AQA repository.
- Organize the dataset in the following structure:
$DATASET_ROOT
├── MTL-AQA/
├── new
├── 01
...
└── 26
├── info
├── final_annotations_dict_with_dive_number
├── test_split_0.pkl
└── train_split_0.pkl
└── model_rgb.pth
- Download the AQA-7 dataset:
mkdir AQA-Seven & cd AQA-Seven
wget http://rtis.oit.unlv.edu/datasets/AQA-7.zip
unzip AQA-7.zip
- Organize the dataset as follows:
$DATASET_ROOT
├── Seven/
├── diving-out
├── 001
├── img_00001.jpg
...
...
└── 370
├── gym_vault-out
├── 001
├── img_00001.jpg
...
...
└── Split_4
├── split_4_test_list.mat
└── split_4_train_list.mat
Contact the corresponding author of the JDM-MSA paper to obtain access to the dataset. You may need to complete a form before using this dataset for academic research.
Please download the pre-trained I3D model and then put it to weights/model_rgb.pth.
We provide for training our model in both distributed and dataparallel modes. Here's a breakdown of each command:
- In Distributed Data Parallel (DDP) mode:
torchrun \
--nproc_per_node 2 --master_port 29505 main.py \
--config ./configs/{config file name}.yaml \
--model {model name: joint/adam/fea_gr} \
--dataset {dataset name: class-mtl/class-aqa/class-jdm} \
--batch_size 5 --minibatch_size 3 --n_tasks 5 --n_epochs 50 \
--fewshot True --buffer_size 50 \
--gpus 0 1- In Data Parallel (DP) mode
python main.py \
--config ./configs/{config file name}.yaml \
--model {model name: joint/adam/fea_gr} \
--dataset {dataset name: class-mtl/class-aqa/class-jdm} \
--batch_size 5 --minibatch_size 3 --n_tasks 5 --n_epochs 50 \
--fewshot True --buffer_size 50 \
--gpus 0 1Note: We highly recommend using the DDP mode for training, as we have not tested the effectiveness of the DP mode.
If you want to perform testing using the same configurations as training but with the addition of the --phase test option.
torchrun \
--nproc_per_node 2 --master_port 29503 main.py \
--config ./configs/mtl.yaml \
--model fea_gr --dataset class-mtl \
--batch_size 5 --minibatch_size 3 \
--n_tasks 5 --n_epochs 50 --gpus {gpu id} \
--base_pretrain True --fewshot True \
--buffer_size 50 --phase test \
--exp_name {exp name}This repository provides an example for training and testing. The logs for this example can be found in the directory outputs/ubuntu-fscl/class-mtl/fewshot_from_scratch. This model is trained from scratch, and logs are located in outputs/ubuntu-fscl/class-mtl/fewshot_from_scratch/logs/train-20250510-132848.log and outputs/ubuntu-fscl/class-mtl/fewshot_from_scratch/logs/test-20250510-183020.log.
Explanation of Training Log
The last row of "Rho (overall)" in the log is considered the overall Spearman’s Rank Correlation Coefficient (SRCC, denoted asrho_avg in our paper). Since SRCC is sensitive to sample size, instead of averaging each session's coefficient, all seen testing samples are accumulated to compute an overall correlation coefficient, which serves as the primary metric. In addition, rho_aft in the log represents the "forgetting" metric, while rho_fwt corresponds to the "fwt" metric.
Note: The evaluation results may slightly differ from the testing results due to the use of the reparameterization trick.
Training using a robust base session model with --base_pretrain True allows you to reuse these weights for further experiments, significantly reducing training time. To utilize this, place the base model in the directory weights/{class/domain}-{mtl/fd}{number_of_tasks}. For example: weights/class-mtl05.pth.
This repository is based on mammoth, many thanks.
If you have any specific questions or if there's anything else you'd like assistance with regarding the code, feel free to let us know.
