Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 53 additions & 0 deletions src/boltz/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -1039,6 +1039,39 @@ def cli() -> None:
is_flag=True,
help=" to dump the s and z embeddings into a npz file. Default is False.",
)
@click.option(
"--write_intermediates",
is_flag=True,
help=(
"Whether to dump intermediate trunk representations after MSA and "
"pairformer during inference. Default is False."
),
)
@click.option(
"--capture_pairformer_layers",
is_flag=True,
help=(
"Whether to also dump per-layer pairformer states when writing "
"intermediates. Can produce very large files."
),
)
@click.option(
"--capture_pairformer_layer_stride",
type=int,
default=1,
help=(
"When capturing pairformer layers, save every Nth layer "
"(and always the last). Default is 1."
),
)
@click.option(
"--capture_mlp_layers",
is_flag=True,
help=(
"Whether to also dump per-layer transition MLP intermediates "
"when writing intermediates. Can produce very large files."
),
)
def predict( # noqa: C901, PLR0915, PLR0912
data: str,
out_dir: str,
Expand Down Expand Up @@ -1077,6 +1110,10 @@ def predict( # noqa: C901, PLR0915, PLR0912
num_subsampled_msa: int = 1024,
no_kernels: bool = False,
write_embeddings: bool = False,
write_intermediates: bool = False,
capture_pairformer_layers: bool = False,
capture_pairformer_layer_stride: int = 1,
capture_mlp_layers: bool = False,
) -> None:
"""Run predictions with Boltz."""
# If cpu, write a friendly warning
Expand Down Expand Up @@ -1156,6 +1193,18 @@ def predict( # noqa: C901, PLR0915, PLR0912
msg = f"Method {method} not supported. Supported: {method_names}"
raise ValueError(msg)

if capture_pairformer_layers and not write_intermediates:
msg = "--capture_pairformer_layers requires --write_intermediates."
raise ValueError(msg)
if capture_mlp_layers and not write_intermediates:
msg = "--capture_mlp_layers requires --write_intermediates."
raise ValueError(msg)
if (
write_intermediates or capture_pairformer_layers or capture_mlp_layers
) and model != "boltz2":
msg = "Intermediate capture is currently supported only for Boltz-2."
raise ValueError(msg)

# Process inputs
ccd_path = cache / "ccd.pkl"
mol_dir = cache / "mols"
Expand Down Expand Up @@ -1304,6 +1353,10 @@ def predict( # noqa: C901, PLR0915, PLR0912
"write_confidence_summary": True,
"write_full_pae": write_full_pae,
"write_full_pde": write_full_pde,
"capture_intermediates": write_intermediates,
"capture_pairformer_layers": capture_pairformer_layers,
"capture_pairformer_layer_stride": capture_pairformer_layer_stride,
"capture_transition_mlp": capture_mlp_layers,
}

steering_args = BoltzSteeringParams()
Expand Down
70 changes: 63 additions & 7 deletions src/boltz/model/layers/pairformer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional
from typing import Optional, Union

import torch
from torch import Tensor, nn
Expand Down Expand Up @@ -71,7 +71,8 @@ def forward(
use_kernels: bool = False,
use_cuequiv_mul: bool = False,
use_cuequiv_attn: bool = False,
) -> tuple[Tensor, Tensor]:
capture_transition_mlp: bool = False,
) -> Union[tuple[Tensor, Tensor], tuple[Tensor, Tensor, dict[str, Tensor]]]:
# Compute pairwise stack
dropout = get_dropout_mask(self.dropout, z, self.training)
z = z + dropout * self.tri_mul_out(
Expand Down Expand Up @@ -99,17 +100,36 @@ def forward(
use_kernels=use_cuequiv_attn or use_kernels,
)

z = z + self.transition_z(z)
if capture_transition_mlp:
z_delta, z_transition = self.transition_z(z, return_intermediates=True)
z = z + z_delta
else:
z = z + self.transition_z(z)

# Compute sequence stack
with torch.autocast("cuda", enabled=False):
s_normed = self.pre_norm_s(s.float())
s = s.float() + self.attention(
s=s_normed, z=z.float(), mask=mask.float(), k_in=s_normed
)
s = s + self.transition_s(s)
if capture_transition_mlp:
s_delta, s_transition = self.transition_s(s, return_intermediates=True)
s = s + s_delta
else:
s = s + self.transition_s(s)
s = self.s_post_norm(s)

if capture_transition_mlp:
transition_state = {
"z_transition_x_norm": z_transition["x_norm"],
"z_transition_fc1": z_transition["fc1"],
"z_transition_fc2": z_transition["fc2"],
"z_transition_hidden": z_transition["hidden"],
"s_transition_x_norm": s_transition["x_norm"],
"s_transition_fc1": s_transition["fc1"],
"s_transition_fc2": s_transition["fc2"],
"s_transition_hidden": s_transition["hidden"],
}
return s, z, transition_state
return s, z


Expand Down Expand Up @@ -160,7 +180,9 @@ def forward(
mask: Tensor,
pair_mask: Tensor,
use_kernels: bool = False,
) -> tuple[Tensor, Tensor]:
capture_layers: bool = False,
capture_transition_mlp: bool = False,
) -> Union[tuple[Tensor, Tensor], tuple[Tensor, Tensor, list[dict[str, Tensor]]]]:
"""Perform the forward pass.

Parameters
Expand All @@ -185,7 +207,19 @@ def forward(
else:
chunk_size_tri_attn = None

layer_states = []
for layer in self.layers:
transition_state = {}
if (
capture_transition_mlp
and self.activation_checkpointing
and self.training
):
msg = (
"capture_transition_mlp is not supported with activation "
"checkpointing during training."
)
raise ValueError(msg)
if self.activation_checkpointing and self.training:
s, z = torch.utils.checkpoint.checkpoint(
layer,
Expand All @@ -197,7 +231,29 @@ def forward(
use_kernels,
)
else:
s, z = layer(s, z, mask, pair_mask, chunk_size_tri_attn, use_kernels)
layer_out = layer(
s,
z,
mask,
pair_mask,
chunk_size_tri_attn,
use_kernels,
capture_transition_mlp=capture_transition_mlp,
)
if capture_transition_mlp:
s, z, transition_state = layer_out
else:
s, z = layer_out
if capture_layers or capture_transition_mlp:
state = {}
if capture_layers:
state["s"] = s
state["z"] = z
if capture_transition_mlp:
state.update(transition_state)
layer_states.append(state)
if capture_layers or capture_transition_mlp:
return s, z, layer_states
return s, z


Expand Down
66 changes: 49 additions & 17 deletions src/boltz/model/layers/transition.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Optional
from typing import Optional, Union

import torch
from torch import Tensor, nn

import boltz.model.layers.initialize as init
Expand Down Expand Up @@ -44,7 +45,12 @@ def __init__(
init.lecun_normal_init_(self.fc2.weight)
init.final_init_(self.fc3.weight)

def forward(self, x: Tensor, chunk_size: int = None) -> Tensor:
def forward(
self,
x: Tensor,
chunk_size: Optional[int] = None,
return_intermediates: bool = False,
) -> Union[Tensor, tuple[Tensor, dict[str, Tensor]]]:
"""Perform a forward pass.

Parameters
Expand All @@ -59,20 +65,46 @@ def forward(self, x: Tensor, chunk_size: int = None) -> Tensor:

"""
x = self.norm(x)
x_norm = x

if chunk_size is None or self.training:
x = self.silu(self.fc1(x)) * self.fc2(x)
x = self.fc3(x)
return x
else:
# Compute in chunks
for i in range(0, self.hidden, chunk_size):
fc1_slice = self.fc1.weight[i : i + chunk_size, :]
fc2_slice = self.fc2.weight[i : i + chunk_size, :]
fc3_slice = self.fc3.weight[:, i : i + chunk_size]
x_chunk = self.silu((x @ fc1_slice.T)) * (x @ fc2_slice.T)
if i == 0:
x_out = x_chunk @ fc3_slice.T
else:
x_out = x_out + x_chunk @ fc3_slice.T
return x_out
fc1 = self.fc1(x)
fc2 = self.fc2(x)
hidden = self.silu(fc1) * fc2
out = self.fc3(hidden)
if return_intermediates:
return out, {
"x_norm": x_norm,
"fc1": fc1,
"fc2": fc2,
"hidden": hidden,
}
return out

# Compute in chunks
fc1_chunks = []
fc2_chunks = []
hidden_chunks = []
for i in range(0, self.hidden, chunk_size):
fc1_slice = self.fc1.weight[i : i + chunk_size, :]
fc2_slice = self.fc2.weight[i : i + chunk_size, :]
fc3_slice = self.fc3.weight[:, i : i + chunk_size]
fc1_chunk = x @ fc1_slice.T
fc2_chunk = x @ fc2_slice.T
x_chunk = self.silu(fc1_chunk) * fc2_chunk
out = x_chunk @ fc3_slice.T if i == 0 else out + x_chunk @ fc3_slice.T
if return_intermediates:
fc1_chunks.append(fc1_chunk)
fc2_chunks.append(fc2_chunk)
hidden_chunks.append(x_chunk)
if return_intermediates:
fc1 = torch.cat(fc1_chunks, dim=-1)
fc2 = torch.cat(fc2_chunks, dim=-1)
hidden = torch.cat(hidden_chunks, dim=-1)
return out, {
"x_norm": x_norm,
"fc1": fc1,
"fc2": fc2,
"hidden": hidden,
}
return out
Loading