Skip to content

Commit 93cdd93

Browse files
authored
Merge pull request jwohlwend#466 from papagala/main
Adding MSA server security
2 parents 8bef88b + a5043ad commit 93cdd93

4 files changed

Lines changed: 211 additions & 3 deletions

File tree

README.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,10 @@ boltz predict input_path --use_msa_server
4949
### Binding Affinity Prediction
5050
There are two main predictions in the affinity output: `affinity_pred_value` and `affinity_probability_binary`. They are trained on largely different datasets, with different supervisions, and should be used in different contexts. The `affinity_probability_binary` field should be used to detect binders from decoys, for example in a hit-discovery stage. It's value ranges from 0 to 1 and represents the predicted probability that the ligand is a binder. The `affinity_pred_value` aims to measure the specific affinity of different binders and how this changes with small modifications of the molecule. This should be used in ligand optimization stages such as hit-to-lead and lead-optimization. It reports a binding affinity value as `log(IC50)`, derived from an `IC50` measured in `μM`. More details on how to run affinity predictions and parse the output can be found in our [prediction instructions](docs/prediction.md).
5151

52+
## Authentication to MSA Server
53+
54+
When using the `--use_msa_server` option with a server that requires authentication, you can provide credentials in one of two ways. More information is available in our [prediction instructions](docs/prediction.md).
55+
5256
## Evaluation
5357

5458
⚠️ **Coming soon: updated evaluation code for Boltz-2!**

docs/prediction.md

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,51 @@ Before diving into more details about the input formats, here are the key differ
2121
| Pocket conditioning | :x: | :white_check_mark: |
2222
| Affinity | :x: | :white_check_mark: |
2323

24+
## Authentication to MSA Server
25+
26+
When using the `--use_msa_server` option with a server that requires authentication, you can provide credentials in one of two ways:
27+
28+
### 1. Basic Authentication
29+
30+
- Use the CLI options `--msa_server_username` and `--msa_server_password`.
31+
- Or, set the environment variables:
32+
- `BOLTZ_MSA_USERNAME` (for the username)
33+
- `BOLTZ_MSA_PASSWORD` (for the password, recommended for security)
34+
35+
**Example:**
36+
```bash
37+
export BOLTZ_MSA_USERNAME=myuser
38+
export BOLTZ_MSA_PASSWORD=mypassword
39+
boltz predict ... --use_msa_server
40+
```
41+
Or:
42+
```bash
43+
boltz predict ... --use_msa_server --msa_server_username myuser --msa_server_password mypassword
44+
```
45+
46+
### 2. API Key Authentication
47+
48+
- Use the CLI options `--api_key_header` (default: `X-API-Key`) and `--api_key_value` to specify the header and value for API key authentication.
49+
- Or, set the API key value via the environment variable `MSA_API_KEY_VALUE` (recommended for security).
50+
51+
**Example using CLI:**
52+
```bash
53+
boltz predict ... --use_msa_server --api_key_header X-API-Key --api_key_value <your-api-key>
54+
```
55+
56+
**Example using environment variable:**
57+
```bash
58+
export MSA_API_KEY_VALUE=<your-api-key>
59+
boltz predict ... --use_msa_server --api_key_header X-API-Key
60+
```
61+
If both the CLI option and environment variable are set, the CLI option takes precedence.
62+
63+
> If your server expects a different header, set `--api_key_header` accordingly (e.g., `--api_key_header X-Gravitee-Api-Key`).
64+
65+
---
66+
67+
**Note:**
68+
Only one authentication method (basic or API key) can be used at a time. If both are provided, the program will raise an error.
2469

2570

2671
## YAML format

src/boltz/data/msa/mmseqs2.py

Lines changed: 42 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,10 @@
55
import random
66
import tarfile
77
import time
8-
from typing import Union
8+
from typing import Optional, Union, Dict
99

1010
import requests
11+
from requests.auth import HTTPBasicAuth
1112
from tqdm import tqdm
1213

1314
logger = logging.getLogger(__name__)
@@ -25,13 +26,42 @@ def run_mmseqs2( # noqa: PLR0912, D103, C901, PLR0915
2526
use_pairing: bool = False,
2627
pairing_strategy: str = "greedy",
2728
host_url: str = "https://api.colabfold.com",
29+
msa_server_username: Optional[str] = None,
30+
msa_server_password: Optional[str] = None,
31+
auth_headers: Optional[Dict[str, str]] = None,
2832
) -> tuple[list[str], list[str]]:
2933
submission_endpoint = "ticket/pair" if use_pairing else "ticket/msa"
3034

