Skip to content

vijayabhaskar-ev/dreamer_v4

Repository files navigation

Dreamer V4

A from-scratch PyTorch implementation of DreamerV4 (Hafner, Yan & Lillicrap, DeepMind, 2025; arXiv:2509.24527) — a model-based reinforcement learning agent that learns by imagining trajectories inside a learned world model. This repo covers all three training phases of the paper:

  1. Tokenizer — a masked autoencoder that compresses 64×64 RGB frames into a small set of latent tokens.
  2. Dynamics model — a block-causal transformer trained with a flow-matching objective (with bootstrap loss and curriculum) to predict future latents.
  3. Imagination + RL — agent tokens, value/policy/reward/continue heads, λ-returns, and PMPO (preference-based MPO) policy optimization rolled out inside the frozen world model.

The implementation targets CUDA. See What's noteworthy for the design decisions that shape the codebase.


Status

Phase Component Status
1 Tokenizer model + losses + masking Complete
1 Tokenizer trainer + entrypoint + eval Complete
2 Dynamics transformer + flow matching + embeddings Complete
2 Bootstrap loss + warmup→ramp→full curriculum Complete
2 Agent tokens + MTP heads (reward / continue / policy) Complete
2 Trainer + entrypoint + autoregressive eval Complete
3 imagine_rollout (Euler denoise + sliding context) Complete
3 λ-returns, advantages, value loss, PMPO policy loss Complete
3 End-to-end imagination trainer + entrypoint In progress
CUDA distributed-training Not implemented

Architecture at a glance

   raw frames (B,T,3,64,64)
            │
            ▼
   ┌──────────────────────┐
   │  Tokenizer  (MAE)    │   spatial 8×8 patches → 32 latent tokens / frame
   │  encoder ─► z_clean  │   bottleneck: latent_dim = 128, tanh
   │  decoder ─► recon    │   loss = masked MSE + 0.2 · masked LPIPS
   └──────────┬───────────┘
              │ frozen after Phase 1
              ▼
   ┌──────────────────────┐
   │  Dynamics            │   12-layer transformer, embed_dim 512, GQA
   │  (flow matching)     │   spatial attn every layer, temporal every 4
   │  z_noised, τ, d, a   │   block-causal across time, sliding window C=16
   │  ─► ẑ                │   loss = flow MSE + bootstrap (curriculum-mixed)
   │  + agent tokens ─► h │   asymmetric mask: agent sees all, world ignores agent
   └──────────┬───────────┘
              │ frozen after Phase 2
              ▼
   ┌──────────────────────┐
   │  Heads (on h)        │   reward: 255 symexp twohot bins
   │                      │   continue: Bernoulli
   │                      │   value: 255 symexp twohot bins  (Phase 3)
   │                      │   policy: diagonal Gaussian       (Phase 3)
   └──────────┬───────────┘
              │
              ▼
   ┌──────────────────────┐
   │  Imagination loop    │   H = 15 steps, K = 4 Euler denoise sub-steps
   │  (Phase 3)           │   λ-returns (γ=0.997, λ=0.95)
   │                      │   PMPO policy loss (α=0.5, β=0.3, reverse KL)
   └──────────────────────┘

Repository layout

dreamer_v4/
├── tokenizer/                 Phase 1 — masked autoencoder
│   ├── tokenizer.py           encoder + decoder + latent bottleneck
│   ├── layers.py              RoPE, GQA, QK-norm, soft-capped flex_attention
│   ├── masking.py             tube masking (spatial mask shared across T)
│   ├── losses.py              MSE + optional LPIPS
│   ├── trainer.py             training loop, metrics, checkpointing
│   ├── train_tokenizer.py     CLI entrypoint
│   └── config.py
├── dynamics/                  Phase 2 — flow-matching world model
│   ├── dynamic_model.py       transformer + register tokens + agent tokens
│   ├── dynamic_block.py       spatial + (periodic) temporal + FF block
│   ├── flow_matching.py       add_noise, sample_tau_and_d (on-device RNG)
│   ├── embedding.py           action / agent / (τ,d) embeddings
│   ├── trainer.py             curriculum, MTP, on-device grad clip
│   ├── train_dynamics.py      CLI entrypoint
│   ├── evaluate_dynamics.py   K-step denoise + autoregressive rollout + GIFs
│   └── config.py
├── imagination/               Phase 3 — RL inside the world model
│   ├── rollout.py             imagine_rollout: H-step Euler + sliding buffer
│   ├── algorithms.py          λ-returns, advantages, value loss, PMPO
│   ├── trainer.py             (WIP) imagination training loop
│   ├── train_imagination.py   (WIP) CLI entrypoint
│   └── config.py
├── heads.py                   reward / continue / policy heads + symlog twohot
├── device_utils.py            CUDA / CPU device abstraction
├── _env_setup.py              must-be-first import: env vars for inductor / wandb / tmp
├── generate_dataset.py        dm_control → .npz episodes
├── mock_data.py               synthetic moving-square videos (for smoke tests)
├── test_imagination_algorithms.py   λ-returns / advantage / loss unit tests
├── test_gradient_isolation.py       proves agent_out detach blocks dynamics grads
└── requirements.txt

