Skip to content

[BUG] Head._task_weights = defaultdict() silently behaves like a plain dict #813

@shaun0927

Description

@shaun0927

Bug description

PR #802 (commit ab7207cf) changed self._task_weights = defaultdict(lambda: 1.0) to self._task_weights = defaultdict() in transformers4rec/torch/model/base.py:

self._task_weights = defaultdict()

self._task_weights = defaultdict()

defaultdict with no factory argument has the same behavior as a plain dict — missing keys raise KeyError. The previous defaultdict(lambda: 1.0) returned the documented default 1.0 for tasks without an explicit weight.

Only one usage was patched to .get(name, 1.0). Direct indexing sites (e.g. future code, or external subclasses that read head._task_weights[task_name]) will now raise KeyError where they previously returned 1.0.

Steps/Code to reproduce bug

from collections import defaultdict

old = defaultdict(lambda: 1.0)
new = defaultdict()

old["anything"]   # -> 1.0
new["anything"]   # KeyError: 'anything'

Expected behavior

Either restore the lambda: 1.0 factory, so defaultdict conveys its documented meaning, or switch to plain dict() with consistent .get(name, 1.0) access at every call site. Using defaultdict() is misleading because the name promises a default factory it no longer has.

Environment details

  • Transformers4Rec: main @ 8bf122f5 (regression from PR Sec pic fix #802 / commit ab7207cf)

Additional context

Preferred fix — restore the factory (closest to original behavior):

self._task_weights = defaultdict(lambda: 1.0)

Happy to send a one-line PR.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions