Skip to content

Commit e6f0790

Browse files
committed
Bug fixes for VIPRSGrid models and in BayesPRSModel
1 parent 62e682e commit e6f0790

6 files changed

Lines changed: 220 additions & 72 deletions

File tree

bin/viprs_fit

Lines changed: 161 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,127 @@ Usage:
2929
"""
3030

3131
import os.path as osp
32+
import psutil
33+
import threading
34+
import time
35+
36+
37+
class PeakMemoryProfiler:
38+
"""
39+
A context manager that monitors and tracks the peak memory usage of a process
40+
(and optionally its children) over a period of time. The memory usage can be
41+
reported in various units (bytes, MB, or GB).
42+
43+
Example:
44+
45+
```
46+
with PeakMemoryProfiler() as profiler:
47+
# Code block to monitor memory usage
48+
...
49+
```
50+
51+
Class Attributes:
52+
:ivar pid: The PID of the process being monitored. Defaults to the current process.
53+
:ivar interval: Time interval (in seconds) between memory checks. Defaults to 0.1.
54+
:ivar include_children: Whether memory usage from child processes is included. Defaults to True.
55+
:ivar unit: The unit used to report memory usage (either 'bytes', 'MB', or 'GB'). Defaults to 'MB'.
56+
:ivar max_memory: The peak memory usage observed during the monitoring period.
57+
:ivar monitoring_thread: Thread used for monitoring memory usage.
58+
:ivar _stop_monitoring: Event used to signal when to stop monitoring.
59+
"""
60+
61+
def __init__(self, pid=None, interval=0.1, include_children=True, unit="MB"):
62+
"""
63+
Initializes the PeakMemoryProfiler instance with the provided parameters.
64+
65+
:param pid: The PID of the process to monitor. Defaults to None (current process).
66+
:param interval: The interval (in seconds) between memory checks. Defaults to 0.1.
67+
:param include_children: Whether to include memory usage from child processes. Defaults to True.
68+
:param unit: The unit in which to report memory usage. Options are 'bytes', 'MB', or 'GB'. Defaults to 'MB'.
69+
"""
70+
self.pid = pid or psutil.Process().pid # Default to current process if no PID is provided
71+
self.interval = interval
72+
self.include_children = include_children
73+
self.unit = unit
74+
self.max_memory = 0
75+
self.monitoring_thread = None
76+
self._stop_monitoring = threading.Event()
77+
78+
def __enter__(self):
79+
"""
80+
Starts monitoring memory usage when entering the context block.
81+
82+
:return: Returns the instance of PeakMemoryProfiler, so that we can access peak memory later.
83+
"""
84+
self.process = psutil.Process(self.pid)
85+
self.max_memory = 0
86+
self._stop_monitoring.clear() # Clear the stop flag to begin monitoring
87+
self.monitoring_thread = threading.Thread(target=self._monitor_memory)
88+
self.monitoring_thread.start()
89+
return self # Return the instance so that the caller can access max_memory
90+
91+
def __exit__(self, exc_type, exc_value, traceback):
92+
"""
93+
Stops the memory monitoring when exiting the context block.
94+
95+
:param exc_type: The exception type if an exception was raised in the block.
96+
:param exc_value: The exception instance if an exception was raised.
97+
:param traceback: The traceback object if an exception was raised.
98+
"""
99+
self._stop_monitoring.set() # Signal the thread to stop monitoring
100+
self.monitoring_thread.join() # Wait for the monitoring thread to finish
101+
102+
def get_curr_memory(self):
103+
"""
104+
Get the current memory usage of the monitored process and its children.
105+
106+
:return: The current memory usage in the specified unit (bytes, MB, or GB).
107+
:rtype: float
108+
"""
109+
110+
memory = self.process.memory_info().rss
111+
112+
if self.include_children:
113+
# Include memory usage of child processes recursively
114+
for child in self.process.children(recursive=True):
115+
try:
116+
memory += child.memory_info().rss
117+
except (psutil.NoSuchProcess, psutil.AccessDenied):
118+
continue
119+
120+
if self.unit == "MB":
121+
return memory / (1024 ** 2) # Convert to MB
122+
elif self.unit == "GB":
123+
return memory / (1024 ** 3) # Convert to GB
124+
else:
125+
return memory # Return in bytes if no conversion is requested
126+
127+
def _monitor_memory(self):
128+
"""
129+
Monitors the memory usage of the process and its children continuously
130+
until the monitoring is stopped.
131+
132+
This method runs in a separate thread and updates the peak memory usage
133+
as long as the monitoring flag is not set.
134+
"""
135+
while not self._stop_monitoring.is_set():
136+
try:
137+
curr_memory = self.get_curr_memory()
138+
139+
# Update max memory if a new peak is found
140+
self.max_memory = max(self.max_memory, curr_memory)
141+
time.sleep(self.interval)
142+
except psutil.NoSuchProcess:
143+
break # Process no longer exists, stop monitoring
144+
145+
def get_peak_memory(self):
146+
"""
147+
Get the peak memory usage observed during the monitoring period.
148+
149+
:return: The peak memory usage in the specified unit (bytes, MB, or GB).
150+
:rtype: float
151+
"""
152+
return self.max_memory
32153

33154

34155
def check_args(args):
@@ -110,15 +231,15 @@ def check_args(args):
110231
# provides either individual-level data or summary statistics for the validation set:
111232
if args.validation_bed is not None and args.validation_pheno is not None:
112233
pass
113-
elif args.validation_ld_panel is not None and args.validation_sumstats is not None:
234+
elif args.validation_ld_panel is not None and args.validation_sumstats_path is not None:
114235
ld_store_files = get_filenames(args.validation_ld_panel, extension='.zgroup')
115236
if len(ld_store_files) < 1:
116237
raise FileNotFoundError(f"No valid LD matrix files for the "
117238
f"validation set were found at: {args.ld_dir}")
118-
sumstats_files = get_filenames(args.validation_sumstats)
239+
sumstats_files = get_filenames(args.validation_sumstats_path)
119240
if len(sumstats_files) < 1:
120241
raise FileNotFoundError(f"No valid summary statistics files for the validation set "
121-
f"were found at: {args.sumstats_path}")
242+
f"were found at: {args.validation_sumstats_path}")
122243
else:
123244
raise ValueError("To perform pseudo-validation, you need to provide either individual-level data "
124245
"or summary statistics for the validation set.")
@@ -227,7 +348,7 @@ def init_data(args, verbose=True):
227348
else:
228349

229350
# Construct the validation GWADataLoader object using LD + summary statistics:
230-
validation_gdl = GWADataLoader(ld_store_files=args.validation_ld_panel_ld_dir,
351+
validation_gdl = GWADataLoader(ld_store_files=args.validation_ld_panel,
231352
temp_dir=args.temp_dir,
232353
verbose=verbose,
233354
threads=args.threads)
@@ -245,12 +366,13 @@ def init_data(args, verbose=True):
245366
validation_gdl.read_summary_statistics(args.validation_sumstats_path,
246367
sumstats_format=ss_format,
247368
parser=ss_parser)
248-
# Harmonize the data:
249-
validation_gdl.harmonize_data()
250369

251370
# Filter SNPs:
252371
validation_gdl.filter_snps(extract_snps)
253372

373+
# Harmonize the data:
374+
validation_gdl.harmonize_data()
375+
254376
# If overall GWAS sample size is provided, set it here:
255377
if args.validation_gwas_sample_size is not None:
256378
for ss in validation_gdl.sumstats_table.values():
@@ -289,7 +411,9 @@ def prepare_model(args, verbose=True):
289411
from viprs.model.VIPRS import VIPRS
290412
from viprs.model.VIPRSMix import VIPRSMix
291413

292-
if args.lambda_min == 'infer':
414+
if args.lambda_min is None:
415+
lambda_min = 0.
416+
elif args.lambda_min == 'infer':
293417
lambda_min = 'infer'
294418
else:
295419
lambda_min = float(args.lambda_min)
@@ -378,6 +502,7 @@ def fit_model(model, data_dict, args):
378502

379503
import time
380504
import numpy as np
505+
from viprs.utils.exceptions import OptimizationDivergence
381506

382507
# Set the random seed:
383508
np.random.seed(args.seed)
@@ -405,15 +530,14 @@ def fit_model(model, data_dict, args):
405530

406531
if args.pi_steps is not None:
407532
grid.n_snps = data_dict['train'].n_snps
408-
grid.generate_pi_grid(n_steps=args.pi_steps)
533+
grid.generate_pi_grid(steps=args.pi_steps)
409534

410535
if args.lambda_min_steps is not None:
411536

412537
ld_mat = list(data_dict['train'].ld.values())[0]
413-
# Cap it at 5. to avoid over-shrinkage:
414-
lambda_min = np.minimum(ld_mat.get_lambda_min(aggregate='min'), 5.)
538+
lambda_min = ld_mat.get_lambda_min(aggregate='min')
415539

416-
grid.generate_lambda_min_grid(n_steps=args.lambda_min_steps, emp_lambda_min=lambda_min)
540+
grid.generate_lambda_min_grid(steps=args.lambda_min_steps, emp_lambda_min=lambda_min)
417541

418542
from functools import partial
419543
model = partial(model.func, **{**model.keywords, 'grid': grid})
@@ -431,7 +555,21 @@ def fit_model(model, data_dict, args):
431555

432556
# Fit the model to data:
433557
fit_start_time = time.time()
434-
m = m.fit(max_iter=args.max_iter)
558+
try:
559+
m = m.fit(max_iter=args.max_iter)
560+
except OptimizationDivergence as e:
561+
if m._sigma_g < 0. and np.all(m.lambda_min == 0.):
562+
print("> Optimization diverged. Re-trying with setting regularization parameter lambda_min...")
563+
for c in m.shapes:
564+
m.lambda_min = m.gdl.ld[c].get_lambda_min(min_max_ratio=1e-3)
565+
m = m.fit(max_iter=args.max_iter)
566+
# If the optimization diverges with multi-threading, try a single thread:
567+
elif m.threads > 1:
568+
print("> Optimization diverged. Retrying with a single thread...")
569+
m.threads = 1
570+
m = m.fit(max_iter=args.max_iter)
571+
else:
572+
raise e
435573
fit_end_time = time.time()
436574

437575
# ----------------------------------------------------------
@@ -456,7 +594,7 @@ def fit_model(model, data_dict, args):
456594

457595
valid_end_time = time.time()
458596

459-
result_dict['ProfilerMetrics']['Validation time'] = valid_end_time - valid_start_time
597+
result_dict['ProfilerMetrics']['Validation_time'] = valid_end_time - valid_start_time
460598

461599
result_dict['Validation'] = m.to_validation_table()
462600

@@ -660,7 +798,7 @@ def main():
660798
from datetime import timedelta
661799
import pandas as pd
662800
import numpy as np
663-
from magenpy.utils.system_utils import makedir, get_peak_memory_usage
801+
from magenpy.utils.system_utils import makedir
664802
from joblib import Parallel, delayed
665803
from joblib.externals.loky import get_reusable_executor
666804

@@ -692,13 +830,15 @@ def main():
692830
# (4) Fit to data:
693831
print('\n{:-^62}\n'.format(' Inference '))
694832

695-
fit_results = Parallel(n_jobs=args.n_jobs)(
696-
delayed(fit_model)(model, dl, args)
697-
for idx, dl in enumerate(data_loaders)
698-
)
833+
with PeakMemoryProfiler() as peak_mem:
834+
835+
fit_results = Parallel(n_jobs=args.n_jobs)(
836+
delayed(fit_model)(model, dl, args)
837+
for idx, dl in enumerate(data_loaders)
838+
)
699839

700-
# Shut down the parallel executor:
701-
get_reusable_executor().shutdown(wait=True)
840+
# Shut down the parallel executor:
841+
get_reusable_executor().shutdown(wait=True)
702842

703843
# Record end time:
704844
total_end_time = time.time()
@@ -724,7 +864,7 @@ def main():
724864
for r in fit_results])
725865
profm_table['Total_WallClockTime'] = round(total_end_time - total_start_time, 2)
726866
profm_table['DataPrep_Time'] = round(data_prep_time - total_start_time, 2)
727-
profm_table['Peak_Memory_MB'] = round(get_peak_memory_usage(include_children=True) or np.nan, 2)
867+
profm_table['Peak_Memory_MB'] = round(peak_mem.get_peak_memory(), 2)
728868

729869
output_prefix = osp.join(args.output_dir, args.output_prefix + args.model + '_' + args.hyp_search)
730870

viprs/model/BayesPRSModel.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -191,16 +191,8 @@ def harmonize_data(self, gdl=None, parameter_table=None):
191191

192192
try:
193193
post_mean_cols = expand_column_names('BETA', self.post_mean_beta[c].shape)
194-
if isinstance(post_mean_cols, str):
195-
post_mean_cols = [post_mean_cols]
196-
197194
pip_cols = expand_column_names('PIP', self.post_mean_beta[c].shape)
198-
if isinstance(pip_cols, str):
199-
pip_cols = [pip_cols]
200-
201195
post_var_cols = expand_column_names('VAR_BETA', self.post_mean_beta[c].shape)
202-
if isinstance(post_var_cols, str):
203-
post_var_cols = [post_var_cols]
204196

205197
except (TypeError, KeyError):
206198
pip_cols = [col for col in parameter_table[c].columns if 'PIP' in col]
@@ -254,13 +246,26 @@ def to_table(self, col_subset=('CHR', 'SNP', 'POS', 'A1', 'A2'), per_chromosome=
254246

255247
for c in self.chromosomes:
256248

257-
tables[c][expand_column_names('BETA', self.post_mean_beta[c].shape)] = self.post_mean_beta[c]
249+
cols_to_add = []
250+
251+
mean_beta_df = pd.DataFrame(self.post_mean_beta[c],
252+
columns=expand_column_names('BETA', self.post_mean_beta[c].shape),
253+
index=tables[c].index)
254+
cols_to_add.append(mean_beta_df)
258255

259256
if self.pip is not None:
260-
tables[c][expand_column_names('PIP', self.pip[c].shape)] = self.pip[c]
257+
pip_df = pd.DataFrame(self.pip[c],
258+
columns=expand_column_names('PIP', self.pip[c].shape),
259+
index=tables[c].index)
260+
cols_to_add.append(pip_df)
261261

262262
if self.post_var_beta is not None:
263-
tables[c][expand_column_names('VAR_BETA', self.post_var_beta[c].shape)] = self.post_var_beta[c]
263+
var_beta_df = pd.DataFrame(self.post_var_beta[c],
264+
columns=expand_column_names('VAR_BETA', self.post_var_beta[c].shape),
265+
index=tables[c].index)
266+
cols_to_add.append(var_beta_df)
267+
268+
tables[c] = pd.concat([tables[c]] + cols_to_add, axis=1)
264269

265270
if per_chromosome:
266271
return tables

0 commit comments

Comments
 (0)