Skip to content

How to determine whether a Linen Module is still initializing? #652

@marcvanzee

Description

@marcvanzee

We currently check whether a module is initialized using self.has_variable, for instance in our definition of MultiHeadDotProductAttention:

...
# detect if we're initializing by absence of existing cache data.
is_initialized = self.has_variable('cache', 'cached_key')
...
if is_initialized:
  ...

I think this looks a bit strange to the average user. In the old API we were using self.is_initialized, and I think we should consider bringing this back in some form or another.

Below some remarks made by different users offline:

From @levskaya : "The reason we did it this way is that it's strictly more general if you imagine having separate variable collections that might have a more complicated initialization than just a single-shot "init" function. That said, everyone seems to hate it, so I'm not sure the generality is worth not having an is_initializing module-global state variable..."

From @jheek : "You need to be careful with the definition of initializing. The simplest one is we called Module.init and not Module.apply. But there are other cases like if I call the same Module in init for the second time it's already fully initialized so should self.is_initializing return True or False? Something like this:

def __call__(self, x):
  fc = Dense(5)
  x = fc(x) # this will init kernel, bias...
  x = fc(x) # already initialized. 

So to summarize: it seems useful to be able to have a boolean in Module specifying whether a module is initialized or not, but there are some tricky cases we should think about (see example above).

Metadata

Metadata

Assignees

Labels

Priority: P1 - soonResponse within 5 business days. Resolution within 30 days. (Assignee required)

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions