diff --git a/.gitignore b/.gitignore index 3d20fc11a..a1f5a545c 100644 --- a/.gitignore +++ b/.gitignore @@ -163,4 +163,11 @@ cython_debug/ # Boltz prediction outputs # All result files generated from a boltz prediction call -boltz_results_*/ \ No newline at end of file +boltz_results_*/ + +# Local datasets and caches +data/ +natives/ +cache/ +mhc_one_sample/ +val_cif_output/ diff --git a/scripts/precompute_masks.py b/scripts/precompute_masks.py new file mode 100644 index 000000000..0fdf91bc3 --- /dev/null +++ b/scripts/precompute_masks.py @@ -0,0 +1,113 @@ +from __future__ import annotations + +import argparse +from pathlib import Path + +import numpy as np + + +def compute_masks(data: np.lib.npyio.NpzFile) -> tuple[np.ndarray, np.ndarray]: + """Return (alignment_mask, rmsd_mask) inferred from raw NPZ contents.""" + chains = data["chains"] + chain_mask = data["mask"].astype(bool) + + alignment_segments: list[np.ndarray] = [] + rmsd_segments: list[np.ndarray] = [] + + entity_counts: dict[int, int] = {} + for idx, chain in enumerate(chains): + if not chain_mask[idx]: + continue + entity_id = int(chain["entity_id"]) + entity_counts[entity_id] = entity_counts.get(entity_id, 0) + int(chain["res_num"]) + + if not entity_counts: + return np.zeros(0, dtype=bool), np.zeros(0, dtype=bool) + + # For now, assume the peptide entity is the smallest one by residue count, and MHC/heavy is the largest. + # Will need to revist. + peptide_entity = min(entity_counts, key=entity_counts.get) + heavy_candidates = {k: v for k, v in entity_counts.items() if k != peptide_entity} + heavy_entity = ( + max(heavy_candidates, key=heavy_candidates.get) + if heavy_candidates + else peptide_entity + ) + + for idx, chain in enumerate(chains): + if not chain_mask[idx]: + continue + entity_id = int(chain["entity_id"]) + atom_num = int(chain["atom_num"]) + + align_segment = np.zeros(atom_num, dtype=bool) + rmsd_segment = np.zeros(atom_num, dtype=bool) + + if entity_id == heavy_entity: + align_segment[:] = True + if entity_id == peptide_entity: + rmsd_segment[:] = True + + alignment_segments.append(align_segment) + rmsd_segments.append(rmsd_segment) + + alignment_mask = ( + np.concatenate(alignment_segments) if alignment_segments else np.zeros(0, dtype=bool) + ) + rmsd_mask = ( + np.concatenate(rmsd_segments) if rmsd_segments else np.zeros(0, dtype=bool) + ) + + return alignment_mask, rmsd_mask + + +def process_structure(npz_path: Path, dry_run: bool = False) -> None: + with np.load(npz_path) as data: + alignment_mask, rmsd_mask = compute_masks(data) + + if dry_run: + print( + f"{npz_path.name}: align atoms={alignment_mask.sum()} " + f"(len={alignment_mask.shape[0]}), " + f"rmsd atoms={rmsd_mask.sum()} (len={rmsd_mask.shape[0]})" + ) + return + + updated = {key: data[key] for key in data.files} + updated["alignment_mask"] = alignment_mask + updated["rmsd_mask"] = rmsd_mask + + tmp_path = npz_path.with_suffix(".tmp.npz") + np.savez_compressed(tmp_path, **updated) + tmp_path.replace(npz_path) + + +def main() -> None: + parser = argparse.ArgumentParser(description="Inject alignment/rmsd masks into structure NPZ files.") + parser.add_argument( + "structures_dir", + type=Path, + help="Directory containing processed structure NPZ files.", + ) + parser.add_argument( + "--dry-run", + action="store_true", + help="Report mask statistics without writing files.", + ) + args = parser.parse_args() + + npz_paths = sorted(args.structures_dir.glob("*.npz")) + if not npz_paths: + raise SystemExit(f"No .npz files found under {args.structures_dir}") + + for npz_path in npz_paths: + process_structure(npz_path, dry_run=args.dry_run) + + if args.dry_run: + print("Dry run complete; no files modified.") + else: + print(f"Processed {len(npz_paths)} structures.") + + +if __name__ == "__main__": + main() diff --git a/scripts/train/configs/structure.yaml b/scripts/train/configs/structure.yaml index 6591f386a..44484e69f 100644 --- a/scripts/train/configs/structure.yaml +++ b/scripts/train/configs/structure.yaml @@ -3,7 +3,7 @@ trainer: devices: 1 precision: 32 gradient_clip_val: 10.0 - max_epochs: -1 + max_epochs: 0 accumulate_grad_batches: 128 # to adjust depending on the number of devices # Optional set wandb here @@ -12,9 +12,9 @@ trainer: # project: boltz # entity: boltz -output: SET_PATH_HERE -pretrained: PATH_TO_CHECKPOINT_FILE -resume: null +output: /storage/alvand/boltz/output +pretrained: /storage/alvand/boltz/data/boltz1.ckpt +resume: /storage/alvand/boltz/data/boltz1.ckpt disable_checkpoint: false matmul_precision: null save_top_k: -1 @@ -22,16 +22,17 @@ save_top_k: -1 data: datasets: - _target_: boltz.data.module.training.DatasetConfig - target_dir: PATH_TO_TARGETS_DIR - msa_dir: PATH_TO_MSA_DIR + target_dir: /storage/alvand/boltz/mhc_one_sample/processed_structures/ + msa_dir: /storage/alvand/boltz/mhc_one_sample/processed_msa/ prob: 1.0 + manifest_path: /storage/alvand/boltz/mhc_one_sample/manifest.json sampler: _target_: boltz.data.sample.cluster.ClusterSampler cropper: _target_: boltz.data.crop.boltz.BoltzCropper min_neighborhood: 0 max_neighborhood: 40 - split: ./scripts/train/assets/validation_ids.txt + split: /storage/alvand/boltz/mhc_one_sample/validation_ids.txt filters: - _target_: boltz.data.filter.dynamic.size.SizeFilter @@ -48,16 +49,16 @@ data: featurizer: _target_: boltz.data.feature.featurizer.BoltzFeaturizer - symmetries: PATH_TO_SYMMETRY_FILE + symmetries: /storage/alvand/boltz/data/symmetry.pkl max_tokens: 512 max_atoms: 4608 max_seqs: 2048 - pad_to_max_tokens: true - pad_to_max_atoms: true - pad_to_max_seqs: true - samples_per_epoch: 100000 + pad_to_max_tokens: false + pad_to_max_atoms: false + pad_to_max_seqs: false + samples_per_epoch: 1 batch_size: 1 - num_workers: 4 + num_workers: 1 random_seed: 42 pin_memory: true overfit: null @@ -75,7 +76,7 @@ data: compute_constraint_features: false model: - _target_: boltz.model.model.Boltz1 + _target_: boltz.model.models.boltz1.Boltz1 atom_s: 128 atom_z: 16 token_s: 384 @@ -161,8 +162,11 @@ model: recycling_steps: 3 sampling_steps: 200 diffusion_samples: 5 - symmetry_correction: true + symmetry_correction: false run_confidence_sequentially: false + val_cif_out_dir: /storage/alvand/boltz/val_cif_output + write_cif_for_validation: false # whether to write cif files during validation + debug_peptide_mask_info: false # Keep this false unless debugging diffusion_process_args: sigma_min: 0.0004 diff --git a/scripts/train/train.py b/scripts/train/train.py index f83966bdd..1cbe70dbc 100644 --- a/scripts/train/train.py +++ b/scripts/train/train.py @@ -71,7 +71,7 @@ class TrainConfig: matmul_precision: Optional[str] = None find_unused_parameters: Optional[bool] = False save_top_k: Optional[int] = 1 - validation_only: bool = False + validation_only: bool = True debug: bool = False strict_loading: bool = True load_confidence_from_trunk: Optional[bool] = False @@ -222,6 +222,7 @@ def save_config_to_wandb() -> None: model_module.strict_loading = False if cfg.validation_only: + print(f'====== Running validation on {cfg.resume}===========') trainer.validate( model_module, datamodule=data_module, diff --git a/src/boltz/data/module/training.py b/src/boltz/data/module/training.py index 36583b6cf..6e3e6c92c 100644 --- a/src/boltz/data/module/training.py +++ b/src/boltz/data/module/training.py @@ -82,6 +82,75 @@ class Dataset: featurizer: BoltzFeaturizer +def _prepare_structure(data: np.lib.npyio.NpzFile) -> tuple[Structure, Optional[np.ndarray], Optional[np.ndarray]]: + chains = data["chains"] + if "cyclic_period" not in chains.dtype.names: + new_dtype = chains.dtype.descr + [("cyclic_period", "i4")] + new_chains = np.empty(chains.shape, dtype=new_dtype) + for name in chains.dtype.names: + new_chains[name] = chains[name] + new_chains["cyclic_period"] = 0 + chains = new_chains + + structure = Structure( + atoms=data["atoms"], + bonds=data["bonds"], + residues=data["residues"], + chains=chains, + connections=data["connections"].astype(Connection), + interfaces=data["interfaces"], + mask=data["mask"], + ) + + alignment_mask = data.get("alignment_mask") + rmsd_mask = data.get("rmsd_mask") + + if alignment_mask is not None: + alignment_mask = alignment_mask.astype(bool, copy=False) + if rmsd_mask is not None: + rmsd_mask = rmsd_mask.astype(bool, copy=False) + + chain_mask = data["mask"].astype(bool) + kept_atom_total = sum( + int(chain["atom_num"]) for idx, chain in enumerate(data["chains"]) if chain_mask[idx] + ) + + def _reshape_mask(mask: Optional[np.ndarray]) -> Optional[np.ndarray]: + if mask is None: + return None + if mask.shape[0] == kept_atom_total: + return mask + + segments = [] + for idx, chain in enumerate(data["chains"]): + if not chain_mask[idx]: + continue + start = int(chain["atom_idx"]) + end = start + int(chain["atom_num"]) + segments.append(mask[start:end]) + if any(seg.size == 0 for seg in segments): + return None + return np.concatenate(segments) + + alignment_mask = _reshape_mask(alignment_mask) + rmsd_mask = _reshape_mask(rmsd_mask) + + structure_clean = structure.remove_invalid_chains() + + if alignment_mask is not None and alignment_mask.shape[0] != structure_clean.atoms.shape[0]: + print( + f"Alignment mask length mismatch for structure ({alignment_mask.shape[0]} vs {structure_clean.atoms.shape[0]}). Discarding mask." + ) + alignment_mask = None + if rmsd_mask is not None and rmsd_mask.shape[0] != structure_clean.atoms.shape[0]: + print( + f"rmsd mask length mismatch for structure ({rmsd_mask.shape[0]} vs {structure_clean.atoms.shape[0]}). Discarding mask." + ) + rmsd_mask = None + + return structure_clean, alignment_mask, rmsd_mask + + def load_input(record: Record, target_dir: Path, msa_dir: Path) -> Input: """Load the given input data. @@ -101,34 +170,9 @@ def load_input(record: Record, target_dir: Path, msa_dir: Path) -> Input: """ # Load the structure - structure = np.load(target_dir / "structures" / f"{record.id}.npz") - - # In order to add cyclic_period to chains if it does not exist - # Extract the chains array - chains = structure["chains"] - # Check if the field exists - if "cyclic_period" not in chains.dtype.names: - # Create a new dtype with the additional field - new_dtype = chains.dtype.descr + [("cyclic_period", "i4")] - # Create a new array with the new dtype - new_chains = np.empty(chains.shape, dtype=new_dtype) - # Copy over existing fields - for name in chains.dtype.names: - new_chains[name] = chains[name] - # Set the new field to 0 - new_chains["cyclic_period"] = 0 - # Replace old chains array with new one - chains = new_chains - - structure = Structure( - atoms=structure["atoms"], - bonds=structure["bonds"], - residues=structure["residues"], - chains=chains, # chains var accounting for missing cyclic_period - connections=structure["connections"].astype(Connection), - interfaces=structure["interfaces"], - mask=structure["mask"], - ) + structure_path = target_dir / "structures" / f"{record.id}.npz" + npz_data = np.load(structure_path) + structure, alignment_mask, rmsd_mask = _prepare_structure(npz_data) msas = {} for chain in record.chains: @@ -138,7 +182,13 @@ def load_input(record: Record, target_dir: Path, msa_dir: Path) -> Input: msa = np.load(msa_dir / f"{msa_id}.npz") msas[chain.chain_id] = MSA(**msa) - return Input(structure, msas) + return Input( + structure=structure, + msa=msas, + record=record, + alignment_mask=alignment_mask, + rmsd_mask=rmsd_mask, + ) def collate(data: list[dict[str, Tensor]]) -> dict[str, Tensor]: @@ -171,12 +221,18 @@ def collate(data: list[dict[str, Tensor]]) -> dict[str, Tensor]: "amino_acids_symmetries", "ligand_symmetries", ]: - # Check if all have the same shape - shape = values[0].shape - if not all(v.shape == shape for v in values): - values, _ = pad_to_max(values, 0) + first_value = values[0] + if isinstance(first_value, torch.Tensor): + # Check if all have the same shape + shape = first_value.shape + if not all(v.shape == shape for v in values): + values, _ = pad_to_max(values, 0) + else: + values = torch.stack(values, dim=0) else: - values = torch.stack(values, dim=0) + # Keep list for non-tensor entries (e.g. record ids) + collated[key] = values + continue # Stack the values collated[key] = values @@ -470,6 +526,21 @@ def __getitem__(self, idx: int) -> dict[str, Tensor]: print(f"Featurizer failed on {record.id} with error {e}. Skipping.") return self.__getitem__(0) + features["record_id"] = record.id + features["record"] = record + if input_data.alignment_mask is not None: + features["alignment_mask"] = torch.from_numpy( + input_data.alignment_mask.astype(np.bool_) + ) + if input_data.rmsd_mask is not None: + features["rmsd_mask"] = torch.from_numpy( + input_data.rmsd_mask.astype(np.bool_) + ) + features["base_structure"] = input_data.structure + features["structure_path"] = str( + (dataset.target_dir / "structures" / f"{record.id}.npz").resolve() + ) + return features def __len__(self) -> int: diff --git a/src/boltz/data/types.py b/src/boltz/data/types.py index 1ce26b558..ac3aef24c 100644 --- a/src/boltz/data/types.py +++ b/src/boltz/data/types.py @@ -705,6 +705,8 @@ class Input: residue_constraints: Optional[ResidueConstraints] = None templates: Optional[dict[str, StructureV2]] = None extra_mols: Optional[dict[str, Mol]] = None + alignment_mask: Optional[np.ndarray] = None + rmsd_mask: Optional[np.ndarray] = None #################################################################################################### diff --git a/src/boltz/data/write/writer.py b/src/boltz/data/write/writer.py index 984be2ae5..915b98a7e 100644 --- a/src/boltz/data/write/writer.py +++ b/src/boltz/data/write/writer.py @@ -1,7 +1,8 @@ import json +import os from dataclasses import asdict, replace from pathlib import Path -from typing import Literal +from typing import Literal, Optional import numpy as np import torch @@ -66,9 +67,12 @@ def write_on_batch_end( # Get the predictions coords = prediction["coords"] - coords = coords.unsqueeze(0) - + if coords.ndim == 3: + coords = coords.unsqueeze(0) pad_masks = prediction["masks"] + structure_paths = prediction.get("structure_paths") + provided_structures = prediction.get("structures") + custom_filenames = prediction.get("filenames") # Get ranking if "confidence_score" in prediction: @@ -79,22 +83,39 @@ def write_on_batch_end( idx_to_rank = {i: i for i in range(len(records))} # Iterate over the records - for record, coord, pad_mask in zip(records, coords, pad_masks): + for rec_idx, (record, coord, pad_mask) in enumerate(zip(records, coords, pad_masks)): # Load the structure - path = self.data_dir / f"{record.id}.npz" - if self.boltz2: - structure: StructureV2 = StructureV2.load(path) + if provided_structures is not None and rec_idx < len(provided_structures): + structure = provided_structures[rec_idx] + chain_map = { + int(chain["asym_id"]): int(chain["asym_id"]) + for chain in structure.chains + } else: - structure: Structure = Structure.load(path) + if structure_paths is not None and rec_idx < len(structure_paths): + path = Path(structure_paths[rec_idx]) + else: + path = self.data_dir / f"{record.id}.npz" + if self.boltz2: + structure: StructureV2 = StructureV2.load(path) + else: + structure: Structure = Structure.load(path) + + # Compute chain map with masked removed, to be used later + chain_map = {} + for i, mask in enumerate(structure.mask): + if mask: + chain_map[len(chain_map)] = i + + # Remove masked chains completely + structure = structure.remove_invalid_chains() - # Compute chain map with masked removed, to be used later - chain_map = {} - for i, mask in enumerate(structure.mask): - if mask: - chain_map[len(chain_map)] = i + custom_names = None + if custom_filenames is not None and rec_idx < len(custom_filenames): + custom_names = custom_filenames[rec_idx] - # Remove masked chains completely - structure = structure.remove_invalid_chains() + ext_map = {"mmcif": "cif", "pdb": "pdb"} + out_ext = ext_map.get(self.output_format, "npz") for model_idx in range(coord.shape[0]): # Get model coord @@ -156,24 +177,29 @@ def write_on_batch_end( plddts = prediction["plddt"][model_idx] # Create path name - outname = f"{record.id}_model_{idx_to_rank[model_idx]}" + if custom_names and model_idx < len(custom_names): + name = custom_names[model_idx] + if Path(name).suffix: + filename = name + else: + filename = f"{name}.{out_ext}" + else: + filename = f"{record.id}_model_{idx_to_rank[model_idx]}.{out_ext}" + output_path = struct_dir / filename # Save the structure if self.output_format == "pdb": - path = struct_dir / f"{outname}.pdb" - with path.open("w") as f: + with output_path.open("w") as f: f.write( to_pdb(new_structure, plddts=plddts, boltz2=self.boltz2) ) elif self.output_format == "mmcif": - path = struct_dir / f"{outname}.cif" - with path.open("w") as f: + with output_path.open("w") as f: f.write( to_mmcif(new_structure, plddts=plddts, boltz2=self.boltz2) ) else: - path = struct_dir / f"{outname}.npz" - np.savez_compressed(path, **asdict(new_structure)) + np.savez_compressed(output_path, **asdict(new_structure)) if self.boltz2 and record.affinity and idx_to_rank[model_idx] == 0: path = struct_dir / f"pre_affinity_{record.id}.npz" @@ -267,6 +293,26 @@ def on_predict_epoch_end( print(f"Number of failed examples: {self.failed}") # noqa: T201 +def atomic_save_cif(filepath: Path, content: str) -> bool: + """ + Safely write CIF content to disk (atomic rename + fsync). + """ + tmp_path = filepath.parent / f".{filepath.name}.tmp" + try: + filepath.parent.mkdir(parents=True, exist_ok=True) + with tmp_path.open("w") as handle: + handle.write(content) + handle.flush() + os.fsync(handle.fileno()) + tmp_path.replace(filepath) + return True + except Exception as exc: # noqa: BLE001 + print(f"Error writing file {filepath}: {exc}") + if tmp_path.exists(): + tmp_path.unlink() + return False + + class BoltzAffinityWriter(BasePredictionWriter): """Custom writer for predictions.""" diff --git a/src/boltz/model/loss/validation.py b/src/boltz/model/loss/validation.py index 00d1aa7c3..d957ee726 100644 --- a/src/boltz/model/loss/validation.py +++ b/src/boltz/model/loss/validation.py @@ -8,6 +8,9 @@ ) from boltz.model.loss.diffusion import weighted_rigid_align +from dataclasses import dataclass +from torch import Tensor + def factored_lddt_loss( true_atom_coords, @@ -895,6 +898,8 @@ def weighted_minimum_rmsd( multiplicity=1, nucleotide_weight=5.0, ligand_weight=10.0, + alignment_mask=None, + rmsd_mask=None, ): """Compute rmsd of the aligned atom coordinates. @@ -922,6 +927,27 @@ def weighted_minimum_rmsd( atom_mask = feats["atom_resolved_mask"] atom_mask = atom_mask.repeat_interleave(multiplicity, 0) + target_len = pred_atom_coords.shape[0] + base_len = feats["coords"].shape[0] + + def _prepare_mask(mask, fallback): + if mask is None: + prepared = fallback + else: + if mask.shape[0] == target_len: + prepared = mask + elif mask.shape[0] == base_len and base_len * multiplicity == target_len: + prepared = mask.repeat_interleave(multiplicity, 0) + else: + raise ValueError( + "Unexpected mask shape: " + f"{mask.shape[0]} (expected {target_len} or {base_len})." + ) + return prepared.to(dtype=fallback.dtype, device=pred_atom_coords.device) + + align_mask = _prepare_mask(alignment_mask, atom_mask) + calc_mask = _prepare_mask(rmsd_mask, atom_mask) + align_weights = atom_coords.new_ones(atom_coords.shape[:2]) atom_type = ( torch.bmm( @@ -945,14 +971,15 @@ def weighted_minimum_rmsd( with torch.no_grad(): atom_coords_aligned_ground_truth = weighted_rigid_align( - atom_coords, pred_atom_coords, align_weights, mask=atom_mask + atom_coords, pred_atom_coords, align_weights, mask=align_mask ) # weighted MSE loss of denoised atom positions mse_loss = ((pred_atom_coords - atom_coords_aligned_ground_truth) ** 2).sum(dim=-1) + denom = torch.sum(align_weights * calc_mask, dim=-1) rmsd = torch.sqrt( - torch.sum(mse_loss * align_weights * atom_mask, dim=-1) - / torch.sum(align_weights * atom_mask, dim=-1) + torch.sum(mse_loss * align_weights * calc_mask, dim=-1) + / torch.clamp(denom, min=1e-8) ) best_rmsd = torch.min(rmsd.reshape(-1, multiplicity), dim=1).values @@ -967,6 +994,8 @@ def weighted_minimum_rmsd_single( mol_type, nucleotide_weight=5.0, ligand_weight=10.0, + alignment_mask=None, + rmsd_mask=None, ): """Compute rmsd of the aligned atom coordinates. @@ -1012,14 +1041,23 @@ def weighted_minimum_rmsd_single( ) with torch.no_grad(): + mask_align = alignment_mask if alignment_mask is not None else atom_mask atom_coords_aligned_ground_truth = weighted_rigid_align( - atom_coords, pred_atom_coords, align_weights, mask=atom_mask + atom_coords, pred_atom_coords, align_weights, mask=mask_align ) # weighted MSE loss of denoised atom positions + calc_mask = rmsd_mask if rmsd_mask is not None else atom_mask mse_loss = ((pred_atom_coords - atom_coords_aligned_ground_truth) ** 2).sum(dim=-1) + denom = torch.sum(align_weights * calc_mask, dim=-1) rmsd = torch.sqrt( - torch.sum(mse_loss * align_weights * atom_mask, dim=-1) - / torch.sum(align_weights * atom_mask, dim=-1) + torch.sum(mse_loss * align_weights * calc_mask, dim=-1) + / torch.clamp(denom, min=1e-8) ) return rmsd, atom_coords_aligned_ground_truth, align_weights + +@dataclass +class SampleMetrics: + sample_idx: int + rmsd_whole: float + rmsd_masked: float diff --git a/src/boltz/model/models/boltz1.py b/src/boltz/model/models/boltz1.py index 51889b882..b5966b08d 100644 --- a/src/boltz/model/models/boltz1.py +++ b/src/boltz/model/models/boltz1.py @@ -1,6 +1,11 @@ import gc +import math import random +from pathlib import Path from typing import Any, Optional +from typing import Iterable + +import numpy as np import torch import torch._dynamo @@ -9,22 +14,27 @@ from torchmetrics import MeanMetric import boltz.model.layers.initialize as init +from boltz.data.types import Connection, Structure from boltz.data import const from boltz.data.feature.symmetry import ( minimum_lddt_symmetry_coords, minimum_symmetry_coords, ) +from lightning_fabric.utilities.apply_func import move_data_to_device from boltz.model.loss.confidence import confidence_loss from boltz.model.loss.distogram import distogram_loss from boltz.model.loss.validation import ( + SampleMetrics, compute_pae_mae, compute_pde_mae, compute_plddt_mae, factored_lddt_loss, factored_token_lddt_dist_loss, weighted_minimum_rmsd, + weighted_minimum_rmsd_single, ) from boltz.model.modules.confidence import ConfidenceModule +from boltz.data.write.writer import BoltzWriter from boltz.model.modules.diffusion import AtomDiffusion from boltz.model.modules.encoders import RelativePositionEncoder from boltz.model.modules.trunk import ( @@ -141,6 +151,7 @@ def __init__( # noqa: PLR0915, C901, PLR0912 self.steering_args = steering_args self.use_kernels = use_kernels + self.validation_writer: Optional[BoltzWriter] = None self.nucleotide_rmsd_weight = nucleotide_rmsd_weight self.ligand_rmsd_weight = ligand_rmsd_weight @@ -269,6 +280,27 @@ def setup(self, stage: str) -> None: ): self.use_kernels = False + def transfer_batch_to_device( + self, batch: dict[str, Any], device: torch.device, dataloader_idx: int + ) -> dict[str, Any]: + """Move tensor entries to device while keeping CPU-only metadata intact.""" + if batch is None: + return batch + + cpu_only_keys = {"record", "base_structure"} + cpu_entries = {} + tensor_entries = {} + + for key, value in batch.items(): + if key in cpu_only_keys: + cpu_entries[key] = value + else: + tensor_entries[key] = value + + moved = move_data_to_device(tensor_entries, device) + moved.update(cpu_entries) + return moved + def forward( self, feats: dict[str, Tensor], @@ -276,7 +308,7 @@ def forward( num_sampling_steps: Optional[int] = None, multiplicity_diffusion_train: int = 1, diffusion_samples: int = 1, - max_parallel_samples: Optional[int] = None, + max_parallel_samples: Optional[int] = 1, run_confidence_sequentially: bool = False, ) -> dict[str, Tensor]: dict_out = {} @@ -406,7 +438,12 @@ def get_true_coordinates( diffusion_samples, symmetry_correction, lddt_minimization=True, + alignment_mask: Optional[Tensor] = None, + rmsd_mask: Optional[Tensor] = None, ): + masked_rmsds = None + best_masked_rmsds = None + if symmetry_correction: min_coords_routine = ( minimum_lddt_symmetry_coords @@ -437,6 +474,8 @@ def get_true_coordinates( best_rmsds.append(best_rmsd) true_coords = torch.cat(true_coords, dim=0) true_coords_resolved_mask = torch.cat(true_coords_resolved_mask, dim=0) + masked_rmsds = None + best_masked_rmsds = None else: true_coords = ( batch["coords"].squeeze(1).repeat_interleave(diffusion_samples, 0) @@ -453,7 +492,50 @@ def get_true_coordinates( ligand_weight=self.ligand_rmsd_weight, ) - return true_coords, rmsds, best_rmsds, true_coords_resolved_mask + align_mask = alignment_mask + if align_mask is not None: + target_len = out["sample_atom_coords"].shape[0] + if align_mask.shape[0] != target_len: + if align_mask.shape[0] * diffusion_samples == target_len: + align_mask = align_mask.repeat_interleave(diffusion_samples, 0) + else: + raise ValueError( + "Alignment mask has unexpected first dimension " + f"{align_mask.shape[0]} (expected {target_len})." + ) + calc_mask = rmsd_mask + if calc_mask is not None: + target_len = out["sample_atom_coords"].shape[0] + if calc_mask.shape[0] != target_len: + if calc_mask.shape[0] * diffusion_samples == target_len: + calc_mask = calc_mask.repeat_interleave(diffusion_samples, 0) + else: + raise ValueError( + "RMSD mask has unexpected first dimension " + f"{calc_mask.shape[0]} (expected {target_len})." + ) + + if align_mask is not None or calc_mask is not None: + masked_rmsds, best_masked_rmsds = weighted_minimum_rmsd( + out["sample_atom_coords"], + batch, + multiplicity=diffusion_samples, + nucleotide_weight=self.nucleotide_rmsd_weight, + ligand_weight=self.ligand_rmsd_weight, + alignment_mask=align_mask, + rmsd_mask=calc_mask, + ) + else: + masked_rmsds, best_masked_rmsds = rmsds, best_rmsds + + return ( + true_coords, + rmsds, + best_rmsds, + true_coords_resolved_mask, + masked_rmsds, + best_masked_rmsds, + ) def training_step(self, batch: dict[str, Tensor], batch_idx: int) -> Tensor: # Sample recycling steps @@ -491,7 +573,7 @@ def training_step(self, batch: dict[str, Tensor], batch_idx: int) -> Tensor: if self.confidence_prediction: # confidence model symmetry correction - true_coords, _, _, true_coords_resolved_mask = self.get_true_coordinates( + true_coords, _, _, true_coords_resolved_mask, _, _ = self.get_true_coordinates( batch, out, diffusion_samples=self.training_args.diffusion_samples, @@ -615,6 +697,332 @@ def parameter_norm(self, module) -> float: norm = torch.tensor([p.norm(p=2) ** 2 for p in parameters]).sum().sqrt() return norm + def _build_alignment_masks( + self, + batch: dict[str, Tensor], + record_ids: list[str], + base_structures: list[Optional[Structure]], + batch_idx: int, + precomputed_align_masks: Optional[list[Optional[np.ndarray]]] = None, + precomputed_rmsd_masks: Optional[list[Optional[np.ndarray]]] = None, + ) -> tuple[Tensor, Tensor]: + batch_size, num_atoms, _ = batch["atom_to_token"].shape + device = batch["atom_pad_mask"].device + + heavy_calc_mask = torch.zeros((batch_size, num_atoms), dtype=torch.bool, device=device) + peptide_calc_mask = torch.zeros_like(heavy_calc_mask) + + if precomputed_align_masks is None or precomputed_rmsd_masks is None: + raise ValueError("Precomputed masks are required but were not provided.") + + for structure_idx in range(batch_size): + record_id = ( + record_ids[structure_idx] + if structure_idx < len(record_ids) + else f"batch_{batch_idx}_{structure_idx}" + ) + structure = ( + base_structures[structure_idx] + if structure_idx < len(base_structures) + else None + ) + pre_align = precomputed_align_masks[structure_idx] + pre_rmsd = precomputed_rmsd_masks[structure_idx] + + atom_to_token = batch["atom_to_token"][structure_idx].bool() + entity_ids = batch["entity_id"][structure_idx] + present_atom = batch["atom_pad_mask"][structure_idx].bool() + + unique_entities = torch.unique(entity_ids) + res_counts = {int(e.item()): int((entity_ids == e).sum().item()) for e in unique_entities} + if not res_counts: + print(f"Warning: no entities found for structure {structure_idx}.") + continue + + peptide_entity = min(res_counts, key=res_counts.get) + heavy_candidates = {k: v for k, v in res_counts.items() if k != peptide_entity} + heavy_entity = max(heavy_candidates, key=heavy_candidates.get) if heavy_candidates else peptide_entity + + tok_is_pep = (entity_ids == peptide_entity) + tok_is_heavy = (entity_ids == heavy_entity) + + atom_pep = atom_to_token[:, tok_is_pep].any(dim=1) + atom_heavy = atom_to_token[:, tok_is_heavy].any(dim=1) + + atom_heavy = atom_heavy & ~atom_pep + atom_pep = atom_pep & present_atom + atom_heavy = atom_heavy & present_atom + + if atom_heavy.sum() < 3 or atom_pep.sum() < 3: + print( + f"Insufficient atoms for alignment in record {record_id}" + ) + continue + + if pre_align is None or pre_rmsd is None: + raise ValueError( + f"Missing precomputed masks for record '{record_id}'. " + "Please regenerate mask files." + ) + + valid_mask = present_atom + valid_count = int(valid_mask.sum().item()) + if structure is None or pre_align.shape[0] != valid_count or pre_rmsd.shape[0] != valid_count: + raise ValueError( + f"Precomputed mask length mismatch for record '{record_id}'. " + "Ensure masks were generated after removing invalid chains." + ) + + heavy_mask = torch.zeros_like(valid_mask, dtype=torch.bool, device=device) + peptide_mask = torch.zeros_like(valid_mask, dtype=torch.bool, device=device) + + heavy_mask[valid_mask] = torch.from_numpy(pre_align.astype(bool, copy=False)).to(device=device) + peptide_mask[valid_mask] = torch.from_numpy(pre_rmsd.astype(bool, copy=False)).to(device=device) + + if heavy_mask.sum() < 3 or peptide_mask.sum() < 3: + raise ValueError( + f"Precomputed masks for record '{record_id}' contain fewer than 3 atoms. " + "Please regenerate masks with enough atoms for alignment and RMSD." + ) + + heavy_calc_mask[structure_idx] = heavy_mask + peptide_calc_mask[structure_idx] = peptide_mask + print( + f"[debug] structure {structure_idx} entities: {res_counts} | heavy={heavy_entity} peptide={peptide_entity} | " + "align_mode=precomputed calc_mode=precomputed" + ) + print( + f"Align atoms: {int(heavy_mask.sum().item())} | Calc atoms: {int(peptide_mask.sum().item())}" + ) + + return heavy_calc_mask, peptide_calc_mask + + def _write_validation_cifs( + self, + batch: dict[str, Tensor], + out: dict[str, Tensor], + base_structures: list[Optional[Structure]], + record_ids: list[str], + sample_metrics: Iterable[SampleMetrics], + n_samples: int, + batch_idx: int, + ) -> None: + if not getattr(self.validation_args, "write_cif_for_validation", True): + return + + output_dir = Path(getattr(self.validation_args, "val_cif_out_dir", "validation_outputs")) + output_dir.mkdir(parents=True, exist_ok=True) + + if ( + self.validation_writer is None + or Path(self.validation_writer.output_dir) != output_dir + ): + self.validation_writer = BoltzWriter( + data_dir=".", + output_dir=str(output_dir), + output_format="mmcif", + ) + + records = batch.get("record", []) + structure_paths = batch.get("structure_path") + atom_pad_mask = batch["atom_pad_mask"].detach().cpu().bool() + metrics_map = {metric.sample_idx: metric for metric in sample_metrics} + total_samples = out["sample_atom_coords"].shape[0] + samples_per_structure = n_samples if n_samples > 0 else total_samples + + for struct_idx, record_id in enumerate(record_ids): + if not records or struct_idx >= len(records): + print(f"[warning] Missing record metadata for '{record_id}', skipping CIF write.") + continue + record = records[struct_idx] + + sample_block = out["sample_atom_coords"][ + struct_idx * samples_per_structure : (struct_idx + 1) * samples_per_structure + ].detach().cpu() + pad_mask = atom_pad_mask[struct_idx : struct_idx + 1] + + filenames = [] + for sample_offset in range(samples_per_structure): + sample_idx = struct_idx * samples_per_structure + sample_offset + metrics = metrics_map.get(sample_idx) + if metrics is not None: + name = ( + f"prediction_{record_id}_sample_{sample_idx}" + f"_whole{metrics.rmsd_whole:.2f}_mask{metrics.rmsd_masked:.2f}" + ) + else: + name = f"prediction_{record_id}_sample_{sample_idx}" + filenames.append(name) + + structure_path_entry = None + if structure_paths is not None and struct_idx < len(structure_paths): + structure_path_entry = [structure_paths[struct_idx]] + + structure_entry = None + if base_structures and struct_idx < len(base_structures): + structure_entry = [base_structures[struct_idx]] + + prediction_payload = { + "exception": False, + "coords": sample_block, + "masks": pad_mask, + "structure_paths": structure_path_entry, + "structures": structure_entry, + "filenames": [filenames], + } + + writer_batch = {"record": [record]} + + self.validation_writer.write_on_batch_end( + trainer=None, + pl_module=self, + prediction=prediction_payload, + batch_indices=[0], + batch=writer_batch, + batch_idx=batch_idx, + dataloader_idx=0, + ) + + def _compute_sample_metrics( + self, + batch: dict[str, Tensor], + out: dict[str, Tensor], + true_coords: Tensor, + heavy_calc_mask: Tensor, + peptide_calc_mask: Tensor, + n_samples: int, + ) -> list[SampleMetrics]: + sample_coords = out["sample_atom_coords"] + device = sample_coords.device + total_samples = sample_coords.shape[0] + samples_per_structure = max(n_samples, 1) + + atom_mask_full = batch["atom_resolved_mask"].to(device=device).float() + atom_to_token_full = batch["atom_to_token"].float().to(device=device) + mol_type_full = batch["mol_type"].to(device=device) + heavy_mask_full = heavy_calc_mask.to(device=device, dtype=atom_mask_full.dtype) + peptide_mask_full = peptide_calc_mask.to(device=device, dtype=atom_mask_full.dtype) + + sample_metrics: list[SampleMetrics] = [] + + for sample_idx in range(total_samples): + struct_idx = sample_idx // samples_per_structure + pred_sample = sample_coords[sample_idx : sample_idx + 1] + ref_sample = true_coords[sample_idx : sample_idx + 1] + + struct_atom_mask = atom_mask_full[struct_idx : struct_idx + 1] + struct_atom_to_token = atom_to_token_full[struct_idx : struct_idx + 1] + struct_mol_type = mol_type_full[struct_idx : struct_idx + 1] + + whole_rmsd = float("nan") + masked_rmsd = float("nan") + + try: + whole_rmsd_tensor, _, _ = weighted_minimum_rmsd_single( + pred_sample, + ref_sample, + struct_atom_mask, + struct_atom_to_token, + struct_mol_type, + nucleotide_weight=self.nucleotide_rmsd_weight, + ligand_weight=self.ligand_rmsd_weight, + ) + whole_rmsd = whole_rmsd_tensor.item() + except Exception as exc: # noqa: BLE001 + print(f"Weighted RMSD (MHC Chain) failed for sample {sample_idx}: {exc}") + + align_mask_row = heavy_mask_full[struct_idx : struct_idx + 1] + calc_mask_row = peptide_mask_full[struct_idx : struct_idx + 1] + try: + if ( + calc_mask_row.float().sum() >= 3 + and align_mask_row.float().sum() >= 3 + ): + masked_rmsd_tensor, _, _ = weighted_minimum_rmsd_single( + pred_sample, + ref_sample, + struct_atom_mask, + struct_atom_to_token, + struct_mol_type, + nucleotide_weight=self.nucleotide_rmsd_weight, + ligand_weight=self.ligand_rmsd_weight, + alignment_mask=align_mask_row, + rmsd_mask=calc_mask_row, + ) + masked_rmsd = masked_rmsd_tensor.item() + except Exception as exc: # noqa: BLE001 + print(f"Weighted RMSD (masked) failed for sample {sample_idx}: {exc}") + + print(f"Sample {sample_idx} weighted RMSD (MHC Chain): {whole_rmsd:.3f}Å") + if math.isnan(masked_rmsd): + print(f"Sample {sample_idx} weighted RMSD (masked region): nan") + else: + print( + f"Sample {sample_idx} weighted RMSD (masked region): {masked_rmsd:.3f}Å" + ) + + sample_metrics.append( + SampleMetrics( + sample_idx=sample_idx, + rmsd_whole=whole_rmsd, + rmsd_masked=masked_rmsd, + ) + ) + + return sample_metrics + + def _extract_record_ids(self, batch: dict[str, Tensor], batch_idx: int) -> list[str]: + record_ids = batch.get("record_id", None) + if record_ids is None: + return [f"batch_{batch_idx}_{idx}" for idx in range(batch["atom_to_token"].shape[0])] + if isinstance(record_ids, torch.Tensor): + return [str(r.item()) for r in record_ids] + return [str(r) for r in record_ids] + + def debug_peptide_mask_residues( + self, + base_structures: list[Optional[Structure]], + record_ids: list[str], + peptide_calc_mask: torch.Tensor, + precomputed_rmsd_masks: list[Optional[np.ndarray]], + ): + def _labels(structure: Structure, mask) -> list[str]: + if structure is None or not hasattr(structure, "residues"): + return [] + if isinstance(mask, torch.Tensor): + mask = mask.detach().cpu().numpy() + if mask is None or mask.size == 0: + return [] + + res_arr = structure.residues + names = set(getattr(res_arr, "dtype", ()).names or []) + if not {"atom_idx", "atom_num", "name"}.issubset(names): + return [] + + resnum_field = "res_num" if "res_num" in names else ("res_idx" if "res_idx" in names else None) + out = [] + L = mask.shape[0] + for i, r in enumerate(res_arr): + s = int(r["atom_idx"]); e = s + int(r["atom_num"]) + if s < L and mask[s:e].any(): + rname = r["name"].decode("utf-8", "ignore") if isinstance(r["name"], (bytes, bytearray)) else str(r["name"]) + rnum = int(r[resnum_field]) if resnum_field else (i + 1) + out.append(f"{rname}{rnum}") + return out + + for i, (structure, rec_id) in enumerate(zip(base_structures, record_ids)): + if structure is None: + continue + pred_mask = peptide_calc_mask[i].detach().cpu().numpy() + ref_mask = precomputed_rmsd_masks[i] if i < len(precomputed_rmsd_masks) else None + + pred_list = _labels(structure, pred_mask) + ref_list = _labels(structure, ref_mask) + + pred_str = ", ".join(pred_list) if pred_list else "None" + ref_str = ", ".join(ref_list) if ref_list else "None" + print(f"[debug-peptide-residues] record {rec_id}\n REF: {ref_str}\n PRD: {pred_str}") + def validation_step(self, batch: dict[str, Tensor], batch_idx: int): # Compute the forward pass n_samples = self.validation_args.diffusion_samples @@ -626,8 +1034,7 @@ def validation_step(self, batch: dict[str, Tensor], batch_idx: int): diffusion_samples=n_samples, run_confidence_sequentially=self.validation_args.run_confidence_sequentially, ) - - except RuntimeError as e: # catch out of memory exceptions + except RuntimeError as e: if "out of memory" in str(e): print("| WARNING: ran out of memory, skipping batch") torch.cuda.empty_cache() @@ -637,42 +1044,77 @@ def validation_step(self, batch: dict[str, Tensor], batch_idx: int): raise e try: - # Compute distogram LDDT - boundaries = torch.linspace(2, 22.0, 63) - lower = torch.tensor([1.0]) - upper = torch.tensor([22.0 + 5.0]) - exp_boundaries = torch.cat((lower, boundaries, upper)) - mid_points = ((exp_boundaries[:-1] + exp_boundaries[1:]) / 2).to( - out["pdistogram"] + record_ids = self._extract_record_ids(batch, batch_idx) + + batch_size = batch["atom_to_token"].shape[0] + + base_structures = batch.get("base_structure", None) + if base_structures is None: + raise ValueError("Validation batch missing 'base_structure'. Ensure dataloader supplies preloaded structures.") + precomputed_align_masks: list[Optional[np.ndarray]] + precomputed_rmsd_masks: list[Optional[np.ndarray]] + + if base_structures is None: + raise ValueError("Validation requires 'base_structure' in batch to avoid disk reloads.") + if not isinstance(base_structures, list): + base_structures = [base_structures] + precomputed_align_masks = [None] * len(base_structures) + precomputed_rmsd_masks = [None] * len(base_structures) + + align_tensor = batch.get("alignment_mask") + if isinstance(align_tensor, torch.Tensor): + align_tensor = align_tensor.bool() + for idx in range(min(align_tensor.shape[0], len(base_structures))): + precomputed_align_masks[idx] = align_tensor[idx].detach().cpu().numpy() + + rmsd_tensor = batch.get("rmsd_mask") + if isinstance(rmsd_tensor, torch.Tensor): + rmsd_tensor = rmsd_tensor.bool() + for idx in range(min(rmsd_tensor.shape[0], len(base_structures))): + precomputed_rmsd_masks[idx] = rmsd_tensor[idx].detach().cpu().numpy() + + heavy_calc_mask, peptide_calc_mask = self._build_alignment_masks( + batch=batch, + record_ids=record_ids, + base_structures=base_structures, + batch_idx=batch_idx, + precomputed_align_masks=precomputed_align_masks, + precomputed_rmsd_masks=precomputed_rmsd_masks, ) - # Compute predicted dists + if getattr(self.validation_args, "debug_peptide_mask_info", False): + self.debug_peptide_mask_residues(base_structures, record_ids, peptide_calc_mask, precomputed_rmsd_masks) + + true_coords, rmsds, best_rmsds, true_coords_resolved_mask, masked_rmsds, best_masked_rmsds = self.get_true_coordinates( + batch=batch, + out=out, + diffusion_samples=n_samples, + symmetry_correction=self.validation_args.symmetry_correction, + alignment_mask=heavy_calc_mask, + rmsd_mask=peptide_calc_mask, + ) + + + boundaries = torch.linspace(2, 22.0, 63, device=out["pdistogram"].device) + lower = torch.tensor([1.0], device=boundaries.device) + upper = torch.tensor([27.0], device=boundaries.device) + exp_boundaries = torch.cat((lower, boundaries, upper)) + mid_points = ((exp_boundaries[:-1] + exp_boundaries[1:]) / 2).to(out["pdistogram"]) + preds = out["pdistogram"] pred_softmax = torch.softmax(preds, dim=-1) - pred_softmax = pred_softmax.argmax(dim=-1) - pred_softmax = torch.nn.functional.one_hot( - pred_softmax, num_classes=preds.shape[-1] - ) - pred_dist = (pred_softmax * mid_points).sum(dim=-1) + pred_max = pred_softmax.argmax(dim=-1) + pred_one_hot = torch.nn.functional.one_hot(pred_max, num_classes=preds.shape[-1]) + pred_dist = (pred_one_hot * mid_points).sum(dim=-1) true_center = batch["disto_center"] true_dists = torch.cdist(true_center, true_center) - # Compute lddt's - batch["token_disto_mask"] = batch["token_disto_mask"] disto_lddt_dict, disto_total_dict = factored_token_lddt_dist_loss( feats=batch, true_d=true_dists, pred_d=pred_dist, ) - - true_coords, rmsds, best_rmsds, true_coords_resolved_mask = ( - self.get_true_coordinates( - batch=batch, - out=out, - diffusion_samples=n_samples, - symmetry_correction=self.validation_args.symmetry_correction, - ) - ) + batch["token_disto_mask"] = batch["token_disto_mask"] all_lddt_dict, all_total_dict = factored_lddt_loss( feats=batch, @@ -681,201 +1123,186 @@ def validation_step(self, batch: dict[str, Tensor], batch_idx: int): pred_atom_coords=out["sample_atom_coords"], multiplicity=n_samples, ) - except RuntimeError as e: # catch out of memory exceptions - if "out of memory" in str(e): - print("| WARNING: ran out of memory, skipping batch") - torch.cuda.empty_cache() - gc.collect() - return - else: - raise e - # if the multiplicity used is > 1 then we take the best lddt of the different samples - # AF3 combines this with the confidence based filtering - best_lddt_dict, best_total_dict = {}, {} - best_complex_lddt_dict, best_complex_total_dict = {}, {} - B = true_coords.shape[0] // n_samples - if n_samples > 1: - # NOTE: we can change the way we aggregate the lddt - complex_total = 0 - complex_lddt = 0 - for key in all_lddt_dict.keys(): - complex_lddt += all_lddt_dict[key] * all_total_dict[key] - complex_total += all_total_dict[key] - complex_lddt /= complex_total + 1e-7 - best_complex_idx = complex_lddt.reshape(-1, n_samples).argmax(dim=1) - for key in all_lddt_dict: - best_idx = all_lddt_dict[key].reshape(-1, n_samples).argmax(dim=1) - best_lddt_dict[key] = all_lddt_dict[key].reshape(-1, n_samples)[ - torch.arange(B), best_idx - ] - best_total_dict[key] = all_total_dict[key].reshape(-1, n_samples)[ - torch.arange(B), best_idx - ] - best_complex_lddt_dict[key] = all_lddt_dict[key].reshape(-1, n_samples)[ - torch.arange(B), best_complex_idx - ] - best_complex_total_dict[key] = all_total_dict[key].reshape( - -1, n_samples - )[torch.arange(B), best_complex_idx] - else: - best_lddt_dict = all_lddt_dict - best_total_dict = all_total_dict - best_complex_lddt_dict = all_lddt_dict - best_complex_total_dict = all_total_dict - - # Filtering based on confidence - if self.confidence_prediction and n_samples > 1: - # note: for now we don't have pae predictions so have to use pLDDT instead of pTM - # also, while AF3 differentiates the best prediction per confidence type we are currently not doing it - # consider this in the future as well as weighing the different pLLDT types before aggregation - mae_plddt_dict, total_mae_plddt_dict = compute_plddt_mae( - pred_atom_coords=out["sample_atom_coords"], - feats=batch, - true_atom_coords=true_coords, - pred_lddt=out["plddt"], - true_coords_resolved_mask=true_coords_resolved_mask, - multiplicity=n_samples, - ) - mae_pde_dict, total_mae_pde_dict = compute_pde_mae( - pred_atom_coords=out["sample_atom_coords"], - feats=batch, - true_atom_coords=true_coords, - pred_pde=out["pde"], - true_coords_resolved_mask=true_coords_resolved_mask, - multiplicity=n_samples, - ) - mae_pae_dict, total_mae_pae_dict = compute_pae_mae( - pred_atom_coords=out["sample_atom_coords"], - feats=batch, - true_atom_coords=true_coords, - pred_pae=out["pae"], - true_coords_resolved_mask=true_coords_resolved_mask, - multiplicity=n_samples, - ) - plddt = out["complex_plddt"].reshape(-1, n_samples) - top1_idx = plddt.argmax(dim=1) - iplddt = out["complex_iplddt"].reshape(-1, n_samples) - iplddt_top1_idx = iplddt.argmax(dim=1) - pde = out["complex_pde"].reshape(-1, n_samples) - pde_top1_idx = pde.argmin(dim=1) - ipde = out["complex_ipde"].reshape(-1, n_samples) - ipde_top1_idx = ipde.argmin(dim=1) - ptm = out["ptm"].reshape(-1, n_samples) - ptm_top1_idx = ptm.argmax(dim=1) - iptm = out["iptm"].reshape(-1, n_samples) - iptm_top1_idx = iptm.argmax(dim=1) - ligand_iptm = out["ligand_iptm"].reshape(-1, n_samples) - ligand_iptm_top1_idx = ligand_iptm.argmax(dim=1) - protein_iptm = out["protein_iptm"].reshape(-1, n_samples) - protein_iptm_top1_idx = protein_iptm.argmax(dim=1) - - for key in all_lddt_dict: - top1_lddt = all_lddt_dict[key].reshape(-1, n_samples)[ - torch.arange(B), top1_idx - ] - top1_total = all_total_dict[key].reshape(-1, n_samples)[ - torch.arange(B), top1_idx - ] - iplddt_top1_lddt = all_lddt_dict[key].reshape(-1, n_samples)[ - torch.arange(B), iplddt_top1_idx - ] - iplddt_top1_total = all_total_dict[key].reshape(-1, n_samples)[ - torch.arange(B), iplddt_top1_idx - ] - pde_top1_lddt = all_lddt_dict[key].reshape(-1, n_samples)[ - torch.arange(B), pde_top1_idx - ] - pde_top1_total = all_total_dict[key].reshape(-1, n_samples)[ - torch.arange(B), pde_top1_idx - ] - ipde_top1_lddt = all_lddt_dict[key].reshape(-1, n_samples)[ - torch.arange(B), ipde_top1_idx - ] - ipde_top1_total = all_total_dict[key].reshape(-1, n_samples)[ - torch.arange(B), ipde_top1_idx - ] - ptm_top1_lddt = all_lddt_dict[key].reshape(-1, n_samples)[ - torch.arange(B), ptm_top1_idx - ] - ptm_top1_total = all_total_dict[key].reshape(-1, n_samples)[ - torch.arange(B), ptm_top1_idx - ] - iptm_top1_lddt = all_lddt_dict[key].reshape(-1, n_samples)[ - torch.arange(B), iptm_top1_idx - ] - iptm_top1_total = all_total_dict[key].reshape(-1, n_samples)[ - torch.arange(B), iptm_top1_idx - ] - ligand_iptm_top1_lddt = all_lddt_dict[key].reshape(-1, n_samples)[ - torch.arange(B), ligand_iptm_top1_idx - ] - ligand_iptm_top1_total = all_total_dict[key].reshape(-1, n_samples)[ - torch.arange(B), ligand_iptm_top1_idx - ] - protein_iptm_top1_lddt = all_lddt_dict[key].reshape(-1, n_samples)[ - torch.arange(B), protein_iptm_top1_idx - ] - protein_iptm_top1_total = all_total_dict[key].reshape(-1, n_samples)[ - torch.arange(B), protein_iptm_top1_idx - ] - - self.top1_lddt[key].update(top1_lddt, top1_total) - self.iplddt_top1_lddt[key].update(iplddt_top1_lddt, iplddt_top1_total) - self.pde_top1_lddt[key].update(pde_top1_lddt, pde_top1_total) - self.ipde_top1_lddt[key].update(ipde_top1_lddt, ipde_top1_total) - self.ptm_top1_lddt[key].update(ptm_top1_lddt, ptm_top1_total) - self.iptm_top1_lddt[key].update(iptm_top1_lddt, iptm_top1_total) - self.ligand_iptm_top1_lddt[key].update( - ligand_iptm_top1_lddt, ligand_iptm_top1_total + best_lddt_dict, best_total_dict = {}, {} + best_complex_lddt_dict, best_complex_total_dict = {}, {} + B = true_coords.shape[0] // n_samples + if n_samples > 1: + complex_total = 0 + complex_lddt = 0 + for key in all_lddt_dict.keys(): + complex_lddt += all_lddt_dict[key] * all_total_dict[key] + complex_total += all_total_dict[key] + complex_lddt /= complex_total + 1e-7 + best_complex_idx = complex_lddt.reshape(-1, n_samples).argmax(dim=1) + for key in all_lddt_dict: + reshaped_lddt = all_lddt_dict[key].reshape(-1, n_samples) + reshaped_total = all_total_dict[key].reshape(-1, n_samples) + best_idx = reshaped_lddt.argmax(dim=1) + best_lddt_dict[key] = reshaped_lddt[torch.arange(B), best_idx] + best_total_dict[key] = reshaped_total[torch.arange(B), best_idx] + best_complex_lddt_dict[key] = reshaped_lddt[ + torch.arange(B), best_complex_idx + ] + best_complex_total_dict[key] = reshaped_total[ + torch.arange(B), best_complex_idx + ] + else: + best_lddt_dict = all_lddt_dict + best_total_dict = all_total_dict + best_complex_lddt_dict = all_lddt_dict + best_complex_total_dict = all_total_dict + + def _allowed_metric(key: str) -> bool: + low = key.lower() + return ("dna" not in low) and ("rna" not in low) and ("ligand" not in low) + + if self.confidence_prediction and n_samples > 1: + mae_plddt_dict, total_mae_plddt_dict = compute_plddt_mae( + pred_atom_coords=out["sample_atom_coords"], + feats=batch, + true_atom_coords=true_coords, + pred_lddt=out["plddt"], + true_coords_resolved_mask=true_coords_resolved_mask, + multiplicity=n_samples, ) - self.protein_iptm_top1_lddt[key].update( - protein_iptm_top1_lddt, protein_iptm_top1_total + mae_pde_dict, total_mae_pde_dict = compute_pde_mae( + pred_atom_coords=out["sample_atom_coords"], + feats=batch, + true_atom_coords=true_coords, + pred_pde=out["pde"], + true_coords_resolved_mask=true_coords_resolved_mask, + multiplicity=n_samples, ) - - self.avg_lddt[key].update(all_lddt_dict[key], all_total_dict[key]) - self.pde_mae[key].update(mae_pde_dict[key], total_mae_pde_dict[key]) - self.pae_mae[key].update(mae_pae_dict[key], total_mae_pae_dict[key]) - - for key in mae_plddt_dict: - self.plddt_mae[key].update( - mae_plddt_dict[key], total_mae_plddt_dict[key] + mae_pae_dict, total_mae_pae_dict = compute_pae_mae( + pred_atom_coords=out["sample_atom_coords"], + feats=batch, + true_atom_coords=true_coords, + pred_pae=out["pae"], + true_coords_resolved_mask=true_coords_resolved_mask, + multiplicity=n_samples, ) - for m in const.out_types: - if m == "ligand_protein": - if torch.any( - batch["pocket_feature"][ - :, :, const.pocket_contact_info["POCKET"] - ].bool() - ): - self.lddt["pocket_ligand_protein"].update( - best_lddt_dict[m], best_total_dict[m] + plddt = out["complex_plddt"].reshape(-1, n_samples) + top1_idx = plddt.argmax(dim=1) + iplddt = out["complex_iplddt"].reshape(-1, n_samples) + iplddt_top1_idx = iplddt.argmax(dim=1) + pde = out["complex_pde"].reshape(-1, n_samples) + pde_top1_idx = pde.argmin(dim=1) + ipde = out["complex_ipde"].reshape(-1, n_samples) + ipde_top1_idx = ipde.argmin(dim=1) + ptm = out["ptm"].reshape(-1, n_samples) + ptm_top1_idx = ptm.argmax(dim=1) + iptm = out["iptm"].reshape(-1, n_samples) + iptm_top1_idx = iptm.argmax(dim=1) + ligand_iptm = out["ligand_iptm"].reshape(-1, n_samples) + ligand_iptm_top1_idx = ligand_iptm.argmax(dim=1) + protein_iptm = out["protein_iptm"].reshape(-1, n_samples) + protein_iptm_top1_idx = protein_iptm.argmax(dim=1) + + for key in all_lddt_dict: + if not _allowed_metric(key): + continue + reshaped_lddt = all_lddt_dict[key].reshape(-1, n_samples) + reshaped_total = all_total_dict[key].reshape(-1, n_samples) + + self.top1_lddt[key].update( + reshaped_lddt[torch.arange(B), top1_idx], + reshaped_total[torch.arange(B), top1_idx], ) - self.disto_lddt["pocket_ligand_protein"].update( - disto_lddt_dict[m], disto_total_dict[m] + self.iplddt_top1_lddt[key].update( + reshaped_lddt[torch.arange(B), iplddt_top1_idx], + reshaped_total[torch.arange(B), iplddt_top1_idx], ) - self.complex_lddt["pocket_ligand_protein"].update( - best_complex_lddt_dict[m], best_complex_total_dict[m] + self.pde_top1_lddt[key].update( + reshaped_lddt[torch.arange(B), pde_top1_idx], + reshaped_total[torch.arange(B), pde_top1_idx], ) - else: - self.lddt["ligand_protein"].update( - best_lddt_dict[m], best_total_dict[m] + self.ipde_top1_lddt[key].update( + reshaped_lddt[torch.arange(B), ipde_top1_idx], + reshaped_total[torch.arange(B), ipde_top1_idx], + ) + self.ptm_top1_lddt[key].update( + reshaped_lddt[torch.arange(B), ptm_top1_idx], + reshaped_total[torch.arange(B), ptm_top1_idx], ) - self.disto_lddt["ligand_protein"].update( - disto_lddt_dict[m], disto_total_dict[m] + self.iptm_top1_lddt[key].update( + reshaped_lddt[torch.arange(B), iptm_top1_idx], + reshaped_total[torch.arange(B), iptm_top1_idx], ) - self.complex_lddt["ligand_protein"].update( - best_complex_lddt_dict[m], best_complex_total_dict[m] + self.ligand_iptm_top1_lddt[key].update( + reshaped_lddt[torch.arange(B), ligand_iptm_top1_idx], + reshaped_total[torch.arange(B), ligand_iptm_top1_idx], ) + self.protein_iptm_top1_lddt[key].update( + reshaped_lddt[torch.arange(B), protein_iptm_top1_idx], + reshaped_total[torch.arange(B), protein_iptm_top1_idx], + ) + + self.avg_lddt[key].update(all_lddt_dict[key], all_total_dict[key]) + self.pde_mae[key].update(mae_pde_dict[key], total_mae_pde_dict[key]) + self.pae_mae[key].update(mae_pae_dict[key], total_mae_pae_dict[key]) + + for key in mae_plddt_dict: + if not _allowed_metric(key): + continue + self.plddt_mae[key].update( + mae_plddt_dict[key], total_mae_plddt_dict[key] + ) + + def _update_metric_collection(collection, values_dict, totals_dict): + for key in values_dict: + if not _allowed_metric(key): + continue + collection[key].update(values_dict[key], totals_dict[key]) + + _update_metric_collection(self.lddt, best_lddt_dict, best_total_dict) + _update_metric_collection(self.disto_lddt, disto_lddt_dict, disto_total_dict) + _update_metric_collection( + self.complex_lddt, best_complex_lddt_dict, best_complex_total_dict + ) + self.rmsd.update(rmsds) + self.best_rmsd.update(best_rmsds) + + sample_metrics = self._compute_sample_metrics( + batch=batch, + out=out, + true_coords=true_coords, + heavy_calc_mask=heavy_calc_mask, + peptide_calc_mask=peptide_calc_mask, + n_samples=n_samples, + ) + + whole_values = [m.rmsd_whole for m in sample_metrics if not math.isnan(m.rmsd_whole)] + masked_values = [m.rmsd_masked for m in sample_metrics if not math.isnan(m.rmsd_masked)] + + if whole_values: + whole_mean = torch.tensor(whole_values, device=out["sample_atom_coords"].device, dtype=torch.float32).mean() + self.log("val/weighted_rmsd_whole", whole_mean, prog_bar=False, sync_dist=True, batch_size=batch_size) + if masked_values: + masked_mean = torch.tensor(masked_values, device=out["sample_atom_coords"].device, dtype=torch.float32).mean() + self.log("val/weighted_rmsd_masked", masked_mean, prog_bar=False, sync_dist=True, batch_size=batch_size) + + + self._write_validation_cifs( + batch=batch, + out=out, + base_structures=base_structures, + record_ids=record_ids, + sample_metrics=sample_metrics, + n_samples=n_samples, + batch_idx=batch_idx, + ) + + + except RuntimeError as e: + if "out of memory" in str(e): + print("| WARNING: ran out of memory, skipping batch") + torch.cuda.empty_cache() + gc.collect() + return else: - self.lddt[m].update(best_lddt_dict[m], best_total_dict[m]) - self.disto_lddt[m].update(disto_lddt_dict[m], disto_total_dict[m]) - self.complex_lddt[m].update( - best_complex_lddt_dict[m], best_complex_total_dict[m] - ) - self.rmsd.update(rmsds) - self.best_rmsd.update(best_rmsds) + raise e def on_validation_epoch_end(self): avg_lddt = {}