From 1de4a74adedc18001b04043142816909435eed59 Mon Sep 17 00:00:00 2001 From: AlvandVahedi Date: Wed, 5 Nov 2025 04:54:16 +0000 Subject: [PATCH 1/5] Refactor validation to add weighted RMSD metrics and shared CIF writer --- .gitignore | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index 3d20fc11a..6f79226b3 100644 --- a/.gitignore +++ b/.gitignore @@ -163,4 +163,10 @@ cython_debug/ # Boltz prediction outputs # All result files generated from a boltz prediction call -boltz_results_*/ \ No newline at end of file +boltz_results_*/ + +# Local datasets and caches +data/ +natives/ +cache/ +mhc_one_sample/ From b015baa21a2e2e6f25fa4abdecbae71fe8f7dfd7 Mon Sep 17 00:00:00 2001 From: AlvandVahedi Date: Wed, 5 Nov 2025 04:54:40 +0000 Subject: [PATCH 2/5] Refactor validation to add weighted RMSD metrics and shared CIF writer --- scripts/train/configs/structure.yaml | 32 +- scripts/train/train.py | 3 +- src/boltz/data/module/training.py | 21 +- src/boltz/data/write/writer.py | 116 ++++++- src/boltz/model/loss/validation.py | 122 +++++++ src/boltz/model/models/boltz1.py | 497 +++++++++++++++------------ 6 files changed, 541 insertions(+), 250 deletions(-) diff --git a/scripts/train/configs/structure.yaml b/scripts/train/configs/structure.yaml index 6591f386a..81e9ef55a 100644 --- a/scripts/train/configs/structure.yaml +++ b/scripts/train/configs/structure.yaml @@ -3,7 +3,7 @@ trainer: devices: 1 precision: 32 gradient_clip_val: 10.0 - max_epochs: -1 + max_epochs: 0 accumulate_grad_batches: 128 # to adjust depending on the number of devices # Optional set wandb here @@ -12,9 +12,9 @@ trainer: # project: boltz # entity: boltz -output: SET_PATH_HERE -pretrained: PATH_TO_CHECKPOINT_FILE -resume: null +output: /storage/alvand/boltz/output +pretrained: /storage/alvand/boltz/data/boltz1.ckpt +resume: /storage/alvand/boltz/data/boltz1.ckpt disable_checkpoint: false matmul_precision: null save_top_k: -1 @@ -22,16 +22,17 @@ save_top_k: -1 data: datasets: - _target_: boltz.data.module.training.DatasetConfig - target_dir: PATH_TO_TARGETS_DIR - msa_dir: PATH_TO_MSA_DIR + target_dir: /storage/alvand/boltz/mhc_one_sample/processed_structures/ + msa_dir: /storage/alvand/boltz/mhc_one_sample/processed_msa/ prob: 1.0 + manifest_path: /storage/alvand/boltz/mhc_one_sample/manifest.json sampler: _target_: boltz.data.sample.cluster.ClusterSampler cropper: _target_: boltz.data.crop.boltz.BoltzCropper min_neighborhood: 0 max_neighborhood: 40 - split: ./scripts/train/assets/validation_ids.txt + split: /storage/alvand/boltz/mhc_one_sample/validation_ids.txt filters: - _target_: boltz.data.filter.dynamic.size.SizeFilter @@ -48,16 +49,16 @@ data: featurizer: _target_: boltz.data.feature.featurizer.BoltzFeaturizer - symmetries: PATH_TO_SYMMETRY_FILE + symmetries: /storage/alvand/boltz/data/symmetry.pkl max_tokens: 512 max_atoms: 4608 max_seqs: 2048 - pad_to_max_tokens: true - pad_to_max_atoms: true - pad_to_max_seqs: true - samples_per_epoch: 100000 + pad_to_max_tokens: false + pad_to_max_atoms: false + pad_to_max_seqs: false + samples_per_epoch: 1 batch_size: 1 - num_workers: 4 + num_workers: 1 random_seed: 42 pin_memory: true overfit: null @@ -75,7 +76,7 @@ data: compute_constraint_features: false model: - _target_: boltz.model.model.Boltz1 + _target_: boltz.model.models.boltz1.Boltz1 atom_s: 128 atom_z: 16 token_s: 384 @@ -161,8 +162,9 @@ model: recycling_steps: 3 sampling_steps: 200 diffusion_samples: 5 - symmetry_correction: true + symmetry_correction: false run_confidence_sequentially: false + val_cif_out_dir: /storage/alvand/boltz/val_cif_output diffusion_process_args: sigma_min: 0.0004 diff --git a/scripts/train/train.py b/scripts/train/train.py index f83966bdd..1cbe70dbc 100644 --- a/scripts/train/train.py +++ b/scripts/train/train.py @@ -71,7 +71,7 @@ class TrainConfig: matmul_precision: Optional[str] = None find_unused_parameters: Optional[bool] = False save_top_k: Optional[int] = 1 - validation_only: bool = False + validation_only: bool = True debug: bool = False strict_loading: bool = True load_confidence_from_trunk: Optional[bool] = False @@ -222,6 +222,7 @@ def save_config_to_wandb() -> None: model_module.strict_loading = False if cfg.validation_only: + print(f'====== Running validation on {cfg.resume}===========') trainer.validate( model_module, datamodule=data_module, diff --git a/src/boltz/data/module/training.py b/src/boltz/data/module/training.py index 36583b6cf..a450d249b 100644 --- a/src/boltz/data/module/training.py +++ b/src/boltz/data/module/training.py @@ -171,12 +171,18 @@ def collate(data: list[dict[str, Tensor]]) -> dict[str, Tensor]: "amino_acids_symmetries", "ligand_symmetries", ]: - # Check if all have the same shape - shape = values[0].shape - if not all(v.shape == shape for v in values): - values, _ = pad_to_max(values, 0) + first_value = values[0] + if isinstance(first_value, torch.Tensor): + # Check if all have the same shape + shape = first_value.shape + if not all(v.shape == shape for v in values): + values, _ = pad_to_max(values, 0) + else: + values = torch.stack(values, dim=0) else: - values = torch.stack(values, dim=0) + # Keep list for non-tensor entries (e.g. record ids) + collated[key] = values + continue # Stack the values collated[key] = values @@ -470,6 +476,11 @@ def __getitem__(self, idx: int) -> dict[str, Tensor]: print(f"Featurizer failed on {record.id} with error {e}. Skipping.") return self.__getitem__(0) + features["record_id"] = record.id + features["structure_path"] = str( + (dataset.target_dir / "structures" / f"{record.id}.npz").resolve() + ) + return features def __len__(self) -> int: diff --git a/src/boltz/data/write/writer.py b/src/boltz/data/write/writer.py index 984be2ae5..2c9a09caa 100644 --- a/src/boltz/data/write/writer.py +++ b/src/boltz/data/write/writer.py @@ -1,7 +1,9 @@ import json +import os +import gc from dataclasses import asdict, replace from pathlib import Path -from typing import Literal +from typing import Iterable, Literal, Optional import numpy as np import torch @@ -12,6 +14,7 @@ from boltz.data.types import Coords, Interface, Record, Structure, StructureV2 from boltz.data.write.mmcif import to_mmcif from boltz.data.write.pdb import to_pdb +from boltz.model.loss.validation import SampleMetrics class BoltzWriter(BasePredictionWriter): @@ -267,6 +270,117 @@ def on_predict_epoch_end( print(f"Number of failed examples: {self.failed}") # noqa: T201 +def atomic_save_cif(filepath: Path, content: str) -> bool: + """ + Safely write CIF content to disk (atomic rename + fsync). + """ + tmp_path = filepath.parent / f".{filepath.name}.tmp" + try: + filepath.parent.mkdir(parents=True, exist_ok=True) + with tmp_path.open("w") as handle: + handle.write(content) + handle.flush() + os.fsync(handle.fileno()) + tmp_path.replace(filepath) + return True + except Exception as exc: # noqa: BLE001 + print(f"Error writing file {filepath}: {exc}") + if tmp_path.exists(): + tmp_path.unlink() + return False + + +def write_validation_predictions( + out: dict[str, Tensor], + batch: dict[str, Tensor], + base_structures: list[Optional[Structure]], + record_ids: list[str], + sample_metrics: Iterable[SampleMetrics], + n_samples: int, + output_dir: Path, +) -> None: + """ + Write validation predictions to disk. + """ + output_dir.mkdir(parents=True, exist_ok=True) + + sample_metrics = list(sample_metrics) + metrics_map = {metric.sample_idx: metric for metric in sample_metrics} + total_samples = out["sample_atom_coords"].shape[0] + samples_per_structure = n_samples if n_samples > 0 else total_samples + atom_pad_mask_cpu = batch["atom_pad_mask"].detach().cpu() + + for struct_idx, record_id in enumerate(record_ids): + base_structure = base_structures[struct_idx] if struct_idx < len(base_structures) else None + if base_structure is None: + print(f"Skipping CIF write for record '{record_id}' (missing base structure).") + continue + + valid_mask = atom_pad_mask_cpu[struct_idx].to(dtype=torch.bool).numpy() + if valid_mask.sum() != base_structure.atoms.shape[0]: + print( + f"Warning: atom count mismatch for record '{record_id}': " + f"mask has {int(valid_mask.sum())} atoms, structure has {base_structure.atoms.shape[0]}." + ) + continue + + sample_block = out["sample_atom_coords"][ + struct_idx * samples_per_structure : (struct_idx + 1) * samples_per_structure + ] + + for sample_offset, coords_tensor in enumerate(sample_block): + sample_idx = struct_idx * samples_per_structure + sample_offset + metrics = metrics_map.get(sample_idx) + + try: + coords_np = coords_tensor.detach().cpu().numpy()[valid_mask] + atoms = base_structure.atoms.copy() + atoms["coords"] = coords_np.astype(np.float32) + atoms["conformer"] = coords_np.astype(np.float32) + atoms["is_present"] = True + + residues = base_structure.residues.copy() + residues["is_present"] = True + + new_structure = Structure( + atoms=atoms, + bonds=base_structure.bonds, + residues=residues, + chains=base_structure.chains, + connections=base_structure.connections, + interfaces=base_structure.interfaces, + mask=base_structure.mask, + ) + + plddts = None + if "plddt" in out: + start = struct_idx * samples_per_structure + sample_offset + plddts = out["plddt"][start : start + 1].detach().cpu() + + if metrics is not None: + filename = ( + f"prediction_{record_id}_sample_{sample_idx}" + f"_whole{metrics.rmsd_whole:.2f}_pep{metrics.rmsd_peptide:.2f}.cif" + ) + else: + filename = f"prediction_{record_id}_sample_{sample_idx}.cif" + + output_path = output_dir / filename + print(f"\nSaving prediction to {output_path}") + + cif_content = to_mmcif(new_structure, plddts=plddts) + if atomic_save_cif(output_path, cif_content): + print(f"Successfully saved prediction to {output_path}") + + except Exception as exc: # noqa: BLE001 + print(f"Error processing record '{record_id}' sample {sample_idx}: {exc}") + continue + + if torch.cuda.is_available(): + torch.cuda.empty_cache() + gc.collect() + + class BoltzAffinityWriter(BasePredictionWriter): """Custom writer for predictions.""" diff --git a/src/boltz/model/loss/validation.py b/src/boltz/model/loss/validation.py index 00d1aa7c3..20c9f0ba7 100644 --- a/src/boltz/model/loss/validation.py +++ b/src/boltz/model/loss/validation.py @@ -8,6 +8,11 @@ ) from boltz.model.loss.diffusion import weighted_rigid_align +import math +from dataclasses import dataclass +from typing import List +from torch import Tensor + def factored_lddt_loss( true_atom_coords, @@ -1023,3 +1028,120 @@ def weighted_minimum_rmsd_single( / torch.sum(align_weights * atom_mask, dim=-1) ) return rmsd, atom_coords_aligned_ground_truth, align_weights + +@dataclass +class SampleMetrics: + sample_idx: int + rmsd_whole: float + rmsd_peptide: float + + +def compute_weighted_mhc_rmsds( + out: dict, + true_coords: Tensor, + batch: dict, + peptide_mask: Tensor, + n_samples: int, + nucleotide_weight: float, + ligand_weight: float, +) -> List[SampleMetrics]: + """Compute weighted RMSDs for whole MHC chain and peptide subset. + + Parameters + ---------- + out : dict + Model outputs containing ``sample_atom_coords``. + true_coords : Tensor + Reference coordinates matching each diffusion sample. + batch : dict + Original batch features (masks, mapping, etc.). + peptide_mask : Tensor + Boolean mask selecting peptide atoms per structure. + n_samples : int + Number of diffusion samples per structure. + nucleotide_weight : float + Weight multiplier for nucleic acid atoms. + ligand_weight : float + Weight multiplier for ligand atoms. + + Returns + ------- + list[SampleMetrics] + Metrics per diffusion sample. + """ + device = out["sample_atom_coords"].device + total_samples = out["sample_atom_coords"].shape[0] + denom = max(n_samples, 1) + + metrics: List[SampleMetrics] = [] + + for sample_idx in range(total_samples): + struct_idx = sample_idx // denom + pred_sample = out["sample_atom_coords"][sample_idx : sample_idx + 1] + ref_sample = true_coords[sample_idx : sample_idx + 1] + + atom_mask_full = ( + batch["atom_resolved_mask"][struct_idx : struct_idx + 1] + .to(device=device) + .float() + ) + atom_to_token_full = ( + batch["atom_to_token"][struct_idx : struct_idx + 1] + .float() + .to(device=device) + ) + mol_type_full = batch["mol_type"][struct_idx : struct_idx + 1].to(device=device) + + try: + whole_rmsd_tensor, _, _ = weighted_minimum_rmsd_single( + pred_sample, + ref_sample, + atom_mask_full, + atom_to_token_full, + mol_type_full, + nucleotide_weight=nucleotide_weight, + ligand_weight=ligand_weight, + ) + whole_rmsd = whole_rmsd_tensor.item() + except Exception as e: # noqa: BLE001 + print(f"Weighted RMSD (MHC Chain) failed for sample {sample_idx}: {e}") + whole_rmsd = float("nan") + + peptide_mask_row = ( + peptide_mask[struct_idx : struct_idx + 1] + .to(device=device) + .float() + ) + try: + if peptide_mask_row.sum() >= 3: + peptide_rmsd_tensor, _, _ = weighted_minimum_rmsd_single( + pred_sample, + ref_sample, + peptide_mask_row, + atom_to_token_full, + mol_type_full, + nucleotide_weight=nucleotide_weight, + ligand_weight=ligand_weight, + ) + peptide_rmsd = peptide_rmsd_tensor.item() + else: + peptide_rmsd = float("nan") + except Exception as e: # noqa: BLE001 + print(f"Weighted RMSD (peptide) failed for sample {sample_idx}: {e}") + peptide_rmsd = float("nan") + + print(f"Sample {sample_idx} weighted RMSD (MHC Chain): {whole_rmsd:.3f}Å") + if math.isnan(peptide_rmsd): + print(f"Sample {sample_idx} weighted RMSD (peptide): nan") + else: + print(f"Sample {sample_idx} weighted RMSD (peptide): {peptide_rmsd:.3f}Å") + + metrics.append( + SampleMetrics( + sample_idx=sample_idx, + rmsd_whole=whole_rmsd, + rmsd_peptide=peptide_rmsd, + ) + ) + + return metrics diff --git a/src/boltz/model/models/boltz1.py b/src/boltz/model/models/boltz1.py index 51889b882..7e859e119 100644 --- a/src/boltz/model/models/boltz1.py +++ b/src/boltz/model/models/boltz1.py @@ -1,7 +1,11 @@ import gc +import math import random +from pathlib import Path from typing import Any, Optional +import numpy as np + import torch import torch._dynamo from pytorch_lightning import LightningModule @@ -9,6 +13,7 @@ from torchmetrics import MeanMetric import boltz.model.layers.initialize as init +from boltz.data.types import Connection, Structure from boltz.data import const from boltz.data.feature.symmetry import ( minimum_lddt_symmetry_coords, @@ -22,9 +27,11 @@ compute_plddt_mae, factored_lddt_loss, factored_token_lddt_dist_loss, + compute_weighted_mhc_rmsds, weighted_minimum_rmsd, ) from boltz.model.modules.confidence import ConfidenceModule +from boltz.data.write.writer import write_validation_predictions from boltz.model.modules.diffusion import AtomDiffusion from boltz.model.modules.encoders import RelativePositionEncoder from boltz.model.modules.trunk import ( @@ -276,7 +283,7 @@ def forward( num_sampling_steps: Optional[int] = None, multiplicity_diffusion_train: int = 1, diffusion_samples: int = 1, - max_parallel_samples: Optional[int] = None, + max_parallel_samples: Optional[int] = 1, run_confidence_sequentially: bool = False, ) -> dict[str, Tensor]: dict_out = {} @@ -615,6 +622,217 @@ def parameter_norm(self, module) -> float: norm = torch.tensor([p.norm(p=2) ** 2 for p in parameters]).sum().sqrt() return norm + def _build_alignment_masks( + self, + batch: dict[str, Tensor], + record_ids: list[str], + base_structures: list[Optional[Structure]], + batch_idx: int, + ) -> tuple[Tensor, Tensor, Tensor]: + batch_size, num_atoms, _ = batch["atom_to_token"].shape + device = batch["atom_pad_mask"].device + + align_base_mask = torch.zeros((batch_size, num_atoms), dtype=torch.bool, device=device) + heavy_calc_mask = torch.zeros_like(align_base_mask) + peptide_calc_mask = torch.zeros_like(align_base_mask) + backbone_names = set(const.protein_backbone_atom_names) + + for structure_idx in range(batch_size): + record_id = ( + record_ids[structure_idx] + if structure_idx < len(record_ids) + else f"batch_{batch_idx}_{structure_idx}" + ) + structure = ( + base_structures[structure_idx] + if structure_idx < len(base_structures) + else None + ) + + atom_to_token = batch["atom_to_token"][structure_idx].bool() + entity_ids = batch["entity_id"][structure_idx] + present_atom = batch["atom_pad_mask"][structure_idx].bool() + + unique_entities = torch.unique(entity_ids) + res_counts = {int(e.item()): int((entity_ids == e).sum().item()) for e in unique_entities} + if not res_counts: + print(f"Warning: no entities found for structure {structure_idx}.") + continue + + peptide_entity = min(res_counts, key=res_counts.get) + heavy_candidates = {k: v for k, v in res_counts.items() if k != peptide_entity} + heavy_entity = max(heavy_candidates, key=heavy_candidates.get) if heavy_candidates else peptide_entity + + tok_is_pep = (entity_ids == peptide_entity) + tok_is_heavy = (entity_ids == heavy_entity) + + atom_pep = atom_to_token[:, tok_is_pep].any(dim=1) + atom_heavy = atom_to_token[:, tok_is_heavy].any(dim=1) + + atom_heavy = atom_heavy & ~atom_pep + atom_pep = atom_pep & present_atom + atom_heavy = atom_heavy & present_atom + + if atom_heavy.sum() < 3 or atom_pep.sum() < 3: + print( + f"Insufficient atoms for alignment in record {record_id}" + ) + continue + + align_mask = atom_heavy | atom_pep + heavy_mask = atom_heavy.clone() + peptide_mask = atom_pep.clone() + align_mode = "all atoms" + calc_mode = "all atoms" + + if structure is not None: + valid_mask = present_atom + valid_count = int(valid_mask.sum().item()) + if valid_count == structure.atoms.shape[0]: + atoms = structure.atoms + chains = structure.chains + + def decode_name(code_row: np.ndarray) -> str: + chars = [chr(int(c) + 32) for c in code_row if c > 0] + return "".join(chars) + + atom_names = np.array([decode_name(row) for row in atoms["name"]]) + + heavy_all_base = np.zeros(structure.atoms.shape[0], dtype=bool) + peptide_all_base = np.zeros_like(heavy_all_base) + heavy_ca_base = np.zeros_like(heavy_all_base) + peptide_ca_base = np.zeros_like(heavy_all_base) + + for chain in chains: + start = int(chain["atom_idx"]) + end = start + int(chain["atom_num"]) + entity = int(chain["entity_id"]) + + if entity == heavy_entity: + heavy_all_base[start:end] = True + ca_indices = np.where(atom_names[start:end] == "CA")[0] + start + if ca_indices.size > 0: + heavy_ca_base[ca_indices] = True + if entity == peptide_entity: + peptide_all_base[start:end] = True + ca_indices = np.where(atom_names[start:end] == "CA")[0] + start + if ca_indices.size > 0: + peptide_ca_base[ca_indices] = True + + heavy_backbone_base = heavy_all_base & np.isin(atom_names, list(backbone_names)) + peptide_backbone_base = peptide_all_base & np.isin(atom_names, list(backbone_names)) + align_ca_base = heavy_ca_base | peptide_ca_base + + def pad_mask(base_mask: np.ndarray) -> torch.Tensor: + mask_full = torch.from_numpy(base_mask).to(device=device, dtype=torch.bool) + mask_padded = torch.zeros_like(align_mask) + mask_padded[valid_mask] = mask_full + return mask_padded + + if align_ca_base.sum() >= 3: + align_mask = pad_mask(align_ca_base) + align_mode = "CA" + + if heavy_ca_base.sum() >= 3: + heavy_mask = pad_mask(heavy_ca_base) + elif heavy_backbone_base.sum() >= 3: + heavy_mask = pad_mask(heavy_backbone_base) + else: + heavy_mask = pad_mask(heavy_all_base) + + if peptide_ca_base.sum() >= 3: + peptide_mask = pad_mask(peptide_ca_base) + calc_mode = "CA" + elif peptide_backbone_base.sum() >= 3: + peptide_mask = pad_mask(peptide_backbone_base) + calc_mode = "backbone" + else: + peptide_mask = pad_mask(peptide_all_base) + else: + print( + f"Warning: atom count mismatch for record '{record_id}': " + f"mask has {valid_count} atoms, structure has {structure.atoms.shape[0]}." + ) + + align_base_mask[structure_idx] = align_mask + heavy_calc_mask[structure_idx] = heavy_mask + peptide_calc_mask[structure_idx] = peptide_mask + print( + f"[debug] structure {structure_idx} entities: {res_counts} | heavy={heavy_entity} peptide={peptide_entity} | " + f"align_mode={align_mode} calc_mode={calc_mode}" + ) + print( + f"Align atoms: {int(align_mask.sum().item())} | Calc atoms: {int(peptide_mask.sum().item())}" + ) + + return align_base_mask, heavy_calc_mask, peptide_calc_mask + + def _extract_record_ids(self, batch: dict[str, Tensor], batch_idx: int) -> list[str]: + record_ids = batch.get("record_id", None) + if record_ids is None: + return [f"batch_{batch_idx}_{idx}" for idx in range(batch["atom_to_token"].shape[0])] + if isinstance(record_ids, torch.Tensor): + return [str(r.item()) for r in record_ids] + return [str(r) for r in record_ids] + + @staticmethod + def _extract_structure_paths(batch: dict[str, Tensor], count: int) -> list[Optional[str]]: + structure_paths = batch.get("structure_path", None) + if structure_paths is None: + return [None] * count + return [str(path) if path is not None else None for path in structure_paths] + + def _load_structure_from_npz(self, path: Path) -> Optional[Structure]: + try: + data = np.load(path) + except FileNotFoundError: + print(f"Warning: structure file not found at '{path}'.") + return None + except Exception as e: + print(f"Warning: failed to read structure file '{path}': {e}") + return None + + chains = data["chains"] + if "cyclic_period" not in chains.dtype.names: + new_dtype = chains.dtype.descr + [("cyclic_period", "i4")] + new_chains = np.empty(chains.shape, dtype=new_dtype) + for name in chains.dtype.names: + new_chains[name] = chains[name] + new_chains["cyclic_period"] = 0 + chains = new_chains + + try: + structure = Structure( + atoms=data["atoms"], + bonds=data["bonds"], + residues=data["residues"], + chains=chains, + connections=data["connections"].astype(Connection), + interfaces=data["interfaces"], + mask=data["mask"], + ) + except Exception as e: + print(f"Warning: failed to construct Structure from '{path}': {e}") + return None + + try: + return structure.remove_invalid_chains() + except Exception as e: + print(f"Warning: failed to clean structure from '{path}': {e}") + return None + + def _load_structures(self, structure_paths: list[Optional[str]]) -> list[Optional[Structure]]: + structures: list[Optional[Structure]] = [] + for path_str in structure_paths: + if path_str is None: + structures.append(None) + continue + structure = self._load_structure_from_npz(Path(path_str)) + if structure is None: + print(f"Skipping CIF write for record path '{path_str}' (structure load failed).") + structures.append(structure) + return structures + def validation_step(self, batch: dict[str, Tensor], batch_idx: int): # Compute the forward pass n_samples = self.validation_args.diffusion_samples @@ -626,8 +844,7 @@ def validation_step(self, batch: dict[str, Tensor], batch_idx: int): diffusion_samples=n_samples, run_confidence_sequentially=self.validation_args.run_confidence_sequentially, ) - - except RuntimeError as e: # catch out of memory exceptions + except RuntimeError as e: if "out of memory" in str(e): print("| WARNING: ran out of memory, skipping batch") torch.cuda.empty_cache() @@ -637,51 +854,62 @@ def validation_step(self, batch: dict[str, Tensor], batch_idx: int): raise e try: - # Compute distogram LDDT - boundaries = torch.linspace(2, 22.0, 63) - lower = torch.tensor([1.0]) - upper = torch.tensor([22.0 + 5.0]) - exp_boundaries = torch.cat((lower, boundaries, upper)) - mid_points = ((exp_boundaries[:-1] + exp_boundaries[1:]) / 2).to( - out["pdistogram"] + true_coords, _, _, _ = self.get_true_coordinates( + batch=batch, + out=out, + diffusion_samples=n_samples, + symmetry_correction=self.validation_args.symmetry_correction, ) - # Compute predicted dists - preds = out["pdistogram"] - pred_softmax = torch.softmax(preds, dim=-1) - pred_softmax = pred_softmax.argmax(dim=-1) - pred_softmax = torch.nn.functional.one_hot( - pred_softmax, num_classes=preds.shape[-1] - ) - pred_dist = (pred_softmax * mid_points).sum(dim=-1) - true_center = batch["disto_center"] - true_dists = torch.cdist(true_center, true_center) - - # Compute lddt's - batch["token_disto_mask"] = batch["token_disto_mask"] - disto_lddt_dict, disto_total_dict = factored_token_lddt_dist_loss( - feats=batch, - true_d=true_dists, - pred_d=pred_dist, + + record_ids = self._extract_record_ids(batch, batch_idx) + + structure_paths = self._extract_structure_paths(batch, len(record_ids)) + + base_structures = self._load_structures(structure_paths) + + batch_size = batch['atom_to_token'].shape[0] + _, _, peptide_calc_mask = self._build_alignment_masks( + batch=batch, + record_ids=record_ids, + base_structures=base_structures, + batch_idx=batch_idx, ) - true_coords, rmsds, best_rmsds, true_coords_resolved_mask = ( - self.get_true_coordinates( - batch=batch, - out=out, - diffusion_samples=n_samples, - symmetry_correction=self.validation_args.symmetry_correction, - ) + sample_metrics = compute_weighted_mhc_rmsds( + out=out, + true_coords=true_coords, + batch=batch, + peptide_mask=peptide_calc_mask, + n_samples=n_samples, + nucleotide_weight=self.nucleotide_rmsd_weight, + ligand_weight=self.ligand_rmsd_weight, ) - all_lddt_dict, all_total_dict = factored_lddt_loss( - feats=batch, - atom_mask=true_coords_resolved_mask, - true_atom_coords=true_coords, - pred_atom_coords=out["sample_atom_coords"], - multiplicity=n_samples, + whole_values = [m.rmsd_whole for m in sample_metrics if not math.isnan(m.rmsd_whole)] + peptide_values = [m.rmsd_peptide for m in sample_metrics if not math.isnan(m.rmsd_peptide)] + + if whole_values: + whole_mean = torch.tensor(whole_values, device=out["sample_atom_coords"].device, dtype=torch.float32).mean() + self.log("val/weighted_rmsd_whole", whole_mean, prog_bar=False, sync_dist=True, batch_size=batch_size) + if peptide_values: + peptide_mean = torch.tensor(peptide_values, device=out["sample_atom_coords"].device, dtype=torch.float32).mean() + self.log("val/weighted_rmsd_peptide", peptide_mean, prog_bar=False, sync_dist=True, batch_size=batch_size) + + + output_dir = Path(getattr(self.validation_args, "val_cif_out_dir", "validation_outputs")) + write_validation_predictions( + out=out, + batch=batch, + base_structures=base_structures, + record_ids=record_ids, + sample_metrics=sample_metrics, + n_samples=n_samples, + output_dir=output_dir, ) - except RuntimeError as e: # catch out of memory exceptions + + + except RuntimeError as e: if "out of memory" in str(e): print("| WARNING: ran out of memory, skipping batch") torch.cuda.empty_cache() @@ -689,193 +917,6 @@ def validation_step(self, batch: dict[str, Tensor], batch_idx: int): return else: raise e - # if the multiplicity used is > 1 then we take the best lddt of the different samples - # AF3 combines this with the confidence based filtering - best_lddt_dict, best_total_dict = {}, {} - best_complex_lddt_dict, best_complex_total_dict = {}, {} - B = true_coords.shape[0] // n_samples - if n_samples > 1: - # NOTE: we can change the way we aggregate the lddt - complex_total = 0 - complex_lddt = 0 - for key in all_lddt_dict.keys(): - complex_lddt += all_lddt_dict[key] * all_total_dict[key] - complex_total += all_total_dict[key] - complex_lddt /= complex_total + 1e-7 - best_complex_idx = complex_lddt.reshape(-1, n_samples).argmax(dim=1) - for key in all_lddt_dict: - best_idx = all_lddt_dict[key].reshape(-1, n_samples).argmax(dim=1) - best_lddt_dict[key] = all_lddt_dict[key].reshape(-1, n_samples)[ - torch.arange(B), best_idx - ] - best_total_dict[key] = all_total_dict[key].reshape(-1, n_samples)[ - torch.arange(B), best_idx - ] - best_complex_lddt_dict[key] = all_lddt_dict[key].reshape(-1, n_samples)[ - torch.arange(B), best_complex_idx - ] - best_complex_total_dict[key] = all_total_dict[key].reshape( - -1, n_samples - )[torch.arange(B), best_complex_idx] - else: - best_lddt_dict = all_lddt_dict - best_total_dict = all_total_dict - best_complex_lddt_dict = all_lddt_dict - best_complex_total_dict = all_total_dict - - # Filtering based on confidence - if self.confidence_prediction and n_samples > 1: - # note: for now we don't have pae predictions so have to use pLDDT instead of pTM - # also, while AF3 differentiates the best prediction per confidence type we are currently not doing it - # consider this in the future as well as weighing the different pLLDT types before aggregation - mae_plddt_dict, total_mae_plddt_dict = compute_plddt_mae( - pred_atom_coords=out["sample_atom_coords"], - feats=batch, - true_atom_coords=true_coords, - pred_lddt=out["plddt"], - true_coords_resolved_mask=true_coords_resolved_mask, - multiplicity=n_samples, - ) - mae_pde_dict, total_mae_pde_dict = compute_pde_mae( - pred_atom_coords=out["sample_atom_coords"], - feats=batch, - true_atom_coords=true_coords, - pred_pde=out["pde"], - true_coords_resolved_mask=true_coords_resolved_mask, - multiplicity=n_samples, - ) - mae_pae_dict, total_mae_pae_dict = compute_pae_mae( - pred_atom_coords=out["sample_atom_coords"], - feats=batch, - true_atom_coords=true_coords, - pred_pae=out["pae"], - true_coords_resolved_mask=true_coords_resolved_mask, - multiplicity=n_samples, - ) - - plddt = out["complex_plddt"].reshape(-1, n_samples) - top1_idx = plddt.argmax(dim=1) - iplddt = out["complex_iplddt"].reshape(-1, n_samples) - iplddt_top1_idx = iplddt.argmax(dim=1) - pde = out["complex_pde"].reshape(-1, n_samples) - pde_top1_idx = pde.argmin(dim=1) - ipde = out["complex_ipde"].reshape(-1, n_samples) - ipde_top1_idx = ipde.argmin(dim=1) - ptm = out["ptm"].reshape(-1, n_samples) - ptm_top1_idx = ptm.argmax(dim=1) - iptm = out["iptm"].reshape(-1, n_samples) - iptm_top1_idx = iptm.argmax(dim=1) - ligand_iptm = out["ligand_iptm"].reshape(-1, n_samples) - ligand_iptm_top1_idx = ligand_iptm.argmax(dim=1) - protein_iptm = out["protein_iptm"].reshape(-1, n_samples) - protein_iptm_top1_idx = protein_iptm.argmax(dim=1) - - for key in all_lddt_dict: - top1_lddt = all_lddt_dict[key].reshape(-1, n_samples)[ - torch.arange(B), top1_idx - ] - top1_total = all_total_dict[key].reshape(-1, n_samples)[ - torch.arange(B), top1_idx - ] - iplddt_top1_lddt = all_lddt_dict[key].reshape(-1, n_samples)[ - torch.arange(B), iplddt_top1_idx - ] - iplddt_top1_total = all_total_dict[key].reshape(-1, n_samples)[ - torch.arange(B), iplddt_top1_idx - ] - pde_top1_lddt = all_lddt_dict[key].reshape(-1, n_samples)[ - torch.arange(B), pde_top1_idx - ] - pde_top1_total = all_total_dict[key].reshape(-1, n_samples)[ - torch.arange(B), pde_top1_idx - ] - ipde_top1_lddt = all_lddt_dict[key].reshape(-1, n_samples)[ - torch.arange(B), ipde_top1_idx - ] - ipde_top1_total = all_total_dict[key].reshape(-1, n_samples)[ - torch.arange(B), ipde_top1_idx - ] - ptm_top1_lddt = all_lddt_dict[key].reshape(-1, n_samples)[ - torch.arange(B), ptm_top1_idx - ] - ptm_top1_total = all_total_dict[key].reshape(-1, n_samples)[ - torch.arange(B), ptm_top1_idx - ] - iptm_top1_lddt = all_lddt_dict[key].reshape(-1, n_samples)[ - torch.arange(B), iptm_top1_idx - ] - iptm_top1_total = all_total_dict[key].reshape(-1, n_samples)[ - torch.arange(B), iptm_top1_idx - ] - ligand_iptm_top1_lddt = all_lddt_dict[key].reshape(-1, n_samples)[ - torch.arange(B), ligand_iptm_top1_idx - ] - ligand_iptm_top1_total = all_total_dict[key].reshape(-1, n_samples)[ - torch.arange(B), ligand_iptm_top1_idx - ] - protein_iptm_top1_lddt = all_lddt_dict[key].reshape(-1, n_samples)[ - torch.arange(B), protein_iptm_top1_idx - ] - protein_iptm_top1_total = all_total_dict[key].reshape(-1, n_samples)[ - torch.arange(B), protein_iptm_top1_idx - ] - - self.top1_lddt[key].update(top1_lddt, top1_total) - self.iplddt_top1_lddt[key].update(iplddt_top1_lddt, iplddt_top1_total) - self.pde_top1_lddt[key].update(pde_top1_lddt, pde_top1_total) - self.ipde_top1_lddt[key].update(ipde_top1_lddt, ipde_top1_total) - self.ptm_top1_lddt[key].update(ptm_top1_lddt, ptm_top1_total) - self.iptm_top1_lddt[key].update(iptm_top1_lddt, iptm_top1_total) - self.ligand_iptm_top1_lddt[key].update( - ligand_iptm_top1_lddt, ligand_iptm_top1_total - ) - self.protein_iptm_top1_lddt[key].update( - protein_iptm_top1_lddt, protein_iptm_top1_total - ) - - self.avg_lddt[key].update(all_lddt_dict[key], all_total_dict[key]) - self.pde_mae[key].update(mae_pde_dict[key], total_mae_pde_dict[key]) - self.pae_mae[key].update(mae_pae_dict[key], total_mae_pae_dict[key]) - - for key in mae_plddt_dict: - self.plddt_mae[key].update( - mae_plddt_dict[key], total_mae_plddt_dict[key] - ) - - for m in const.out_types: - if m == "ligand_protein": - if torch.any( - batch["pocket_feature"][ - :, :, const.pocket_contact_info["POCKET"] - ].bool() - ): - self.lddt["pocket_ligand_protein"].update( - best_lddt_dict[m], best_total_dict[m] - ) - self.disto_lddt["pocket_ligand_protein"].update( - disto_lddt_dict[m], disto_total_dict[m] - ) - self.complex_lddt["pocket_ligand_protein"].update( - best_complex_lddt_dict[m], best_complex_total_dict[m] - ) - else: - self.lddt["ligand_protein"].update( - best_lddt_dict[m], best_total_dict[m] - ) - self.disto_lddt["ligand_protein"].update( - disto_lddt_dict[m], disto_total_dict[m] - ) - self.complex_lddt["ligand_protein"].update( - best_complex_lddt_dict[m], best_complex_total_dict[m] - ) - else: - self.lddt[m].update(best_lddt_dict[m], best_total_dict[m]) - self.disto_lddt[m].update(disto_lddt_dict[m], disto_total_dict[m]) - self.complex_lddt[m].update( - best_complex_lddt_dict[m], best_complex_total_dict[m] - ) - self.rmsd.update(rmsds) - self.best_rmsd.update(best_rmsds) def on_validation_epoch_end(self): avg_lddt = {} From fdde7eb4ca5e39f513b7666e7869caa0f23611f2 Mon Sep 17 00:00:00 2001 From: AlvandVahedi Date: Thu, 6 Nov 2025 04:57:31 +0000 Subject: [PATCH 3/5] Log peptide-mask RMSD during validation using aligned true coordinates and no CIF saving at validation step --- src/boltz/data/write/writer.py | 91 -------------------------- src/boltz/model/loss/validation.py | 4 ++ src/boltz/model/models/boltz1.py | 100 +++++++++++++++++++++++++---- 3 files changed, 92 insertions(+), 103 deletions(-) diff --git a/src/boltz/data/write/writer.py b/src/boltz/data/write/writer.py index 2c9a09caa..ac4fedfca 100644 --- a/src/boltz/data/write/writer.py +++ b/src/boltz/data/write/writer.py @@ -290,97 +290,6 @@ def atomic_save_cif(filepath: Path, content: str) -> bool: return False -def write_validation_predictions( - out: dict[str, Tensor], - batch: dict[str, Tensor], - base_structures: list[Optional[Structure]], - record_ids: list[str], - sample_metrics: Iterable[SampleMetrics], - n_samples: int, - output_dir: Path, -) -> None: - """ - Write validation predictions to disk. - """ - output_dir.mkdir(parents=True, exist_ok=True) - - sample_metrics = list(sample_metrics) - metrics_map = {metric.sample_idx: metric for metric in sample_metrics} - total_samples = out["sample_atom_coords"].shape[0] - samples_per_structure = n_samples if n_samples > 0 else total_samples - atom_pad_mask_cpu = batch["atom_pad_mask"].detach().cpu() - - for struct_idx, record_id in enumerate(record_ids): - base_structure = base_structures[struct_idx] if struct_idx < len(base_structures) else None - if base_structure is None: - print(f"Skipping CIF write for record '{record_id}' (missing base structure).") - continue - - valid_mask = atom_pad_mask_cpu[struct_idx].to(dtype=torch.bool).numpy() - if valid_mask.sum() != base_structure.atoms.shape[0]: - print( - f"Warning: atom count mismatch for record '{record_id}': " - f"mask has {int(valid_mask.sum())} atoms, structure has {base_structure.atoms.shape[0]}." - ) - continue - - sample_block = out["sample_atom_coords"][ - struct_idx * samples_per_structure : (struct_idx + 1) * samples_per_structure - ] - - for sample_offset, coords_tensor in enumerate(sample_block): - sample_idx = struct_idx * samples_per_structure + sample_offset - metrics = metrics_map.get(sample_idx) - - try: - coords_np = coords_tensor.detach().cpu().numpy()[valid_mask] - atoms = base_structure.atoms.copy() - atoms["coords"] = coords_np.astype(np.float32) - atoms["conformer"] = coords_np.astype(np.float32) - atoms["is_present"] = True - - residues = base_structure.residues.copy() - residues["is_present"] = True - - new_structure = Structure( - atoms=atoms, - bonds=base_structure.bonds, - residues=residues, - chains=base_structure.chains, - connections=base_structure.connections, - interfaces=base_structure.interfaces, - mask=base_structure.mask, - ) - - plddts = None - if "plddt" in out: - start = struct_idx * samples_per_structure + sample_offset - plddts = out["plddt"][start : start + 1].detach().cpu() - - if metrics is not None: - filename = ( - f"prediction_{record_id}_sample_{sample_idx}" - f"_whole{metrics.rmsd_whole:.2f}_pep{metrics.rmsd_peptide:.2f}.cif" - ) - else: - filename = f"prediction_{record_id}_sample_{sample_idx}.cif" - - output_path = output_dir / filename - print(f"\nSaving prediction to {output_path}") - - cif_content = to_mmcif(new_structure, plddts=plddts) - if atomic_save_cif(output_path, cif_content): - print(f"Successfully saved prediction to {output_path}") - - except Exception as exc: # noqa: BLE001 - print(f"Error processing record '{record_id}' sample {sample_idx}: {exc}") - continue - - if torch.cuda.is_available(): - torch.cuda.empty_cache() - gc.collect() - - class BoltzAffinityWriter(BasePredictionWriter): """Custom writer for predictions.""" diff --git a/src/boltz/model/loss/validation.py b/src/boltz/model/loss/validation.py index 20c9f0ba7..476d194f1 100644 --- a/src/boltz/model/loss/validation.py +++ b/src/boltz/model/loss/validation.py @@ -900,6 +900,7 @@ def weighted_minimum_rmsd( multiplicity=1, nucleotide_weight=5.0, ligand_weight=10.0, + return_aligned_coords: bool = False, ): """Compute rmsd of the aligned atom coordinates. @@ -961,6 +962,9 @@ def weighted_minimum_rmsd( ) best_rmsd = torch.min(rmsd.reshape(-1, multiplicity), dim=1).values + if return_aligned_coords: + return rmsd, best_rmsd, atom_coords_aligned_ground_truth + return rmsd, best_rmsd diff --git a/src/boltz/model/models/boltz1.py b/src/boltz/model/models/boltz1.py index 7e859e119..5ecd20be1 100644 --- a/src/boltz/model/models/boltz1.py +++ b/src/boltz/model/models/boltz1.py @@ -31,7 +31,6 @@ weighted_minimum_rmsd, ) from boltz.model.modules.confidence import ConfidenceModule -from boltz.data.write.writer import write_validation_predictions from boltz.model.modules.diffusion import AtomDiffusion from boltz.model.modules.encoders import RelativePositionEncoder from boltz.model.modules.trunk import ( @@ -126,6 +125,10 @@ def __init__( # noqa: PLR0915, C901, PLR0912 self.plddt_mae[m] = MeanMetric() self.rmsd = MeanMetric() self.best_rmsd = MeanMetric() + self.mask_rmsd_metric = MeanMetric() + self.best_mask_rmsd_metric = MeanMetric() + self.latest_mask_rmsd = None + self.latest_best_mask_rmsd = None self.train_confidence_loss_logger = MeanMetric() self.train_confidence_loss_dict_logger = nn.ModuleDict() @@ -413,6 +416,7 @@ def get_true_coordinates( diffusion_samples, symmetry_correction, lddt_minimization=True, + return_aligned_coords: bool = False, ): if symmetry_correction: min_coords_routine = ( @@ -423,6 +427,7 @@ def get_true_coordinates( true_coords = [] true_coords_resolved_mask = [] rmsds, best_rmsds = [], [] + aligned_coords_list = [] for idx in range(batch["token_index"].shape[0]): best_rmsd = float("inf") for rep in range(diffusion_samples): @@ -439,11 +444,18 @@ def get_true_coordinates( rmsds.append(rmsd) true_coords.append(best_true_coords) true_coords_resolved_mask.append(best_true_coords_resolved_mask) + if return_aligned_coords: + aligned_coords_list.append(best_true_coords) if rmsd < best_rmsd: best_rmsd = rmsd best_rmsds.append(best_rmsd) true_coords = torch.cat(true_coords, dim=0) true_coords_resolved_mask = torch.cat(true_coords_resolved_mask, dim=0) + aligned_true_coords = ( + torch.cat(aligned_coords_list, dim=0) + if return_aligned_coords + else None + ) else: true_coords = ( batch["coords"].squeeze(1).repeat_interleave(diffusion_samples, 0) @@ -452,14 +464,28 @@ def get_true_coordinates( true_coords_resolved_mask = batch["atom_resolved_mask"].repeat_interleave( diffusion_samples, 0 ) - rmsds, best_rmsds = weighted_minimum_rmsd( + result = weighted_minimum_rmsd( out["sample_atom_coords"], batch, multiplicity=diffusion_samples, nucleotide_weight=self.nucleotide_rmsd_weight, ligand_weight=self.ligand_rmsd_weight, + return_aligned_coords=return_aligned_coords, ) + if return_aligned_coords: + rmsds, best_rmsds, aligned_true_coords = result + else: + rmsds, best_rmsds = result + aligned_true_coords = None + if return_aligned_coords: + return ( + true_coords, + rmsds, + best_rmsds, + true_coords_resolved_mask, + aligned_true_coords, + ) return true_coords, rmsds, best_rmsds, true_coords_resolved_mask def training_step(self, batch: dict[str, Tensor], batch_idx: int) -> Tensor: @@ -854,11 +880,12 @@ def validation_step(self, batch: dict[str, Tensor], batch_idx: int): raise e try: - true_coords, _, _, _ = self.get_true_coordinates( + true_coords, _, _, _, aligned_true_coords = self.get_true_coordinates( batch=batch, out=out, diffusion_samples=n_samples, symmetry_correction=self.validation_args.symmetry_correction, + return_aligned_coords=True, ) @@ -897,16 +924,49 @@ def validation_step(self, batch: dict[str, Tensor], batch_idx: int): self.log("val/weighted_rmsd_peptide", peptide_mean, prog_bar=False, sync_dist=True, batch_size=batch_size) - output_dir = Path(getattr(self.validation_args, "val_cif_out_dir", "validation_outputs")) - write_validation_predictions( - out=out, - batch=batch, - base_structures=base_structures, - record_ids=record_ids, - sample_metrics=sample_metrics, - n_samples=n_samples, - output_dir=output_dir, + device = out["sample_atom_coords"].device + total_samples = out["sample_atom_coords"].shape[0] + samples_per_structure = n_samples if n_samples > 0 else total_samples + if total_samples != len(record_ids) * samples_per_structure: + raise ValueError("Mismatch between diffusion samples and batch structures.") + + peptide_mask = peptide_calc_mask.to(device=device) + peptide_mask = peptide_mask.repeat_interleave(samples_per_structure, dim=0) + aligned_true_coords = aligned_true_coords.to(device) + + diff = out["sample_atom_coords"] - aligned_true_coords + mask_float = peptide_mask.float() + mask_counts = mask_float.sum(dim=1) + valid_samples = mask_counts > 0 + + mask_rmsd = torch.full( + (total_samples,), float("nan"), device=device, dtype=diff.dtype ) + if valid_samples.any(): + sq = (diff[valid_samples] ** 2).sum(dim=-1) + mask_rmsd[valid_samples] = torch.sqrt( + (sq * mask_float[valid_samples]).sum(dim=1) / mask_counts[valid_samples] + ) + + mask_rmsd_valid = mask_rmsd[valid_samples] + if mask_rmsd_valid.numel() > 0: + self.mask_rmsd_metric.update(mask_rmsd_valid.detach().cpu()) + + best_mask_rmsd = torch.full( + (len(record_ids),), float("nan"), device=device, dtype=diff.dtype + ) + mask_rmsd_matrix = mask_rmsd.view(len(record_ids), samples_per_structure) + valid_matrix = valid_samples.view(len(record_ids), samples_per_structure) + for idx in range(len(record_ids)): + if valid_matrix[idx].any(): + best_mask_rmsd[idx] = mask_rmsd_matrix[idx][valid_matrix[idx]].min() + + best_mask_valid = best_mask_rmsd[torch.isfinite(best_mask_rmsd)] + if best_mask_valid.numel() > 0: + self.best_mask_rmsd_metric.update(best_mask_valid.detach().cpu()) + + self.latest_mask_rmsd = mask_rmsd.detach().cpu() + self.latest_best_mask_rmsd = best_mask_rmsd.detach().cpu() except RuntimeError as e: @@ -922,6 +982,22 @@ def on_validation_epoch_end(self): avg_lddt = {} avg_disto_lddt = {} avg_complex_lddt = {} + + mask_rmsd_val = self.mask_rmsd_metric.compute() + if torch.isfinite(mask_rmsd_val).item(): + self.log("val/mask_rmsd", mask_rmsd_val, prog_bar=False, sync_dist=True) + self.mask_rmsd_metric.reset() + + best_mask_rmsd_val = self.best_mask_rmsd_metric.compute() + if torch.isfinite(best_mask_rmsd_val).item(): + self.log( + "val/best_mask_rmsd", + best_mask_rmsd_val, + prog_bar=False, + sync_dist=True, + ) + self.best_mask_rmsd_metric.reset() + if self.confidence_prediction: avg_top1_lddt = {} avg_iplddt_top1_lddt = {} From 462d9d3d33e23935fb6b855c2fcc3e0bc679b8e0 Mon Sep 17 00:00:00 2001 From: AlvandVahedi Date: Thu, 6 Nov 2025 05:30:50 +0000 Subject: [PATCH 4/5] Revert to b015baa baseline for validation writer --- src/boltz/data/write/writer.py | 91 ++++++++++++++++++++++++++ src/boltz/model/loss/validation.py | 4 -- src/boltz/model/models/boltz1.py | 100 ++++------------------------- 3 files changed, 103 insertions(+), 92 deletions(-) diff --git a/src/boltz/data/write/writer.py b/src/boltz/data/write/writer.py index ac4fedfca..2c9a09caa 100644 --- a/src/boltz/data/write/writer.py +++ b/src/boltz/data/write/writer.py @@ -290,6 +290,97 @@ def atomic_save_cif(filepath: Path, content: str) -> bool: return False +def write_validation_predictions( + out: dict[str, Tensor], + batch: dict[str, Tensor], + base_structures: list[Optional[Structure]], + record_ids: list[str], + sample_metrics: Iterable[SampleMetrics], + n_samples: int, + output_dir: Path, +) -> None: + """ + Write validation predictions to disk. + """ + output_dir.mkdir(parents=True, exist_ok=True) + + sample_metrics = list(sample_metrics) + metrics_map = {metric.sample_idx: metric for metric in sample_metrics} + total_samples = out["sample_atom_coords"].shape[0] + samples_per_structure = n_samples if n_samples > 0 else total_samples + atom_pad_mask_cpu = batch["atom_pad_mask"].detach().cpu() + + for struct_idx, record_id in enumerate(record_ids): + base_structure = base_structures[struct_idx] if struct_idx < len(base_structures) else None + if base_structure is None: + print(f"Skipping CIF write for record '{record_id}' (missing base structure).") + continue + + valid_mask = atom_pad_mask_cpu[struct_idx].to(dtype=torch.bool).numpy() + if valid_mask.sum() != base_structure.atoms.shape[0]: + print( + f"Warning: atom count mismatch for record '{record_id}': " + f"mask has {int(valid_mask.sum())} atoms, structure has {base_structure.atoms.shape[0]}." + ) + continue + + sample_block = out["sample_atom_coords"][ + struct_idx * samples_per_structure : (struct_idx + 1) * samples_per_structure + ] + + for sample_offset, coords_tensor in enumerate(sample_block): + sample_idx = struct_idx * samples_per_structure + sample_offset + metrics = metrics_map.get(sample_idx) + + try: + coords_np = coords_tensor.detach().cpu().numpy()[valid_mask] + atoms = base_structure.atoms.copy() + atoms["coords"] = coords_np.astype(np.float32) + atoms["conformer"] = coords_np.astype(np.float32) + atoms["is_present"] = True + + residues = base_structure.residues.copy() + residues["is_present"] = True + + new_structure = Structure( + atoms=atoms, + bonds=base_structure.bonds, + residues=residues, + chains=base_structure.chains, + connections=base_structure.connections, + interfaces=base_structure.interfaces, + mask=base_structure.mask, + ) + + plddts = None + if "plddt" in out: + start = struct_idx * samples_per_structure + sample_offset + plddts = out["plddt"][start : start + 1].detach().cpu() + + if metrics is not None: + filename = ( + f"prediction_{record_id}_sample_{sample_idx}" + f"_whole{metrics.rmsd_whole:.2f}_pep{metrics.rmsd_peptide:.2f}.cif" + ) + else: + filename = f"prediction_{record_id}_sample_{sample_idx}.cif" + + output_path = output_dir / filename + print(f"\nSaving prediction to {output_path}") + + cif_content = to_mmcif(new_structure, plddts=plddts) + if atomic_save_cif(output_path, cif_content): + print(f"Successfully saved prediction to {output_path}") + + except Exception as exc: # noqa: BLE001 + print(f"Error processing record '{record_id}' sample {sample_idx}: {exc}") + continue + + if torch.cuda.is_available(): + torch.cuda.empty_cache() + gc.collect() + + class BoltzAffinityWriter(BasePredictionWriter): """Custom writer for predictions.""" diff --git a/src/boltz/model/loss/validation.py b/src/boltz/model/loss/validation.py index 476d194f1..20c9f0ba7 100644 --- a/src/boltz/model/loss/validation.py +++ b/src/boltz/model/loss/validation.py @@ -900,7 +900,6 @@ def weighted_minimum_rmsd( multiplicity=1, nucleotide_weight=5.0, ligand_weight=10.0, - return_aligned_coords: bool = False, ): """Compute rmsd of the aligned atom coordinates. @@ -962,9 +961,6 @@ def weighted_minimum_rmsd( ) best_rmsd = torch.min(rmsd.reshape(-1, multiplicity), dim=1).values - if return_aligned_coords: - return rmsd, best_rmsd, atom_coords_aligned_ground_truth - return rmsd, best_rmsd diff --git a/src/boltz/model/models/boltz1.py b/src/boltz/model/models/boltz1.py index 5ecd20be1..7e859e119 100644 --- a/src/boltz/model/models/boltz1.py +++ b/src/boltz/model/models/boltz1.py @@ -31,6 +31,7 @@ weighted_minimum_rmsd, ) from boltz.model.modules.confidence import ConfidenceModule +from boltz.data.write.writer import write_validation_predictions from boltz.model.modules.diffusion import AtomDiffusion from boltz.model.modules.encoders import RelativePositionEncoder from boltz.model.modules.trunk import ( @@ -125,10 +126,6 @@ def __init__( # noqa: PLR0915, C901, PLR0912 self.plddt_mae[m] = MeanMetric() self.rmsd = MeanMetric() self.best_rmsd = MeanMetric() - self.mask_rmsd_metric = MeanMetric() - self.best_mask_rmsd_metric = MeanMetric() - self.latest_mask_rmsd = None - self.latest_best_mask_rmsd = None self.train_confidence_loss_logger = MeanMetric() self.train_confidence_loss_dict_logger = nn.ModuleDict() @@ -416,7 +413,6 @@ def get_true_coordinates( diffusion_samples, symmetry_correction, lddt_minimization=True, - return_aligned_coords: bool = False, ): if symmetry_correction: min_coords_routine = ( @@ -427,7 +423,6 @@ def get_true_coordinates( true_coords = [] true_coords_resolved_mask = [] rmsds, best_rmsds = [], [] - aligned_coords_list = [] for idx in range(batch["token_index"].shape[0]): best_rmsd = float("inf") for rep in range(diffusion_samples): @@ -444,18 +439,11 @@ def get_true_coordinates( rmsds.append(rmsd) true_coords.append(best_true_coords) true_coords_resolved_mask.append(best_true_coords_resolved_mask) - if return_aligned_coords: - aligned_coords_list.append(best_true_coords) if rmsd < best_rmsd: best_rmsd = rmsd best_rmsds.append(best_rmsd) true_coords = torch.cat(true_coords, dim=0) true_coords_resolved_mask = torch.cat(true_coords_resolved_mask, dim=0) - aligned_true_coords = ( - torch.cat(aligned_coords_list, dim=0) - if return_aligned_coords - else None - ) else: true_coords = ( batch["coords"].squeeze(1).repeat_interleave(diffusion_samples, 0) @@ -464,28 +452,14 @@ def get_true_coordinates( true_coords_resolved_mask = batch["atom_resolved_mask"].repeat_interleave( diffusion_samples, 0 ) - result = weighted_minimum_rmsd( + rmsds, best_rmsds = weighted_minimum_rmsd( out["sample_atom_coords"], batch, multiplicity=diffusion_samples, nucleotide_weight=self.nucleotide_rmsd_weight, ligand_weight=self.ligand_rmsd_weight, - return_aligned_coords=return_aligned_coords, ) - if return_aligned_coords: - rmsds, best_rmsds, aligned_true_coords = result - else: - rmsds, best_rmsds = result - aligned_true_coords = None - if return_aligned_coords: - return ( - true_coords, - rmsds, - best_rmsds, - true_coords_resolved_mask, - aligned_true_coords, - ) return true_coords, rmsds, best_rmsds, true_coords_resolved_mask def training_step(self, batch: dict[str, Tensor], batch_idx: int) -> Tensor: @@ -880,12 +854,11 @@ def validation_step(self, batch: dict[str, Tensor], batch_idx: int): raise e try: - true_coords, _, _, _, aligned_true_coords = self.get_true_coordinates( + true_coords, _, _, _ = self.get_true_coordinates( batch=batch, out=out, diffusion_samples=n_samples, symmetry_correction=self.validation_args.symmetry_correction, - return_aligned_coords=True, ) @@ -924,49 +897,16 @@ def validation_step(self, batch: dict[str, Tensor], batch_idx: int): self.log("val/weighted_rmsd_peptide", peptide_mean, prog_bar=False, sync_dist=True, batch_size=batch_size) - device = out["sample_atom_coords"].device - total_samples = out["sample_atom_coords"].shape[0] - samples_per_structure = n_samples if n_samples > 0 else total_samples - if total_samples != len(record_ids) * samples_per_structure: - raise ValueError("Mismatch between diffusion samples and batch structures.") - - peptide_mask = peptide_calc_mask.to(device=device) - peptide_mask = peptide_mask.repeat_interleave(samples_per_structure, dim=0) - aligned_true_coords = aligned_true_coords.to(device) - - diff = out["sample_atom_coords"] - aligned_true_coords - mask_float = peptide_mask.float() - mask_counts = mask_float.sum(dim=1) - valid_samples = mask_counts > 0 - - mask_rmsd = torch.full( - (total_samples,), float("nan"), device=device, dtype=diff.dtype - ) - if valid_samples.any(): - sq = (diff[valid_samples] ** 2).sum(dim=-1) - mask_rmsd[valid_samples] = torch.sqrt( - (sq * mask_float[valid_samples]).sum(dim=1) / mask_counts[valid_samples] - ) - - mask_rmsd_valid = mask_rmsd[valid_samples] - if mask_rmsd_valid.numel() > 0: - self.mask_rmsd_metric.update(mask_rmsd_valid.detach().cpu()) - - best_mask_rmsd = torch.full( - (len(record_ids),), float("nan"), device=device, dtype=diff.dtype + output_dir = Path(getattr(self.validation_args, "val_cif_out_dir", "validation_outputs")) + write_validation_predictions( + out=out, + batch=batch, + base_structures=base_structures, + record_ids=record_ids, + sample_metrics=sample_metrics, + n_samples=n_samples, + output_dir=output_dir, ) - mask_rmsd_matrix = mask_rmsd.view(len(record_ids), samples_per_structure) - valid_matrix = valid_samples.view(len(record_ids), samples_per_structure) - for idx in range(len(record_ids)): - if valid_matrix[idx].any(): - best_mask_rmsd[idx] = mask_rmsd_matrix[idx][valid_matrix[idx]].min() - - best_mask_valid = best_mask_rmsd[torch.isfinite(best_mask_rmsd)] - if best_mask_valid.numel() > 0: - self.best_mask_rmsd_metric.update(best_mask_valid.detach().cpu()) - - self.latest_mask_rmsd = mask_rmsd.detach().cpu() - self.latest_best_mask_rmsd = best_mask_rmsd.detach().cpu() except RuntimeError as e: @@ -982,22 +922,6 @@ def on_validation_epoch_end(self): avg_lddt = {} avg_disto_lddt = {} avg_complex_lddt = {} - - mask_rmsd_val = self.mask_rmsd_metric.compute() - if torch.isfinite(mask_rmsd_val).item(): - self.log("val/mask_rmsd", mask_rmsd_val, prog_bar=False, sync_dist=True) - self.mask_rmsd_metric.reset() - - best_mask_rmsd_val = self.best_mask_rmsd_metric.compute() - if torch.isfinite(best_mask_rmsd_val).item(): - self.log( - "val/best_mask_rmsd", - best_mask_rmsd_val, - prog_bar=False, - sync_dist=True, - ) - self.best_mask_rmsd_metric.reset() - if self.confidence_prediction: avg_top1_lddt = {} avg_iplddt_top1_lddt = {} From 246fa30b36b1a7e053ff67d14f2154d2ba84f87f Mon Sep 17 00:00:00 2001 From: AlvandVahedi Date: Sat, 8 Nov 2025 04:55:03 +0000 Subject: [PATCH 5/5] Refactor validation loop: fix RMSD computation, replace record loading with precomputed MHC/peptide masks, add residue debug print, and make CIF output optional --- .gitignore | 1 + scripts/precompute_masks.py | 113 +++++ scripts/train/configs/structure.yaml | 2 + src/boltz/data/module/training.py | 118 +++-- src/boltz/data/types.py | 2 + src/boltz/data/write/writer.py | 162 ++---- src/boltz/model/loss/validation.py | 156 ++---- src/boltz/model/models/boltz1.py | 716 +++++++++++++++++++++------ 8 files changed, 841 insertions(+), 429 deletions(-) create mode 100644 scripts/precompute_masks.py diff --git a/.gitignore b/.gitignore index 6f79226b3..a1f5a545c 100644 --- a/.gitignore +++ b/.gitignore @@ -170,3 +170,4 @@ data/ natives/ cache/ mhc_one_sample/ +val_cif_output/ diff --git a/scripts/precompute_masks.py b/scripts/precompute_masks.py new file mode 100644 index 000000000..0fdf91bc3 --- /dev/null +++ b/scripts/precompute_masks.py @@ -0,0 +1,113 @@ +from __future__ import annotations + +import argparse +from pathlib import Path + +import numpy as np + + +def compute_masks(data: np.lib.npyio.NpzFile) -> tuple[np.ndarray, np.ndarray]: + """Return (alignment_mask, rmsd_mask) inferred from raw NPZ contents.""" + chains = data["chains"] + chain_mask = data["mask"].astype(bool) + + alignment_segments: list[np.ndarray] = [] + rmsd_segments: list[np.ndarray] = [] + + entity_counts: dict[int, int] = {} + for idx, chain in enumerate(chains): + if not chain_mask[idx]: + continue + entity_id = int(chain["entity_id"]) + entity_counts[entity_id] = entity_counts.get(entity_id, 0) + int(chain["res_num"]) + + if not entity_counts: + return np.zeros(0, dtype=bool), np.zeros(0, dtype=bool) + + # For now, assume the peptide entity is the smallest one by residue count, and MHC/heavy is the largest. + # Will need to revist. + peptide_entity = min(entity_counts, key=entity_counts.get) + heavy_candidates = {k: v for k, v in entity_counts.items() if k != peptide_entity} + heavy_entity = ( + max(heavy_candidates, key=heavy_candidates.get) + if heavy_candidates + else peptide_entity + ) + + for idx, chain in enumerate(chains): + if not chain_mask[idx]: + continue + entity_id = int(chain["entity_id"]) + atom_num = int(chain["atom_num"]) + + align_segment = np.zeros(atom_num, dtype=bool) + rmsd_segment = np.zeros(atom_num, dtype=bool) + + if entity_id == heavy_entity: + align_segment[:] = True + if entity_id == peptide_entity: + rmsd_segment[:] = True + + alignment_segments.append(align_segment) + rmsd_segments.append(rmsd_segment) + + alignment_mask = ( + np.concatenate(alignment_segments) if alignment_segments else np.zeros(0, dtype=bool) + ) + rmsd_mask = ( + np.concatenate(rmsd_segments) if rmsd_segments else np.zeros(0, dtype=bool) + ) + + return alignment_mask, rmsd_mask + + +def process_structure(npz_path: Path, dry_run: bool = False) -> None: + with np.load(npz_path) as data: + alignment_mask, rmsd_mask = compute_masks(data) + + if dry_run: + print( + f"{npz_path.name}: align atoms={alignment_mask.sum()} " + f"(len={alignment_mask.shape[0]}), " + f"rmsd atoms={rmsd_mask.sum()} (len={rmsd_mask.shape[0]})" + ) + return + + updated = {key: data[key] for key in data.files} + updated["alignment_mask"] = alignment_mask + updated["rmsd_mask"] = rmsd_mask + + tmp_path = npz_path.with_suffix(".tmp.npz") + np.savez_compressed(tmp_path, **updated) + tmp_path.replace(npz_path) + + +def main() -> None: + parser = argparse.ArgumentParser(description="Inject alignment/rmsd masks into structure NPZ files.") + parser.add_argument( + "structures_dir", + type=Path, + help="Directory containing processed structure NPZ files.", + ) + parser.add_argument( + "--dry-run", + action="store_true", + help="Report mask statistics without writing files.", + ) + args = parser.parse_args() + + npz_paths = sorted(args.structures_dir.glob("*.npz")) + if not npz_paths: + raise SystemExit(f"No .npz files found under {args.structures_dir}") + + for npz_path in npz_paths: + process_structure(npz_path, dry_run=args.dry_run) + + if args.dry_run: + print("Dry run complete; no files modified.") + else: + print(f"Processed {len(npz_paths)} structures.") + + +if __name__ == "__main__": + main() diff --git a/scripts/train/configs/structure.yaml b/scripts/train/configs/structure.yaml index 81e9ef55a..44484e69f 100644 --- a/scripts/train/configs/structure.yaml +++ b/scripts/train/configs/structure.yaml @@ -165,6 +165,8 @@ model: symmetry_correction: false run_confidence_sequentially: false val_cif_out_dir: /storage/alvand/boltz/val_cif_output + write_cif_for_validation: false # whether to write cif files during validation + debug_peptide_mask_info: false # Keep this false unless debugging diffusion_process_args: sigma_min: 0.0004 diff --git a/src/boltz/data/module/training.py b/src/boltz/data/module/training.py index a450d249b..6e3e6c92c 100644 --- a/src/boltz/data/module/training.py +++ b/src/boltz/data/module/training.py @@ -82,6 +82,75 @@ class Dataset: featurizer: BoltzFeaturizer +def _prepare_structure(data: np.lib.npyio.NpzFile) -> tuple[Structure, Optional[np.ndarray], Optional[np.ndarray]]: + chains = data["chains"] + if "cyclic_period" not in chains.dtype.names: + new_dtype = chains.dtype.descr + [("cyclic_period", "i4")] + new_chains = np.empty(chains.shape, dtype=new_dtype) + for name in chains.dtype.names: + new_chains[name] = chains[name] + new_chains["cyclic_period"] = 0 + chains = new_chains + + structure = Structure( + atoms=data["atoms"], + bonds=data["bonds"], + residues=data["residues"], + chains=chains, + connections=data["connections"].astype(Connection), + interfaces=data["interfaces"], + mask=data["mask"], + ) + + alignment_mask = data.get("alignment_mask") + rmsd_mask = data.get("rmsd_mask") + + if alignment_mask is not None: + alignment_mask = alignment_mask.astype(bool, copy=False) + if rmsd_mask is not None: + rmsd_mask = rmsd_mask.astype(bool, copy=False) + + chain_mask = data["mask"].astype(bool) + kept_atom_total = sum( + int(chain["atom_num"]) for idx, chain in enumerate(data["chains"]) if chain_mask[idx] + ) + + def _reshape_mask(mask: Optional[np.ndarray]) -> Optional[np.ndarray]: + if mask is None: + return None + if mask.shape[0] == kept_atom_total: + return mask + + segments = [] + for idx, chain in enumerate(data["chains"]): + if not chain_mask[idx]: + continue + start = int(chain["atom_idx"]) + end = start + int(chain["atom_num"]) + segments.append(mask[start:end]) + if any(seg.size == 0 for seg in segments): + return None + return np.concatenate(segments) + + alignment_mask = _reshape_mask(alignment_mask) + rmsd_mask = _reshape_mask(rmsd_mask) + + structure_clean = structure.remove_invalid_chains() + + if alignment_mask is not None and alignment_mask.shape[0] != structure_clean.atoms.shape[0]: + print( + f"Alignment mask length mismatch for structure ({alignment_mask.shape[0]} vs {structure_clean.atoms.shape[0]}). Discarding mask." + ) + alignment_mask = None + if rmsd_mask is not None and rmsd_mask.shape[0] != structure_clean.atoms.shape[0]: + print( + f"rmsd mask length mismatch for structure ({rmsd_mask.shape[0]} vs {structure_clean.atoms.shape[0]}). Discarding mask." + ) + rmsd_mask = None + + return structure_clean, alignment_mask, rmsd_mask + + def load_input(record: Record, target_dir: Path, msa_dir: Path) -> Input: """Load the given input data. @@ -101,34 +170,9 @@ def load_input(record: Record, target_dir: Path, msa_dir: Path) -> Input: """ # Load the structure - structure = np.load(target_dir / "structures" / f"{record.id}.npz") - - # In order to add cyclic_period to chains if it does not exist - # Extract the chains array - chains = structure["chains"] - # Check if the field exists - if "cyclic_period" not in chains.dtype.names: - # Create a new dtype with the additional field - new_dtype = chains.dtype.descr + [("cyclic_period", "i4")] - # Create a new array with the new dtype - new_chains = np.empty(chains.shape, dtype=new_dtype) - # Copy over existing fields - for name in chains.dtype.names: - new_chains[name] = chains[name] - # Set the new field to 0 - new_chains["cyclic_period"] = 0 - # Replace old chains array with new one - chains = new_chains - - structure = Structure( - atoms=structure["atoms"], - bonds=structure["bonds"], - residues=structure["residues"], - chains=chains, # chains var accounting for missing cyclic_period - connections=structure["connections"].astype(Connection), - interfaces=structure["interfaces"], - mask=structure["mask"], - ) + structure_path = target_dir / "structures" / f"{record.id}.npz" + npz_data = np.load(structure_path) + structure, alignment_mask, rmsd_mask = _prepare_structure(npz_data) msas = {} for chain in record.chains: @@ -138,7 +182,13 @@ def load_input(record: Record, target_dir: Path, msa_dir: Path) -> Input: msa = np.load(msa_dir / f"{msa_id}.npz") msas[chain.chain_id] = MSA(**msa) - return Input(structure, msas) + return Input( + structure=structure, + msa=msas, + record=record, + alignment_mask=alignment_mask, + rmsd_mask=rmsd_mask, + ) def collate(data: list[dict[str, Tensor]]) -> dict[str, Tensor]: @@ -477,6 +527,16 @@ def __getitem__(self, idx: int) -> dict[str, Tensor]: return self.__getitem__(0) features["record_id"] = record.id + features["record"] = record + if input_data.alignment_mask is not None: + features["alignment_mask"] = torch.from_numpy( + input_data.alignment_mask.astype(np.bool_) + ) + if input_data.rmsd_mask is not None: + features["rmsd_mask"] = torch.from_numpy( + input_data.rmsd_mask.astype(np.bool_) + ) + features["base_structure"] = input_data.structure features["structure_path"] = str( (dataset.target_dir / "structures" / f"{record.id}.npz").resolve() ) diff --git a/src/boltz/data/types.py b/src/boltz/data/types.py index 1ce26b558..ac3aef24c 100644 --- a/src/boltz/data/types.py +++ b/src/boltz/data/types.py @@ -705,6 +705,8 @@ class Input: residue_constraints: Optional[ResidueConstraints] = None templates: Optional[dict[str, StructureV2]] = None extra_mols: Optional[dict[str, Mol]] = None + alignment_mask: Optional[np.ndarray] = None + rmsd_mask: Optional[np.ndarray] = None #################################################################################################### diff --git a/src/boltz/data/write/writer.py b/src/boltz/data/write/writer.py index 2c9a09caa..915b98a7e 100644 --- a/src/boltz/data/write/writer.py +++ b/src/boltz/data/write/writer.py @@ -1,9 +1,8 @@ import json import os -import gc from dataclasses import asdict, replace from pathlib import Path -from typing import Iterable, Literal, Optional +from typing import Literal, Optional import numpy as np import torch @@ -14,7 +13,6 @@ from boltz.data.types import Coords, Interface, Record, Structure, StructureV2 from boltz.data.write.mmcif import to_mmcif from boltz.data.write.pdb import to_pdb -from boltz.model.loss.validation import SampleMetrics class BoltzWriter(BasePredictionWriter): @@ -69,9 +67,12 @@ def write_on_batch_end( # Get the predictions coords = prediction["coords"] - coords = coords.unsqueeze(0) - + if coords.ndim == 3: + coords = coords.unsqueeze(0) pad_masks = prediction["masks"] + structure_paths = prediction.get("structure_paths") + provided_structures = prediction.get("structures") + custom_filenames = prediction.get("filenames") # Get ranking if "confidence_score" in prediction: @@ -82,22 +83,39 @@ def write_on_batch_end( idx_to_rank = {i: i for i in range(len(records))} # Iterate over the records - for record, coord, pad_mask in zip(records, coords, pad_masks): + for rec_idx, (record, coord, pad_mask) in enumerate(zip(records, coords, pad_masks)): # Load the structure - path = self.data_dir / f"{record.id}.npz" - if self.boltz2: - structure: StructureV2 = StructureV2.load(path) + if provided_structures is not None and rec_idx < len(provided_structures): + structure = provided_structures[rec_idx] + chain_map = { + int(chain["asym_id"]): int(chain["asym_id"]) + for chain in structure.chains + } else: - structure: Structure = Structure.load(path) + if structure_paths is not None and rec_idx < len(structure_paths): + path = Path(structure_paths[rec_idx]) + else: + path = self.data_dir / f"{record.id}.npz" + if self.boltz2: + structure: StructureV2 = StructureV2.load(path) + else: + structure: Structure = Structure.load(path) - # Compute chain map with masked removed, to be used later - chain_map = {} - for i, mask in enumerate(structure.mask): - if mask: - chain_map[len(chain_map)] = i + # Compute chain map with masked removed, to be used later + chain_map = {} + for i, mask in enumerate(structure.mask): + if mask: + chain_map[len(chain_map)] = i - # Remove masked chains completely - structure = structure.remove_invalid_chains() + # Remove masked chains completely + structure = structure.remove_invalid_chains() + + custom_names = None + if custom_filenames is not None and rec_idx < len(custom_filenames): + custom_names = custom_filenames[rec_idx] + + ext_map = {"mmcif": "cif", "pdb": "pdb"} + out_ext = ext_map.get(self.output_format, "npz") for model_idx in range(coord.shape[0]): # Get model coord @@ -159,24 +177,29 @@ def write_on_batch_end( plddts = prediction["plddt"][model_idx] # Create path name - outname = f"{record.id}_model_{idx_to_rank[model_idx]}" + if custom_names and model_idx < len(custom_names): + name = custom_names[model_idx] + if Path(name).suffix: + filename = name + else: + filename = f"{name}.{out_ext}" + else: + filename = f"{record.id}_model_{idx_to_rank[model_idx]}.{out_ext}" + output_path = struct_dir / filename # Save the structure if self.output_format == "pdb": - path = struct_dir / f"{outname}.pdb" - with path.open("w") as f: + with output_path.open("w") as f: f.write( to_pdb(new_structure, plddts=plddts, boltz2=self.boltz2) ) elif self.output_format == "mmcif": - path = struct_dir / f"{outname}.cif" - with path.open("w") as f: + with output_path.open("w") as f: f.write( to_mmcif(new_structure, plddts=plddts, boltz2=self.boltz2) ) else: - path = struct_dir / f"{outname}.npz" - np.savez_compressed(path, **asdict(new_structure)) + np.savez_compressed(output_path, **asdict(new_structure)) if self.boltz2 and record.affinity and idx_to_rank[model_idx] == 0: path = struct_dir / f"pre_affinity_{record.id}.npz" @@ -290,97 +313,6 @@ def atomic_save_cif(filepath: Path, content: str) -> bool: return False -def write_validation_predictions( - out: dict[str, Tensor], - batch: dict[str, Tensor], - base_structures: list[Optional[Structure]], - record_ids: list[str], - sample_metrics: Iterable[SampleMetrics], - n_samples: int, - output_dir: Path, -) -> None: - """ - Write validation predictions to disk. - """ - output_dir.mkdir(parents=True, exist_ok=True) - - sample_metrics = list(sample_metrics) - metrics_map = {metric.sample_idx: metric for metric in sample_metrics} - total_samples = out["sample_atom_coords"].shape[0] - samples_per_structure = n_samples if n_samples > 0 else total_samples - atom_pad_mask_cpu = batch["atom_pad_mask"].detach().cpu() - - for struct_idx, record_id in enumerate(record_ids): - base_structure = base_structures[struct_idx] if struct_idx < len(base_structures) else None - if base_structure is None: - print(f"Skipping CIF write for record '{record_id}' (missing base structure).") - continue - - valid_mask = atom_pad_mask_cpu[struct_idx].to(dtype=torch.bool).numpy() - if valid_mask.sum() != base_structure.atoms.shape[0]: - print( - f"Warning: atom count mismatch for record '{record_id}': " - f"mask has {int(valid_mask.sum())} atoms, structure has {base_structure.atoms.shape[0]}." - ) - continue - - sample_block = out["sample_atom_coords"][ - struct_idx * samples_per_structure : (struct_idx + 1) * samples_per_structure - ] - - for sample_offset, coords_tensor in enumerate(sample_block): - sample_idx = struct_idx * samples_per_structure + sample_offset - metrics = metrics_map.get(sample_idx) - - try: - coords_np = coords_tensor.detach().cpu().numpy()[valid_mask] - atoms = base_structure.atoms.copy() - atoms["coords"] = coords_np.astype(np.float32) - atoms["conformer"] = coords_np.astype(np.float32) - atoms["is_present"] = True - - residues = base_structure.residues.copy() - residues["is_present"] = True - - new_structure = Structure( - atoms=atoms, - bonds=base_structure.bonds, - residues=residues, - chains=base_structure.chains, - connections=base_structure.connections, - interfaces=base_structure.interfaces, - mask=base_structure.mask, - ) - - plddts = None - if "plddt" in out: - start = struct_idx * samples_per_structure + sample_offset - plddts = out["plddt"][start : start + 1].detach().cpu() - - if metrics is not None: - filename = ( - f"prediction_{record_id}_sample_{sample_idx}" - f"_whole{metrics.rmsd_whole:.2f}_pep{metrics.rmsd_peptide:.2f}.cif" - ) - else: - filename = f"prediction_{record_id}_sample_{sample_idx}.cif" - - output_path = output_dir / filename - print(f"\nSaving prediction to {output_path}") - - cif_content = to_mmcif(new_structure, plddts=plddts) - if atomic_save_cif(output_path, cif_content): - print(f"Successfully saved prediction to {output_path}") - - except Exception as exc: # noqa: BLE001 - print(f"Error processing record '{record_id}' sample {sample_idx}: {exc}") - continue - - if torch.cuda.is_available(): - torch.cuda.empty_cache() - gc.collect() - - class BoltzAffinityWriter(BasePredictionWriter): """Custom writer for predictions.""" diff --git a/src/boltz/model/loss/validation.py b/src/boltz/model/loss/validation.py index 20c9f0ba7..d957ee726 100644 --- a/src/boltz/model/loss/validation.py +++ b/src/boltz/model/loss/validation.py @@ -8,9 +8,7 @@ ) from boltz.model.loss.diffusion import weighted_rigid_align -import math from dataclasses import dataclass -from typing import List from torch import Tensor @@ -900,6 +898,8 @@ def weighted_minimum_rmsd( multiplicity=1, nucleotide_weight=5.0, ligand_weight=10.0, + alignment_mask=None, + rmsd_mask=None, ): """Compute rmsd of the aligned atom coordinates. @@ -927,6 +927,27 @@ def weighted_minimum_rmsd( atom_mask = feats["atom_resolved_mask"] atom_mask = atom_mask.repeat_interleave(multiplicity, 0) + target_len = pred_atom_coords.shape[0] + base_len = feats["coords"].shape[0] + + def _prepare_mask(mask, fallback): + if mask is None: + prepared = fallback + else: + if mask.shape[0] == target_len: + prepared = mask + elif mask.shape[0] == base_len and base_len * multiplicity == target_len: + prepared = mask.repeat_interleave(multiplicity, 0) + else: + raise ValueError( + "Unexpected mask shape: " + f"{mask.shape[0]} (expected {target_len} or {base_len})." + ) + return prepared.to(dtype=fallback.dtype, device=pred_atom_coords.device) + + align_mask = _prepare_mask(alignment_mask, atom_mask) + calc_mask = _prepare_mask(rmsd_mask, atom_mask) + align_weights = atom_coords.new_ones(atom_coords.shape[:2]) atom_type = ( torch.bmm( @@ -950,14 +971,15 @@ def weighted_minimum_rmsd( with torch.no_grad(): atom_coords_aligned_ground_truth = weighted_rigid_align( - atom_coords, pred_atom_coords, align_weights, mask=atom_mask + atom_coords, pred_atom_coords, align_weights, mask=align_mask ) # weighted MSE loss of denoised atom positions mse_loss = ((pred_atom_coords - atom_coords_aligned_ground_truth) ** 2).sum(dim=-1) + denom = torch.sum(align_weights * calc_mask, dim=-1) rmsd = torch.sqrt( - torch.sum(mse_loss * align_weights * atom_mask, dim=-1) - / torch.sum(align_weights * atom_mask, dim=-1) + torch.sum(mse_loss * align_weights * calc_mask, dim=-1) + / torch.clamp(denom, min=1e-8) ) best_rmsd = torch.min(rmsd.reshape(-1, multiplicity), dim=1).values @@ -972,6 +994,8 @@ def weighted_minimum_rmsd_single( mol_type, nucleotide_weight=5.0, ligand_weight=10.0, + alignment_mask=None, + rmsd_mask=None, ): """Compute rmsd of the aligned atom coordinates. @@ -1017,15 +1041,18 @@ def weighted_minimum_rmsd_single( ) with torch.no_grad(): + mask_align = alignment_mask if alignment_mask is not None else atom_mask atom_coords_aligned_ground_truth = weighted_rigid_align( - atom_coords, pred_atom_coords, align_weights, mask=atom_mask + atom_coords, pred_atom_coords, align_weights, mask=mask_align ) # weighted MSE loss of denoised atom positions + calc_mask = rmsd_mask if rmsd_mask is not None else atom_mask mse_loss = ((pred_atom_coords - atom_coords_aligned_ground_truth) ** 2).sum(dim=-1) + denom = torch.sum(align_weights * calc_mask, dim=-1) rmsd = torch.sqrt( - torch.sum(mse_loss * align_weights * atom_mask, dim=-1) - / torch.sum(align_weights * atom_mask, dim=-1) + torch.sum(mse_loss * align_weights * calc_mask, dim=-1) + / torch.clamp(denom, min=1e-8) ) return rmsd, atom_coords_aligned_ground_truth, align_weights @@ -1033,115 +1060,4 @@ def weighted_minimum_rmsd_single( class SampleMetrics: sample_idx: int rmsd_whole: float - rmsd_peptide: float - - -def compute_weighted_mhc_rmsds( - out: dict, - true_coords: Tensor, - batch: dict, - peptide_mask: Tensor, - n_samples: int, - nucleotide_weight: float, - ligand_weight: float, -) -> List[SampleMetrics]: - """Compute weighted RMSDs for whole MHC chain and peptide subset. - - Parameters - ---------- - out : dict - Model outputs containing ``sample_atom_coords``. - true_coords : Tensor - Reference coordinates matching each diffusion sample. - batch : dict - Original batch features (masks, mapping, etc.). - peptide_mask : Tensor - Boolean mask selecting peptide atoms per structure. - n_samples : int - Number of diffusion samples per structure. - nucleotide_weight : float - Weight multiplier for nucleic acid atoms. - ligand_weight : float - Weight multiplier for ligand atoms. - - Returns - ------- - list[SampleMetrics] - Metrics per diffusion sample. - """ - device = out["sample_atom_coords"].device - total_samples = out["sample_atom_coords"].shape[0] - denom = max(n_samples, 1) - - metrics: List[SampleMetrics] = [] - - for sample_idx in range(total_samples): - struct_idx = sample_idx // denom - pred_sample = out["sample_atom_coords"][sample_idx : sample_idx + 1] - ref_sample = true_coords[sample_idx : sample_idx + 1] - - atom_mask_full = ( - batch["atom_resolved_mask"][struct_idx : struct_idx + 1] - .to(device=device) - .float() - ) - atom_to_token_full = ( - batch["atom_to_token"][struct_idx : struct_idx + 1] - .float() - .to(device=device) - ) - mol_type_full = batch["mol_type"][struct_idx : struct_idx + 1].to(device=device) - - try: - whole_rmsd_tensor, _, _ = weighted_minimum_rmsd_single( - pred_sample, - ref_sample, - atom_mask_full, - atom_to_token_full, - mol_type_full, - nucleotide_weight=nucleotide_weight, - ligand_weight=ligand_weight, - ) - whole_rmsd = whole_rmsd_tensor.item() - except Exception as e: # noqa: BLE001 - print(f"Weighted RMSD (MHC Chain) failed for sample {sample_idx}: {e}") - whole_rmsd = float("nan") - - peptide_mask_row = ( - peptide_mask[struct_idx : struct_idx + 1] - .to(device=device) - .float() - ) - try: - if peptide_mask_row.sum() >= 3: - peptide_rmsd_tensor, _, _ = weighted_minimum_rmsd_single( - pred_sample, - ref_sample, - peptide_mask_row, - atom_to_token_full, - mol_type_full, - nucleotide_weight=nucleotide_weight, - ligand_weight=ligand_weight, - ) - peptide_rmsd = peptide_rmsd_tensor.item() - else: - peptide_rmsd = float("nan") - except Exception as e: # noqa: BLE001 - print(f"Weighted RMSD (peptide) failed for sample {sample_idx}: {e}") - peptide_rmsd = float("nan") - - print(f"Sample {sample_idx} weighted RMSD (MHC Chain): {whole_rmsd:.3f}Å") - if math.isnan(peptide_rmsd): - print(f"Sample {sample_idx} weighted RMSD (peptide): nan") - else: - print(f"Sample {sample_idx} weighted RMSD (peptide): {peptide_rmsd:.3f}Å") - - metrics.append( - SampleMetrics( - sample_idx=sample_idx, - rmsd_whole=whole_rmsd, - rmsd_peptide=peptide_rmsd, - ) - ) - - return metrics + rmsd_masked: float diff --git a/src/boltz/model/models/boltz1.py b/src/boltz/model/models/boltz1.py index 7e859e119..b5966b08d 100644 --- a/src/boltz/model/models/boltz1.py +++ b/src/boltz/model/models/boltz1.py @@ -3,6 +3,7 @@ import random from pathlib import Path from typing import Any, Optional +from typing import Iterable import numpy as np @@ -19,19 +20,21 @@ minimum_lddt_symmetry_coords, minimum_symmetry_coords, ) +from lightning_fabric.utilities.apply_func import move_data_to_device from boltz.model.loss.confidence import confidence_loss from boltz.model.loss.distogram import distogram_loss from boltz.model.loss.validation import ( + SampleMetrics, compute_pae_mae, compute_pde_mae, compute_plddt_mae, factored_lddt_loss, factored_token_lddt_dist_loss, - compute_weighted_mhc_rmsds, weighted_minimum_rmsd, + weighted_minimum_rmsd_single, ) from boltz.model.modules.confidence import ConfidenceModule -from boltz.data.write.writer import write_validation_predictions +from boltz.data.write.writer import BoltzWriter from boltz.model.modules.diffusion import AtomDiffusion from boltz.model.modules.encoders import RelativePositionEncoder from boltz.model.modules.trunk import ( @@ -148,6 +151,7 @@ def __init__( # noqa: PLR0915, C901, PLR0912 self.steering_args = steering_args self.use_kernels = use_kernels + self.validation_writer: Optional[BoltzWriter] = None self.nucleotide_rmsd_weight = nucleotide_rmsd_weight self.ligand_rmsd_weight = ligand_rmsd_weight @@ -276,6 +280,27 @@ def setup(self, stage: str) -> None: ): self.use_kernels = False + def transfer_batch_to_device( + self, batch: dict[str, Any], device: torch.device, dataloader_idx: int + ) -> dict[str, Any]: + """Move tensor entries to device while keeping CPU-only metadata intact.""" + if batch is None: + return batch + + cpu_only_keys = {"record", "base_structure"} + cpu_entries = {} + tensor_entries = {} + + for key, value in batch.items(): + if key in cpu_only_keys: + cpu_entries[key] = value + else: + tensor_entries[key] = value + + moved = move_data_to_device(tensor_entries, device) + moved.update(cpu_entries) + return moved + def forward( self, feats: dict[str, Tensor], @@ -413,7 +438,12 @@ def get_true_coordinates( diffusion_samples, symmetry_correction, lddt_minimization=True, + alignment_mask: Optional[Tensor] = None, + rmsd_mask: Optional[Tensor] = None, ): + masked_rmsds = None + best_masked_rmsds = None + if symmetry_correction: min_coords_routine = ( minimum_lddt_symmetry_coords @@ -444,6 +474,8 @@ def get_true_coordinates( best_rmsds.append(best_rmsd) true_coords = torch.cat(true_coords, dim=0) true_coords_resolved_mask = torch.cat(true_coords_resolved_mask, dim=0) + masked_rmsds = None + best_masked_rmsds = None else: true_coords = ( batch["coords"].squeeze(1).repeat_interleave(diffusion_samples, 0) @@ -460,7 +492,50 @@ def get_true_coordinates( ligand_weight=self.ligand_rmsd_weight, ) - return true_coords, rmsds, best_rmsds, true_coords_resolved_mask + align_mask = alignment_mask + if align_mask is not None: + target_len = out["sample_atom_coords"].shape[0] + if align_mask.shape[0] != target_len: + if align_mask.shape[0] * diffusion_samples == target_len: + align_mask = align_mask.repeat_interleave(diffusion_samples, 0) + else: + raise ValueError( + "Alignment mask has unexpected first dimension " + f"{align_mask.shape[0]} (expected {target_len})." + ) + calc_mask = rmsd_mask + if calc_mask is not None: + target_len = out["sample_atom_coords"].shape[0] + if calc_mask.shape[0] != target_len: + if calc_mask.shape[0] * diffusion_samples == target_len: + calc_mask = calc_mask.repeat_interleave(diffusion_samples, 0) + else: + raise ValueError( + "RMSD mask has unexpected first dimension " + f"{calc_mask.shape[0]} (expected {target_len})." + ) + + if align_mask is not None or calc_mask is not None: + masked_rmsds, best_masked_rmsds = weighted_minimum_rmsd( + out["sample_atom_coords"], + batch, + multiplicity=diffusion_samples, + nucleotide_weight=self.nucleotide_rmsd_weight, + ligand_weight=self.ligand_rmsd_weight, + alignment_mask=align_mask, + rmsd_mask=calc_mask, + ) + else: + masked_rmsds, best_masked_rmsds = rmsds, best_rmsds + + return ( + true_coords, + rmsds, + best_rmsds, + true_coords_resolved_mask, + masked_rmsds, + best_masked_rmsds, + ) def training_step(self, batch: dict[str, Tensor], batch_idx: int) -> Tensor: # Sample recycling steps @@ -498,7 +573,7 @@ def training_step(self, batch: dict[str, Tensor], batch_idx: int) -> Tensor: if self.confidence_prediction: # confidence model symmetry correction - true_coords, _, _, true_coords_resolved_mask = self.get_true_coordinates( + true_coords, _, _, true_coords_resolved_mask, _, _ = self.get_true_coordinates( batch, out, diffusion_samples=self.training_args.diffusion_samples, @@ -628,14 +703,17 @@ def _build_alignment_masks( record_ids: list[str], base_structures: list[Optional[Structure]], batch_idx: int, - ) -> tuple[Tensor, Tensor, Tensor]: + precomputed_align_masks: Optional[list[Optional[np.ndarray]]] = None, + precomputed_rmsd_masks: Optional[list[Optional[np.ndarray]]] = None, + ) -> tuple[Tensor, Tensor]: batch_size, num_atoms, _ = batch["atom_to_token"].shape device = batch["atom_pad_mask"].device - align_base_mask = torch.zeros((batch_size, num_atoms), dtype=torch.bool, device=device) - heavy_calc_mask = torch.zeros_like(align_base_mask) - peptide_calc_mask = torch.zeros_like(align_base_mask) - backbone_names = set(const.protein_backbone_atom_names) + heavy_calc_mask = torch.zeros((batch_size, num_atoms), dtype=torch.bool, device=device) + peptide_calc_mask = torch.zeros_like(heavy_calc_mask) + + if precomputed_align_masks is None or precomputed_rmsd_masks is None: + raise ValueError("Precomputed masks are required but were not provided.") for structure_idx in range(batch_size): record_id = ( @@ -648,6 +726,8 @@ def _build_alignment_masks( if structure_idx < len(base_structures) else None ) + pre_align = precomputed_align_masks[structure_idx] + pre_rmsd = precomputed_rmsd_masks[structure_idx] atom_to_token = batch["atom_to_token"][structure_idx].bool() entity_ids = batch["entity_id"][structure_idx] @@ -679,93 +759,217 @@ def _build_alignment_masks( ) continue - align_mask = atom_heavy | atom_pep - heavy_mask = atom_heavy.clone() - peptide_mask = atom_pep.clone() - align_mode = "all atoms" - calc_mode = "all atoms" - - if structure is not None: - valid_mask = present_atom - valid_count = int(valid_mask.sum().item()) - if valid_count == structure.atoms.shape[0]: - atoms = structure.atoms - chains = structure.chains - - def decode_name(code_row: np.ndarray) -> str: - chars = [chr(int(c) + 32) for c in code_row if c > 0] - return "".join(chars) - - atom_names = np.array([decode_name(row) for row in atoms["name"]]) - - heavy_all_base = np.zeros(structure.atoms.shape[0], dtype=bool) - peptide_all_base = np.zeros_like(heavy_all_base) - heavy_ca_base = np.zeros_like(heavy_all_base) - peptide_ca_base = np.zeros_like(heavy_all_base) - - for chain in chains: - start = int(chain["atom_idx"]) - end = start + int(chain["atom_num"]) - entity = int(chain["entity_id"]) - - if entity == heavy_entity: - heavy_all_base[start:end] = True - ca_indices = np.where(atom_names[start:end] == "CA")[0] + start - if ca_indices.size > 0: - heavy_ca_base[ca_indices] = True - if entity == peptide_entity: - peptide_all_base[start:end] = True - ca_indices = np.where(atom_names[start:end] == "CA")[0] + start - if ca_indices.size > 0: - peptide_ca_base[ca_indices] = True - - heavy_backbone_base = heavy_all_base & np.isin(atom_names, list(backbone_names)) - peptide_backbone_base = peptide_all_base & np.isin(atom_names, list(backbone_names)) - align_ca_base = heavy_ca_base | peptide_ca_base - - def pad_mask(base_mask: np.ndarray) -> torch.Tensor: - mask_full = torch.from_numpy(base_mask).to(device=device, dtype=torch.bool) - mask_padded = torch.zeros_like(align_mask) - mask_padded[valid_mask] = mask_full - return mask_padded - - if align_ca_base.sum() >= 3: - align_mask = pad_mask(align_ca_base) - align_mode = "CA" - - if heavy_ca_base.sum() >= 3: - heavy_mask = pad_mask(heavy_ca_base) - elif heavy_backbone_base.sum() >= 3: - heavy_mask = pad_mask(heavy_backbone_base) - else: - heavy_mask = pad_mask(heavy_all_base) - - if peptide_ca_base.sum() >= 3: - peptide_mask = pad_mask(peptide_ca_base) - calc_mode = "CA" - elif peptide_backbone_base.sum() >= 3: - peptide_mask = pad_mask(peptide_backbone_base) - calc_mode = "backbone" - else: - peptide_mask = pad_mask(peptide_all_base) - else: - print( - f"Warning: atom count mismatch for record '{record_id}': " - f"mask has {valid_count} atoms, structure has {structure.atoms.shape[0]}." - ) + if pre_align is None or pre_rmsd is None: + raise ValueError( + f"Missing precomputed masks for record '{record_id}'. " + "Please regenerate mask files." + ) + + valid_mask = present_atom + valid_count = int(valid_mask.sum().item()) + if structure is None or pre_align.shape[0] != valid_count or pre_rmsd.shape[0] != valid_count: + raise ValueError( + f"Precomputed mask length mismatch for record '{record_id}'. " + "Ensure masks were generated after removing invalid chains." + ) + + heavy_mask = torch.zeros_like(valid_mask, dtype=torch.bool, device=device) + peptide_mask = torch.zeros_like(valid_mask, dtype=torch.bool, device=device) + + heavy_mask[valid_mask] = torch.from_numpy(pre_align.astype(bool, copy=False)).to(device=device) + peptide_mask[valid_mask] = torch.from_numpy(pre_rmsd.astype(bool, copy=False)).to(device=device) + + if heavy_mask.sum() < 3 or peptide_mask.sum() < 3: + raise ValueError( + f"Precomputed masks for record '{record_id}' contain fewer than 3 atoms. " + "Please regenerate masks with enough atoms for alignment and RMSD." + ) - align_base_mask[structure_idx] = align_mask heavy_calc_mask[structure_idx] = heavy_mask peptide_calc_mask[structure_idx] = peptide_mask print( f"[debug] structure {structure_idx} entities: {res_counts} | heavy={heavy_entity} peptide={peptide_entity} | " - f"align_mode={align_mode} calc_mode={calc_mode}" + "align_mode=precomputed calc_mode=precomputed" ) print( - f"Align atoms: {int(align_mask.sum().item())} | Calc atoms: {int(peptide_mask.sum().item())}" + f"Align atoms: {int(heavy_mask.sum().item())} | Calc atoms: {int(peptide_mask.sum().item())}" ) - return align_base_mask, heavy_calc_mask, peptide_calc_mask + return heavy_calc_mask, peptide_calc_mask + + def _write_validation_cifs( + self, + batch: dict[str, Tensor], + out: dict[str, Tensor], + base_structures: list[Optional[Structure]], + record_ids: list[str], + sample_metrics: Iterable[SampleMetrics], + n_samples: int, + batch_idx: int, + ) -> None: + if not getattr(self.validation_args, "write_cif_for_validation", True): + return + + output_dir = Path(getattr(self.validation_args, "val_cif_out_dir", "validation_outputs")) + output_dir.mkdir(parents=True, exist_ok=True) + + if ( + self.validation_writer is None + or Path(self.validation_writer.output_dir) != output_dir + ): + self.validation_writer = BoltzWriter( + data_dir=".", + output_dir=str(output_dir), + output_format="mmcif", + ) + + records = batch.get("record", []) + structure_paths = batch.get("structure_path") + atom_pad_mask = batch["atom_pad_mask"].detach().cpu().bool() + metrics_map = {metric.sample_idx: metric for metric in sample_metrics} + total_samples = out["sample_atom_coords"].shape[0] + samples_per_structure = n_samples if n_samples > 0 else total_samples + + for struct_idx, record_id in enumerate(record_ids): + if not records or struct_idx >= len(records): + print(f"[warning] Missing record metadata for '{record_id}', skipping CIF write.") + continue + record = records[struct_idx] + + sample_block = out["sample_atom_coords"][ + struct_idx * samples_per_structure : (struct_idx + 1) * samples_per_structure + ].detach().cpu() + pad_mask = atom_pad_mask[struct_idx : struct_idx + 1] + + filenames = [] + for sample_offset in range(samples_per_structure): + sample_idx = struct_idx * samples_per_structure + sample_offset + metrics = metrics_map.get(sample_idx) + if metrics is not None: + name = ( + f"prediction_{record_id}_sample_{sample_idx}" + f"_whole{metrics.rmsd_whole:.2f}_mask{metrics.rmsd_masked:.2f}" + ) + else: + name = f"prediction_{record_id}_sample_{sample_idx}" + filenames.append(name) + + structure_path_entry = None + if structure_paths is not None and struct_idx < len(structure_paths): + structure_path_entry = [structure_paths[struct_idx]] + + structure_entry = None + if base_structures and struct_idx < len(base_structures): + structure_entry = [base_structures[struct_idx]] + + prediction_payload = { + "exception": False, + "coords": sample_block, + "masks": pad_mask, + "structure_paths": structure_path_entry, + "structures": structure_entry, + "filenames": [filenames], + } + + writer_batch = {"record": [record]} + + self.validation_writer.write_on_batch_end( + trainer=None, + pl_module=self, + prediction=prediction_payload, + batch_indices=[0], + batch=writer_batch, + batch_idx=batch_idx, + dataloader_idx=0, + ) + + def _compute_sample_metrics( + self, + batch: dict[str, Tensor], + out: dict[str, Tensor], + true_coords: Tensor, + heavy_calc_mask: Tensor, + peptide_calc_mask: Tensor, + n_samples: int, + ) -> list[SampleMetrics]: + sample_coords = out["sample_atom_coords"] + device = sample_coords.device + total_samples = sample_coords.shape[0] + samples_per_structure = max(n_samples, 1) + + atom_mask_full = batch["atom_resolved_mask"].to(device=device).float() + atom_to_token_full = batch["atom_to_token"].float().to(device=device) + mol_type_full = batch["mol_type"].to(device=device) + heavy_mask_full = heavy_calc_mask.to(device=device, dtype=atom_mask_full.dtype) + peptide_mask_full = peptide_calc_mask.to(device=device, dtype=atom_mask_full.dtype) + + sample_metrics: list[SampleMetrics] = [] + + for sample_idx in range(total_samples): + struct_idx = sample_idx // samples_per_structure + pred_sample = sample_coords[sample_idx : sample_idx + 1] + ref_sample = true_coords[sample_idx : sample_idx + 1] + + struct_atom_mask = atom_mask_full[struct_idx : struct_idx + 1] + struct_atom_to_token = atom_to_token_full[struct_idx : struct_idx + 1] + struct_mol_type = mol_type_full[struct_idx : struct_idx + 1] + + whole_rmsd = float("nan") + masked_rmsd = float("nan") + + try: + whole_rmsd_tensor, _, _ = weighted_minimum_rmsd_single( + pred_sample, + ref_sample, + struct_atom_mask, + struct_atom_to_token, + struct_mol_type, + nucleotide_weight=self.nucleotide_rmsd_weight, + ligand_weight=self.ligand_rmsd_weight, + ) + whole_rmsd = whole_rmsd_tensor.item() + except Exception as exc: # noqa: BLE001 + print(f"Weighted RMSD (MHC Chain) failed for sample {sample_idx}: {exc}") + + align_mask_row = heavy_mask_full[struct_idx : struct_idx + 1] + calc_mask_row = peptide_mask_full[struct_idx : struct_idx + 1] + try: + if ( + calc_mask_row.float().sum() >= 3 + and align_mask_row.float().sum() >= 3 + ): + masked_rmsd_tensor, _, _ = weighted_minimum_rmsd_single( + pred_sample, + ref_sample, + struct_atom_mask, + struct_atom_to_token, + struct_mol_type, + nucleotide_weight=self.nucleotide_rmsd_weight, + ligand_weight=self.ligand_rmsd_weight, + alignment_mask=align_mask_row, + rmsd_mask=calc_mask_row, + ) + masked_rmsd = masked_rmsd_tensor.item() + except Exception as exc: # noqa: BLE001 + print(f"Weighted RMSD (masked) failed for sample {sample_idx}: {exc}") + + print(f"Sample {sample_idx} weighted RMSD (MHC Chain): {whole_rmsd:.3f}Å") + if math.isnan(masked_rmsd): + print(f"Sample {sample_idx} weighted RMSD (masked region): nan") + else: + print( + f"Sample {sample_idx} weighted RMSD (masked region): {masked_rmsd:.3f}Å" + ) + + sample_metrics.append( + SampleMetrics( + sample_idx=sample_idx, + rmsd_whole=whole_rmsd, + rmsd_masked=masked_rmsd, + ) + ) + + return sample_metrics def _extract_record_ids(self, batch: dict[str, Tensor], batch_idx: int) -> list[str]: record_ids = batch.get("record_id", None) @@ -775,63 +979,49 @@ def _extract_record_ids(self, batch: dict[str, Tensor], batch_idx: int) -> list[ return [str(r.item()) for r in record_ids] return [str(r) for r in record_ids] - @staticmethod - def _extract_structure_paths(batch: dict[str, Tensor], count: int) -> list[Optional[str]]: - structure_paths = batch.get("structure_path", None) - if structure_paths is None: - return [None] * count - return [str(path) if path is not None else None for path in structure_paths] + def debug_peptide_mask_residues( + self, + base_structures: list[Optional[Structure]], + record_ids: list[str], + peptide_calc_mask: torch.Tensor, + precomputed_rmsd_masks: list[Optional[np.ndarray]], + ): + def _labels(structure: Structure, mask) -> list[str]: + if structure is None or not hasattr(structure, "residues"): + return [] + if isinstance(mask, torch.Tensor): + mask = mask.detach().cpu().numpy() + if mask is None or mask.size == 0: + return [] + + res_arr = structure.residues + names = set(getattr(res_arr, "dtype", ()).names or []) + if not {"atom_idx", "atom_num", "name"}.issubset(names): + return [] + + resnum_field = "res_num" if "res_num" in names else ("res_idx" if "res_idx" in names else None) + out = [] + L = mask.shape[0] + for i, r in enumerate(res_arr): + s = int(r["atom_idx"]); e = s + int(r["atom_num"]) + if s < L and mask[s:e].any(): + rname = r["name"].decode("utf-8", "ignore") if isinstance(r["name"], (bytes, bytearray)) else str(r["name"]) + rnum = int(r[resnum_field]) if resnum_field else (i + 1) + out.append(f"{rname}{rnum}") + return out + + for i, (structure, rec_id) in enumerate(zip(base_structures, record_ids)): + if structure is None: + continue + pred_mask = peptide_calc_mask[i].detach().cpu().numpy() + ref_mask = precomputed_rmsd_masks[i] if i < len(precomputed_rmsd_masks) else None - def _load_structure_from_npz(self, path: Path) -> Optional[Structure]: - try: - data = np.load(path) - except FileNotFoundError: - print(f"Warning: structure file not found at '{path}'.") - return None - except Exception as e: - print(f"Warning: failed to read structure file '{path}': {e}") - return None - - chains = data["chains"] - if "cyclic_period" not in chains.dtype.names: - new_dtype = chains.dtype.descr + [("cyclic_period", "i4")] - new_chains = np.empty(chains.shape, dtype=new_dtype) - for name in chains.dtype.names: - new_chains[name] = chains[name] - new_chains["cyclic_period"] = 0 - chains = new_chains + pred_list = _labels(structure, pred_mask) + ref_list = _labels(structure, ref_mask) - try: - structure = Structure( - atoms=data["atoms"], - bonds=data["bonds"], - residues=data["residues"], - chains=chains, - connections=data["connections"].astype(Connection), - interfaces=data["interfaces"], - mask=data["mask"], - ) - except Exception as e: - print(f"Warning: failed to construct Structure from '{path}': {e}") - return None - - try: - return structure.remove_invalid_chains() - except Exception as e: - print(f"Warning: failed to clean structure from '{path}': {e}") - return None - - def _load_structures(self, structure_paths: list[Optional[str]]) -> list[Optional[Structure]]: - structures: list[Optional[Structure]] = [] - for path_str in structure_paths: - if path_str is None: - structures.append(None) - continue - structure = self._load_structure_from_npz(Path(path_str)) - if structure is None: - print(f"Skipping CIF write for record path '{path_str}' (structure load failed).") - structures.append(structure) - return structures + pred_str = ", ".join(pred_list) if pred_list else "None" + ref_str = ", ".join(ref_list) if ref_list else "None" + print(f"[debug-peptide-residues] record {rec_id}\n REF: {ref_str}\n PRD: {pred_str}") def validation_step(self, batch: dict[str, Tensor], batch_idx: int): # Compute the forward pass @@ -854,58 +1044,254 @@ def validation_step(self, batch: dict[str, Tensor], batch_idx: int): raise e try: - true_coords, _, _, _ = self.get_true_coordinates( + record_ids = self._extract_record_ids(batch, batch_idx) + + batch_size = batch["atom_to_token"].shape[0] + + base_structures = batch.get("base_structure", None) + if base_structures is None: + raise ValueError("Validation batch missing 'base_structure'. Ensure dataloader supplies preloaded structures.") + precomputed_align_masks: list[Optional[np.ndarray]] + precomputed_rmsd_masks: list[Optional[np.ndarray]] + + if base_structures is None: + raise ValueError("Validation requires 'base_structure' in batch to avoid disk reloads.") + if not isinstance(base_structures, list): + base_structures = [base_structures] + precomputed_align_masks = [None] * len(base_structures) + precomputed_rmsd_masks = [None] * len(base_structures) + + align_tensor = batch.get("alignment_mask") + if isinstance(align_tensor, torch.Tensor): + align_tensor = align_tensor.bool() + for idx in range(min(align_tensor.shape[0], len(base_structures))): + precomputed_align_masks[idx] = align_tensor[idx].detach().cpu().numpy() + + rmsd_tensor = batch.get("rmsd_mask") + if isinstance(rmsd_tensor, torch.Tensor): + rmsd_tensor = rmsd_tensor.bool() + for idx in range(min(rmsd_tensor.shape[0], len(base_structures))): + precomputed_rmsd_masks[idx] = rmsd_tensor[idx].detach().cpu().numpy() + + heavy_calc_mask, peptide_calc_mask = self._build_alignment_masks( + batch=batch, + record_ids=record_ids, + base_structures=base_structures, + batch_idx=batch_idx, + precomputed_align_masks=precomputed_align_masks, + precomputed_rmsd_masks=precomputed_rmsd_masks, + ) + + if getattr(self.validation_args, "debug_peptide_mask_info", False): + self.debug_peptide_mask_residues(base_structures, record_ids, peptide_calc_mask, precomputed_rmsd_masks) + + true_coords, rmsds, best_rmsds, true_coords_resolved_mask, masked_rmsds, best_masked_rmsds = self.get_true_coordinates( batch=batch, out=out, diffusion_samples=n_samples, symmetry_correction=self.validation_args.symmetry_correction, + alignment_mask=heavy_calc_mask, + rmsd_mask=peptide_calc_mask, ) - record_ids = self._extract_record_ids(batch, batch_idx) + boundaries = torch.linspace(2, 22.0, 63, device=out["pdistogram"].device) + lower = torch.tensor([1.0], device=boundaries.device) + upper = torch.tensor([27.0], device=boundaries.device) + exp_boundaries = torch.cat((lower, boundaries, upper)) + mid_points = ((exp_boundaries[:-1] + exp_boundaries[1:]) / 2).to(out["pdistogram"]) + + preds = out["pdistogram"] + pred_softmax = torch.softmax(preds, dim=-1) + pred_max = pred_softmax.argmax(dim=-1) + pred_one_hot = torch.nn.functional.one_hot(pred_max, num_classes=preds.shape[-1]) + pred_dist = (pred_one_hot * mid_points).sum(dim=-1) + true_center = batch["disto_center"] + true_dists = torch.cdist(true_center, true_center) + + disto_lddt_dict, disto_total_dict = factored_token_lddt_dist_loss( + feats=batch, + true_d=true_dists, + pred_d=pred_dist, + ) + batch["token_disto_mask"] = batch["token_disto_mask"] + + all_lddt_dict, all_total_dict = factored_lddt_loss( + feats=batch, + atom_mask=true_coords_resolved_mask, + true_atom_coords=true_coords, + pred_atom_coords=out["sample_atom_coords"], + multiplicity=n_samples, + ) - structure_paths = self._extract_structure_paths(batch, len(record_ids)) + best_lddt_dict, best_total_dict = {}, {} + best_complex_lddt_dict, best_complex_total_dict = {}, {} + B = true_coords.shape[0] // n_samples + if n_samples > 1: + complex_total = 0 + complex_lddt = 0 + for key in all_lddt_dict.keys(): + complex_lddt += all_lddt_dict[key] * all_total_dict[key] + complex_total += all_total_dict[key] + complex_lddt /= complex_total + 1e-7 + best_complex_idx = complex_lddt.reshape(-1, n_samples).argmax(dim=1) + for key in all_lddt_dict: + reshaped_lddt = all_lddt_dict[key].reshape(-1, n_samples) + reshaped_total = all_total_dict[key].reshape(-1, n_samples) + best_idx = reshaped_lddt.argmax(dim=1) + best_lddt_dict[key] = reshaped_lddt[torch.arange(B), best_idx] + best_total_dict[key] = reshaped_total[torch.arange(B), best_idx] + best_complex_lddt_dict[key] = reshaped_lddt[ + torch.arange(B), best_complex_idx + ] + best_complex_total_dict[key] = reshaped_total[ + torch.arange(B), best_complex_idx + ] + else: + best_lddt_dict = all_lddt_dict + best_total_dict = all_total_dict + best_complex_lddt_dict = all_lddt_dict + best_complex_total_dict = all_total_dict + + def _allowed_metric(key: str) -> bool: + low = key.lower() + return ("dna" not in low) and ("rna" not in low) and ("ligand" not in low) + + if self.confidence_prediction and n_samples > 1: + mae_plddt_dict, total_mae_plddt_dict = compute_plddt_mae( + pred_atom_coords=out["sample_atom_coords"], + feats=batch, + true_atom_coords=true_coords, + pred_lddt=out["plddt"], + true_coords_resolved_mask=true_coords_resolved_mask, + multiplicity=n_samples, + ) + mae_pde_dict, total_mae_pde_dict = compute_pde_mae( + pred_atom_coords=out["sample_atom_coords"], + feats=batch, + true_atom_coords=true_coords, + pred_pde=out["pde"], + true_coords_resolved_mask=true_coords_resolved_mask, + multiplicity=n_samples, + ) + mae_pae_dict, total_mae_pae_dict = compute_pae_mae( + pred_atom_coords=out["sample_atom_coords"], + feats=batch, + true_atom_coords=true_coords, + pred_pae=out["pae"], + true_coords_resolved_mask=true_coords_resolved_mask, + multiplicity=n_samples, + ) - base_structures = self._load_structures(structure_paths) + plddt = out["complex_plddt"].reshape(-1, n_samples) + top1_idx = plddt.argmax(dim=1) + iplddt = out["complex_iplddt"].reshape(-1, n_samples) + iplddt_top1_idx = iplddt.argmax(dim=1) + pde = out["complex_pde"].reshape(-1, n_samples) + pde_top1_idx = pde.argmin(dim=1) + ipde = out["complex_ipde"].reshape(-1, n_samples) + ipde_top1_idx = ipde.argmin(dim=1) + ptm = out["ptm"].reshape(-1, n_samples) + ptm_top1_idx = ptm.argmax(dim=1) + iptm = out["iptm"].reshape(-1, n_samples) + iptm_top1_idx = iptm.argmax(dim=1) + ligand_iptm = out["ligand_iptm"].reshape(-1, n_samples) + ligand_iptm_top1_idx = ligand_iptm.argmax(dim=1) + protein_iptm = out["protein_iptm"].reshape(-1, n_samples) + protein_iptm_top1_idx = protein_iptm.argmax(dim=1) + + for key in all_lddt_dict: + if not _allowed_metric(key): + continue + reshaped_lddt = all_lddt_dict[key].reshape(-1, n_samples) + reshaped_total = all_total_dict[key].reshape(-1, n_samples) + + self.top1_lddt[key].update( + reshaped_lddt[torch.arange(B), top1_idx], + reshaped_total[torch.arange(B), top1_idx], + ) + self.iplddt_top1_lddt[key].update( + reshaped_lddt[torch.arange(B), iplddt_top1_idx], + reshaped_total[torch.arange(B), iplddt_top1_idx], + ) + self.pde_top1_lddt[key].update( + reshaped_lddt[torch.arange(B), pde_top1_idx], + reshaped_total[torch.arange(B), pde_top1_idx], + ) + self.ipde_top1_lddt[key].update( + reshaped_lddt[torch.arange(B), ipde_top1_idx], + reshaped_total[torch.arange(B), ipde_top1_idx], + ) + self.ptm_top1_lddt[key].update( + reshaped_lddt[torch.arange(B), ptm_top1_idx], + reshaped_total[torch.arange(B), ptm_top1_idx], + ) + self.iptm_top1_lddt[key].update( + reshaped_lddt[torch.arange(B), iptm_top1_idx], + reshaped_total[torch.arange(B), iptm_top1_idx], + ) + self.ligand_iptm_top1_lddt[key].update( + reshaped_lddt[torch.arange(B), ligand_iptm_top1_idx], + reshaped_total[torch.arange(B), ligand_iptm_top1_idx], + ) + self.protein_iptm_top1_lddt[key].update( + reshaped_lddt[torch.arange(B), protein_iptm_top1_idx], + reshaped_total[torch.arange(B), protein_iptm_top1_idx], + ) - batch_size = batch['atom_to_token'].shape[0] - _, _, peptide_calc_mask = self._build_alignment_masks( - batch=batch, - record_ids=record_ids, - base_structures=base_structures, - batch_idx=batch_idx, + self.avg_lddt[key].update(all_lddt_dict[key], all_total_dict[key]) + self.pde_mae[key].update(mae_pde_dict[key], total_mae_pde_dict[key]) + self.pae_mae[key].update(mae_pae_dict[key], total_mae_pae_dict[key]) + + for key in mae_plddt_dict: + if not _allowed_metric(key): + continue + self.plddt_mae[key].update( + mae_plddt_dict[key], total_mae_plddt_dict[key] + ) + + def _update_metric_collection(collection, values_dict, totals_dict): + for key in values_dict: + if not _allowed_metric(key): + continue + collection[key].update(values_dict[key], totals_dict[key]) + + _update_metric_collection(self.lddt, best_lddt_dict, best_total_dict) + _update_metric_collection(self.disto_lddt, disto_lddt_dict, disto_total_dict) + _update_metric_collection( + self.complex_lddt, best_complex_lddt_dict, best_complex_total_dict ) + self.rmsd.update(rmsds) + self.best_rmsd.update(best_rmsds) - sample_metrics = compute_weighted_mhc_rmsds( + sample_metrics = self._compute_sample_metrics( + batch=batch, out=out, true_coords=true_coords, - batch=batch, - peptide_mask=peptide_calc_mask, + heavy_calc_mask=heavy_calc_mask, + peptide_calc_mask=peptide_calc_mask, n_samples=n_samples, - nucleotide_weight=self.nucleotide_rmsd_weight, - ligand_weight=self.ligand_rmsd_weight, ) whole_values = [m.rmsd_whole for m in sample_metrics if not math.isnan(m.rmsd_whole)] - peptide_values = [m.rmsd_peptide for m in sample_metrics if not math.isnan(m.rmsd_peptide)] + masked_values = [m.rmsd_masked for m in sample_metrics if not math.isnan(m.rmsd_masked)] if whole_values: whole_mean = torch.tensor(whole_values, device=out["sample_atom_coords"].device, dtype=torch.float32).mean() self.log("val/weighted_rmsd_whole", whole_mean, prog_bar=False, sync_dist=True, batch_size=batch_size) - if peptide_values: - peptide_mean = torch.tensor(peptide_values, device=out["sample_atom_coords"].device, dtype=torch.float32).mean() - self.log("val/weighted_rmsd_peptide", peptide_mean, prog_bar=False, sync_dist=True, batch_size=batch_size) + if masked_values: + masked_mean = torch.tensor(masked_values, device=out["sample_atom_coords"].device, dtype=torch.float32).mean() + self.log("val/weighted_rmsd_masked", masked_mean, prog_bar=False, sync_dist=True, batch_size=batch_size) - output_dir = Path(getattr(self.validation_args, "val_cif_out_dir", "validation_outputs")) - write_validation_predictions( - out=out, + self._write_validation_cifs( batch=batch, + out=out, base_structures=base_structures, record_ids=record_ids, sample_metrics=sample_metrics, n_samples=n_samples, - output_dir=output_dir, + batch_idx=batch_idx, )