35+
# Validate mutually exclusive authentication methods
36+
has_basic_auth = msa_server_username and msa_server_password
37+
has_header_auth = auth_headers is not None
38+
if has_basic_auth and (has_header_auth or auth_headers):
39+
raise ValueError(
40+
"Cannot use both basic authentication (username/password) and header/API key authentication. "
41+
"Please use only one authentication method."
42+
)
43+
3144
# Set header agent as boltz
3245
headers = {}
3346
headers["User-Agent"] = "boltz"
3447

48+
# Set up authentication
49+
auth = None
50+
if has_basic_auth:
51+
auth = HTTPBasicAuth(msa_server_username, msa_server_password)
52+
logger.debug(f"MMSeqs2 server authentication: using basic auth for user '{msa_server_username}'")
53+
elif has_header_auth:
54+
headers.update(auth_headers)
55+
logger.debug("MMSeqs2 server authentication: using header-based authentication")
56+
else:
57+
logger.debug("MMSeqs2 server authentication: no credentials provided")
58+
59+
logger.debug(f"Connecting to MMSeqs2 server at: {host_url}")
60+
logger.debug(f"Using endpoint: {submission_endpoint}")
61+
logger.debug(f"Pairing strategy: {pairing_strategy}")
62+
logger.debug(f"Use environment databases: {use_env}")
63+
logger.debug(f"Use filtering: {use_filter}")
64+
3565
def submit(seqs, mode, N=101):
3666
n, query = N, ""
3767
for seq in seqs:
@@ -43,12 +73,15 @@ def submit(seqs, mode, N=101):
4373
try:
4474
# https://requests.readthedocs.io/en/latest/user/advanced/#advanced
4575
# "good practice to set connect timeouts to slightly larger than a multiple of 3"
76+
logger.debug(f"Submitting MSA request to {host_url}/{submission_endpoint}")
4677
res = requests.post(
4778
f"{host_url}/{submission_endpoint}",
4879
data={"q": query, "mode": mode},
4980
timeout=6.02,
5081
headers=headers,
82+
auth=auth,
5183
)
84+
logger.debug(f"MSA submission response status: {res.status_code}")
5285
except Exception as e:
5386
error_count += 1
5487
logger.warning(
@@ -74,9 +107,11 @@ def status(ID):
74107
error_count = 0
75108
while True:
76109
try:
110+
logger.debug(f"Checking MSA job status for ID: {ID}")
77111
res = requests.get(
78-
f"{host_url}/ticket/{ID}", timeout=6.02, headers=headers
112+
f"{host_url}/ticket/{ID}", timeout=6.02, headers=headers, auth=auth
79113
)
114+
logger.debug(f"MSA status check response status: {res.status_code}")
80115
except Exception as e:
81116
error_count += 1
82117
logger.warning(
@@ -101,9 +136,11 @@ def download(ID, path):
101136
error_count = 0
102137
while True:
103138
try:
139+
logger.debug(f"Downloading MSA results for ID: {ID}")
104140
res = requests.get(
105-
f"{host_url}/result/download/{ID}", timeout=6.02, headers=headers
141+
f"{host_url}/result/download/{ID}", timeout=6.02, headers=headers, auth=auth
106142
)
143+
logger.debug(f"MSA download response status: {res.status_code}")
107144
except Exception as e:
108145
error_count += 1
109146
logger.warning(
@@ -186,6 +223,7 @@ def download(ID, path):
186223

187224
# wait for job to finish
188225
ID, TIME = out["id"], 0
226+
logger.debug(f"MSA job submitted successfully with ID: {ID}")
189227
pbar.set_description(out["status"])
190228
while out["status"] in ["UNKNOWN", "RUNNING", "PENDING"]:
191229
t = 5 + random.randint(0, 5)
@@ -198,6 +236,7 @@ def download(ID, path):
198236
pbar.update(n=t)
199237

200238
if out["status"] == "COMPLETE":
239+
logger.debug(f"MSA job completed successfully for ID: {ID}")
201240
if TIME < TIME_ESTIMATE:
202241
pbar.update(n=(TIME_ESTIMATE - TIME))
203242
REDO = False

src/boltz/main.py

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -417,6 +417,10 @@ def compute_msa(
417417
msa_dir: Path,
418418
msa_server_url: str,
419419
msa_pairing_strategy: str,
420+
msa_server_username: Optional[str] = None,
421+
msa_server_password: Optional[str] = None,
422+
api_key_header: Optional[str] = None,
423+
api_key_value: Optional[str] = None,
420424
) -> None:
421425
"""Compute the MSA for the input data.
422426
@@ -432,8 +436,35 @@ def compute_msa(
432436
The MSA server URL.
433437
msa_pairing_strategy : str
434438
The MSA pairing strategy.
439+
msa_server_username : str, optional
440+
Username for basic authentication with MSA server.
441+
msa_server_password : str, optional
442+
Password for basic authentication with MSA server.
443+
api_key_header : str, optional
444+
Custom header key for API key authentication (default: X-API-Key).
445+
api_key_value : str, optional
446+
Custom header value for API key authentication (overrides --api_key if set).
435447
436448
"""
449+
click.echo(f"Calling MSA server for target {target_id} with {len(data)} sequences")
450+
click.echo(f"MSA server URL: {msa_server_url}")
451+
click.echo(f"MSA pairing strategy: {msa_pairing_strategy}")
452+
453+
# Construct auth headers if API key header/value is provided
454+
auth_headers = None
455+
if api_key_value:
456+
key = api_key_header if api_key_header else "X-API-Key"
457+
value = api_key_value
458+
auth_headers = {
459+
"Content-Type": "application/json",
460+
key: value
461+
}
462+
click.echo(f"Using API key authentication for MSA server (header: {key})")
463+
elif msa_server_username and msa_server_password:
464+
click.echo("Using basic authentication for MSA server")
465+
else:
466+
click.echo("No authentication provided for MSA server")
467+
437468
if len(data) > 1:
438469
paired_msas = run_mmseqs2(
439470
list(data.values()),
@@ -442,6 +473,9 @@ def compute_msa(
442473
use_pairing=True,
443474
host_url=msa_server_url,
444475
pairing_strategy=msa_pairing_strategy,
476+
msa_server_username=msa_server_username,
477+
msa_server_password=msa_server_password,
478+
auth_headers=auth_headers,
445479
)
446480
else:
447481
paired_msas = [""] * len(data)
@@ -453,6 +487,9 @@ def compute_msa(
453487
use_pairing=False,
454488
host_url=msa_server_url,
455489
pairing_strategy=msa_pairing_strategy,
490+
msa_server_username=msa_server_username,
491+
msa_server_password=msa_server_password,
492+
auth_headers=auth_headers,
456493
)
457494

458495
for idx, name in enumerate(data):
@@ -493,6 +530,10 @@ def process_input( # noqa: C901, PLR0912, PLR0915, D103
493530
use_msa_server: bool,
494531
msa_server_url: str,
495532
msa_pairing_strategy: str,
533+
msa_server_username: Optional[str],
534+
msa_server_password: Optional[str],
535+
api_key_header: Optional[str],
536+
api_key_value: Optional[str],
496537
max_msa_seqs: int,
497538
processed_msa_dir: Path,
498539
processed_constraints_dir: Path,
@@ -549,6 +590,10 @@ def process_input( # noqa: C901, PLR0912, PLR0915, D103
549590
msa_dir=msa_dir,
550591
msa_server_url=msa_server_url,
551592
msa_pairing_strategy=msa_pairing_strategy,
593+
msa_server_username=msa_server_username,
594+
msa_server_password=msa_server_password,
595+
api_key_header=api_key_header,
596+
api_key_value=api_key_value,
552597
)
553598

554599
# Parse MSA data
@@ -625,6 +670,10 @@ def process_inputs(
625670
msa_pairing_strategy: str,
626671
max_msa_seqs: int = 8192,
627672
use_msa_server: bool = False,
673+
msa_server_username: Optional[str] = None,
674+
msa_server_password: Optional[str] = None,
675+
api_key_header: Optional[str] = None,
676+
api_key_value: Optional[str] = None,
628677
boltz2: bool = False,
629678
preprocessing_threads: int = 1,
630679
) -> Manifest:
@@ -642,6 +691,14 @@ def process_inputs(
642691
Max number of MSA sequences, by default 4096.
643692
use_msa_server : bool, optional
644693
Whether to use the MMSeqs2 server for MSA generation, by default False.
694+
msa_server_username : str, optional
695+
Username for basic authentication with MSA server, by default None.
696+
msa_server_password : str, optional
697+
Password for basic authentication with MSA server, by default None.
698+
api_key_header : str, optional
699+
Custom header key for API key authentication (default: X-API-Key).
700+
api_key_value : str, optional
701+
Custom header value for API key authentication (overrides --api_key if set).
645702
boltz2: bool, optional
646703
Whether to use Boltz2, by default False.
647704
preprocessing_threads: int, optional
@@ -653,6 +710,16 @@ def process_inputs(
653710
The manifest of the processed input data.
654711
655712
"""
713+
# Validate mutually exclusive authentication methods
714+
has_basic_auth = msa_server_username and msa_server_password
715+
has_api_key = api_key_value is not None
716+
717+
if has_basic_auth and has_api_key:
718+
raise ValueError(
719+
"Cannot use both basic authentication (--msa_server_username/--msa_server_password) "
720+
"and API key authentication (--api_key_header/--api_key_value). Please use only one authentication method."
721+
)
722+
656723
# Check if records exist at output path
657724
records_dir = out_dir / "processed" / "records"
658725
if records_dir.exists():
@@ -710,6 +777,10 @@ def process_inputs(
710777
use_msa_server=use_msa_server,
711778
msa_server_url=msa_server_url,
712779
msa_pairing_strategy=msa_pairing_strategy,
780+
msa_server_username=msa_server_username,
781+
msa_server_password=msa_server_password,
782+
api_key_header=api_key_header,
783+
api_key_value=api_key_value,
713784
max_msa_seqs=max_msa_seqs,
714785
processed_msa_dir=processed_msa_dir,
715786
processed_constraints_dir=processed_constraints_dir,
@@ -869,6 +940,30 @@ def cli() -> None:
869940
),
870941
default="greedy",
871942
)
943+
@click.option(
944+
"--msa_server_username",
945+
type=str,
946+
help="MSA server username for basic auth. Used only if --use_msa_server is set. Can also be set via BOLTZ_MSA_USERNAME environment variable.",
947+
default=None,
948+
)
949+
@click.option(
950+
"--msa_server_password",
951+
type=str,
952+
help="MSA server password for basic auth. Used only if --use_msa_server is set. Can also be set via BOLTZ_MSA_PASSWORD environment variable.",
953+
default=None,
954+
)
955+
@click.option(
956+
"--api_key_header",
957+
type=str,
958+
help="Custom header key for API key authentication (default: X-API-Key).",
959+
default=None,
960+
)
961+
@click.option(
962+
"--api_key_value",
963+
type=str,
964+
help="Custom header value for API key authentication.",
965+
default=None,
966+
)
872967
@click.option(
873968
"--use_potentials",
874969
is_flag=True,
@@ -967,6 +1062,10 @@ def predict( # noqa: C901, PLR0915, PLR0912
9671062
use_msa_server: bool = False,
9681063
msa_server_url: str = "https://api.colabfold.com",
9691064
msa_pairing_strategy: str = "greedy",
1065+
msa_server_username: Optional[str] = None,
1066+
msa_server_password: Optional[str] = None,
1067+
api_key_header: Optional[str] = None,
1068+
api_key_value: Optional[str] = None,
9701069
use_potentials: bool = False,
9711070
model: Literal["boltz1", "boltz2"] = "boltz2",
9721071
method: Optional[str] = None,
@@ -1011,6 +1110,23 @@ def predict( # noqa: C901, PLR0915, PLR0912
10111110
cache = Path(cache).expanduser()
10121111
cache.mkdir(parents=True, exist_ok=True)
10131112

1113+
# Get MSA server credentials from environment variables if not provided
1114+
if use_msa_server:
1115+
if msa_server_username is None:
1116+
msa_server_username = os.environ.get("BOLTZ_MSA_USERNAME")
1117+
if msa_server_password is None:
1118+
msa_server_password = os.environ.get("BOLTZ_MSA_PASSWORD")
1119+
if api_key_value is None:
1120+
api_key_value = os.environ.get("MSA_API_KEY_VALUE")
1121+
1122+
click.echo(f"MSA server enabled: {msa_server_url}")
1123+
if api_key_value:
1124+
click.echo("MSA server authentication: using API key header")
1125+
elif msa_server_username and msa_server_password:
1126+
click.echo("MSA server authentication: using basic auth")
1127+
else:
1128+
click.echo("MSA server authentication: no credentials provided")
1129+
10141130
# Create output directories
10151131
data = Path(data).expanduser()
10161132
out_dir = Path(out_dir).expanduser()
@@ -1050,6 +1166,10 @@ def predict( # noqa: C901, PLR0915, PLR0912
10501166
use_msa_server=use_msa_server,
10511167
msa_server_url=msa_server_url,
10521168
msa_pairing_strategy=msa_pairing_strategy,
1169+
msa_server_username=msa_server_username,
1170+
msa_server_password=msa_server_password,
1171+
api_key_header=api_key_header,
1172+
api_key_value=api_key_value,
10531173
boltz2=model == "boltz2",
10541174
preprocessing_threads=preprocessing_threads,
10551175
max_msa_seqs=max_msa_seqs,

0 commit comments

Comments
 (0)