Phase 1 — Tokenizer

A spatio-temporal masked autoencoder that compresses video frames into a small set of latent tokens used by the dynamics model.

  • Patch embedding (tokenizer/layers.py): 8×8 spatial patches projected to embed_dim = 512.
  • Latent tokens (tokenizer/tokenizer.py): 32 learnable tokens per frame cross-attend to patches under a block-causal mask (latent at time t may not see future patches).
  • Encoder / decoder: 8 transformer layers each, GQA with 2 KV heads, RMSNorm pre-norm, RoPE, QK-norm, optional soft-capped attention (paper §3.4 spec, tanh at 30; off by default to match Hansen's verified reference), drop-path.
  • Periodic temporal attention: every 4 layers, to keep compute bounded.
  • Bottleneck: latent_dim = 128, tanh activation.
  • Masking (tokenizer/masking.py): per-sample mask probability ∼ Uniform[0, 0.9] with tube consistency — the same spatial pattern is masked across all frames in a clip, which prevents temporal flickering in reconstructions.
  • Loss (tokenizer/losses.py): masked-only MSE + 0.2 × masked LPIPS via hybrid recon (prediction at masked positions, target at unmasked). Both terms are RMS-EMA-normalized per paper §3.1 before weighting; LPIPS auto-disabled if weights unavailable.
  • Encode-only path (tokenizer/tokenizer.py): skips the decoder during dynamics training (~50 % less compute).

Training:

python -m tokenizer.train_tokenizer \
  --dataset offline --data-path data/episodes.npz \
  --epochs 100 --batch-size 32

Phase 2 — Dynamics (flow matching)

A block-causal latent transformer trained to denoise corrupted latents, given previous latents, actions, and the (τ, d) flow-matching parameters.

  • Model (dynamics/dynamic_model.py): 12 transformer blocks, embed_dim 512, 8 heads, GQA (2 KV heads), 4 register tokens, sliding-window causal mask of length C = 16.
  • Flow matching (dynamics/flow_matching.py):
    • z_noised = (1 - τ) · noise + τ · z_clean
    • τ and d = 1 / 2^k sampled on-device to keep the training step on the accelerator.
    • Inference uses a small fixed step count K_inference = 4 for fast rollouts.
  • Bootstrap loss + curriculum (dynamics/trainer.py): warmup (flow-only) → ramp (gradually mix bootstrap, 0 → 1) → full. The curriculum mix is a persistent on-device tensor updated via .fill_() so it does not create a new graph each step under torch.compile.
  • Sequence-length alternation: 85 % short batches (T₁) and 15 % long (T₂) — implemented with a single fixed-shape graph plus a per-frame loss mask, so the trainer compiles one graph variant instead of two.
  • Agent tokens (Phase 2 finetuning): per-frame learnable tokens with an asymmetric spatial mask — agent tokens attend to everything; world-model tokens cannot attend to agent tokens. This is what lets Phase 3 train the policy without contaminating the frozen world model.
  • Multi-token prediction (MTP) heads: reward / continue / policy share a backbone but emit one output layer per temporal offset.

Training:

python -m dynamics.train_dynamics \
  --tokenizer-ckpt checkpoints/tokenizer/final.pt \
  --dataset offline --data-path data/episodes.npz \
  --curriculum-warmup-steps 5000 --curriculum-ramp-steps 15000

Evaluation produces autoregressive rollout GIFs and per-τ-bin reconstruction metrics:

python -m dynamics.evaluate_dynamics \
  --dynamics-ckpt checkpoints/dynamics/final.pt \
  --tokenizer-ckpt checkpoints/tokenizer/final.pt

Phase 3 — Imagination + PMPO RL

Algorithms and the rollout primitive are complete and unit-tested; the end-to-end trainer is in progress.

  • imagine_rollout (imagination/rollout.py):

    1. Denoise z_{t+1} with K = 4 Euler steps using a context buffer corrupted to τ_ctx = 0.1.
    2. Run dynamics with use_agent_tokens=True to obtain the agent hidden state h_{t+1}.
    3. Sample a_{t+1} ∼ policy_head(h_{t+1}) (reparameterized).
    4. Predict reward, continue, value from h_{t+1}.
    5. Slide the context buffer; detach before storing so the autograd graph stays compact.
  • compute_lambda_returns (imagination/algorithms.py): TD(λ) with γ = 0.997, λ = 0.95, returns are detached at the source so a misuse downstream cannot leak gradients into the value targets.

  • compute_advantages = λ_returns − values[:, :H].

  • value_loss: reuses the reward head's symexp twohot encoding (255 bins) and categorical cross-entropy.

  • pmpo_policy_loss — the paper's preference-based MPO objective:

    • Partition the batch by sign(advantage) into D⁺ (good) and D⁻ (bad).
    • Loss = (1 - α)/|D⁻| · Σ ln π(a|s)_bad − α/|D⁺| · Σ ln π(a|s)_good + β · KL(π_θ ‖ π_prior) (reverse KL).
    • α = 0.5, β = 0.3.

Tests live at the repo root:

python test_imagination_algorithms.py   # λ-returns / advantage / loss algebra
python test_gradient_isolation.py       # confirms detach(agent_out) blocks grads

Installation

git clone https://github.com/vijayabhaskar-ev/dreamer_v4.git
cd dreamer_v4
pip install -r requirements.txt
# optional, for dm_control envs and LPIPS:
pip install dm_control lpips

Generate a dataset (or skip and use mock_data.py for smoke tests):

python generate_dataset.py --domain cheetah --task run --episodes 2000 --seq-len 50

What's noteworthy in this implementation

  • Asymmetric agent-token mask in dynamics/dynamic_model.py — the cleanest way to add an agent stream to a pretrained world model without leaking agent state into world-model predictions.
  • On-device flow-matching RNG in dynamics/flow_matching.py — keeps the training step on the accelerator with no per-step host sync.
  • Fixed-shape, mask-everything training path in dynamics/trainer.py — the 70/30 short/long batch split is collapsed into a single graph plus a loss mask, keeping torch.compile warm.
  • PMPO in imagination/algorithms.py — preference-based MPO with reverse KL toward a frozen prior, partitioned by advantage sign rather than via softmax weights.
  • Operational hardening: _env_setup.py redirects /tmp and caps inductor compile workers.

References

  • Hafner, D., Yan, W., Lillicrap, T. — Training Agents Inside of Scalable World Models (DreamerV4). arXiv:2509.24527, Sep 29 2025.
  • Hafner, D., Pasukonis, J., Ba, J., Lillicrap, T. — Mastering Diverse Domains through World Models (DreamerV3). arXiv:2301.04104, Jan 2023; published as Mastering Diverse Control Tasks Through World Models, Nature, Apr 2 2025, DOI:10.1038/s41586-025-08744-2.
  • Lipman, Y. et al. — Flow Matching for Generative Modeling (2023).
  • Abdolmaleki, A. et al. — Maximum a Posteriori Policy Optimization (2018).

License

MIT — see LICENSE.

About

From-scratch PyTorch implementation of DreamerV4 (Hafner et al., 2024): masked-autoencoder tokenizer, block-causal flow-matching dynamics with bootstrap curriculum, agent-token finetuning, and PMPO imagination RL. Hardened for TPU v4 / torch_xla with fixed-shape graphs, on-device RNG, and bounded compile-cache footprint.

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages