Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -163,4 +163,11 @@ cython_debug/

# Boltz prediction outputs
# All result files generated from a boltz prediction call
boltz_results_*/
boltz_results_*/

# Local datasets and caches
data/
natives/
cache/
mhc_one_sample/
val_cif_output/
113 changes: 113 additions & 0 deletions scripts/precompute_masks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
from __future__ import annotations

import argparse
from pathlib import Path

import numpy as np


def compute_masks(data: np.lib.npyio.NpzFile) -> tuple[np.ndarray, np.ndarray]:
"""Return (alignment_mask, rmsd_mask) inferred from raw NPZ contents."""
chains = data["chains"]
chain_mask = data["mask"].astype(bool)

alignment_segments: list[np.ndarray] = []
rmsd_segments: list[np.ndarray] = []

entity_counts: dict[int, int] = {}
for idx, chain in enumerate(chains):
if not chain_mask[idx]:
continue
entity_id = int(chain["entity_id"])
entity_counts[entity_id] = entity_counts.get(entity_id, 0) + int(chain["res_num"])

if not entity_counts:
return np.zeros(0, dtype=bool), np.zeros(0, dtype=bool)

# For now, assume the peptide entity is the smallest one by residue count, and MHC/heavy is the largest.
# Will need to revist.
peptide_entity = min(entity_counts, key=entity_counts.get)
heavy_candidates = {k: v for k, v in entity_counts.items() if k != peptide_entity}
heavy_entity = (
max(heavy_candidates, key=heavy_candidates.get)
if heavy_candidates
else peptide_entity
)

for idx, chain in enumerate(chains):
if not chain_mask[idx]:
continue
entity_id = int(chain["entity_id"])
atom_num = int(chain["atom_num"])

align_segment = np.zeros(atom_num, dtype=bool)
rmsd_segment = np.zeros(atom_num, dtype=bool)

if entity_id == heavy_entity:
align_segment[:] = True
if entity_id == peptide_entity:
rmsd_segment[:] = True

alignment_segments.append(align_segment)
rmsd_segments.append(rmsd_segment)

alignment_mask = (
np.concatenate(alignment_segments) if alignment_segments else np.zeros(0, dtype=bool)
)
rmsd_mask = (
np.concatenate(rmsd_segments) if rmsd_segments else np.zeros(0, dtype=bool)
)

return alignment_mask, rmsd_mask


def process_structure(npz_path: Path, dry_run: bool = False) -> None:
with np.load(npz_path) as data:
alignment_mask, rmsd_mask = compute_masks(data)

if dry_run:
print(
f"{npz_path.name}: align atoms={alignment_mask.sum()} "
f"(len={alignment_mask.shape[0]}), "
f"rmsd atoms={rmsd_mask.sum()} (len={rmsd_mask.shape[0]})"
)
return

updated = {key: data[key] for key in data.files}
updated["alignment_mask"] = alignment_mask
updated["rmsd_mask"] = rmsd_mask

tmp_path = npz_path.with_suffix(".tmp.npz")
np.savez_compressed(tmp_path, **updated)
tmp_path.replace(npz_path)


def main() -> None:
parser = argparse.ArgumentParser(description="Inject alignment/rmsd masks into structure NPZ files.")
parser.add_argument(
"structures_dir",
type=Path,
help="Directory containing processed structure NPZ files.",
)
parser.add_argument(
"--dry-run",
action="store_true",
help="Report mask statistics without writing files.",
)
args = parser.parse_args()

npz_paths = sorted(args.structures_dir.glob("*.npz"))
if not npz_paths:
raise SystemExit(f"No .npz files found under {args.structures_dir}")

for npz_path in npz_paths:
process_structure(npz_path, dry_run=args.dry_run)

if args.dry_run:
print("Dry run complete; no files modified.")
else:
print(f"Processed {len(npz_paths)} structures.")


if __name__ == "__main__":
main()
34 changes: 19 additions & 15 deletions scripts/train/configs/structure.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ trainer:
devices: 1
precision: 32
gradient_clip_val: 10.0
max_epochs: -1
max_epochs: 0
accumulate_grad_batches: 128 # to adjust depending on the number of devices

# Optional set wandb here
Expand All @@ -12,26 +12,27 @@ trainer:
# project: boltz
# entity: boltz

output: SET_PATH_HERE
pretrained: PATH_TO_CHECKPOINT_FILE
resume: null
output: /storage/alvand/boltz/output
pretrained: /storage/alvand/boltz/data/boltz1.ckpt
resume: /storage/alvand/boltz/data/boltz1.ckpt
disable_checkpoint: false
matmul_precision: null
save_top_k: -1

data:
datasets:
- _target_: boltz.data.module.training.DatasetConfig
target_dir: PATH_TO_TARGETS_DIR
msa_dir: PATH_TO_MSA_DIR
target_dir: /storage/alvand/boltz/mhc_one_sample/processed_structures/
msa_dir: /storage/alvand/boltz/mhc_one_sample/processed_msa/
prob: 1.0
manifest_path: /storage/alvand/boltz/mhc_one_sample/manifest.json
sampler:
_target_: boltz.data.sample.cluster.ClusterSampler
cropper:
_target_: boltz.data.crop.boltz.BoltzCropper
min_neighborhood: 0
max_neighborhood: 40
split: ./scripts/train/assets/validation_ids.txt
split: /storage/alvand/boltz/mhc_one_sample/validation_ids.txt

filters:
- _target_: boltz.data.filter.dynamic.size.SizeFilter
Expand All @@ -48,16 +49,16 @@ data:
featurizer:
_target_: boltz.data.feature.featurizer.BoltzFeaturizer

symmetries: PATH_TO_SYMMETRY_FILE
symmetries: /storage/alvand/boltz/data/symmetry.pkl
max_tokens: 512
max_atoms: 4608
max_seqs: 2048
pad_to_max_tokens: true
pad_to_max_atoms: true
pad_to_max_seqs: true
samples_per_epoch: 100000
pad_to_max_tokens: false
pad_to_max_atoms: false
pad_to_max_seqs: false
samples_per_epoch: 1
batch_size: 1
num_workers: 4
num_workers: 1
random_seed: 42
pin_memory: true
overfit: null
Expand All @@ -75,7 +76,7 @@ data:
compute_constraint_features: false

model:
_target_: boltz.model.model.Boltz1
_target_: boltz.model.models.boltz1.Boltz1
atom_s: 128
atom_z: 16
token_s: 384
Expand Down Expand Up @@ -161,8 +162,11 @@ model:
recycling_steps: 3
sampling_steps: 200
diffusion_samples: 5
symmetry_correction: true
symmetry_correction: false
run_confidence_sequentially: false
val_cif_out_dir: /storage/alvand/boltz/val_cif_output
write_cif_for_validation: false # whether to write cif files during validation
debug_peptide_mask_info: false # Keep this false unless debugging

diffusion_process_args:
sigma_min: 0.0004
Expand Down
3 changes: 2 additions & 1 deletion scripts/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ class TrainConfig:
matmul_precision: Optional[str] = None
find_unused_parameters: Optional[bool] = False
save_top_k: Optional[int] = 1
validation_only: bool = False
validation_only: bool = True
debug: bool = False
strict_loading: bool = True
load_confidence_from_trunk: Optional[bool] = False
Expand Down Expand Up @@ -222,6 +222,7 @@ def save_config_to_wandb() -> None:
model_module.strict_loading = False

if cfg.validation_only:
print(f'====== Running validation on {cfg.resume}===========')
trainer.validate(
model_module,
datamodule=data_module,
Expand Down
Loading