diff --git a/README.md b/README.md index 95b1a35d1..83b242d0f 100644 --- a/README.md +++ b/README.md @@ -49,6 +49,10 @@ boltz predict input_path --use_msa_server ### Binding Affinity Prediction There are two main predictions in the affinity output: `affinity_pred_value` and `affinity_probability_binary`. They are trained on largely different datasets, with different supervisions, and should be used in different contexts. The `affinity_probability_binary` field should be used to detect binders from decoys, for example in a hit-discovery stage. It's value ranges from 0 to 1 and represents the predicted probability that the ligand is a binder. The `affinity_pred_value` aims to measure the specific affinity of different binders and how this changes with small modifications of the molecule. This should be used in ligand optimization stages such as hit-to-lead and lead-optimization. It reports a binding affinity value as `log(IC50)`, derived from an `IC50` measured in `μM`. More details on how to run affinity predictions and parse the output can be found in our [prediction instructions](docs/prediction.md). +## Authentication to MSA Server + +When using the `--use_msa_server` option with a server that requires authentication, you can provide credentials in one of two ways. More information is available in our [prediction instructions](docs/prediction.md). + ## Evaluation ⚠️ **Coming soon: updated evaluation code for Boltz-2!** diff --git a/docs/prediction.md b/docs/prediction.md index 15d31829c..2024d63b3 100644 --- a/docs/prediction.md +++ b/docs/prediction.md @@ -21,6 +21,51 @@ Before diving into more details about the input formats, here are the key differ | Pocket conditioning | :x: | :white_check_mark: | | Affinity | :x: | :white_check_mark: | +## Authentication to MSA Server + +When using the `--use_msa_server` option with a server that requires authentication, you can provide credentials in one of two ways: + +### 1. Basic Authentication + +- Use the CLI options `--msa_server_username` and `--msa_server_password`. +- Or, set the environment variables: + - `BOLTZ_MSA_USERNAME` (for the username) + - `BOLTZ_MSA_PASSWORD` (for the password, recommended for security) + +**Example:** +```bash +export BOLTZ_MSA_USERNAME=myuser +export BOLTZ_MSA_PASSWORD=mypassword +boltz predict ... --use_msa_server +``` +Or: +```bash +boltz predict ... --use_msa_server --msa_server_username myuser --msa_server_password mypassword +``` + +### 2. API Key Authentication + +- Use the CLI options `--api_key_header` (default: `X-API-Key`) and `--api_key_value` to specify the header and value for API key authentication. +- Or, set the API key value via the environment variable `MSA_API_KEY_VALUE` (recommended for security). + +**Example using CLI:** +```bash +boltz predict ... --use_msa_server --api_key_header X-API-Key --api_key_value +``` + +**Example using environment variable:** +```bash +export MSA_API_KEY_VALUE= +boltz predict ... --use_msa_server --api_key_header X-API-Key +``` +If both the CLI option and environment variable are set, the CLI option takes precedence. + +> If your server expects a different header, set `--api_key_header` accordingly (e.g., `--api_key_header X-Gravitee-Api-Key`). + +--- + +**Note:** +Only one authentication method (basic or API key) can be used at a time. If both are provided, the program will raise an error. ## YAML format diff --git a/src/boltz/data/msa/mmseqs2.py b/src/boltz/data/msa/mmseqs2.py index f92f02f37..f2aab6e8e 100644 --- a/src/boltz/data/msa/mmseqs2.py +++ b/src/boltz/data/msa/mmseqs2.py @@ -5,9 +5,10 @@ import random import tarfile import time -from typing import Union +from typing import Optional, Union, Dict import requests +from requests.auth import HTTPBasicAuth from tqdm import tqdm logger = logging.getLogger(__name__) @@ -25,13 +26,42 @@ def run_mmseqs2( # noqa: PLR0912, D103, C901, PLR0915 use_pairing: bool = False, pairing_strategy: str = "greedy", host_url: str = "https://api.colabfold.com", + msa_server_username: Optional[str] = None, + msa_server_password: Optional[str] = None, + auth_headers: Optional[Dict[str, str]] = None, ) -> tuple[list[str], list[str]]: submission_endpoint = "ticket/pair" if use_pairing else "ticket/msa" + # Validate mutually exclusive authentication methods + has_basic_auth = msa_server_username and msa_server_password + has_header_auth = auth_headers is not None + if has_basic_auth and (has_header_auth or auth_headers): + raise ValueError( + "Cannot use both basic authentication (username/password) and header/API key authentication. " + "Please use only one authentication method." + ) + # Set header agent as boltz headers = {} headers["User-Agent"] = "boltz" + # Set up authentication + auth = None + if has_basic_auth: + auth = HTTPBasicAuth(msa_server_username, msa_server_password) + logger.debug(f"MMSeqs2 server authentication: using basic auth for user '{msa_server_username}'") + elif has_header_auth: + headers.update(auth_headers) + logger.debug("MMSeqs2 server authentication: using header-based authentication") + else: + logger.debug("MMSeqs2 server authentication: no credentials provided") + + logger.debug(f"Connecting to MMSeqs2 server at: {host_url}") + logger.debug(f"Using endpoint: {submission_endpoint}") + logger.debug(f"Pairing strategy: {pairing_strategy}") + logger.debug(f"Use environment databases: {use_env}") + logger.debug(f"Use filtering: {use_filter}") + def submit(seqs, mode, N=101): n, query = N, "" for seq in seqs: @@ -43,12 +73,15 @@ def submit(seqs, mode, N=101): try: # https://requests.readthedocs.io/en/latest/user/advanced/#advanced # "good practice to set connect timeouts to slightly larger than a multiple of 3" + logger.debug(f"Submitting MSA request to {host_url}/{submission_endpoint}") res = requests.post( f"{host_url}/{submission_endpoint}", data={"q": query, "mode": mode}, timeout=6.02, headers=headers, + auth=auth, ) + logger.debug(f"MSA submission response status: {res.status_code}") except Exception as e: error_count += 1 logger.warning( @@ -74,9 +107,11 @@ def status(ID): error_count = 0 while True: try: + logger.debug(f"Checking MSA job status for ID: {ID}") res = requests.get( - f"{host_url}/ticket/{ID}", timeout=6.02, headers=headers + f"{host_url}/ticket/{ID}", timeout=6.02, headers=headers, auth=auth ) + logger.debug(f"MSA status check response status: {res.status_code}") except Exception as e: error_count += 1 logger.warning( @@ -101,9 +136,11 @@ def download(ID, path): error_count = 0 while True: try: + logger.debug(f"Downloading MSA results for ID: {ID}") res = requests.get( - f"{host_url}/result/download/{ID}", timeout=6.02, headers=headers + f"{host_url}/result/download/{ID}", timeout=6.02, headers=headers, auth=auth ) + logger.debug(f"MSA download response status: {res.status_code}") except Exception as e: error_count += 1 logger.warning( @@ -186,6 +223,7 @@ def download(ID, path): # wait for job to finish ID, TIME = out["id"], 0 + logger.debug(f"MSA job submitted successfully with ID: {ID}") pbar.set_description(out["status"]) while out["status"] in ["UNKNOWN", "RUNNING", "PENDING"]: t = 5 + random.randint(0, 5) @@ -198,6 +236,7 @@ def download(ID, path): pbar.update(n=t) if out["status"] == "COMPLETE": + logger.debug(f"MSA job completed successfully for ID: {ID}") if TIME < TIME_ESTIMATE: pbar.update(n=(TIME_ESTIMATE - TIME)) REDO = False diff --git a/src/boltz/main.py b/src/boltz/main.py index 6893bf595..123050914 100644 --- a/src/boltz/main.py +++ b/src/boltz/main.py @@ -417,6 +417,10 @@ def compute_msa( msa_dir: Path, msa_server_url: str, msa_pairing_strategy: str, + msa_server_username: Optional[str] = None, + msa_server_password: Optional[str] = None, + api_key_header: Optional[str] = None, + api_key_value: Optional[str] = None, ) -> None: """Compute the MSA for the input data. @@ -432,8 +436,35 @@ def compute_msa( The MSA server URL. msa_pairing_strategy : str The MSA pairing strategy. + msa_server_username : str, optional + Username for basic authentication with MSA server. + msa_server_password : str, optional + Password for basic authentication with MSA server. + api_key_header : str, optional + Custom header key for API key authentication (default: X-API-Key). + api_key_value : str, optional + Custom header value for API key authentication (overrides --api_key if set). """ + click.echo(f"Calling MSA server for target {target_id} with {len(data)} sequences") + click.echo(f"MSA server URL: {msa_server_url}") + click.echo(f"MSA pairing strategy: {msa_pairing_strategy}") + + # Construct auth headers if API key header/value is provided + auth_headers = None + if api_key_value: + key = api_key_header if api_key_header else "X-API-Key" + value = api_key_value + auth_headers = { + "Content-Type": "application/json", + key: value + } + click.echo(f"Using API key authentication for MSA server (header: {key})") + elif msa_server_username and msa_server_password: + click.echo("Using basic authentication for MSA server") + else: + click.echo("No authentication provided for MSA server") + if len(data) > 1: paired_msas = run_mmseqs2( list(data.values()), @@ -442,6 +473,9 @@ def compute_msa( use_pairing=True, host_url=msa_server_url, pairing_strategy=msa_pairing_strategy, + msa_server_username=msa_server_username, + msa_server_password=msa_server_password, + auth_headers=auth_headers, ) else: paired_msas = [""] * len(data) @@ -453,6 +487,9 @@ def compute_msa( use_pairing=False, host_url=msa_server_url, pairing_strategy=msa_pairing_strategy, + msa_server_username=msa_server_username, + msa_server_password=msa_server_password, + auth_headers=auth_headers, ) for idx, name in enumerate(data): @@ -493,6 +530,10 @@ def process_input( # noqa: C901, PLR0912, PLR0915, D103 use_msa_server: bool, msa_server_url: str, msa_pairing_strategy: str, + msa_server_username: Optional[str], + msa_server_password: Optional[str], + api_key_header: Optional[str], + api_key_value: Optional[str], max_msa_seqs: int, processed_msa_dir: Path, processed_constraints_dir: Path, @@ -549,6 +590,10 @@ def process_input( # noqa: C901, PLR0912, PLR0915, D103 msa_dir=msa_dir, msa_server_url=msa_server_url, msa_pairing_strategy=msa_pairing_strategy, + msa_server_username=msa_server_username, + msa_server_password=msa_server_password, + api_key_header=api_key_header, + api_key_value=api_key_value, ) # Parse MSA data @@ -625,6 +670,10 @@ def process_inputs( msa_pairing_strategy: str, max_msa_seqs: int = 8192, use_msa_server: bool = False, + msa_server_username: Optional[str] = None, + msa_server_password: Optional[str] = None, + api_key_header: Optional[str] = None, + api_key_value: Optional[str] = None, boltz2: bool = False, preprocessing_threads: int = 1, ) -> Manifest: @@ -642,6 +691,14 @@ def process_inputs( Max number of MSA sequences, by default 4096. use_msa_server : bool, optional Whether to use the MMSeqs2 server for MSA generation, by default False. + msa_server_username : str, optional + Username for basic authentication with MSA server, by default None. + msa_server_password : str, optional + Password for basic authentication with MSA server, by default None. + api_key_header : str, optional + Custom header key for API key authentication (default: X-API-Key). + api_key_value : str, optional + Custom header value for API key authentication (overrides --api_key if set). boltz2: bool, optional Whether to use Boltz2, by default False. preprocessing_threads: int, optional @@ -653,6 +710,16 @@ def process_inputs( The manifest of the processed input data. """ + # Validate mutually exclusive authentication methods + has_basic_auth = msa_server_username and msa_server_password + has_api_key = api_key_value is not None + + if has_basic_auth and has_api_key: + raise ValueError( + "Cannot use both basic authentication (--msa_server_username/--msa_server_password) " + "and API key authentication (--api_key_header/--api_key_value). Please use only one authentication method." + ) + # Check if records exist at output path records_dir = out_dir / "processed" / "records" if records_dir.exists(): @@ -710,6 +777,10 @@ def process_inputs( use_msa_server=use_msa_server, msa_server_url=msa_server_url, msa_pairing_strategy=msa_pairing_strategy, + msa_server_username=msa_server_username, + msa_server_password=msa_server_password, + api_key_header=api_key_header, + api_key_value=api_key_value, max_msa_seqs=max_msa_seqs, processed_msa_dir=processed_msa_dir, processed_constraints_dir=processed_constraints_dir, @@ -869,6 +940,30 @@ def cli() -> None: ), default="greedy", ) +@click.option( + "--msa_server_username", + type=str, + help="MSA server username for basic auth. Used only if --use_msa_server is set. Can also be set via BOLTZ_MSA_USERNAME environment variable.", + default=None, +) +@click.option( + "--msa_server_password", + type=str, + help="MSA server password for basic auth. Used only if --use_msa_server is set. Can also be set via BOLTZ_MSA_PASSWORD environment variable.", + default=None, +) +@click.option( + "--api_key_header", + type=str, + help="Custom header key for API key authentication (default: X-API-Key).", + default=None, +) +@click.option( + "--api_key_value", + type=str, + help="Custom header value for API key authentication.", + default=None, +) @click.option( "--use_potentials", is_flag=True, @@ -967,6 +1062,10 @@ def predict( # noqa: C901, PLR0915, PLR0912 use_msa_server: bool = False, msa_server_url: str = "https://api.colabfold.com", msa_pairing_strategy: str = "greedy", + msa_server_username: Optional[str] = None, + msa_server_password: Optional[str] = None, + api_key_header: Optional[str] = None, + api_key_value: Optional[str] = None, use_potentials: bool = False, model: Literal["boltz1", "boltz2"] = "boltz2", method: Optional[str] = None, @@ -1011,6 +1110,23 @@ def predict( # noqa: C901, PLR0915, PLR0912 cache = Path(cache).expanduser() cache.mkdir(parents=True, exist_ok=True) + # Get MSA server credentials from environment variables if not provided + if use_msa_server: + if msa_server_username is None: + msa_server_username = os.environ.get("BOLTZ_MSA_USERNAME") + if msa_server_password is None: + msa_server_password = os.environ.get("BOLTZ_MSA_PASSWORD") + if api_key_value is None: + api_key_value = os.environ.get("MSA_API_KEY_VALUE") + + click.echo(f"MSA server enabled: {msa_server_url}") + if api_key_value: + click.echo("MSA server authentication: using API key header") + elif msa_server_username and msa_server_password: + click.echo("MSA server authentication: using basic auth") + else: + click.echo("MSA server authentication: no credentials provided") + # Create output directories data = Path(data).expanduser() out_dir = Path(out_dir).expanduser() @@ -1050,6 +1166,10 @@ def predict( # noqa: C901, PLR0915, PLR0912 use_msa_server=use_msa_server, msa_server_url=msa_server_url, msa_pairing_strategy=msa_pairing_strategy, + msa_server_username=msa_server_username, + msa_server_password=msa_server_password, + api_key_header=api_key_header, + api_key_value=api_key_value, boltz2=model == "boltz2", preprocessing_threads=preprocessing_threads, max_msa_seqs=max_msa_seqs,