Official implementation of LDC-MTL: Balancing Multi-Task Learning through Scalable Loss Discrepancy Control. Paper link
Multi-task learning (MTL) has been widely adopted for its ability to simultaneously learn multiple tasks. While existing gradient manipulation methods often yield more balanced solutions than simple scalarization-based approaches, they typically incur a significant computational overhead of
Our bilevel loss balancing pipeline for multi-task learning. First, task losses will be normalized through an initial loss normalization module. Then, the lower-level problem optimizes the model parameter
The loss trajectories of a toy 2-task learning problem and the runtime comparison of different MTL methods for 50000 steps. ★ on the Pareto front denotes the converge points. Although FAMO achieves more balanced results than LS and MGDA, it converges to different points on the Pareto front. Our method reaches the same balanced point with a computational cost comparable to the simple Linear Scalarization (LS).
Time scale comparison among well-performing approaches, with LS considered the reference method for standard time.
Results on Cityscapes (2-task) dataset.
Results on CelebA (40-task), QM9 (11-task), and NYU-v2 (3-task) datasets.
The performance is evaluated under 4 datasets:
- Image-level Classification. The CelebA dataset contains 40 tasks.
- Regression. The QM9 dataset contains 11 tasks, which can be downloaded automatically from Pytorch Geometric.
- Dense Prediction. The NYU-v2 dataset contains 3 tasks and the Cityscapes dataset (UPDATE: the small version) contains 2 tasks.
Create the environment:
conda create -n mtl python=3.9.7
conda activate mtl
python -m pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113
Then, install the repo:
https://github.com/OptMN-Lab/LDC-MTL.git
cd LDC-MTL
python -m pip install -e .




