@@ -29,6 +29,127 @@ Usage:
2929"""
3030
3131import os .path as osp
32+ import psutil
33+ import threading
34+ import time
35+
36+
37+ class PeakMemoryProfiler :
38+ """
39+ A context manager that monitors and tracks the peak memory usage of a process
40+ (and optionally its children) over a period of time. The memory usage can be
41+ reported in various units (bytes, MB, or GB).
42+
43+ Example:
44+
45+ ```
46+ with PeakMemoryProfiler() as profiler:
47+ # Code block to monitor memory usage
48+ ...
49+ ```
50+
51+ Class Attributes:
52+ :ivar pid: The PID of the process being monitored. Defaults to the current process.
53+ :ivar interval: Time interval (in seconds) between memory checks. Defaults to 0.1.
54+ :ivar include_children: Whether memory usage from child processes is included. Defaults to True.
55+ :ivar unit: The unit used to report memory usage (either 'bytes', 'MB', or 'GB'). Defaults to 'MB'.
56+ :ivar max_memory: The peak memory usage observed during the monitoring period.
57+ :ivar monitoring_thread: Thread used for monitoring memory usage.
58+ :ivar _stop_monitoring: Event used to signal when to stop monitoring.
59+ """
60+
61+ def __init__ (self , pid = None , interval = 0.1 , include_children = True , unit = "MB" ):
62+ """
63+ Initializes the PeakMemoryProfiler instance with the provided parameters.
64+
65+ :param pid: The PID of the process to monitor. Defaults to None (current process).
66+ :param interval: The interval (in seconds) between memory checks. Defaults to 0.1.
67+ :param include_children: Whether to include memory usage from child processes. Defaults to True.
68+ :param unit: The unit in which to report memory usage. Options are 'bytes', 'MB', or 'GB'. Defaults to 'MB'.
69+ """
70+ self .pid = pid or psutil .Process ().pid # Default to current process if no PID is provided
71+ self .interval = interval
72+ self .include_children = include_children
73+ self .unit = unit
74+ self .max_memory = 0
75+ self .monitoring_thread = None
76+ self ._stop_monitoring = threading .Event ()
77+
78+ def __enter__ (self ):
79+ """
80+ Starts monitoring memory usage when entering the context block.
81+
82+ :return: Returns the instance of PeakMemoryProfiler, so that we can access peak memory later.
83+ """
84+ self .process = psutil .Process (self .pid )
85+ self .max_memory = 0
86+ self ._stop_monitoring .clear () # Clear the stop flag to begin monitoring
87+ self .monitoring_thread = threading .Thread (target = self ._monitor_memory )
88+ self .monitoring_thread .start ()
89+ return self # Return the instance so that the caller can access max_memory
90+
91+ def __exit__ (self , exc_type , exc_value , traceback ):
92+ """
93+ Stops the memory monitoring when exiting the context block.
94+
95+ :param exc_type: The exception type if an exception was raised in the block.
96+ :param exc_value: The exception instance if an exception was raised.
97+ :param traceback: The traceback object if an exception was raised.
98+ """
99+ self ._stop_monitoring .set () # Signal the thread to stop monitoring
100+ self .monitoring_thread .join () # Wait for the monitoring thread to finish
101+
102+ def get_curr_memory (self ):
103+ """
104+ Get the current memory usage of the monitored process and its children.
105+
106+ :return: The current memory usage in the specified unit (bytes, MB, or GB).
107+ :rtype: float
108+ """
109+
110+ memory = self .process .memory_info ().rss
111+
112+ if self .include_children :
113+ # Include memory usage of child processes recursively
114+ for child in self .process .children (recursive = True ):
115+ try :
116+ memory += child .memory_info ().rss
117+ except (psutil .NoSuchProcess , psutil .AccessDenied ):
118+ continue
119+
120+ if self .unit == "MB" :
121+ return memory / (1024 ** 2 ) # Convert to MB
122+ elif self .unit == "GB" :
123+ return memory / (1024 ** 3 ) # Convert to GB
124+ else :
125+ return memory # Return in bytes if no conversion is requested
126+
127+ def _monitor_memory (self ):
128+ """
129+ Monitors the memory usage of the process and its children continuously
130+ until the monitoring is stopped.
131+
132+ This method runs in a separate thread and updates the peak memory usage
133+ as long as the monitoring flag is not set.
134+ """
135+ while not self ._stop_monitoring .is_set ():
136+ try :
137+ curr_memory = self .get_curr_memory ()
138+
139+ # Update max memory if a new peak is found
140+ self .max_memory = max (self .max_memory , curr_memory )
141+ time .sleep (self .interval )
142+ except psutil .NoSuchProcess :
143+ break # Process no longer exists, stop monitoring
144+
145+ def get_peak_memory (self ):
146+ """
147+ Get the peak memory usage observed during the monitoring period.
148+
149+ :return: The peak memory usage in the specified unit (bytes, MB, or GB).
150+ :rtype: float
151+ """
152+ return self .max_memory
32153
33154
34155def check_args (args ):
@@ -110,15 +231,15 @@ def check_args(args):
110231 # provides either individual-level data or summary statistics for the validation set:
111232 if args .validation_bed is not None and args .validation_pheno is not None :
112233 pass
113- elif args .validation_ld_panel is not None and args .validation_sumstats is not None :
234+ elif args .validation_ld_panel is not None and args .validation_sumstats_path is not None :
114235 ld_store_files = get_filenames (args .validation_ld_panel , extension = '.zgroup' )
115236 if len (ld_store_files ) < 1 :
116237 raise FileNotFoundError (f"No valid LD matrix files for the "
117238 f"validation set were found at: { args .ld_dir } " )
118- sumstats_files = get_filenames (args .validation_sumstats )
239+ sumstats_files = get_filenames (args .validation_sumstats_path )
119240 if len (sumstats_files ) < 1 :
120241 raise FileNotFoundError (f"No valid summary statistics files for the validation set "
121- f"were found at: { args .sumstats_path } " )
242+ f"were found at: { args .validation_sumstats_path } " )
122243 else :
123244 raise ValueError ("To perform pseudo-validation, you need to provide either individual-level data "
124245 "or summary statistics for the validation set." )
@@ -227,7 +348,7 @@ def init_data(args, verbose=True):
227348 else :
228349
229350 # Construct the validation GWADataLoader object using LD + summary statistics:
230- validation_gdl = GWADataLoader (ld_store_files = args .validation_ld_panel_ld_dir ,
351+ validation_gdl = GWADataLoader (ld_store_files = args .validation_ld_panel ,
231352 temp_dir = args .temp_dir ,
232353 verbose = verbose ,
233354 threads = args .threads )
@@ -245,12 +366,13 @@ def init_data(args, verbose=True):
245366 validation_gdl .read_summary_statistics (args .validation_sumstats_path ,
246367 sumstats_format = ss_format ,
247368 parser = ss_parser )
248- # Harmonize the data:
249- validation_gdl .harmonize_data ()
250369
251370 # Filter SNPs:
252371 validation_gdl .filter_snps (extract_snps )
253372
373+ # Harmonize the data:
374+ validation_gdl .harmonize_data ()
375+
254376 # If overall GWAS sample size is provided, set it here:
255377 if args .validation_gwas_sample_size is not None :
256378 for ss in validation_gdl .sumstats_table .values ():
@@ -289,7 +411,9 @@ def prepare_model(args, verbose=True):
289411 from viprs .model .VIPRS import VIPRS
290412 from viprs .model .VIPRSMix import VIPRSMix
291413
292- if args .lambda_min == 'infer' :
414+ if args .lambda_min is None :
415+ lambda_min = 0.
416+ elif args .lambda_min == 'infer' :
293417 lambda_min = 'infer'
294418 else :
295419 lambda_min = float (args .lambda_min )
@@ -378,6 +502,7 @@ def fit_model(model, data_dict, args):
378502
379503 import time
380504 import numpy as np
505+ from viprs .utils .exceptions import OptimizationDivergence
381506
382507 # Set the random seed:
383508 np .random .seed (args .seed )
@@ -405,15 +530,14 @@ def fit_model(model, data_dict, args):
405530
406531 if args .pi_steps is not None :
407532 grid .n_snps = data_dict ['train' ].n_snps
408- grid .generate_pi_grid (n_steps = args .pi_steps )
533+ grid .generate_pi_grid (steps = args .pi_steps )
409534
410535 if args .lambda_min_steps is not None :
411536
412537 ld_mat = list (data_dict ['train' ].ld .values ())[0 ]
413- # Cap it at 5. to avoid over-shrinkage:
414- lambda_min = np .minimum (ld_mat .get_lambda_min (aggregate = 'min' ), 5. )
538+ lambda_min = ld_mat .get_lambda_min (aggregate = 'min' )
415539
416- grid .generate_lambda_min_grid (n_steps = args .lambda_min_steps , emp_lambda_min = lambda_min )
540+ grid .generate_lambda_min_grid (steps = args .lambda_min_steps , emp_lambda_min = lambda_min )
417541
418542 from functools import partial
419543 model = partial (model .func , ** {** model .keywords , 'grid' : grid })
@@ -431,7 +555,21 @@ def fit_model(model, data_dict, args):
431555
432556 # Fit the model to data:
433557 fit_start_time = time .time ()
434- m = m .fit (max_iter = args .max_iter )
558+ try :
559+ m = m .fit (max_iter = args .max_iter )
560+ except OptimizationDivergence as e :
561+ if m ._sigma_g < 0. and np .all (m .lambda_min == 0. ):
562+ print ("> Optimization diverged. Re-trying with setting regularization parameter lambda_min..." )
563+ for c in m .shapes :
564+ m .lambda_min = m .gdl .ld [c ].get_lambda_min (min_max_ratio = 1e-3 )
565+ m = m .fit (max_iter = args .max_iter )
566+ # If the optimization diverges with multi-threading, try a single thread:
567+ elif m .threads > 1 :
568+ print ("> Optimization diverged. Retrying with a single thread..." )
569+ m .threads = 1
570+ m = m .fit (max_iter = args .max_iter )
571+ else :
572+ raise e
435573 fit_end_time = time .time ()
436574
437575 # ----------------------------------------------------------
@@ -456,7 +594,7 @@ def fit_model(model, data_dict, args):
456594
457595 valid_end_time = time .time ()
458596
459- result_dict ['ProfilerMetrics' ]['Validation time ' ] = valid_end_time - valid_start_time
597+ result_dict ['ProfilerMetrics' ]['Validation_time ' ] = valid_end_time - valid_start_time
460598
461599 result_dict ['Validation' ] = m .to_validation_table ()
462600
@@ -660,7 +798,7 @@ def main():
660798 from datetime import timedelta
661799 import pandas as pd
662800 import numpy as np
663- from magenpy .utils .system_utils import makedir , get_peak_memory_usage
801+ from magenpy .utils .system_utils import makedir
664802 from joblib import Parallel , delayed
665803 from joblib .externals .loky import get_reusable_executor
666804
@@ -692,13 +830,15 @@ def main():
692830 # (4) Fit to data:
693831 print ('\n {:-^62}\n ' .format (' Inference ' ))
694832
695- fit_results = Parallel (n_jobs = args .n_jobs )(
696- delayed (fit_model )(model , dl , args )
697- for idx , dl in enumerate (data_loaders )
698- )
833+ with PeakMemoryProfiler () as peak_mem :
834+
835+ fit_results = Parallel (n_jobs = args .n_jobs )(
836+ delayed (fit_model )(model , dl , args )
837+ for idx , dl in enumerate (data_loaders )
838+ )
699839
700- # Shut down the parallel executor:
701- get_reusable_executor ().shutdown (wait = True )
840+ # Shut down the parallel executor:
841+ get_reusable_executor ().shutdown (wait = True )
702842
703843 # Record end time:
704844 total_end_time = time .time ()
@@ -724,7 +864,7 @@ def main():
724864 for r in fit_results ])
725865 profm_table ['Total_WallClockTime' ] = round (total_end_time - total_start_time , 2 )
726866 profm_table ['DataPrep_Time' ] = round (data_prep_time - total_start_time , 2 )
727- profm_table ['Peak_Memory_MB' ] = round (get_peak_memory_usage ( include_children = True ) or np . nan , 2 )
867+ profm_table ['Peak_Memory_MB' ] = round (peak_mem . get_peak_memory () , 2 )
728868
729869 output_prefix = osp .join (args .output_dir , args .output_prefix + args .model + '_' + args .hyp_search )
730870
0 commit comments