diff --git a/src/boltz/main.py b/src/boltz/main.py index 4a3750fec..9e24df1b6 100644 --- a/src/boltz/main.py +++ b/src/boltz/main.py @@ -1039,6 +1039,39 @@ def cli() -> None: is_flag=True, help=" to dump the s and z embeddings into a npz file. Default is False.", ) +@click.option( + "--write_intermediates", + is_flag=True, + help=( + "Whether to dump intermediate trunk representations after MSA and " + "pairformer during inference. Default is False." + ), +) +@click.option( + "--capture_pairformer_layers", + is_flag=True, + help=( + "Whether to also dump per-layer pairformer states when writing " + "intermediates. Can produce very large files." + ), +) +@click.option( + "--capture_pairformer_layer_stride", + type=int, + default=1, + help=( + "When capturing pairformer layers, save every Nth layer " + "(and always the last). Default is 1." + ), +) +@click.option( + "--capture_mlp_layers", + is_flag=True, + help=( + "Whether to also dump per-layer transition MLP intermediates " + "when writing intermediates. Can produce very large files." + ), +) def predict( # noqa: C901, PLR0915, PLR0912 data: str, out_dir: str, @@ -1077,6 +1110,10 @@ def predict( # noqa: C901, PLR0915, PLR0912 num_subsampled_msa: int = 1024, no_kernels: bool = False, write_embeddings: bool = False, + write_intermediates: bool = False, + capture_pairformer_layers: bool = False, + capture_pairformer_layer_stride: int = 1, + capture_mlp_layers: bool = False, ) -> None: """Run predictions with Boltz.""" # If cpu, write a friendly warning @@ -1156,6 +1193,18 @@ def predict( # noqa: C901, PLR0915, PLR0912 msg = f"Method {method} not supported. Supported: {method_names}" raise ValueError(msg) + if capture_pairformer_layers and not write_intermediates: + msg = "--capture_pairformer_layers requires --write_intermediates." + raise ValueError(msg) + if capture_mlp_layers and not write_intermediates: + msg = "--capture_mlp_layers requires --write_intermediates." + raise ValueError(msg) + if ( + write_intermediates or capture_pairformer_layers or capture_mlp_layers + ) and model != "boltz2": + msg = "Intermediate capture is currently supported only for Boltz-2." + raise ValueError(msg) + # Process inputs ccd_path = cache / "ccd.pkl" mol_dir = cache / "mols" @@ -1304,6 +1353,10 @@ def predict( # noqa: C901, PLR0915, PLR0912 "write_confidence_summary": True, "write_full_pae": write_full_pae, "write_full_pde": write_full_pde, + "capture_intermediates": write_intermediates, + "capture_pairformer_layers": capture_pairformer_layers, + "capture_pairformer_layer_stride": capture_pairformer_layer_stride, + "capture_transition_mlp": capture_mlp_layers, } steering_args = BoltzSteeringParams() diff --git a/src/boltz/model/layers/pairformer.py b/src/boltz/model/layers/pairformer.py index 7edadbfe9..f26dcbcc6 100644 --- a/src/boltz/model/layers/pairformer.py +++ b/src/boltz/model/layers/pairformer.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Optional, Union import torch from torch import Tensor, nn @@ -71,7 +71,8 @@ def forward( use_kernels: bool = False, use_cuequiv_mul: bool = False, use_cuequiv_attn: bool = False, - ) -> tuple[Tensor, Tensor]: + capture_transition_mlp: bool = False, + ) -> Union[tuple[Tensor, Tensor], tuple[Tensor, Tensor, dict[str, Tensor]]]: # Compute pairwise stack dropout = get_dropout_mask(self.dropout, z, self.training) z = z + dropout * self.tri_mul_out( @@ -99,7 +100,11 @@ def forward( use_kernels=use_cuequiv_attn or use_kernels, ) - z = z + self.transition_z(z) + if capture_transition_mlp: + z_delta, z_transition = self.transition_z(z, return_intermediates=True) + z = z + z_delta + else: + z = z + self.transition_z(z) # Compute sequence stack with torch.autocast("cuda", enabled=False): @@ -107,9 +112,24 @@ def forward( s = s.float() + self.attention( s=s_normed, z=z.float(), mask=mask.float(), k_in=s_normed ) - s = s + self.transition_s(s) + if capture_transition_mlp: + s_delta, s_transition = self.transition_s(s, return_intermediates=True) + s = s + s_delta + else: + s = s + self.transition_s(s) s = self.s_post_norm(s) - + if capture_transition_mlp: + transition_state = { + "z_transition_x_norm": z_transition["x_norm"], + "z_transition_fc1": z_transition["fc1"], + "z_transition_fc2": z_transition["fc2"], + "z_transition_hidden": z_transition["hidden"], + "s_transition_x_norm": s_transition["x_norm"], + "s_transition_fc1": s_transition["fc1"], + "s_transition_fc2": s_transition["fc2"], + "s_transition_hidden": s_transition["hidden"], + } + return s, z, transition_state return s, z @@ -160,7 +180,9 @@ def forward( mask: Tensor, pair_mask: Tensor, use_kernels: bool = False, - ) -> tuple[Tensor, Tensor]: + capture_layers: bool = False, + capture_transition_mlp: bool = False, + ) -> Union[tuple[Tensor, Tensor], tuple[Tensor, Tensor, list[dict[str, Tensor]]]]: """Perform the forward pass. Parameters @@ -185,7 +207,19 @@ def forward( else: chunk_size_tri_attn = None + layer_states = [] for layer in self.layers: + transition_state = {} + if ( + capture_transition_mlp + and self.activation_checkpointing + and self.training + ): + msg = ( + "capture_transition_mlp is not supported with activation " + "checkpointing during training." + ) + raise ValueError(msg) if self.activation_checkpointing and self.training: s, z = torch.utils.checkpoint.checkpoint( layer, @@ -197,7 +231,29 @@ def forward( use_kernels, ) else: - s, z = layer(s, z, mask, pair_mask, chunk_size_tri_attn, use_kernels) + layer_out = layer( + s, + z, + mask, + pair_mask, + chunk_size_tri_attn, + use_kernels, + capture_transition_mlp=capture_transition_mlp, + ) + if capture_transition_mlp: + s, z, transition_state = layer_out + else: + s, z = layer_out + if capture_layers or capture_transition_mlp: + state = {} + if capture_layers: + state["s"] = s + state["z"] = z + if capture_transition_mlp: + state.update(transition_state) + layer_states.append(state) + if capture_layers or capture_transition_mlp: + return s, z, layer_states return s, z diff --git a/src/boltz/model/layers/transition.py b/src/boltz/model/layers/transition.py index 8bab80937..d482392d3 100644 --- a/src/boltz/model/layers/transition.py +++ b/src/boltz/model/layers/transition.py @@ -1,5 +1,6 @@ -from typing import Optional +from typing import Optional, Union +import torch from torch import Tensor, nn import boltz.model.layers.initialize as init @@ -44,7 +45,12 @@ def __init__( init.lecun_normal_init_(self.fc2.weight) init.final_init_(self.fc3.weight) - def forward(self, x: Tensor, chunk_size: int = None) -> Tensor: + def forward( + self, + x: Tensor, + chunk_size: Optional[int] = None, + return_intermediates: bool = False, + ) -> Union[Tensor, tuple[Tensor, dict[str, Tensor]]]: """Perform a forward pass. Parameters @@ -59,20 +65,46 @@ def forward(self, x: Tensor, chunk_size: int = None) -> Tensor: """ x = self.norm(x) + x_norm = x if chunk_size is None or self.training: - x = self.silu(self.fc1(x)) * self.fc2(x) - x = self.fc3(x) - return x - else: - # Compute in chunks - for i in range(0, self.hidden, chunk_size): - fc1_slice = self.fc1.weight[i : i + chunk_size, :] - fc2_slice = self.fc2.weight[i : i + chunk_size, :] - fc3_slice = self.fc3.weight[:, i : i + chunk_size] - x_chunk = self.silu((x @ fc1_slice.T)) * (x @ fc2_slice.T) - if i == 0: - x_out = x_chunk @ fc3_slice.T - else: - x_out = x_out + x_chunk @ fc3_slice.T - return x_out + fc1 = self.fc1(x) + fc2 = self.fc2(x) + hidden = self.silu(fc1) * fc2 + out = self.fc3(hidden) + if return_intermediates: + return out, { + "x_norm": x_norm, + "fc1": fc1, + "fc2": fc2, + "hidden": hidden, + } + return out + + # Compute in chunks + fc1_chunks = [] + fc2_chunks = [] + hidden_chunks = [] + for i in range(0, self.hidden, chunk_size): + fc1_slice = self.fc1.weight[i : i + chunk_size, :] + fc2_slice = self.fc2.weight[i : i + chunk_size, :] + fc3_slice = self.fc3.weight[:, i : i + chunk_size] + fc1_chunk = x @ fc1_slice.T + fc2_chunk = x @ fc2_slice.T + x_chunk = self.silu(fc1_chunk) * fc2_chunk + out = x_chunk @ fc3_slice.T if i == 0 else out + x_chunk @ fc3_slice.T + if return_intermediates: + fc1_chunks.append(fc1_chunk) + fc2_chunks.append(fc2_chunk) + hidden_chunks.append(x_chunk) + if return_intermediates: + fc1 = torch.cat(fc1_chunks, dim=-1) + fc2 = torch.cat(fc2_chunks, dim=-1) + hidden = torch.cat(hidden_chunks, dim=-1) + return out, { + "x_norm": x_norm, + "fc1": fc1, + "fc2": fc2, + "hidden": hidden, + } + return out diff --git a/src/boltz/model/models/boltz2.py b/src/boltz/model/models/boltz2.py index d42f3400c..a4213d0fd 100644 --- a/src/boltz/model/models/boltz2.py +++ b/src/boltz/model/models/boltz2.py @@ -408,13 +408,47 @@ def forward( max_parallel_samples: Optional[int] = None, run_confidence_sequentially: bool = False, ) -> dict[str, Tensor]: + capture_intermediates = bool( + (not self.training) + and self.predict_args + and self.predict_args.get("capture_intermediates", False) + ) + capture_pairformer_layers = bool( + capture_intermediates + and self.predict_args + and self.predict_args.get("capture_pairformer_layers", False) + ) + capture_transition_mlp = bool( + capture_intermediates + and self.predict_args + and self.predict_args.get("capture_transition_mlp", False) + ) + capture_layer_stride = int( + self.predict_args.get("capture_pairformer_layer_stride", 1) + if self.predict_args + else 1 + ) + capture_layer_stride = max(1, capture_layer_stride) + + def _to_cpu_snapshot(t: Tensor) -> Tensor: + t = t.detach().cpu() + if t.is_floating_point(): + return t.to(torch.float16) + return t + + intermediates = {} if capture_intermediates else None + with torch.set_grad_enabled( self.training and self.structure_prediction_training ): s_inputs = self.input_embedder(feats) + if capture_intermediates: + intermediates["s_inputs"] = _to_cpu_snapshot(s_inputs) # Initialize the sequence embeddings s_init = self.s_init(s_inputs) + if capture_intermediates: + intermediates["s_init"] = _to_cpu_snapshot(s_init) # Initialize pairwise embeddings z_init = ( @@ -427,6 +461,8 @@ def forward( if self.bond_type_feature: z_init = z_init + self.token_bonds_type(feats["type_bonds"].long()) z_init = z_init + self.contact_conditioning(feats) + if capture_intermediates: + intermediates["z_init"] = _to_cpu_snapshot(z_init) # Perform rounds of the pairwise stack s = torch.zeros_like(s_init) @@ -453,6 +489,9 @@ def forward( # Apply recycling s = s_init + self.s_recycle(self.s_norm(s)) z = z_init + self.z_recycle(self.z_norm(z)) + if capture_intermediates: + intermediates[f"s_recycle_{i}"] = _to_cpu_snapshot(s) + intermediates[f"z_recycle_{i}"] = _to_cpu_snapshot(z) # Compute pairwise stack if self.use_templates: @@ -473,6 +512,8 @@ def forward( z = z + msa_module( z, s_inputs, feats, use_kernels=self.use_kernels ) + if capture_intermediates: + intermediates[f"z_after_msa_{i}"] = _to_cpu_snapshot(z) # Revert to uncompiled version for validation if self.is_pairformer_compiled and not self.training: @@ -480,13 +521,54 @@ def forward( else: pairformer_module = self.pairformer_module - s, z = pairformer_module( - s, - z, - mask=mask, - pair_mask=pair_mask, - use_kernels=self.use_kernels, - ) + if capture_pairformer_layers or capture_transition_mlp: + s, z, layer_states = pairformer_module( + s, + z, + mask=mask, + pair_mask=pair_mask, + use_kernels=self.use_kernels, + capture_layers=capture_pairformer_layers, + capture_transition_mlp=capture_transition_mlp, + ) + num_layers = len(layer_states) + for layer_idx, layer_state in enumerate(layer_states): + if ( + (layer_idx % capture_layer_stride) == 0 + or layer_idx == num_layers - 1 + ): + if capture_pairformer_layers: + intermediates[ + f"s_pairformer_{i}_layer_{layer_idx}" + ] = _to_cpu_snapshot(layer_state["s"]) + intermediates[ + f"z_pairformer_{i}_layer_{layer_idx}" + ] = _to_cpu_snapshot(layer_state["z"]) + if capture_transition_mlp: + for key, value in layer_state.items(): + if key in {"s", "z"}: + continue + intermediates[ + ( + f"{key}_pairformer_{i}" + f"_layer_{layer_idx}" + ) + ] = _to_cpu_snapshot(value) + else: + s, z = pairformer_module( + s, + z, + mask=mask, + pair_mask=pair_mask, + use_kernels=self.use_kernels, + ) + if capture_intermediates: + intermediates[f"s_after_pairformer_{i}"] = ( + _to_cpu_snapshot(s) + ) + intermediates[f"z_after_pairformer_{i}"] = ( + _to_cpu_snapshot(z) + ) pdistogram = self.distogram_module(z) dict_out = { @@ -719,6 +801,22 @@ def forward( } ) + if capture_intermediates: + intermediates["capture_pairformer_layers"] = torch.tensor( + capture_pairformer_layers + ) + intermediates["capture_transition_mlp"] = torch.tensor( + capture_transition_mlp + ) + intermediates["capture_pairformer_layer_stride"] = torch.tensor( + capture_layer_stride + ) + intermediates["recycling_steps"] = torch.tensor(recycling_steps) + intermediates["token_count"] = torch.tensor( + feats["token_pad_mask"].shape[1] + ) + dict_out["intermediates"] = intermediates + return dict_out def get_true_coordinates( @@ -1104,6 +1202,8 @@ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> d pred_dict["ligand_iptm"] = out["ligand_iptm"] pred_dict["protein_iptm"] = out["protein_iptm"] pred_dict["pair_chains_iptm"] = out["pair_chains_iptm"] + if "intermediates" in out: + pred_dict["intermediates"] = out["intermediates"] if self.affinity_prediction: pred_dict["affinity_pred_value"] = out["affinity_pred_value"] pred_dict["affinity_probability_binary"] = out[