Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,10 @@ boltz predict input_path --use_msa_server
### Binding Affinity Prediction
There are two main predictions in the affinity output: `affinity_pred_value` and `affinity_probability_binary`. They are trained on largely different datasets, with different supervisions, and should be used in different contexts. The `affinity_probability_binary` field should be used to detect binders from decoys, for example in a hit-discovery stage. It's value ranges from 0 to 1 and represents the predicted probability that the ligand is a binder. The `affinity_pred_value` aims to measure the specific affinity of different binders and how this changes with small modifications of the molecule. This should be used in ligand optimization stages such as hit-to-lead and lead-optimization. It reports a binding affinity value as `log(IC50)`, derived from an `IC50` measured in `μM`. More details on how to run affinity predictions and parse the output can be found in our [prediction instructions](docs/prediction.md).

## Authentication to MSA Server

When using the `--use_msa_server` option with a server that requires authentication, you can provide credentials in one of two ways. More information is available in our [prediction instructions](docs/prediction.md).

## Evaluation

⚠️ **Coming soon: updated evaluation code for Boltz-2!**
Expand Down
45 changes: 45 additions & 0 deletions docs/prediction.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,51 @@ Before diving into more details about the input formats, here are the key differ
| Pocket conditioning | :x: | :white_check_mark: |
| Affinity | :x: | :white_check_mark: |

## Authentication to MSA Server

When using the `--use_msa_server` option with a server that requires authentication, you can provide credentials in one of two ways:

### 1. Basic Authentication

- Use the CLI options `--msa_server_username` and `--msa_server_password`.
- Or, set the environment variables:
- `BOLTZ_MSA_USERNAME` (for the username)
- `BOLTZ_MSA_PASSWORD` (for the password, recommended for security)

**Example:**
```bash
export BOLTZ_MSA_USERNAME=myuser
export BOLTZ_MSA_PASSWORD=mypassword
boltz predict ... --use_msa_server
```
Or:
```bash
boltz predict ... --use_msa_server --msa_server_username myuser --msa_server_password mypassword
```

### 2. API Key Authentication

- Use the CLI options `--api_key_header` (default: `X-API-Key`) and `--api_key_value` to specify the header and value for API key authentication.
- Or, set the API key value via the environment variable `MSA_API_KEY_VALUE` (recommended for security).

**Example using CLI:**
```bash
boltz predict ... --use_msa_server --api_key_header X-API-Key --api_key_value <your-api-key>
```

**Example using environment variable:**
```bash
export MSA_API_KEY_VALUE=<your-api-key>
boltz predict ... --use_msa_server --api_key_header X-API-Key
```
If both the CLI option and environment variable are set, the CLI option takes precedence.

> If your server expects a different header, set `--api_key_header` accordingly (e.g., `--api_key_header X-Gravitee-Api-Key`).

---

**Note:**
Only one authentication method (basic or API key) can be used at a time. If both are provided, the program will raise an error.


## YAML format
Expand Down
45 changes: 42 additions & 3 deletions src/boltz/data/msa/mmseqs2.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@
import random
import tarfile
import time
from typing import Union
from typing import Optional, Union, Dict

import requests
from requests.auth import HTTPBasicAuth
from tqdm import tqdm

logger = logging.getLogger(__name__)
Expand All @@ -25,13 +26,42 @@ def run_mmseqs2( # noqa: PLR0912, D103, C901, PLR0915
use_pairing: bool = False,
pairing_strategy: str = "greedy",
host_url: str = "https://api.colabfold.com",
msa_server_username: Optional[str] = None,
msa_server_password: Optional[str] = None,
auth_headers: Optional[Dict[str, str]] = None,
) -> tuple[list[str], list[str]]:
submission_endpoint = "ticket/pair" if use_pairing else "ticket/msa"

