From f598a2c3cd4de18c27d3a3dfd18b8101c19e6b84 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ph=C3=BAc=20L=C3=AA=20Kh=E1=BA=AFc?= Date: Mon, 15 Aug 2022 23:52:10 +0700 Subject: [PATCH] Reduce peak memory usage when freezing parameters. --- big_vision/optax.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/big_vision/optax.py b/big_vision/optax.py index db7b4fc..5c6a909 100644 --- a/big_vision/optax.py +++ b/big_vision/optax.py @@ -68,7 +68,7 @@ def create_schedule(mult=1.0, **kw): # Removes weight decay updates. Note that weight decay already has an # independent mask (which cannot be combined easily with a second mask), # so instead we multiply updates for frozen params with zero. - optax.masked(optax.scale(0.0), frozen_mask) + optax.masked(optax.set_to_zero(), frozen_mask) ] # Gradient clipping.