Skip to content

OptMN-Lab/LDC-MTL

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

34 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

LDC-MTL

Official implementation of LDC-MTL: Balancing Multi-Task Learning through Scalable Loss Discrepancy Control. Paper link

Multi-task learning (MTL) has been widely adopted for its ability to simultaneously learn multiple tasks. While existing gradient manipulation methods often yield more balanced solutions than simple scalarization-based approaches, they typically incur a significant computational overhead of $\mathcal{O}(K)$ in both time and memory, where $K$ is the number of tasks. In this paper, we propose LDC-MTL, a simple and scalable loss discrepancy control approach for MTL, formulated from a bilevel optimization perspective. Our method incorporates three key components: (i) a coarse loss pre-normalization, (ii) a bilevel formulation for fine-grained loss discrepancy control, and (iii) a scalable first-order bilevel algorithm that requires only $\mathcal{O}(1)$ time and memory. Theoretically, we prove that LDC-MTL guarantees convergence not only to a stationary point of the bilevel problem with loss discrepancy control but also to an $\epsilon$-accurate Pareto stationary point for all $K$ loss functions under mild conditions. Extensive experiments on diverse multi-task datasets demonstrate the superior performance of LDC-MTL in both accuracy and efficiency.


Our bilevel loss balancing pipeline for multi-task learning. First, task losses will be normalized through an initial loss normalization module. Then, the lower-level problem optimizes the model parameter $x^t$ by minimizing the weighted sum of task losses and the upper-level problem optimizes the router model parameter $W^t$ for task balancing.


The loss trajectories of a toy 2-task learning problem and the runtime comparison of different MTL methods for 50000 steps. ★ on the Pareto front denotes the converge points. Although FAMO achieves more balanced results than LS and MGDA, it converges to different points on the Pareto front. Our method reaches the same balanced point with a computational cost comparable to the simple Linear Scalarization (LS).


Time scale comparison among well-performing approaches, with LS considered the reference method for standard time.


Results on Cityscapes (2-task) dataset.


Results on CelebA (40-task), QM9 (11-task), and NYU-v2 (3-task) datasets.


Datasets

The performance is evaluated under 4 datasets:

  • Image-level Classification. The CelebA dataset contains 40 tasks.
  • Regression. The QM9 dataset contains 11 tasks, which can be downloaded automatically from Pytorch Geometric.
  • Dense Prediction. The NYU-v2 dataset contains 3 tasks and the Cityscapes dataset (UPDATE: the small version) contains 2 tasks.

Setup Environment

Create the environment:

conda create -n mtl python=3.9.7
conda activate mtl
python -m pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113

Then, install the repo:

https://github.com/OptMN-Lab/LDC-MTL.git
cd LDC-MTL
python -m pip install -e .

About

Official implementation of Scalable Bilevel Loss Balancing for Multi-Task Learning.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published