# Validate mutually exclusive authentication methods
has_basic_auth = msa_server_username and msa_server_password
has_header_auth = auth_headers is not None
if has_basic_auth and (has_header_auth or auth_headers):
raise ValueError(
"Cannot use both basic authentication (username/password) and header/API key authentication. "
"Please use only one authentication method."
)

# Set header agent as boltz
headers = {}
headers["User-Agent"] = "boltz"

# Set up authentication
auth = None
if has_basic_auth:
auth = HTTPBasicAuth(msa_server_username, msa_server_password)
logger.debug(f"MMSeqs2 server authentication: using basic auth for user '{msa_server_username}'")
elif has_header_auth:
headers.update(auth_headers)
logger.debug("MMSeqs2 server authentication: using header-based authentication")
else:
logger.debug("MMSeqs2 server authentication: no credentials provided")

logger.debug(f"Connecting to MMSeqs2 server at: {host_url}")
logger.debug(f"Using endpoint: {submission_endpoint}")
logger.debug(f"Pairing strategy: {pairing_strategy}")
logger.debug(f"Use environment databases: {use_env}")
logger.debug(f"Use filtering: {use_filter}")

def submit(seqs, mode, N=101):
n, query = N, ""
for seq in seqs:
Expand All @@ -43,12 +73,15 @@ def submit(seqs, mode, N=101):
try:
# https://requests.readthedocs.io/en/latest/user/advanced/#advanced
# "good practice to set connect timeouts to slightly larger than a multiple of 3"
logger.debug(f"Submitting MSA request to {host_url}/{submission_endpoint}")
res = requests.post(
f"{host_url}/{submission_endpoint}",
data={"q": query, "mode": mode},
timeout=6.02,
headers=headers,
auth=auth,
)
logger.debug(f"MSA submission response status: {res.status_code}")
except Exception as e:
error_count += 1
logger.warning(
Expand All @@ -74,9 +107,11 @@ def status(ID):
error_count = 0
while True:
try:
logger.debug(f"Checking MSA job status for ID: {ID}")
res = requests.get(
f"{host_url}/ticket/{ID}", timeout=6.02, headers=headers
f"{host_url}/ticket/{ID}", timeout=6.02, headers=headers, auth=auth
)
logger.debug(f"MSA status check response status: {res.status_code}")
except Exception as e:
error_count += 1
logger.warning(
Expand All @@ -101,9 +136,11 @@ def download(ID, path):
error_count = 0
while True:
try:
logger.debug(f"Downloading MSA results for ID: {ID}")
res = requests.get(
f"{host_url}/result/download/{ID}", timeout=6.02, headers=headers
f"{host_url}/result/download/{ID}", timeout=6.02, headers=headers, auth=auth
)
logger.debug(f"MSA download response status: {res.status_code}")
except Exception as e:
error_count += 1
logger.warning(
Expand Down Expand Up @@ -186,6 +223,7 @@ def download(ID, path):

# wait for job to finish
ID, TIME = out["id"], 0
logger.debug(f"MSA job submitted successfully with ID: {ID}")
pbar.set_description(out["status"])
while out["status"] in ["UNKNOWN", "RUNNING", "PENDING"]:
t = 5 + random.randint(0, 5)
Expand All @@ -198,6 +236,7 @@ def download(ID, path):
pbar.update(n=t)

if out["status"] == "COMPLETE":
logger.debug(f"MSA job completed successfully for ID: {ID}")
if TIME < TIME_ESTIMATE:
pbar.update(n=(TIME_ESTIMATE - TIME))
REDO = False
Expand Down
120 changes: 120 additions & 0 deletions src/boltz/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,6 +417,10 @@ def compute_msa(
msa_dir: Path,
msa_server_url: str,
msa_pairing_strategy: str,
msa_server_username: Optional[str] = None,
msa_server_password: Optional[str] = None,
api_key_header: Optional[str] = None,
api_key_value: Optional[str] = None,
) -> None:
"""Compute the MSA for the input data.

Expand All @@ -432,8 +436,35 @@ def compute_msa(
The MSA server URL.
msa_pairing_strategy : str
The MSA pairing strategy.
msa_server_username : str, optional
Username for basic authentication with MSA server.
msa_server_password : str, optional
Password for basic authentication with MSA server.
api_key_header : str, optional
Custom header key for API key authentication (default: X-API-Key).
api_key_value : str, optional
Custom header value for API key authentication (overrides --api_key if set).

"""
click.echo(f"Calling MSA server for target {target_id} with {len(data)} sequences")
click.echo(f"MSA server URL: {msa_server_url}")
click.echo(f"MSA pairing strategy: {msa_pairing_strategy}")

# Construct auth headers if API key header/value is provided
auth_headers = None
if api_key_value:
key = api_key_header if api_key_header else "X-API-Key"
value = api_key_value
auth_headers = {
"Content-Type": "application/json",
key: value
}
click.echo(f"Using API key authentication for MSA server (header: {key})")
elif msa_server_username and msa_server_password:
click.echo("Using basic authentication for MSA server")
else:
click.echo("No authentication provided for MSA server")

if len(data) > 1:
paired_msas = run_mmseqs2(
list(data.values()),
Expand All @@ -442,6 +473,9 @@ def compute_msa(
use_pairing=True,
host_url=msa_server_url,
pairing_strategy=msa_pairing_strategy,
msa_server_username=msa_server_username,
msa_server_password=msa_server_password,
auth_headers=auth_headers,
)
else:
paired_msas = [""] * len(data)
Expand All @@ -453,6 +487,9 @@ def compute_msa(
use_pairing=False,
host_url=msa_server_url,
pairing_strategy=msa_pairing_strategy,
msa_server_username=msa_server_username,
msa_server_password=msa_server_password,
auth_headers=auth_headers,
)

for idx, name in enumerate(data):
Expand Down Expand Up @@ -493,6 +530,10 @@ def process_input( # noqa: C901, PLR0912, PLR0915, D103
use_msa_server: bool,
msa_server_url: str,
msa_pairing_strategy: str,
msa_server_username: Optional[str],
msa_server_password: Optional[str],
api_key_header: Optional[str],
api_key_value: Optional[str],
max_msa_seqs: int,
processed_msa_dir: Path,
processed_constraints_dir: Path,
Expand Down Expand Up @@ -549,6 +590,10 @@ def process_input( # noqa: C901, PLR0912, PLR0915, D103
msa_dir=msa_dir,
msa_server_url=msa_server_url,
msa_pairing_strategy=msa_pairing_strategy,
msa_server_username=msa_server_username,
msa_server_password=msa_server_password,
api_key_header=api_key_header,
api_key_value=api_key_value,
)

# Parse MSA data
Expand Down Expand Up @@ -625,6 +670,10 @@ def process_inputs(
msa_pairing_strategy: str,
max_msa_seqs: int = 8192,
use_msa_server: bool = False,
msa_server_username: Optional[str] = None,
msa_server_password: Optional[str] = None,
api_key_header: Optional[str] = None,
api_key_value: Optional[str] = None,
boltz2: bool = False,
preprocessing_threads: int = 1,
) -> Manifest:
Expand All @@ -642,6 +691,14 @@ def process_inputs(
Max number of MSA sequences, by default 4096.
use_msa_server : bool, optional
Whether to use the MMSeqs2 server for MSA generation, by default False.
msa_server_username : str, optional
Username for basic authentication with MSA server, by default None.
msa_server_password : str, optional
Password for basic authentication with MSA server, by default None.
api_key_header : str, optional
Custom header key for API key authentication (default: X-API-Key).
api_key_value : str, optional
Custom header value for API key authentication (overrides --api_key if set).
boltz2: bool, optional
Whether to use Boltz2, by default False.
preprocessing_threads: int, optional
Expand All @@ -653,6 +710,16 @@ def process_inputs(
The manifest of the processed input data.

"""
# Validate mutually exclusive authentication methods
has_basic_auth = msa_server_username and msa_server_password
has_api_key = api_key_value is not None

if has_basic_auth and has_api_key:
raise ValueError(
"Cannot use both basic authentication (--msa_server_username/--msa_server_password) "
"and API key authentication (--api_key_header/--api_key_value). Please use only one authentication method."
)

# Check if records exist at output path
records_dir = out_dir / "processed" / "records"
if records_dir.exists():
Expand Down Expand Up @@ -710,6 +777,10 @@ def process_inputs(
use_msa_server=use_msa_server,
msa_server_url=msa_server_url,
msa_pairing_strategy=msa_pairing_strategy,
msa_server_username=msa_server_username,
msa_server_password=msa_server_password,
api_key_header=api_key_header,
api_key_value=api_key_value,
max_msa_seqs=max_msa_seqs,
processed_msa_dir=processed_msa_dir,
processed_constraints_dir=processed_constraints_dir,
Expand Down Expand Up @@ -869,6 +940,30 @@ def cli() -> None:
),
default="greedy",
)
@click.option(
"--msa_server_username",
type=str,
help="MSA server username for basic auth. Used only if --use_msa_server is set. Can also be set via BOLTZ_MSA_USERNAME environment variable.",
default=None,
)
@click.option(
"--msa_server_password",
type=str,
help="MSA server password for basic auth. Used only if --use_msa_server is set. Can also be set via BOLTZ_MSA_PASSWORD environment variable.",
default=None,
)
@click.option(
"--api_key_header",
type=str,
help="Custom header key for API key authentication (default: X-API-Key).",
default=None,
)
@click.option(
"--api_key_value",
type=str,
help="Custom header value for API key authentication.",
default=None,
)
@click.option(
"--use_potentials",
is_flag=True,
Expand Down Expand Up @@ -967,6 +1062,10 @@ def predict( # noqa: C901, PLR0915, PLR0912
use_msa_server: bool = False,
msa_server_url: str = "https://api.colabfold.com",
msa_pairing_strategy: str = "greedy",
msa_server_username: Optional[str] = None,
msa_server_password: Optional[str] = None,
api_key_header: Optional[str] = None,
api_key_value: Optional[str] = None,
use_potentials: bool = False,
model: Literal["boltz1", "boltz2"] = "boltz2",
method: Optional[str] = None,
Expand Down Expand Up @@ -1011,6 +1110,23 @@ def predict( # noqa: C901, PLR0915, PLR0912
cache = Path(cache).expanduser()
cache.mkdir(parents=True, exist_ok=True)

# Get MSA server credentials from environment variables if not provided
if use_msa_server:
if msa_server_username is None:
msa_server_username = os.environ.get("BOLTZ_MSA_USERNAME")
if msa_server_password is None:
msa_server_password = os.environ.get("BOLTZ_MSA_PASSWORD")
if api_key_value is None:
api_key_value = os.environ.get("MSA_API_KEY_VALUE")

click.echo(f"MSA server enabled: {msa_server_url}")
if api_key_value:
click.echo("MSA server authentication: using API key header")
elif msa_server_username and msa_server_password:
click.echo("MSA server authentication: using basic auth")
else:
click.echo("MSA server authentication: no credentials provided")

# Create output directories
data = Path(data).expanduser()
out_dir = Path(out_dir).expanduser()
Expand Down Expand Up @@ -1050,6 +1166,10 @@ def predict( # noqa: C901, PLR0915, PLR0912
use_msa_server=use_msa_server,
msa_server_url=msa_server_url,
msa_pairing_strategy=msa_pairing_strategy,
msa_server_username=msa_server_username,
msa_server_password=msa_server_password,
api_key_header=api_key_header,
api_key_value=api_key_value,
boltz2=model == "boltz2",
preprocessing_threads=preprocessing_threads,
max_msa_seqs=max_msa_seqs,
Expand Down