From 668e5c6db06f43ae343f08de7e207bde81893c55 Mon Sep 17 00:00:00 2001 From: "arushi.arora" Date: Fri, 11 Jul 2025 00:55:07 +0000 Subject: [PATCH 01/25] Implement round robin and separate ventilator and result queue --- petastorm/arrow_reader_worker.py | 20 +++- petastorm/reader.py | 15 ++- petastorm/workers_pool/thread_pool.py | 143 ++++++++++++++++++++++---- petastorm/workers_pool/ventilator.py | 50 ++++----- 4 files changed, 166 insertions(+), 62 deletions(-) diff --git a/petastorm/arrow_reader_worker.py b/petastorm/arrow_reader_worker.py index fedad01d8..895745b2a 100644 --- a/petastorm/arrow_reader_worker.py +++ b/petastorm/arrow_reader_worker.py @@ -101,7 +101,10 @@ def __init__(self, worker_id, publish_func, args): self._transformed_schema = args[7] self._arrow_filters = args[8] self._shuffle_rows = args[9] - self._random_state = np.random.RandomState(seed=args[10]) + self._random_seed = args[10] + + # Initialize random number generator + self._rng = np.random.default_rng(self._random_seed) if self._ngram: raise NotImplementedError('ngrams are not supported by ArrowReaderWorker') @@ -164,6 +167,7 @@ def process(self, piece_index, worker_predicate, shuffle_row_drop_partition): lambda: self._load_rows(parquet_file, piece, shuffle_row_drop_partition)) if all_cols: + # 3. Also pass the Worker ID to the publish_function self.publish_func(all_cols) @staticmethod @@ -289,9 +293,19 @@ def _read_with_shuffle_row_drop(self, piece, pq_file, column_names, shuffle_row_ # pyarrow would fail if we request a column names that the dataset is partitioned by table = piece.read(columns=column_names - partition_names, partitions=self._dataset.partitions) + + # Handle row shuffling based on shuffle_rows setting if self._shuffle_rows: - indices = self._random_state.permutation(table.num_rows) - table = table.take(indices) + if self._random_seed is not None and self._random_seed != 0: + # Deterministic randomization: use provided seed + indices = self._rng.permutation(table.num_rows) + else: + # Non-deterministic randomization: use np.random directly + indices = np.random.permutation(table.num_rows) + else: + # Deterministic natural order: shuffle_rows=False + indices = np.arange(table.num_rows) + table = table.take(indices) # Drop columns we did not explicitly request. This may happen when a table is partitioned. Besides columns # requested, pyarrow will also return partition values. Having these unexpected fields will break some diff --git a/petastorm/reader.py b/petastorm/reader.py index 8fa69935b..c9b6490fe 100644 --- a/petastorm/reader.py +++ b/petastorm/reader.py @@ -25,6 +25,7 @@ from petastorm.etl import dataset_metadata, rowgroup_indexing from petastorm.etl.dataset_metadata import PetastormMetadataError, infer_or_load_unischema from petastorm.fs_utils import get_filesystem_and_path_or_paths, normalize_dir_url +from petastorm.local_disk_arrow_table_cache import LocalDiskArrowTableCache from petastorm.local_disk_cache import LocalDiskCache from petastorm.ngram import NGram from petastorm.predicates import PredicateBase @@ -159,7 +160,7 @@ def make_reader(dataset_url, 'To read from a non-Petastorm Parquet store use make_batch_reader') if reader_pool_type == 'thread': - reader_pool = ThreadPool(workers_count, results_queue_size) + reader_pool = ThreadPool(workers_count, results_queue_size, shuffle_rows=shuffle_rows, seed=seed) elif reader_pool_type == 'process': if pyarrow_serialize: warnings.warn("pyarrow_serializer was deprecated and will be removed in future versions. " @@ -205,7 +206,6 @@ def make_reader(dataset_url, def make_batch_reader(dataset_url_or_urls, schema_fields=None, reader_pool_type='thread', workers_count=10, - results_queue_size=50, seed=None, shuffle_rows=False, shuffle_row_groups=True, shuffle_row_drop_partitions=1, predicate=None, @@ -244,8 +244,6 @@ def make_batch_reader(dataset_url_or_urls, denoting a thread pool, process pool, or running everything in the master thread. Defaults to 'thread' :param workers_count: An int for the number of workers to use in the reader pool. This only is used for the thread or process pool. Defaults to 10 - :param results_queue_size: Size of the results queue to store prefetched row-groups. Currently only applicable to - thread reader pool type. :param seed: Random seed specified for shuffle and sharding with reproducible outputs. Defaults to None :param shuffle_rows: Whether to shuffle inside a single row group. Defaults to False. :param shuffle_row_groups: Whether to shuffle row groups (the order in which full row groups are read) @@ -309,13 +307,13 @@ def make_batch_reader(dataset_url_or_urls, if cache_type is None or cache_type == NULL_CACHE: cache = NullCache() elif cache_type == LOCAL_DISK_CACHE: - cache = LocalDiskCache(cache_location, cache_size_limit, cache_row_size_estimate, - **cache_extra_settings or {}) + cache = LocalDiskArrowTableCache(cache_location, cache_size_limit, cache_row_size_estimate, + **cache_extra_settings or {}) else: raise ValueError('Unknown cache_type: {}'.format(cache_type)) if reader_pool_type == 'thread': - reader_pool = ThreadPool(workers_count, results_queue_size) + reader_pool = ThreadPool(workers_count, results_queue_size, shuffle_rows=shuffle_rows, seed=seed) elif reader_pool_type == 'process': serializer = ArrowTableSerializer() reader_pool = ProcessPool(workers_count, serializer, zmq_copy_buffers=zmq_copy_buffers) @@ -439,7 +437,7 @@ def __init__(self, pyarrow_filesystem, dataset_path, schema_fields=None, cache = cache or NullCache() - self._workers_pool = reader_pool or ThreadPool(10) + self._workers_pool = reader_pool or ThreadPool(10, shuffle_rows=shuffle_rows, seed=seed) # Make a schema view (a view is a Unischema containing only a subset of fields # Will raise an exception if invalid schema fields are in schema_fields @@ -665,7 +663,6 @@ def _create_ventilator(self, row_group_indexes, shuffle_row_groups, shuffle_row_ return ConcurrentVentilator(self._workers_pool.ventilate, items_to_ventilate, iterations=num_epochs, - max_ventilation_queue_size=max_ventilation_queue_size, randomize_item_order=shuffle_row_groups, random_seed=seed) diff --git a/petastorm/workers_pool/thread_pool.py b/petastorm/workers_pool/thread_pool.py index 649aa77f3..45efb817d 100644 --- a/petastorm/workers_pool/thread_pool.py +++ b/petastorm/workers_pool/thread_pool.py @@ -59,9 +59,15 @@ def run(self): # If the message came from work_receiver channel try: (args, kargs) = self._ventilator_queue.get(block=True, timeout=IO_TIMEOUT_INTERVAL_S) + # Mark worker as busy when processing + self._worker_impl.thread_pool._set_worker_busy(self._worker_impl.worker_id) self._worker_impl.process(*args, **kargs) self._worker_impl.publish_func(VentilatedItemProcessedMessage()) + # Mark worker as idle after processing + self._worker_impl.thread_pool._set_worker_idle(self._worker_impl.worker_id) except queue.Empty: + # Mark worker as idle when waiting + self._worker_impl.thread_pool._set_worker_idle(self._worker_impl.worker_id) pass except WorkerTerminationRequested: pass @@ -76,7 +82,7 @@ def run(self): class ThreadPool(object): - def __init__(self, workers_count, results_queue_size=50, profiling_enabled=False): + def __init__(self, workers_count, results_queue_size=50, profiling_enabled=False, shuffle_rows=False, seed=None): """Initializes a thread pool. TODO: consider using a standard thread pool @@ -86,20 +92,29 @@ def __init__(self, workers_count, results_queue_size=50, profiling_enabled=False implementation that would not use fork) :param workers_count: Number of threads - :param profile: Whether to run a profiler on the threads + :param profiling_enabled: Whether to run a profiler on the threads + :param shuffle_rows: Whether to shuffle rows (affects round-robin behavior) + :param seed: Random seed for deterministic behavior """ - self._seed = random.randint(0, 100000) self._workers = [] - self._ventilator_queue = None + self._ventilator_queues = None self.workers_count = workers_count self._results_queue_size = results_queue_size # Worker threads will watch this event and gracefully shutdown when the event is set self._stop_event = Event() self._profiling_enabled = profiling_enabled + self._shuffle_rows = shuffle_rows + self._seed = seed self._ventilated_items = 0 self._ventilated_items_processed = 0 self._ventilator = None + + # Round-robin consumer thread + self._round_robin_thread = None + + # Worker status tracking + self._worker_status = [False] * workers_count # False = idle, True = busy def start(self, worker_class, worker_args=None, ventilator=None): """Starts worker threads. @@ -116,13 +131,22 @@ class must implement :class:`.WorkerBase` protocol .format(len(self._workers), self._stop_event.is_set())) # Set up a channel to send work - self._ventilator_queue = queue.Queue() - self._results_queue = queue.Queue(self._results_queue_size) + self._ventilator_queues = [queue.Queue() for _ in range(self.workers_count)] + # Todo: Update the results queue size + self._results_queues = [queue.Queue(self._results_queue_size) for _ in range(self.workers_count)] + self._shared_results_queue = queue.Queue(self._results_queue_size) self._workers = [] for worker_id in range(self.workers_count): - worker_impl = worker_class(worker_id, self._stop_aware_put, worker_args) - new_thread = WorkerThread(worker_impl, self._stop_event, self._ventilator_queue, - self._results_queue, self._profiling_enabled) + # Create a closure that captures the worker_id for this specific worker + def make_publish_func(worker_id): + return lambda data: self._stop_aware_put(data, worker_id) + + worker_impl = worker_class(worker_id, make_publish_func(worker_id), worker_args) + # Add thread_pool reference to worker for status tracking + worker_impl.thread_pool = self + + new_thread = WorkerThread(worker_impl, self._stop_event, self._ventilator_queues[worker_id], + self._results_queues[worker_id], self._profiling_enabled) # Make the thread daemonic. Since it only reads it's ok to abort while running - no resource corruption # will occur. new_thread.daemon = True @@ -132,15 +156,32 @@ class must implement :class:`.WorkerBase` protocol for w in self._workers: w.start() + # Start round-robin consumer thread + self._round_robin_thread = Thread(target=self._round_robin_consumer, daemon=True) + self._round_robin_thread.start() + if ventilator: self._ventilator = ventilator self._ventilator.start() - def ventilate(self, *args, **kargs): + def _set_worker_busy(self, worker_id): + """Mark worker as busy (processing work).""" + self._worker_status[worker_id] = True + + def _set_worker_idle(self, worker_id): + """Mark worker as idle (waiting for work).""" + self._worker_status[worker_id] = False + + def _is_worker_idle(self, worker_id): + """Check if worker is idle.""" + return not self._worker_status[worker_id] + + def ventilate(self, items_to_ventilate): """Sends a work item to a worker process. Will result in ``worker.process(...)`` call with arbitrary arguments. """ - self._ventilated_items += 1 - self._ventilator_queue.put((args, kargs)) + for i, item in enumerate(items_to_ventilate): + self._ventilator_queues[i % self.workers_count].put(item) + self._ventilated_items += 1 def get_results(self): """Returns results from worker pool or re-raise worker's exception if any happen in worker thread. @@ -153,14 +194,25 @@ def get_results(self): """ while True: - # If there is no more work to do, raise an EmptyResultError - if self._results_queue.empty() and self._ventilated_items == self._ventilated_items_processed: - # We also need to check if we are using a ventilator and if it is completed + # Check termination condition: all workers are truly done + all_workers_done = True + for worker_id in range(self.workers_count): + worker_done = ( + self._results_queues[worker_id].empty() and # Worker result queue is empty + self._is_worker_idle(worker_id) and # Worker is idle + self._ventilator_queues[worker_id].empty() # Ventilator queue for worker is empty + ) + if not worker_done: + all_workers_done = False + break + + # If all workers are done and shared queue is empty, raise EmptyResultError + if all_workers_done and self._shared_results_queue.empty(): if not self._ventilator or self._ventilator.completed(): raise EmptyResultError() try: - result = self._results_queue.get(timeout=_VERIFY_END_OF_VENTILATION_PERIOD) + result = self._shared_results_queue.get(timeout=_VERIFY_END_OF_VENTILATION_PERIOD) if isinstance(result, VentilatedItemProcessedMessage): self._ventilated_items_processed += 1 if self._ventilator: @@ -197,7 +249,7 @@ def join(self): stats = pstats.Stats(w.prof) stats.sort_stats('cumulative').print_stats() - def _stop_aware_put(self, data): + def _stop_aware_put(self, data, worker_id): """This method is called to write the results to the results queue. We use ``put`` in a non-blocking way so we can gracefully terminate the worker thread without being stuck on :func:`Queue.put`. @@ -205,7 +257,7 @@ def _stop_aware_put(self, data): :func:`WorkerThread.run` which will gracefully terminate main worker loop.""" while True: try: - self._results_queue.put(data, block=True, timeout=IO_TIMEOUT_INTERVAL_S) + self._results_queues[worker_id].put(data, block=True, timeout=IO_TIMEOUT_INTERVAL_S) return except queue.Full: pass @@ -213,8 +265,61 @@ def _stop_aware_put(self, data): if self._stop_event.is_set(): raise WorkerTerminationRequested() + def _round_robin_consumer(self): + """Round-robin consumer that takes items from each worker's queue in strict round-robin order + and puts them into the shared results queue.""" + current_worker = 0 + + # Determine if we should use non-blocking behavior + use_non_blocking = self._shuffle_rows and (self._seed is None or self._seed == 0) + + while not self._stop_event.is_set(): + try: + # Check if current worker should be skipped + should_skip = ( + self._results_queues[current_worker].empty() and # Worker result queue is empty + self._is_worker_idle(current_worker) and # Worker is idle + self._ventilator_queues[current_worker].empty() # Ventilator queue for worker is empty + ) + + if should_skip: + # Skip this worker and move to next + current_worker = (current_worker + 1) % self.workers_count + continue + + # Try to get an item from the current worker's queue + if use_non_blocking: + # Non-blocking: try to get item without waiting + try: + item = self._results_queues[current_worker].get(block=False) + except queue.Empty: + # No item available, move to next worker immediately + current_worker = (current_worker + 1) % self.workers_count + continue + else: + # Blocking: wait for item (strict round-robin) + item = self._results_queues[current_worker].get(block=True, timeout=1.0) + + # Put the item into the shared results queue + self._shared_results_queue.put(item, block=True, timeout=1.0) + + # Move to next worker in round-robin fashion + current_worker = (current_worker + 1) % self.workers_count + + except queue.Empty: + # No item available from current worker, move to next + current_worker = (current_worker + 1) % self.workers_count + continue + except queue.Full: + # Shared queue is full, wait a bit and try again + continue + except Exception: + # Any other exception, continue to next worker + current_worker = (current_worker + 1) % self.workers_count + continue + def results_qsize(self): - return self._results_queue.qsize() + return self._shared_results_queue.qsize() @property def diagnostics(self): diff --git a/petastorm/workers_pool/ventilator.py b/petastorm/workers_pool/ventilator.py index 0f26bec13..dfa4c21f4 100644 --- a/petastorm/workers_pool/ventilator.py +++ b/petastorm/workers_pool/ventilator.py @@ -66,9 +66,7 @@ def __init__(self, items_to_ventilate, iterations=1, randomize_item_order=False, - random_seed=None, - max_ventilation_queue_size=None, - ventilation_interval=_VENTILATION_INTERVAL): + random_seed=None): """ Constructor for a concurrent ventilator. @@ -98,17 +96,11 @@ def __init__(self, self._items_to_ventilate = items_to_ventilate self._iterations_remaining = iterations self._randomize_item_order = randomize_item_order - self._random_state = np.random.RandomState(seed=random_seed) + self._random_seed = random_seed + self._rng = np.random.default_rng(self._random_seed) self._iterations = iterations - # For the default max ventilation queue size we will use the size of the items to ventilate - self._max_ventilation_queue_size = max_ventilation_queue_size or len(items_to_ventilate) - self._ventilation_interval = ventilation_interval - - self._current_item_to_ventilate = 0 self._ventilation_thread = None - self._ventilated_items_count = 0 - self._processed_items_count = 0 self._stop_requested = False def start(self): @@ -118,7 +110,7 @@ def start(self): self._ventilation_thread.start() def processed_item(self): - self._processed_items_count += 1 + pass def completed(self): assert self._iterations_remaining is None or self._iterations_remaining >= 0 @@ -141,25 +133,21 @@ def _ventilate(self): if self.completed(): break - # If we are ventilating the first item, we check if we would like to randomize the item order - if self._current_item_to_ventilate == 0 and self._randomize_item_order: - self._random_state.shuffle(self._items_to_ventilate) - - # Block until queue has room, but use continue to allow for checking if stop has been called - if self._ventilated_items_count - self._processed_items_count >= self._max_ventilation_queue_size: - sleep(self._ventilation_interval) - continue - - item_to_ventilate = self._items_to_ventilate[self._current_item_to_ventilate] - self._ventilate_fn(**item_to_ventilate) - self._current_item_to_ventilate += 1 - self._ventilated_items_count += 1 - - if self._current_item_to_ventilate >= len(self._items_to_ventilate): - self._current_item_to_ventilate = 0 - # If iterations was set to None, that means we will iterate until stop is called - if self._iterations_remaining is not None: - self._iterations_remaining -= 1 + if self._randomize_item_order: + if self._random_seed is not None and self._random_seed != 0: + # Deterministic randomization: use provided seed + items_to_ventilate = self._rng.permutation(self._items_to_ventilate) + else: + # Non-deterministic randomization: use np.random + items_to_ventilate = np.random.permutation(self._items_to_ventilate) + else: + # Deterministic natural order: randomize_item_order=False + items_to_ventilate = self._items_to_ventilate.copy() + + self._ventilate_fn(items_to_ventilate) + + if self._iterations_remaining is not None: + self._iterations_remaining -= 1 def stop(self): self._stop_requested = True From 3b02944388b3897db6a2f887f4d2951b18fe59a1 Mon Sep 17 00:00:00 2001 From: "arushi.arora" Date: Tue, 15 Jul 2025 22:02:15 +0000 Subject: [PATCH 02/25] Fixed and tested locally --- petastorm/arrow_reader_worker.py | 1 - petastorm/reader.py | 25 ++--- petastorm/workers_pool/thread_pool.py | 151 +++++++++++++++----------- petastorm/workers_pool/ventilator.py | 24 ++-- 4 files changed, 110 insertions(+), 91 deletions(-) diff --git a/petastorm/arrow_reader_worker.py b/petastorm/arrow_reader_worker.py index 895745b2a..4a38a3e4c 100644 --- a/petastorm/arrow_reader_worker.py +++ b/petastorm/arrow_reader_worker.py @@ -167,7 +167,6 @@ def process(self, piece_index, worker_predicate, shuffle_row_drop_partition): lambda: self._load_rows(parquet_file, piece, shuffle_row_drop_partition)) if all_cols: - # 3. Also pass the Worker ID to the publish_function self.publish_func(all_cols) @staticmethod diff --git a/petastorm/reader.py b/petastorm/reader.py index c9b6490fe..b6c459d7d 100644 --- a/petastorm/reader.py +++ b/petastorm/reader.py @@ -14,18 +14,19 @@ import collections.abc import logging +import sys import warnings import six from pyarrow import parquet as pq -from petastorm.arrow_reader_worker import ArrowReaderWorker +# Import ArrowReaderWorker from local modified file instead of installed package +from arrow_reader_worker import ArrowReaderWorker from petastorm.cache import NullCache from petastorm.errors import NoDataAvailableError from petastorm.etl import dataset_metadata, rowgroup_indexing from petastorm.etl.dataset_metadata import PetastormMetadataError, infer_or_load_unischema from petastorm.fs_utils import get_filesystem_and_path_or_paths, normalize_dir_url -from petastorm.local_disk_arrow_table_cache import LocalDiskArrowTableCache from petastorm.local_disk_cache import LocalDiskCache from petastorm.ngram import NGram from petastorm.predicates import PredicateBase @@ -36,8 +37,8 @@ from petastorm.transform import transform_schema from petastorm.workers_pool.dummy_pool import DummyPool from petastorm.workers_pool.process_pool import ProcessPool -from petastorm.workers_pool.thread_pool import ThreadPool -from petastorm.workers_pool.ventilator import ConcurrentVentilator +from thread_pool import ThreadPool +from ventilator import ConcurrentVentilator logger = logging.getLogger(__name__) @@ -60,7 +61,7 @@ def normalize_dataset_url_or_urls(dataset_url_or_urls): def make_reader(dataset_url, schema_fields=None, - reader_pool_type='thread', workers_count=10, pyarrow_serialize=False, results_queue_size=50, + reader_pool_type='thread', workers_count=10, pyarrow_serialize=False, results_queue_size=25, seed=None, shuffle_rows=False, shuffle_row_groups=True, shuffle_row_drop_partitions=1, predicate=None, @@ -160,7 +161,7 @@ def make_reader(dataset_url, 'To read from a non-Petastorm Parquet store use make_batch_reader') if reader_pool_type == 'thread': - reader_pool = ThreadPool(workers_count, results_queue_size, shuffle_rows=shuffle_rows, seed=seed) + reader_pool = ThreadPool(workers_count, shuffle_rows=shuffle_rows, seed=seed) elif reader_pool_type == 'process': if pyarrow_serialize: warnings.warn("pyarrow_serializer was deprecated and will be removed in future versions. " @@ -206,6 +207,7 @@ def make_reader(dataset_url, def make_batch_reader(dataset_url_or_urls, schema_fields=None, reader_pool_type='thread', workers_count=10, + results_queue_size=25, seed=None, shuffle_rows=False, shuffle_row_groups=True, shuffle_row_drop_partitions=1, predicate=None, @@ -307,13 +309,13 @@ def make_batch_reader(dataset_url_or_urls, if cache_type is None or cache_type == NULL_CACHE: cache = NullCache() elif cache_type == LOCAL_DISK_CACHE: - cache = LocalDiskArrowTableCache(cache_location, cache_size_limit, cache_row_size_estimate, - **cache_extra_settings or {}) + cache = LocalDiskCache(cache_location, cache_size_limit, cache_row_size_estimate, + **cache_extra_settings or {}) else: raise ValueError('Unknown cache_type: {}'.format(cache_type)) - + if reader_pool_type == 'thread': - reader_pool = ThreadPool(workers_count, results_queue_size, shuffle_rows=shuffle_rows, seed=seed) + reader_pool = ThreadPool(workers_count, shuffle_rows=shuffle_rows, seed=seed) elif reader_pool_type == 'process': serializer = ArrowTableSerializer() reader_pool = ProcessPool(workers_count, serializer, zmq_copy_buffers=zmq_copy_buffers) @@ -475,13 +477,11 @@ def __init__(self, pyarrow_filesystem, dataset_path, schema_fields=None, self.num_epochs, worker_predicate, self._workers_pool.workers_count + _VENTILATE_EXTRA_ROWGROUPS, seed) - # 5. Start workers pool self._workers_pool.start(worker_class, (pyarrow_filesystem, dataset_path, storage_schema, self.ngram, row_groups, cache, transform_spec, self.schema, filters, shuffle_rows, seed), ventilator=self.ventilator) - logger.debug('Workers pool started') self.last_row_consumed = False self.stopped = False @@ -659,7 +659,6 @@ def _create_ventilator(self, row_group_indexes, shuffle_row_groups, shuffle_row_ 'worker_predicate': worker_predicate, 'shuffle_row_drop_partition': (shuffle_row_drop_partition, shuffle_row_drop_partitions)}) - return ConcurrentVentilator(self._workers_pool.ventilate, items_to_ventilate, iterations=num_epochs, diff --git a/petastorm/workers_pool/thread_pool.py b/petastorm/workers_pool/thread_pool.py index 45efb817d..2f104fffb 100644 --- a/petastorm/workers_pool/thread_pool.py +++ b/petastorm/workers_pool/thread_pool.py @@ -13,9 +13,11 @@ # limitations under the License. import cProfile +import logging import pstats import random import sys +import os from threading import Thread, Event from traceback import format_exc @@ -45,6 +47,7 @@ def __init__(self, worker_impl, stop_event, ventilator_queue, results_queue, pro self._ventilator_queue = ventilator_queue self._results_queue = results_queue self._profiling_enabled = profiling_enabled + self._items_processed = 0 if profiling_enabled: self.prof = cProfile.Profile() @@ -58,16 +61,11 @@ def run(self): break # If the message came from work_receiver channel try: - (args, kargs) = self._ventilator_queue.get(block=True, timeout=IO_TIMEOUT_INTERVAL_S) - # Mark worker as busy when processing - self._worker_impl.thread_pool._set_worker_busy(self._worker_impl.worker_id) - self._worker_impl.process(*args, **kargs) + item = self._ventilator_queue.get(block=True, timeout=IO_TIMEOUT_INTERVAL_S) + self._worker_impl.process(**item) self._worker_impl.publish_func(VentilatedItemProcessedMessage()) - # Mark worker as idle after processing - self._worker_impl.thread_pool._set_worker_idle(self._worker_impl.worker_id) + self._items_processed += 1 # Only increment for actual data items except queue.Empty: - # Mark worker as idle when waiting - self._worker_impl.thread_pool._set_worker_idle(self._worker_impl.worker_id) pass except WorkerTerminationRequested: pass @@ -79,10 +77,13 @@ def run(self): break if self._profiling_enabled: self.prof.disable() + + def is_worker_done(self): + return self._items_processed == self._worker_impl.thread_pool._items_per_worker[self._worker_impl.worker_id] class ThreadPool(object): - def __init__(self, workers_count, results_queue_size=50, profiling_enabled=False, shuffle_rows=False, seed=None): + def __init__(self, workers_count, results_queue_size=25, worker_results_queue_size=5, profiling_enabled=False, shuffle_rows=False, seed=None): """Initializes a thread pool. TODO: consider using a standard thread pool @@ -96,10 +97,12 @@ def __init__(self, workers_count, results_queue_size=50, profiling_enabled=False :param shuffle_rows: Whether to shuffle rows (affects round-robin behavior) :param seed: Random seed for deterministic behavior """ + self._workers = [] self._ventilator_queues = None self.workers_count = workers_count self._results_queue_size = results_queue_size + self._worker_results_queue_size = worker_results_queue_size # Worker threads will watch this event and gracefully shutdown when the event is set self._stop_event = Event() self._profiling_enabled = profiling_enabled @@ -112,10 +115,9 @@ def __init__(self, workers_count, results_queue_size=50, profiling_enabled=False # Round-robin consumer thread self._round_robin_thread = None + self._items_per_worker = [0] * workers_count - # Worker status tracking - self._worker_status = [False] * workers_count # False = idle, True = busy - + def start(self, worker_class, worker_args=None, ventilator=None): """Starts worker threads. @@ -127,15 +129,18 @@ class must implement :class:`.WorkerBase` protocol """ # Verify stop_event and raise exception if it's already set! if self._stop_event.is_set(): - raise RuntimeError('ThreadPool({}) cannot be reused! stop_event set? {}' - .format(len(self._workers), self._stop_event.is_set())) + error_msg = f'ThreadPool({len(self._workers)}) cannot be reused! stop_event set? {self._stop_event.is_set()}' + raise RuntimeError(error_msg) + # Set up a channel to send work self._ventilator_queues = [queue.Queue() for _ in range(self.workers_count)] - # Todo: Update the results queue size - self._results_queues = [queue.Queue(self._results_queue_size) for _ in range(self.workers_count)] + self._results_queues = [queue.Queue(self._worker_results_queue_size) for _ in range(self.workers_count)] self._shared_results_queue = queue.Queue(self._results_queue_size) self._workers = [] + self.thread_pool = self + + for worker_id in range(self.workers_count): # Create a closure that captures the worker_id for this specific worker def make_publish_func(worker_id): @@ -151,37 +156,46 @@ def make_publish_func(worker_id): # will occur. new_thread.daemon = True self._workers.append(new_thread) + # Spin up all worker threads - for w in self._workers: + for i, w in enumerate(self._workers): w.start() - # Start round-robin consumer thread self._round_robin_thread = Thread(target=self._round_robin_consumer, daemon=True) - self._round_robin_thread.start() - + if ventilator: self._ventilator = ventilator self._ventilator.start() - def _set_worker_busy(self, worker_id): - """Mark worker as busy (processing work).""" - self._worker_status[worker_id] = True - - def _set_worker_idle(self, worker_id): - """Mark worker as idle (waiting for work).""" - self._worker_status[worker_id] = False - - def _is_worker_idle(self, worker_id): - """Check if worker is idle.""" - return not self._worker_status[worker_id] - def ventilate(self, items_to_ventilate): """Sends a work item to a worker process. Will result in ``worker.process(...)`` call with arbitrary arguments. """ + for i, item in enumerate(items_to_ventilate): - self._ventilator_queues[i % self.workers_count].put(item) + worker_id = i % self.workers_count + self._ventilator_queues[worker_id].put(item) + self._items_per_worker[worker_id] += 1 self._ventilated_items += 1 + + # Start the round-robin consumer after ventilation has started + if self._round_robin_thread and not self._round_robin_thread.is_alive(): + self._round_robin_thread.start() + + + def all_workers_done(self): + for worker_id in range(self.workers_count): + if not self._results_queues[worker_id].empty() or not self._ventilator_queues[worker_id].empty() or not self._workers[worker_id].is_worker_done(): + return False + return True + + def completed(self): + # If all workers are done and shared queue is empty, raise EmptyResultError + if self.all_workers_done() and self._shared_results_queue.empty(): + if not self._ventilator or self._ventilator.completed(): + return True + return False + def get_results(self): """Returns results from worker pool or re-raise worker's exception if any happen in worker thread. @@ -194,31 +208,14 @@ def get_results(self): """ while True: - # Check termination condition: all workers are truly done - all_workers_done = True - for worker_id in range(self.workers_count): - worker_done = ( - self._results_queues[worker_id].empty() and # Worker result queue is empty - self._is_worker_idle(worker_id) and # Worker is idle - self._ventilator_queues[worker_id].empty() # Ventilator queue for worker is empty - ) - if not worker_done: - all_workers_done = False - break - - # If all workers are done and shared queue is empty, raise EmptyResultError - if all_workers_done and self._shared_results_queue.empty(): - if not self._ventilator or self._ventilator.completed(): - raise EmptyResultError() + # Check termination condition: all workers are truly done and shared queue is empty + if self.completed(): + raise EmptyResultError() try: + result = self._shared_results_queue.get(timeout=_VERIFY_END_OF_VENTILATION_PERIOD) - if isinstance(result, VentilatedItemProcessedMessage): - self._ventilated_items_processed += 1 - if self._ventilator: - self._ventilator.processed_item() - continue - elif isinstance(result, Exception): + if isinstance(result, Exception): self.stop() self.join() raise result @@ -229,18 +226,24 @@ def get_results(self): def stop(self): """Stops all workers (non-blocking).""" + if self._ventilator: self._ventilator.stop() self._stop_event.set() + def join(self): """Block until all workers are terminated.""" - for w in self._workers: + + for i, w in enumerate(self._workers): if w.is_alive(): w.join() + # Join the round-robin consumer thread + if self._round_robin_thread and self._round_robin_thread.is_alive(): + self._round_robin_thread.join() + if self._profiling_enabled: - # If we have profiling set, collect stats and print them stats = None for w in self._workers: if stats: @@ -248,6 +251,16 @@ def join(self): else: stats = pstats.Stats(w.prof) stats.sort_stats('cumulative').print_stats() + + if self._profiling_enabled: + stats = None + for w in self._workers: + if stats: + stats.add(w.prof) + else: + stats = pstats.Stats(w.prof) + stats.sort_stats('cumulative').print_stats() + def _stop_aware_put(self, data, worker_id): """This method is called to write the results to the results queue. We use ``put`` in a non-blocking way so we @@ -255,6 +268,11 @@ def _stop_aware_put(self, data, worker_id): The method raises :class:`.WorkerTerminationRequested` exception that should be passed through all the way up to :func:`WorkerThread.run` which will gracefully terminate main worker loop.""" + + # Skip control messages - they shouldn't go into the results queue + if isinstance(data, VentilatedItemProcessedMessage): + return + while True: try: self._results_queues[worker_id].put(data, block=True, timeout=IO_TIMEOUT_INTERVAL_S) @@ -268,6 +286,7 @@ def _stop_aware_put(self, data, worker_id): def _round_robin_consumer(self): """Round-robin consumer that takes items from each worker's queue in strict round-robin order and puts them into the shared results queue.""" + current_worker = 0 # Determine if we should use non-blocking behavior @@ -275,10 +294,12 @@ def _round_robin_consumer(self): while not self._stop_event.is_set(): try: + if self.all_workers_done() : + break # Check if current worker should be skipped should_skip = ( self._results_queues[current_worker].empty() and # Worker result queue is empty - self._is_worker_idle(current_worker) and # Worker is idle + self._workers[current_worker].is_worker_done() and # Worker is done self._ventilator_queues[current_worker].empty() # Ventilator queue for worker is empty ) @@ -298,14 +319,14 @@ def _round_robin_consumer(self): continue else: # Blocking: wait for item (strict round-robin) - item = self._results_queues[current_worker].get(block=True, timeout=1.0) - + item = self._results_queues[current_worker].get(block=True, timeout=5.0) # Put the item into the shared results queue - self._shared_results_queue.put(item, block=True, timeout=1.0) + if not isinstance(item, VentilatedItemProcessedMessage): + # Skip VentilatedItemProcessedMessage - it's just a control message + self._shared_results_queue.put(item, block=False) # Move to next worker in round-robin fashion current_worker = (current_worker + 1) % self.workers_count - except queue.Empty: # No item available from current worker, move to next current_worker = (current_worker + 1) % self.workers_count @@ -313,14 +334,14 @@ def _round_robin_consumer(self): except queue.Full: # Shared queue is full, wait a bit and try again continue - except Exception: + except Exception as e: # Any other exception, continue to next worker current_worker = (current_worker + 1) % self.workers_count continue - + def results_qsize(self): return self._shared_results_queue.qsize() @property def diagnostics(self): - return {'output_queue_size': self.results_qsize()} + return {'output_queue_size': self.results_qsize()} \ No newline at end of file diff --git a/petastorm/workers_pool/ventilator.py b/petastorm/workers_pool/ventilator.py index dfa4c21f4..b3752904c 100644 --- a/petastorm/workers_pool/ventilator.py +++ b/petastorm/workers_pool/ventilator.py @@ -16,6 +16,7 @@ import threading from abc import ABCMeta, abstractmethod from time import sleep +import sys import six @@ -128,26 +129,25 @@ def reset(self): self.start() def _ventilate(self): + if self._randomize_item_order: + if self._random_seed is not None and self._random_seed != 0: + # Deterministic randomization: use provided seed + self._items_to_ventilate = list(self._rng.permutation(self._items_to_ventilate)) + else: + # Non-deterministic randomization: use np.random + self._items_to_ventilate = list(np.random.permutation(self._items_to_ventilate)) + while True: # Stop condition is when no iterations are remaining or there are no items to ventilate if self.completed(): break - if self._randomize_item_order: - if self._random_seed is not None and self._random_seed != 0: - # Deterministic randomization: use provided seed - items_to_ventilate = self._rng.permutation(self._items_to_ventilate) - else: - # Non-deterministic randomization: use np.random - items_to_ventilate = np.random.permutation(self._items_to_ventilate) - else: - # Deterministic natural order: randomize_item_order=False - items_to_ventilate = self._items_to_ventilate.copy() - - self._ventilate_fn(items_to_ventilate) + self._ventilate_fn(self._items_to_ventilate) if self._iterations_remaining is not None: self._iterations_remaining -= 1 + elif self._iterations_remaining is None: + self._iterations_remaining = 0 def stop(self): self._stop_requested = True From ff076ec472b80a4b950d47abdc6e3fd0694a5741 Mon Sep 17 00:00:00 2001 From: "arushi.arora" Date: Fri, 18 Jul 2025 18:01:17 +0000 Subject: [PATCH 03/25] Implement the alternate design to fix race condition for some tests --- petastorm/arrow_reader_worker.py | 2 +- petastorm/reader.py | 20 ++- petastorm/workers_pool/thread_pool.py | 209 +++++++------------------- petastorm/workers_pool/ventilator.py | 40 +++-- 4 files changed, 101 insertions(+), 170 deletions(-) diff --git a/petastorm/arrow_reader_worker.py b/petastorm/arrow_reader_worker.py index 4a38a3e4c..ad6d5dd39 100644 --- a/petastorm/arrow_reader_worker.py +++ b/petastorm/arrow_reader_worker.py @@ -292,7 +292,7 @@ def _read_with_shuffle_row_drop(self, piece, pq_file, column_names, shuffle_row_ # pyarrow would fail if we request a column names that the dataset is partitioned by table = piece.read(columns=column_names - partition_names, partitions=self._dataset.partitions) - + # Handle row shuffling based on shuffle_rows setting if self._shuffle_rows: if self._random_seed is not None and self._random_seed != 0: diff --git a/petastorm/reader.py b/petastorm/reader.py index b6c459d7d..a2fcc57a9 100644 --- a/petastorm/reader.py +++ b/petastorm/reader.py @@ -14,13 +14,11 @@ import collections.abc import logging -import sys import warnings import six from pyarrow import parquet as pq -# Import ArrowReaderWorker from local modified file instead of installed package from arrow_reader_worker import ArrowReaderWorker from petastorm.cache import NullCache from petastorm.errors import NoDataAvailableError @@ -61,7 +59,7 @@ def normalize_dataset_url_or_urls(dataset_url_or_urls): def make_reader(dataset_url, schema_fields=None, - reader_pool_type='thread', workers_count=10, pyarrow_serialize=False, results_queue_size=25, + reader_pool_type='thread', workers_count=10, pyarrow_serialize=False, results_queue_size=50, seed=None, shuffle_rows=False, shuffle_row_groups=True, shuffle_row_drop_partitions=1, predicate=None, @@ -161,7 +159,7 @@ def make_reader(dataset_url, 'To read from a non-Petastorm Parquet store use make_batch_reader') if reader_pool_type == 'thread': - reader_pool = ThreadPool(workers_count, shuffle_rows=shuffle_rows, seed=seed) + reader_pool = ThreadPool(workers_count, results_queue_size, shuffle_rows, seed) elif reader_pool_type == 'process': if pyarrow_serialize: warnings.warn("pyarrow_serializer was deprecated and will be removed in future versions. " @@ -207,7 +205,7 @@ def make_reader(dataset_url, def make_batch_reader(dataset_url_or_urls, schema_fields=None, reader_pool_type='thread', workers_count=10, - results_queue_size=25, + results_queue_size=50, seed=None, shuffle_rows=False, shuffle_row_groups=True, shuffle_row_drop_partitions=1, predicate=None, @@ -246,6 +244,8 @@ def make_batch_reader(dataset_url_or_urls, denoting a thread pool, process pool, or running everything in the master thread. Defaults to 'thread' :param workers_count: An int for the number of workers to use in the reader pool. This only is used for the thread or process pool. Defaults to 10 + :param results_queue_size: Size of the results queue to store prefetched row-groups. Currently only applicable to + thread reader pool type. :param seed: Random seed specified for shuffle and sharding with reproducible outputs. Defaults to None :param shuffle_rows: Whether to shuffle inside a single row group. Defaults to False. :param shuffle_row_groups: Whether to shuffle row groups (the order in which full row groups are read) @@ -313,9 +313,9 @@ def make_batch_reader(dataset_url_or_urls, **cache_extra_settings or {}) else: raise ValueError('Unknown cache_type: {}'.format(cache_type)) - + if reader_pool_type == 'thread': - reader_pool = ThreadPool(workers_count, shuffle_rows=shuffle_rows, seed=seed) + reader_pool = ThreadPool(workers_count, results_queue_size, shuffle_rows, seed) elif reader_pool_type == 'process': serializer = ArrowTableSerializer() reader_pool = ProcessPool(workers_count, serializer, zmq_copy_buffers=zmq_copy_buffers) @@ -477,11 +477,13 @@ def __init__(self, pyarrow_filesystem, dataset_path, schema_fields=None, self.num_epochs, worker_predicate, self._workers_pool.workers_count + _VENTILATE_EXTRA_ROWGROUPS, seed) + # 5. Start workers pool self._workers_pool.start(worker_class, (pyarrow_filesystem, dataset_path, storage_schema, self.ngram, row_groups, cache, transform_spec, self.schema, filters, shuffle_rows, seed), ventilator=self.ventilator) + logger.debug('Workers pool started') self.last_row_consumed = False self.stopped = False @@ -659,9 +661,11 @@ def _create_ventilator(self, row_group_indexes, shuffle_row_groups, shuffle_row_ 'worker_predicate': worker_predicate, 'shuffle_row_drop_partition': (shuffle_row_drop_partition, shuffle_row_drop_partitions)}) + return ConcurrentVentilator(self._workers_pool.ventilate, items_to_ventilate, iterations=num_epochs, + max_ventilation_queue_size=max_ventilation_queue_size, randomize_item_order=shuffle_row_groups, random_seed=seed) @@ -702,4 +706,4 @@ def __enter__(self): def __exit__(self, exc_type, exc_val, exc_tb): self.stop() - self.join() + self.join() \ No newline at end of file diff --git a/petastorm/workers_pool/thread_pool.py b/petastorm/workers_pool/thread_pool.py index 2f104fffb..e34b91968 100644 --- a/petastorm/workers_pool/thread_pool.py +++ b/petastorm/workers_pool/thread_pool.py @@ -13,11 +13,9 @@ # limitations under the License. import cProfile -import logging import pstats import random import sys -import os from threading import Thread, Event from traceback import format_exc @@ -29,7 +27,7 @@ IO_TIMEOUT_INTERVAL_S = 0.001 # Amount of time we will wait on a the queue to get the next result. If no results received until then, we will # recheck if no more items are expected to be ventilated -_VERIFY_END_OF_VENTILATION_PERIOD = 0.1 +_VERIFY_END_OF_VENTILATION_PERIOD = 1 class WorkerTerminationRequested(Exception): @@ -47,7 +45,6 @@ def __init__(self, worker_impl, stop_event, ventilator_queue, results_queue, pro self._ventilator_queue = ventilator_queue self._results_queue = results_queue self._profiling_enabled = profiling_enabled - self._items_processed = 0 if profiling_enabled: self.prof = cProfile.Profile() @@ -61,10 +58,9 @@ def run(self): break # If the message came from work_receiver channel try: - item = self._ventilator_queue.get(block=True, timeout=IO_TIMEOUT_INTERVAL_S) - self._worker_impl.process(**item) + (args, kargs) = self._ventilator_queue.get(block=True, timeout=IO_TIMEOUT_INTERVAL_S) + self._worker_impl.process(*args, **kargs) self._worker_impl.publish_func(VentilatedItemProcessedMessage()) - self._items_processed += 1 # Only increment for actual data items except queue.Empty: pass except WorkerTerminationRequested: @@ -77,13 +73,10 @@ def run(self): break if self._profiling_enabled: self.prof.disable() - - def is_worker_done(self): - return self._items_processed == self._worker_impl.thread_pool._items_per_worker[self._worker_impl.worker_id] class ThreadPool(object): - def __init__(self, workers_count, results_queue_size=25, worker_results_queue_size=5, profiling_enabled=False, shuffle_rows=False, seed=None): + def __init__(self, workers_count, results_queue_size=50, shuffle_rows=False, seed=None, profiling_enabled=False): """Initializes a thread pool. TODO: consider using a standard thread pool @@ -93,31 +86,28 @@ def __init__(self, workers_count, results_queue_size=25, worker_results_queue_si implementation that would not use fork) :param workers_count: Number of threads - :param profiling_enabled: Whether to run a profiler on the threads - :param shuffle_rows: Whether to shuffle rows (affects round-robin behavior) - :param seed: Random seed for deterministic behavior + :param profile: Whether to run a profiler on the threads """ - + self._seed = random.randint(0, 100000) + self._shuffle_rows = shuffle_rows + self._seed = seed self._workers = [] - self._ventilator_queues = None + self._ventilator_queues = [] + self.workers_count = workers_count self._results_queue_size = results_queue_size - self._worker_results_queue_size = worker_results_queue_size # Worker threads will watch this event and gracefully shutdown when the event is set self._stop_event = Event() self._profiling_enabled = profiling_enabled - self._shuffle_rows = shuffle_rows - self._seed = seed self._ventilated_items = 0 - self._ventilated_items_processed = 0 + self._ventilated_items_by_worker = [0 for _ in range(self.workers_count)] + # self._ventilated_items_processed = 0 + self._ventilated_items_processed_by_worker = [0 for _ in range(self.workers_count)] self._ventilator = None - - # Round-robin consumer thread - self._round_robin_thread = None - self._items_per_worker = [0] * workers_count - - + + self._get_results_worker_id = 0 + def start(self, worker_class, worker_args=None, ventilator=None): """Starts worker threads. @@ -129,73 +119,52 @@ class must implement :class:`.WorkerBase` protocol """ # Verify stop_event and raise exception if it's already set! if self._stop_event.is_set(): - error_msg = f'ThreadPool({len(self._workers)}) cannot be reused! stop_event set? {self._stop_event.is_set()}' - raise RuntimeError(error_msg) + raise RuntimeError('ThreadPool({}) cannot be reused! stop_event set? {}' + .format(len(self._workers), self._stop_event.is_set())) - # Set up a channel to send work self._ventilator_queues = [queue.Queue() for _ in range(self.workers_count)] - self._results_queues = [queue.Queue(self._worker_results_queue_size) for _ in range(self.workers_count)] - self._shared_results_queue = queue.Queue(self._results_queue_size) - self._workers = [] - self.thread_pool = self - + + self._results_queues = [queue.Queue(self._results_queue_size / self.workers_count) for _ in range(self.workers_count)] + self._workers = [] for worker_id in range(self.workers_count): # Create a closure that captures the worker_id for this specific worker def make_publish_func(worker_id): - return lambda data: self._stop_aware_put(data, worker_id) + return lambda data: self._stop_aware_put(worker_id, data) worker_impl = worker_class(worker_id, make_publish_func(worker_id), worker_args) - # Add thread_pool reference to worker for status tracking - worker_impl.thread_pool = self - new_thread = WorkerThread(worker_impl, self._stop_event, self._ventilator_queues[worker_id], self._results_queues[worker_id], self._profiling_enabled) # Make the thread daemonic. Since it only reads it's ok to abort while running - no resource corruption # will occur. new_thread.daemon = True self._workers.append(new_thread) - # Spin up all worker threads - for i, w in enumerate(self._workers): + for w in self._workers: w.start() - self._round_robin_thread = Thread(target=self._round_robin_consumer, daemon=True) - if ventilator: self._ventilator = ventilator self._ventilator.start() - def ventilate(self, items_to_ventilate): + def ventilate(self, *args, **kargs): """Sends a work item to a worker process. Will result in ``worker.process(...)`` call with arbitrary arguments. """ + current_worker_id = self._ventilated_items % self.workers_count + self._ventilated_items += 1 + self._ventilated_items_by_worker[current_worker_id] += 1 + self._ventilator_queues[current_worker_id].put((args, kargs)) - for i, item in enumerate(items_to_ventilate): - worker_id = i % self.workers_count - self._ventilator_queues[worker_id].put(item) - self._items_per_worker[worker_id] += 1 - self._ventilated_items += 1 - - # Start the round-robin consumer after ventilation has started - if self._round_robin_thread and not self._round_robin_thread.is_alive(): - self._round_robin_thread.start() - + def current_worker_done(self, worker_id): + return self._ventilated_items_processed_by_worker[worker_id] == self._ventilated_items_by_worker[worker_id] and self._results_queues[worker_id].empty() def all_workers_done(self): - for worker_id in range(self.workers_count): - if not self._results_queues[worker_id].empty() or not self._ventilator_queues[worker_id].empty() or not self._workers[worker_id].is_worker_done(): + for i in range(self.workers_count): + if not self.current_worker_done(i): return False return True - - def completed(self): - # If all workers are done and shared queue is empty, raise EmptyResultError - if self.all_workers_done() and self._shared_results_queue.empty(): - if not self._ventilator or self._ventilator.completed(): - return True - return False - def get_results(self): """Returns results from worker pool or re-raise worker's exception if any happen in worker thread. @@ -206,16 +175,28 @@ def get_results(self): :return: arguments passed to ``publish_func(...)`` by a worker. If no more results are anticipated, :class:`.EmptyResultError`. """ - - while True: - # Check termination condition: all workers are truly done and shared queue is empty - if self.completed(): - raise EmptyResultError() + use_non_blocking_get = self._shuffle_rows and (self._seed is None or self._seed == 0) + while True: + # If there is no more work to do, raise an EmptyResultError + if self.all_workers_done(): + # We also need to check if we are using a ventilator and if it is completed + if not self._ventilator or self._ventilator.completed(): + raise EmptyResultError() + + # If the current worker is done, we need to get the result from the next worker + if self.current_worker_done(self._get_results_worker_id): + self._get_results_worker_id = (self._get_results_worker_id + 1) % self.workers_count + continue try: - - result = self._shared_results_queue.get(timeout=_VERIFY_END_OF_VENTILATION_PERIOD) - if isinstance(result, Exception): + result = self._results_queues[self._get_results_worker_id].get(block=not use_non_blocking_get, timeout=_VERIFY_END_OF_VENTILATION_PERIOD) + if isinstance(result, VentilatedItemProcessedMessage): + self._ventilated_items_processed_by_worker[self._get_results_worker_id] += 1 + if self._ventilator: + self._ventilator.processed_item() + self._get_results_worker_id = (self._get_results_worker_id + 1) % self.workers_count + continue + elif isinstance(result, Exception): self.stop() self.join() raise result @@ -224,35 +205,21 @@ def get_results(self): except queue.Empty: continue + def stop(self): """Stops all workers (non-blocking).""" - if self._ventilator: self._ventilator.stop() self._stop_event.set() - def join(self): """Block until all workers are terminated.""" - - for i, w in enumerate(self._workers): + for w in self._workers: if w.is_alive(): w.join() - # Join the round-robin consumer thread - if self._round_robin_thread and self._round_robin_thread.is_alive(): - self._round_robin_thread.join() - - if self._profiling_enabled: - stats = None - for w in self._workers: - if stats: - stats.add(w.prof) - else: - stats = pstats.Stats(w.prof) - stats.sort_stats('cumulative').print_stats() - if self._profiling_enabled: + # If we have profiling set, collect stats and print them stats = None for w in self._workers: if stats: @@ -261,18 +228,12 @@ def join(self): stats = pstats.Stats(w.prof) stats.sort_stats('cumulative').print_stats() - - def _stop_aware_put(self, data, worker_id): + def _stop_aware_put(self, worker_id, data): """This method is called to write the results to the results queue. We use ``put`` in a non-blocking way so we can gracefully terminate the worker thread without being stuck on :func:`Queue.put`. The method raises :class:`.WorkerTerminationRequested` exception that should be passed through all the way up to :func:`WorkerThread.run` which will gracefully terminate main worker loop.""" - - # Skip control messages - they shouldn't go into the results queue - if isinstance(data, VentilatedItemProcessedMessage): - return - while True: try: self._results_queues[worker_id].put(data, block=True, timeout=IO_TIMEOUT_INTERVAL_S) @@ -283,64 +244,8 @@ def _stop_aware_put(self, data, worker_id): if self._stop_event.is_set(): raise WorkerTerminationRequested() - def _round_robin_consumer(self): - """Round-robin consumer that takes items from each worker's queue in strict round-robin order - and puts them into the shared results queue.""" - - current_worker = 0 - - # Determine if we should use non-blocking behavior - use_non_blocking = self._shuffle_rows and (self._seed is None or self._seed == 0) - - while not self._stop_event.is_set(): - try: - if self.all_workers_done() : - break - # Check if current worker should be skipped - should_skip = ( - self._results_queues[current_worker].empty() and # Worker result queue is empty - self._workers[current_worker].is_worker_done() and # Worker is done - self._ventilator_queues[current_worker].empty() # Ventilator queue for worker is empty - ) - - if should_skip: - # Skip this worker and move to next - current_worker = (current_worker + 1) % self.workers_count - continue - - # Try to get an item from the current worker's queue - if use_non_blocking: - # Non-blocking: try to get item without waiting - try: - item = self._results_queues[current_worker].get(block=False) - except queue.Empty: - # No item available, move to next worker immediately - current_worker = (current_worker + 1) % self.workers_count - continue - else: - # Blocking: wait for item (strict round-robin) - item = self._results_queues[current_worker].get(block=True, timeout=5.0) - # Put the item into the shared results queue - if not isinstance(item, VentilatedItemProcessedMessage): - # Skip VentilatedItemProcessedMessage - it's just a control message - self._shared_results_queue.put(item, block=False) - - # Move to next worker in round-robin fashion - current_worker = (current_worker + 1) % self.workers_count - except queue.Empty: - # No item available from current worker, move to next - current_worker = (current_worker + 1) % self.workers_count - continue - except queue.Full: - # Shared queue is full, wait a bit and try again - continue - except Exception as e: - # Any other exception, continue to next worker - current_worker = (current_worker + 1) % self.workers_count - continue - def results_qsize(self): - return self._shared_results_queue.qsize() + return self._results_queues[0].qsize() @property def diagnostics(self): diff --git a/petastorm/workers_pool/ventilator.py b/petastorm/workers_pool/ventilator.py index b3752904c..f40e4e448 100644 --- a/petastorm/workers_pool/ventilator.py +++ b/petastorm/workers_pool/ventilator.py @@ -16,7 +16,6 @@ import threading from abc import ABCMeta, abstractmethod from time import sleep -import sys import six @@ -67,7 +66,9 @@ def __init__(self, items_to_ventilate, iterations=1, randomize_item_order=False, - random_seed=None): + random_seed=None, + max_ventilation_queue_size=None, + ventilation_interval=_VENTILATION_INTERVAL): """ Constructor for a concurrent ventilator. @@ -99,9 +100,17 @@ def __init__(self, self._randomize_item_order = randomize_item_order self._random_seed = random_seed self._rng = np.random.default_rng(self._random_seed) + # self._random_state = np.random.RandomState(seed=random_seed) self._iterations = iterations + # For the default max ventilation queue size we will use the size of the items to ventilate + self._max_ventilation_queue_size = max_ventilation_queue_size or len(items_to_ventilate) + self._ventilation_interval = ventilation_interval + + self._current_item_to_ventilate = 0 self._ventilation_thread = None + self._ventilated_items_count = 0 + self._processed_items_count = 0 self._stop_requested = False def start(self): @@ -111,7 +120,7 @@ def start(self): self._ventilation_thread.start() def processed_item(self): - pass + self._processed_items_count += 1 def completed(self): assert self._iterations_remaining is None or self._iterations_remaining >= 0 @@ -142,15 +151,28 @@ def _ventilate(self): if self.completed(): break - self._ventilate_fn(self._items_to_ventilate) + # If we are ventilating the first item, we check if we would like to randomize the item order + # if self._current_item_to_ventilate == 0: + # self._random_state.shuffle(self._items_to_ventilate) + + # Block until queue has room, but use continue to allow for checking if stop has been called + if self._ventilated_items_count - self._processed_items_count >= self._max_ventilation_queue_size: + sleep(self._ventilation_interval) + continue + + item_to_ventilate = self._items_to_ventilate[self._current_item_to_ventilate] + self._ventilate_fn(**item_to_ventilate) + self._current_item_to_ventilate += 1 + self._ventilated_items_count += 1 - if self._iterations_remaining is not None: - self._iterations_remaining -= 1 - elif self._iterations_remaining is None: - self._iterations_remaining = 0 + if self._current_item_to_ventilate >= len(self._items_to_ventilate): + self._current_item_to_ventilate = 0 + # If iterations was set to None, that means we will iterate until stop is called + if self._iterations_remaining is not None: + self._iterations_remaining -= 1 def stop(self): self._stop_requested = True if self._ventilation_thread: self._ventilation_thread.join() - self._ventilation_thread = None + self._ventilation_thread = None \ No newline at end of file From 555548b7cc3881a3f2a78b9bff6e5effd21dca76 Mon Sep 17 00:00:00 2001 From: "arushi.arora" Date: Fri, 18 Jul 2025 20:31:44 +0000 Subject: [PATCH 04/25] do code cleanup --- petastorm/reader.py | 4 ++-- petastorm/workers_pool/thread_pool.py | 17 +++++++++++++---- petastorm/workers_pool/ventilator.py | 5 +---- 3 files changed, 16 insertions(+), 10 deletions(-) diff --git a/petastorm/reader.py b/petastorm/reader.py index a2fcc57a9..00b4e0210 100644 --- a/petastorm/reader.py +++ b/petastorm/reader.py @@ -159,7 +159,7 @@ def make_reader(dataset_url, 'To read from a non-Petastorm Parquet store use make_batch_reader') if reader_pool_type == 'thread': - reader_pool = ThreadPool(workers_count, results_queue_size, shuffle_rows, seed) + reader_pool = ThreadPool(workers_count, results_queue_size, shuffle_rows=shuffle_rows, seed=seed) elif reader_pool_type == 'process': if pyarrow_serialize: warnings.warn("pyarrow_serializer was deprecated and will be removed in future versions. " @@ -315,7 +315,7 @@ def make_batch_reader(dataset_url_or_urls, raise ValueError('Unknown cache_type: {}'.format(cache_type)) if reader_pool_type == 'thread': - reader_pool = ThreadPool(workers_count, results_queue_size, shuffle_rows, seed) + reader_pool = ThreadPool(workers_count, results_queue_size, shuffle_rows=shuffle_rows, seed=seed) elif reader_pool_type == 'process': serializer = ArrowTableSerializer() reader_pool = ProcessPool(workers_count, serializer, zmq_copy_buffers=zmq_copy_buffers) diff --git a/petastorm/workers_pool/thread_pool.py b/petastorm/workers_pool/thread_pool.py index e34b91968..0508eb20b 100644 --- a/petastorm/workers_pool/thread_pool.py +++ b/petastorm/workers_pool/thread_pool.py @@ -101,8 +101,9 @@ def __init__(self, workers_count, results_queue_size=50, shuffle_rows=False, see self._profiling_enabled = profiling_enabled self._ventilated_items = 0 + # Count of items ventilated by each worker self._ventilated_items_by_worker = [0 for _ in range(self.workers_count)] - # self._ventilated_items_processed = 0 + # Count of items processed by each worker self._ventilated_items_processed_by_worker = [0 for _ in range(self.workers_count)] self._ventilator = None @@ -122,10 +123,11 @@ class must implement :class:`.WorkerBase` protocol raise RuntimeError('ThreadPool({}) cannot be reused! stop_event set? {}' .format(len(self._workers), self._stop_event.is_set())) - # Set up a channel to send work + # Set up a channel for each worker to send work self._ventilator_queues = [queue.Queue() for _ in range(self.workers_count)] - self._results_queues = [queue.Queue(self._results_queue_size / self.workers_count) for _ in range(self.workers_count)] + # Set up a channel for each worker to send results + self._results_queues = max(5, [queue.Queue(self._results_queue_size / self.workers_count) for _ in range(self.workers_count)]) self._workers = [] for worker_id in range(self.workers_count): @@ -152,15 +154,18 @@ def make_publish_func(worker_id): def ventilate(self, *args, **kargs): """Sends a work item to a worker process. Will result in ``worker.process(...)`` call with arbitrary arguments. """ + # Distribute work items in a round-robin manner across each worker ventilator queue current_worker_id = self._ventilated_items % self.workers_count self._ventilated_items += 1 self._ventilated_items_by_worker[current_worker_id] += 1 self._ventilator_queues[current_worker_id].put((args, kargs)) def current_worker_done(self, worker_id): + # Check if the current worker has processed all the items it was assigned and if the results queue is empty return self._ventilated_items_processed_by_worker[worker_id] == self._ventilated_items_by_worker[worker_id] and self._results_queues[worker_id].empty() def all_workers_done(self): + # Check if all workers have processed all the items they were assigned and if the results queues are empty for i in range(self.workers_count): if not self.current_worker_done(i): return False @@ -175,6 +180,7 @@ def get_results(self): :return: arguments passed to ``publish_func(...)`` by a worker. If no more results are anticipated, :class:`.EmptyResultError`. """ + # If shuffle_rows is enabled and the seed is not set, we need to use a non-blocking as we don't care about the strict round robin order use_non_blocking_get = self._shuffle_rows and (self._seed is None or self._seed == 0) while True: # If there is no more work to do, raise an EmptyResultError @@ -189,11 +195,14 @@ def get_results(self): continue try: + # Get the result from the current worker's results queue. Use blocking/strict round robin if shuffle_rows is disabled or the seed is set result = self._results_queues[self._get_results_worker_id].get(block=not use_non_blocking_get, timeout=_VERIFY_END_OF_VENTILATION_PERIOD) + # If the result is a VentilatedItemProcessedMessage, we need to increment the count of items processed by the current worker if isinstance(result, VentilatedItemProcessedMessage): self._ventilated_items_processed_by_worker[self._get_results_worker_id] += 1 if self._ventilator: self._ventilator.processed_item() + # Move to the next worker self._get_results_worker_id = (self._get_results_worker_id + 1) % self.workers_count continue elif isinstance(result, Exception): @@ -245,7 +254,7 @@ def _stop_aware_put(self, worker_id, data): raise WorkerTerminationRequested() def results_qsize(self): - return self._results_queues[0].qsize() + return sum(queue.qsize() for queue in self._results_queues) @property def diagnostics(self): diff --git a/petastorm/workers_pool/ventilator.py b/petastorm/workers_pool/ventilator.py index f40e4e448..a75c2efb0 100644 --- a/petastorm/workers_pool/ventilator.py +++ b/petastorm/workers_pool/ventilator.py @@ -138,6 +138,7 @@ def reset(self): self.start() def _ventilate(self): + # Randomize the item order before starting the ventilation if randomize_item_order is set if self._randomize_item_order: if self._random_seed is not None and self._random_seed != 0: # Deterministic randomization: use provided seed @@ -151,10 +152,6 @@ def _ventilate(self): if self.completed(): break - # If we are ventilating the first item, we check if we would like to randomize the item order - # if self._current_item_to_ventilate == 0: - # self._random_state.shuffle(self._items_to_ventilate) - # Block until queue has room, but use continue to allow for checking if stop has been called if self._ventilated_items_count - self._processed_items_count >= self._max_ventilation_queue_size: sleep(self._ventilation_interval) From e3bccb6e6848662e433d7d543ccd386e75046e8b Mon Sep 17 00:00:00 2001 From: "arushi.arora" Date: Fri, 18 Jul 2025 23:04:12 +0000 Subject: [PATCH 05/25] Restore imports --- petastorm/workers_pool/ventilator.py | 1 - 1 file changed, 1 deletion(-) diff --git a/petastorm/workers_pool/ventilator.py b/petastorm/workers_pool/ventilator.py index a75c2efb0..14ae85235 100644 --- a/petastorm/workers_pool/ventilator.py +++ b/petastorm/workers_pool/ventilator.py @@ -100,7 +100,6 @@ def __init__(self, self._randomize_item_order = randomize_item_order self._random_seed = random_seed self._rng = np.random.default_rng(self._random_seed) - # self._random_state = np.random.RandomState(seed=random_seed) self._iterations = iterations # For the default max ventilation queue size we will use the size of the items to ventilate From a96fa326507c11f91d98823c7b4e4e440b4abdb5 Mon Sep 17 00:00:00 2001 From: "arushi.arora" Date: Fri, 18 Jul 2025 23:05:08 +0000 Subject: [PATCH 06/25] Restore imports --- petastorm/reader.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/petastorm/reader.py b/petastorm/reader.py index 00b4e0210..c5acf1975 100644 --- a/petastorm/reader.py +++ b/petastorm/reader.py @@ -19,7 +19,7 @@ import six from pyarrow import parquet as pq -from arrow_reader_worker import ArrowReaderWorker +from petastorm.arrow_reader_worker import ArrowReaderWorker from petastorm.cache import NullCache from petastorm.errors import NoDataAvailableError from petastorm.etl import dataset_metadata, rowgroup_indexing @@ -35,8 +35,8 @@ from petastorm.transform import transform_schema from petastorm.workers_pool.dummy_pool import DummyPool from petastorm.workers_pool.process_pool import ProcessPool -from thread_pool import ThreadPool -from ventilator import ConcurrentVentilator +from petastorm.workers_pool.thread_pool import ThreadPool +from petastorm.workers_pool.ventilator import ConcurrentVentilator logger = logging.getLogger(__name__) From 6708b9ddb49b3e6cce9ad3f64d303a640cf10767 Mon Sep 17 00:00:00 2001 From: "arushi.arora" Date: Wed, 23 Jul 2025 20:43:36 +0000 Subject: [PATCH 07/25] fix queue size --- petastorm/workers_pool/thread_pool.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/petastorm/workers_pool/thread_pool.py b/petastorm/workers_pool/thread_pool.py index 0508eb20b..5c51be77c 100644 --- a/petastorm/workers_pool/thread_pool.py +++ b/petastorm/workers_pool/thread_pool.py @@ -127,7 +127,7 @@ class must implement :class:`.WorkerBase` protocol self._ventilator_queues = [queue.Queue() for _ in range(self.workers_count)] # Set up a channel for each worker to send results - self._results_queues = max(5, [queue.Queue(self._results_queue_size / self.workers_count) for _ in range(self.workers_count)]) + self._results_queues = [queue.Queue(max(5, self._results_queue_size // self.workers_count)) for _ in range(self.workers_count)] self._workers = [] for worker_id in range(self.workers_count): From 0ff55cdfbb7ebddc05934d6cc232cd1e12d60eb0 Mon Sep 17 00:00:00 2001 From: "arushi.arora" Date: Thu, 24 Jul 2025 21:38:01 +0000 Subject: [PATCH 08/25] Change petastorm release version format to fix failure due to setuptools > 80.0 --- petastorm/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/petastorm/__init__.py b/petastorm/__init__.py index 8b210fd09..8689fbc78 100644 --- a/petastorm/__init__.py +++ b/petastorm/__init__.py @@ -16,4 +16,4 @@ from petastorm.reader import make_reader, make_batch_reader # noqa: F401 from petastorm.transform import TransformSpec # noqa: F401 -__version__ = '0.12.2rc0' +__version__ = '0.12.2.rc0' From 80e6d822273395dec9d33bef4a4185a0619c9624 Mon Sep 17 00:00:00 2001 From: "arushi.arora" Date: Thu, 24 Jul 2025 21:46:54 +0000 Subject: [PATCH 09/25] Add constraint on setuptools version to prevent issue with new versions --- petastorm/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/petastorm/__init__.py b/petastorm/__init__.py index 8689fbc78..8b210fd09 100644 --- a/petastorm/__init__.py +++ b/petastorm/__init__.py @@ -16,4 +16,4 @@ from petastorm.reader import make_reader, make_batch_reader # noqa: F401 from petastorm.transform import TransformSpec # noqa: F401 -__version__ = '0.12.2.rc0' +__version__ = '0.12.2rc0' From 7ab26c477cef6d6de07e02fad62824e1005d8bf4 Mon Sep 17 00:00:00 2001 From: "arushi.arora" Date: Thu, 24 Jul 2025 22:59:02 +0000 Subject: [PATCH 10/25] Fix lint issues --- petastorm/arrow_reader_worker.py | 2 +- petastorm/reader.py | 2 +- petastorm/workers_pool/thread_pool.py | 29 ++++++++++++++++----------- petastorm/workers_pool/ventilator.py | 4 ++-- 4 files changed, 21 insertions(+), 16 deletions(-) diff --git a/petastorm/arrow_reader_worker.py b/petastorm/arrow_reader_worker.py index ad6d5dd39..7959d9a09 100644 --- a/petastorm/arrow_reader_worker.py +++ b/petastorm/arrow_reader_worker.py @@ -102,7 +102,7 @@ def __init__(self, worker_id, publish_func, args): self._arrow_filters = args[8] self._shuffle_rows = args[9] self._random_seed = args[10] - + # Initialize random number generator self._rng = np.random.default_rng(self._random_seed) diff --git a/petastorm/reader.py b/petastorm/reader.py index c5acf1975..1ff4231f0 100644 --- a/petastorm/reader.py +++ b/petastorm/reader.py @@ -706,4 +706,4 @@ def __enter__(self): def __exit__(self, exc_type, exc_val, exc_tb): self.stop() - self.join() \ No newline at end of file + self.join() diff --git a/petastorm/workers_pool/thread_pool.py b/petastorm/workers_pool/thread_pool.py index 5c51be77c..6d13b6606 100644 --- a/petastorm/workers_pool/thread_pool.py +++ b/petastorm/workers_pool/thread_pool.py @@ -93,7 +93,7 @@ def __init__(self, workers_count, results_queue_size=50, shuffle_rows=False, see self._seed = seed self._workers = [] self._ventilator_queues = [] - + self.workers_count = workers_count self._results_queue_size = results_queue_size # Worker threads will watch this event and gracefully shutdown when the event is set @@ -127,14 +127,15 @@ class must implement :class:`.WorkerBase` protocol self._ventilator_queues = [queue.Queue() for _ in range(self.workers_count)] # Set up a channel for each worker to send results - self._results_queues = [queue.Queue(max(5, self._results_queue_size // self.workers_count)) for _ in range(self.workers_count)] - + self._results_queues = [queue.Queue(max(5, self._results_queue_size // self.workers_count)) + for _ in range(self.workers_count)] + self._workers = [] for worker_id in range(self.workers_count): # Create a closure that captures the worker_id for this specific worker def make_publish_func(worker_id): return lambda data: self._stop_aware_put(worker_id, data) - + worker_impl = worker_class(worker_id, make_publish_func(worker_id), worker_args) new_thread = WorkerThread(worker_impl, self._stop_event, self._ventilator_queues[worker_id], self._results_queues[worker_id], self._profiling_enabled) @@ -162,7 +163,8 @@ def ventilate(self, *args, **kargs): def current_worker_done(self, worker_id): # Check if the current worker has processed all the items it was assigned and if the results queue is empty - return self._ventilated_items_processed_by_worker[worker_id] == self._ventilated_items_by_worker[worker_id] and self._results_queues[worker_id].empty() + return (self._ventilated_items_processed_by_worker[worker_id] == self._ventilated_items_by_worker[worker_id] + and self._results_queues[worker_id].empty()) def all_workers_done(self): # Check if all workers have processed all the items they were assigned and if the results queues are empty @@ -180,9 +182,10 @@ def get_results(self): :return: arguments passed to ``publish_func(...)`` by a worker. If no more results are anticipated, :class:`.EmptyResultError`. """ - # If shuffle_rows is enabled and the seed is not set, we need to use a non-blocking as we don't care about the strict round robin order + # If shuffle_rows is enabled and the seed is not set, we need to use a non-blocking + # as we don't care about the strict round robin order use_non_blocking_get = self._shuffle_rows and (self._seed is None or self._seed == 0) - while True: + while True: # If there is no more work to do, raise an EmptyResultError if self.all_workers_done(): # We also need to check if we are using a ventilator and if it is completed @@ -195,9 +198,12 @@ def get_results(self): continue try: - # Get the result from the current worker's results queue. Use blocking/strict round robin if shuffle_rows is disabled or the seed is set - result = self._results_queues[self._get_results_worker_id].get(block=not use_non_blocking_get, timeout=_VERIFY_END_OF_VENTILATION_PERIOD) - # If the result is a VentilatedItemProcessedMessage, we need to increment the count of items processed by the current worker + # Get the result from the current worker's results queue. + # Use blocking/strict round robin if shuffle_rows is disabled or the seed is set + result = self._results_queues[self._get_results_worker_id].get( + block=not use_non_blocking_get, timeout=_VERIFY_END_OF_VENTILATION_PERIOD) + # If the result is a VentilatedItemProcessedMessage, we need to increment the count of items + # processed by the current worker if isinstance(result, VentilatedItemProcessedMessage): self._ventilated_items_processed_by_worker[self._get_results_worker_id] += 1 if self._ventilator: @@ -214,7 +220,6 @@ def get_results(self): except queue.Empty: continue - def stop(self): """Stops all workers (non-blocking).""" if self._ventilator: @@ -258,4 +263,4 @@ def results_qsize(self): @property def diagnostics(self): - return {'output_queue_size': self.results_qsize()} \ No newline at end of file + return {'output_queue_size': self.results_qsize()} diff --git a/petastorm/workers_pool/ventilator.py b/petastorm/workers_pool/ventilator.py index 14ae85235..9f6bfb4f5 100644 --- a/petastorm/workers_pool/ventilator.py +++ b/petastorm/workers_pool/ventilator.py @@ -143,7 +143,7 @@ def _ventilate(self): # Deterministic randomization: use provided seed self._items_to_ventilate = list(self._rng.permutation(self._items_to_ventilate)) else: - # Non-deterministic randomization: use np.random + # Non-deterministic randomization: use np.random self._items_to_ventilate = list(np.random.permutation(self._items_to_ventilate)) while True: @@ -171,4 +171,4 @@ def stop(self): self._stop_requested = True if self._ventilation_thread: self._ventilation_thread.join() - self._ventilation_thread = None \ No newline at end of file + self._ventilation_thread = None From ba726f61cf1efd01a71a38c5a8f11e0cfb95d800 Mon Sep 17 00:00:00 2001 From: "arushi.arora" Date: Thu, 24 Jul 2025 23:09:13 +0000 Subject: [PATCH 11/25] fix some more lint issues --- petastorm/workers_pool/thread_pool.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/petastorm/workers_pool/thread_pool.py b/petastorm/workers_pool/thread_pool.py index 6d13b6606..4ab6b7772 100644 --- a/petastorm/workers_pool/thread_pool.py +++ b/petastorm/workers_pool/thread_pool.py @@ -127,8 +127,10 @@ class must implement :class:`.WorkerBase` protocol self._ventilator_queues = [queue.Queue() for _ in range(self.workers_count)] # Set up a channel for each worker to send results - self._results_queues = [queue.Queue(max(5, self._results_queue_size // self.workers_count)) - for _ in range(self.workers_count)] + self._results_queues = [ + queue.Queue(max(5, self._results_queue_size // self.workers_count)) + for _ in range(self.workers_count) + ] self._workers = [] for worker_id in range(self.workers_count): @@ -163,7 +165,7 @@ def ventilate(self, *args, **kargs): def current_worker_done(self, worker_id): # Check if the current worker has processed all the items it was assigned and if the results queue is empty - return (self._ventilated_items_processed_by_worker[worker_id] == self._ventilated_items_by_worker[worker_id] + return (self._ventilated_items_processed_by_worker[worker_id] == self._ventilated_items_by_worker[worker_id] and self._results_queues[worker_id].empty()) def all_workers_done(self): @@ -198,11 +200,11 @@ def get_results(self): continue try: - # Get the result from the current worker's results queue. + # Get the result from the current worker's results queue. # Use blocking/strict round robin if shuffle_rows is disabled or the seed is set result = self._results_queues[self._get_results_worker_id].get( block=not use_non_blocking_get, timeout=_VERIFY_END_OF_VENTILATION_PERIOD) - # If the result is a VentilatedItemProcessedMessage, we need to increment the count of items + # If the result is a VentilatedItemProcessedMessage, we need to increment the count of items # processed by the current worker if isinstance(result, VentilatedItemProcessedMessage): self._ventilated_items_processed_by_worker[self._get_results_worker_id] += 1 From 16764599da87a10adfb21739c1424e04e5fda628 Mon Sep 17 00:00:00 2001 From: "arushi.arora" Date: Sun, 27 Jul 2025 22:24:33 +0000 Subject: [PATCH 12/25] Add logs for testing --- petastorm/tests/test_tf_dataset.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/petastorm/tests/test_tf_dataset.py b/petastorm/tests/test_tf_dataset.py index 17e812a92..6d69c2fe1 100644 --- a/petastorm/tests/test_tf_dataset.py +++ b/petastorm/tests/test_tf_dataset.py @@ -138,12 +138,14 @@ def test_with_dataset_repeat_after_cache(synthetic_dataset, reader_factory): with tf.Session() as sess: with pytest.warns(None): # Expect no warnings since cache() is called before repeat() - for _ in range(epochs): + for epoch in range(epochs): actual_res = [] - for _, _ in enumerate(synthetic_dataset.data): + for i, _ in enumerate(synthetic_dataset.data): actual = sess.run(it_op)._asdict() actual_res.append(actual["id"]) + print(f"iteration: {i} {actual['id']}") expected_res = list(range(len(synthetic_dataset.data))) + print(f"Epoch: {epoch} actual {sorted(actual_res)}, expected {expected_res}") # sort dataset output since row_groups are shuffled from reader. np.testing.assert_equal(sorted(actual_res), expected_res) From 47123911de4e309658a79d59d1c641ee5db48561 Mon Sep 17 00:00:00 2001 From: "arushi.arora" Date: Sun, 27 Jul 2025 23:40:41 +0000 Subject: [PATCH 13/25] [Test] Remove -Y flag to force fresh dataset generation --- .github/workflows/unittest.yml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/unittest.yml b/.github/workflows/unittest.yml index fbc380265..386f0ff76 100644 --- a/.github/workflows/unittest.yml +++ b/.github/workflows/unittest.yml @@ -59,13 +59,13 @@ jobs: $RUN pylint --rcfile=.pylintrc petastorm examples -f parseable -r n $RUN ulimit -c unlimited -S $RUN bash -c "cd /petastorm/docs/autodoc && pwd && make html" - $RUN $PYTEST -m "forked" --forked -Y \ + $RUN $PYTEST -m "forked" --forked \ --ignore=examples/mnist/tests/test_pytorch_mnist.py \ --ignore=petastorm/tests/test_pytorch_utils.py \ --ignore=petastorm/tests/test_pytorch_dataloader.py \ --ignore=petastorm/tests/test_tf_autograph.py \ petastorm examples - $RUN $PYTEST -m "not forked" -Y --cov-append \ + $RUN $PYTEST -m "not forked" --cov-append \ --ignore=examples/mnist/tests/test_pytorch_mnist.py \ --ignore=petastorm/tests/test_pytorch_utils.py \ --ignore=petastorm/tests/test_pytorch_dataloader.py \ @@ -75,7 +75,7 @@ jobs: examples/mnist/tests/test_pytorch_mnist.py \ petastorm/tests/test_pytorch_dataloader.py \ petastorm/tests/test_pytorch_utils.py - $RUN $PYTEST -Y --cov-append petastorm/tests/test_tf_autograph.py + $RUN $PYTEST --cov-append petastorm/tests/test_tf_autograph.py draft_release: needs: unittest From 08e98b6521e0220c234802efbbd5b6fc8924ac12 Mon Sep 17 00:00:00 2001 From: "arushi.arora" Date: Mon, 28 Jul 2025 00:48:54 +0000 Subject: [PATCH 14/25] Update failing test --- petastorm/workers_pool/tests/test_workers_pool.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/petastorm/workers_pool/tests/test_workers_pool.py b/petastorm/workers_pool/tests/test_workers_pool.py index 142c608a8..a7901e690 100644 --- a/petastorm/workers_pool/tests/test_workers_pool.py +++ b/petastorm/workers_pool/tests/test_workers_pool.py @@ -149,7 +149,7 @@ def test_stop_when_result_queue_is_full(self): pool.ventilate() cumulative_wait = 0 - while pool.results_qsize() != QUEUE_SIZE: + while pool.results_qsize() >= QUEUE_SIZE: time.sleep(SLEEP_DELTA) cumulative_wait += SLEEP_DELTA # Make sure we wait no longer than the timeout. Otherwise, something is very wrong From 18a0709f6ef0006ed1dbeff6e187915957a326af Mon Sep 17 00:00:00 2001 From: "arushi.arora" Date: Mon, 28 Jul 2025 01:28:48 +0000 Subject: [PATCH 15/25] Fix test_stop_when_result_queue_is_full expected queue size as per the modified Thread pool --- petastorm/workers_pool/tests/test_workers_pool.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/petastorm/workers_pool/tests/test_workers_pool.py b/petastorm/workers_pool/tests/test_workers_pool.py index a7901e690..a30294763 100644 --- a/petastorm/workers_pool/tests/test_workers_pool.py +++ b/petastorm/workers_pool/tests/test_workers_pool.py @@ -141,15 +141,17 @@ def test_stop_when_result_queue_is_full(self): SLEEP_DELTA = 0.01 TIMEOUT = 20 QUEUE_SIZE = 2 + WORKERS_COUNT = 10 - pool = ThreadPool(10, results_queue_size=QUEUE_SIZE) + pool = ThreadPool(WORKERS_COUNT, results_queue_size=QUEUE_SIZE) pool.start(WorkerIdGeneratingWorker) - for _ in range(100): + for _ in range(1000): pool.ventilate() + expected_queue_size = WORKERS_COUNT * max(5, QUEUE_SIZE // WORKERS_COUNT) cumulative_wait = 0 - while pool.results_qsize() >= QUEUE_SIZE: + while pool.results_qsize() != expected_queue_size: time.sleep(SLEEP_DELTA) cumulative_wait += SLEEP_DELTA # Make sure we wait no longer than the timeout. Otherwise, something is very wrong From dc05685b86fda07620aa152f4bde6236317e0c9a Mon Sep 17 00:00:00 2001 From: "arushi.arora" Date: Mon, 28 Jul 2025 01:56:28 +0000 Subject: [PATCH 16/25] Empty commit to trigger build From bafe06f810237e4b1fa24915c84852756c152ec2 Mon Sep 17 00:00:00 2001 From: "arushi.arora" Date: Mon, 28 Jul 2025 02:18:02 +0000 Subject: [PATCH 17/25] Empty commit to trigger build From 410e07b39914c1bcec4eeba5a94faa4d761b48a0 Mon Sep 17 00:00:00 2001 From: "arushi.arora" Date: Wed, 6 Aug 2025 21:08:12 +0000 Subject: [PATCH 18/25] [Revert] Adding debug logs --- petastorm/arrow_reader_worker.py | 18 ++++++++++++++++++ petastorm/reader.py | 9 +++++++-- petastorm/tests/test_tf_dataset.py | 5 ++++- petastorm/workers_pool/thread_pool.py | 4 ++++ petastorm/workers_pool/ventilator.py | 6 ++++++ 5 files changed, 39 insertions(+), 3 deletions(-) diff --git a/petastorm/arrow_reader_worker.py b/petastorm/arrow_reader_worker.py index 7959d9a09..9db871b67 100644 --- a/petastorm/arrow_reader_worker.py +++ b/petastorm/arrow_reader_worker.py @@ -15,6 +15,10 @@ import hashlib import operator +import logging + +# Initialize logger +logger = logging.getLogger(__name__) import numpy as np import pandas as pd @@ -91,6 +95,9 @@ class ArrowReaderWorker(WorkerBase): def __init__(self, worker_id, publish_func, args): super(ArrowReaderWorker, self).__init__(worker_id, publish_func, args) + # Add debug log in the constructor + logger.debug('Initializing ArrowReaderWorker with worker_id: %s', worker_id) + self._filesystem = args[0] self._dataset_path_or_paths = args[1] self._schema = args[2] @@ -131,12 +138,18 @@ def process(self, piece_index, worker_predicate, shuffle_row_drop_partition): :return: """ + # Add debug log in the process method + logger.debug('Processing piece_index: %s', piece_index) + if not self._dataset: self._dataset = pq.ParquetDataset( self._dataset_path_or_paths, filesystem=self._filesystem, validate_schema=False, filters=self._arrow_filters) + # Add debug log after dataset is initialized + logger.debug('ParquetDataset initialized with path: %s', self._dataset_path_or_paths) + piece = self._split_pieces[piece_index] # Create pyarrow file system @@ -163,11 +176,16 @@ def process(self, piece_index, worker_predicate, shuffle_row_drop_partition): path_str = self._dataset_path_or_paths cache_key = '{}:{}:{}'.format(hashlib.md5(path_str.encode('utf-8')).hexdigest(), piece.path, piece_index) + + # Add debug log for cache key + logger.debug('Cache key generated: %s', cache_key) + all_cols = self._local_cache.get(cache_key, lambda: self._load_rows(parquet_file, piece, shuffle_row_drop_partition)) if all_cols: self.publish_func(all_cols) + logger.debug('Published columns for piece_index: %s', piece_index) @staticmethod def _check_shape_and_ravel(x, field): diff --git a/petastorm/reader.py b/petastorm/reader.py index 1ff4231f0..83730b9cf 100644 --- a/petastorm/reader.py +++ b/petastorm/reader.py @@ -38,6 +38,7 @@ from petastorm.workers_pool.thread_pool import ThreadPool from petastorm.workers_pool.ventilator import ConcurrentVentilator +# Initialize logger logger = logging.getLogger(__name__) # Ventilator guarantees that no more than workers + _VENTILATE_EXTRA_ROWGROUPS are processed at a moment by a @@ -400,6 +401,7 @@ def __init__(self, pyarrow_filesystem, dataset_path, schema_fields=None, These will be applied when loading the parquet file with PyArrow. More information here: https://arrow.apache.org/docs/python/generated/pyarrow.parquet.ParquetDataset.html """ + logger.debug('DEBUG: Initializing Reader with dataset_path: %s, num_epochs: %s', dataset_path, num_epochs) self.num_epochs = num_epochs # 1. Open the parquet storage (dataset) @@ -437,6 +439,8 @@ def __init__(self, pyarrow_filesystem, dataset_path, schema_fields=None, raise NotImplementedError('Using timestamp_overlap=False is not implemented with' ' shuffle_options.shuffle_row_drop_partitions > 1') + logger.debug('DEBUG: Reader initialized with schema_fields: %s', schema_fields) + cache = cache or NullCache() self._workers_pool = reader_pool or ThreadPool(10, shuffle_rows=shuffle_rows, seed=seed) @@ -653,6 +657,7 @@ def _normalize_shuffle_options(shuffle_row_drop_partitions, dataset): def _create_ventilator(self, row_group_indexes, shuffle_row_groups, shuffle_row_drop_partitions, num_epochs, worker_predicate, max_ventilation_queue_size, seed): + logger.debug('DEBUG: Creating ventilator with row_group_indexes: %s', row_group_indexes) items_to_ventilate = [] for piece_index in row_group_indexes: for shuffle_row_drop_partition in range(shuffle_row_drop_partitions): @@ -670,12 +675,12 @@ def _create_ventilator(self, row_group_indexes, shuffle_row_groups, shuffle_row_ random_seed=seed) def stop(self): - """Stops all worker threads/processes.""" + logger.debug('Stopping Reader') self._workers_pool.stop() self.stopped = True def join(self): - """Joins all worker threads/processes. Will block until all worker workers have been fully terminated.""" + logger.debug('Joining Reader') self._workers_pool.join() @property diff --git a/petastorm/tests/test_tf_dataset.py b/petastorm/tests/test_tf_dataset.py index 6d69c2fe1..6ae748097 100644 --- a/petastorm/tests/test_tf_dataset.py +++ b/petastorm/tests/test_tf_dataset.py @@ -128,6 +128,7 @@ def test_with_dataset_repeat(synthetic_dataset, reader_factory): def test_with_dataset_repeat_after_cache(synthetic_dataset, reader_factory): """ Check if ``tf.data.Dataset``'s ``repeat`` works after ``tf.data.Dataset``'s ``cache``.""" epochs = 3 + print(f"Starting test_with_dataset_repeat_after_cache with {epochs} epochs") with reader_factory(synthetic_dataset.url, schema_fields=[TestSchema.id]) as reader: dataset = make_petastorm_dataset(reader) dataset = dataset.cache() @@ -139,6 +140,7 @@ def test_with_dataset_repeat_after_cache(synthetic_dataset, reader_factory): with pytest.warns(None): # Expect no warnings since cache() is called before repeat() for epoch in range(epochs): + print(f"Starting epoch {epoch}") actual_res = [] for i, _ in enumerate(synthetic_dataset.data): actual = sess.run(it_op)._asdict() @@ -148,10 +150,11 @@ def test_with_dataset_repeat_after_cache(synthetic_dataset, reader_factory): print(f"Epoch: {epoch} actual {sorted(actual_res)}, expected {expected_res}") # sort dataset output since row_groups are shuffled from reader. np.testing.assert_equal(sorted(actual_res), expected_res) - + print(f"Completed epoch {epoch}") # Exhausted all epochs. Fetching next value should trigger OutOfRangeError with pytest.raises(tf.errors.OutOfRangeError): sess.run(it_op) + print("Completed test_with_dataset_repeat_after_cache") @pytest.mark.forked diff --git a/petastorm/workers_pool/thread_pool.py b/petastorm/workers_pool/thread_pool.py index 4ab6b7772..0f3b31ffc 100644 --- a/petastorm/workers_pool/thread_pool.py +++ b/petastorm/workers_pool/thread_pool.py @@ -88,6 +88,7 @@ def __init__(self, workers_count, results_queue_size=50, shuffle_rows=False, see :param workers_count: Number of threads :param profile: Whether to run a profiler on the threads """ + logger.debug('Initializing ThreadPool with workers_count: %s', workers_count) self._seed = random.randint(0, 100000) self._shuffle_rows = shuffle_rows self._seed = seed @@ -118,6 +119,7 @@ class must implement :class:`.WorkerBase` protocol :class:`.WorkerBase` :return: ``None`` """ + logger.debug('Starting ThreadPool with worker_class: %s', worker_class) # Verify stop_event and raise exception if it's already set! if self._stop_event.is_set(): raise RuntimeError('ThreadPool({}) cannot be reused! stop_event set? {}' @@ -157,6 +159,7 @@ def make_publish_func(worker_id): def ventilate(self, *args, **kargs): """Sends a work item to a worker process. Will result in ``worker.process(...)`` call with arbitrary arguments. """ + logger.debug('Ventilating work item with args: %s, kargs: %s', args, kargs) # Distribute work items in a round-robin manner across each worker ventilator queue current_worker_id = self._ventilated_items % self.workers_count self._ventilated_items += 1 @@ -204,6 +207,7 @@ def get_results(self): # Use blocking/strict round robin if shuffle_rows is disabled or the seed is set result = self._results_queues[self._get_results_worker_id].get( block=not use_non_blocking_get, timeout=_VERIFY_END_OF_VENTILATION_PERIOD) + print('DEBUG: Result from worker %s: %s' % (self._get_results_worker_id, result)) # If the result is a VentilatedItemProcessedMessage, we need to increment the count of items # processed by the current worker if isinstance(result, VentilatedItemProcessedMessage): diff --git a/petastorm/workers_pool/ventilator.py b/petastorm/workers_pool/ventilator.py index 9f6bfb4f5..431a15252 100644 --- a/petastorm/workers_pool/ventilator.py +++ b/petastorm/workers_pool/ventilator.py @@ -18,6 +18,10 @@ from time import sleep import six +import logging + +# Initialize logger +logger = logging.getLogger(__name__) _VENTILATION_INTERVAL = 0.01 @@ -138,6 +142,7 @@ def reset(self): def _ventilate(self): # Randomize the item order before starting the ventilation if randomize_item_order is set + print('DEBUG: Items to ventilate before shuffle:', self._items_to_ventilate) if self._randomize_item_order: if self._random_seed is not None and self._random_seed != 0: # Deterministic randomization: use provided seed @@ -145,6 +150,7 @@ def _ventilate(self): else: # Non-deterministic randomization: use np.random self._items_to_ventilate = list(np.random.permutation(self._items_to_ventilate)) + print('DEBUG: Items to ventilate after shuffle:', self._items_to_ventilate) while True: # Stop condition is when no iterations are remaining or there are no items to ventilate From 43e3555fcc895a93a9265686d0d34d9e2954db17 Mon Sep 17 00:00:00 2001 From: "arushi.arora" Date: Wed, 6 Aug 2025 21:16:03 +0000 Subject: [PATCH 19/25] [Revert] Adding debug logs --- petastorm/arrow_reader_worker.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/petastorm/arrow_reader_worker.py b/petastorm/arrow_reader_worker.py index 9db871b67..f1c67c01e 100644 --- a/petastorm/arrow_reader_worker.py +++ b/petastorm/arrow_reader_worker.py @@ -96,7 +96,7 @@ def __init__(self, worker_id, publish_func, args): super(ArrowReaderWorker, self).__init__(worker_id, publish_func, args) # Add debug log in the constructor - logger.debug('Initializing ArrowReaderWorker with worker_id: %s', worker_id) + logger.debug('DEBUG: Initializing ArrowReaderWorker with worker_id: %s', worker_id) self._filesystem = args[0] self._dataset_path_or_paths = args[1] @@ -139,7 +139,7 @@ def process(self, piece_index, worker_predicate, shuffle_row_drop_partition): """ # Add debug log in the process method - logger.debug('Processing piece_index: %s', piece_index) + logger.debug('DEBUG: Processing piece_index: %s', piece_index) if not self._dataset: self._dataset = pq.ParquetDataset( @@ -148,7 +148,7 @@ def process(self, piece_index, worker_predicate, shuffle_row_drop_partition): validate_schema=False, filters=self._arrow_filters) # Add debug log after dataset is initialized - logger.debug('ParquetDataset initialized with path: %s', self._dataset_path_or_paths) + logger.debug('DEBUG: ParquetDataset initialized with path: %s', self._dataset_path_or_paths) piece = self._split_pieces[piece_index] @@ -178,14 +178,14 @@ def process(self, piece_index, worker_predicate, shuffle_row_drop_partition): piece.path, piece_index) # Add debug log for cache key - logger.debug('Cache key generated: %s', cache_key) + logger.debug('DEBUG: Cache key generated: %s', cache_key) all_cols = self._local_cache.get(cache_key, lambda: self._load_rows(parquet_file, piece, shuffle_row_drop_partition)) if all_cols: self.publish_func(all_cols) - logger.debug('Published columns for piece_index: %s', piece_index) + logger.debug('DEBUG: Published columns for piece_index: %s', piece_index) @staticmethod def _check_shape_and_ravel(x, field): From a234e67ad48169b6cad9f812098028ad42f15ef3 Mon Sep 17 00:00:00 2001 From: "arushi.arora" Date: Wed, 6 Aug 2025 21:26:03 +0000 Subject: [PATCH 20/25] [Revert] Adding debug logs --- petastorm/reader.py | 6 +++--- petastorm/workers_pool/thread_pool.py | 8 +++++--- petastorm/workers_pool/ventilator.py | 4 ++-- 3 files changed, 10 insertions(+), 8 deletions(-) diff --git a/petastorm/reader.py b/petastorm/reader.py index 83730b9cf..24e6c40e4 100644 --- a/petastorm/reader.py +++ b/petastorm/reader.py @@ -487,7 +487,7 @@ def __init__(self, pyarrow_filesystem, dataset_path, schema_fields=None, self.ngram, row_groups, cache, transform_spec, self.schema, filters, shuffle_rows, seed), ventilator=self.ventilator) - logger.debug('Workers pool started') + logger.debug('DEBUG: Workers pool started') self.last_row_consumed = False self.stopped = False @@ -675,12 +675,12 @@ def _create_ventilator(self, row_group_indexes, shuffle_row_groups, shuffle_row_ random_seed=seed) def stop(self): - logger.debug('Stopping Reader') + logger.debug('DEBUG: Stopping Reader') self._workers_pool.stop() self.stopped = True def join(self): - logger.debug('Joining Reader') + logger.debug('DEBUG: Joining Reader') self._workers_pool.join() @property diff --git a/petastorm/workers_pool/thread_pool.py b/petastorm/workers_pool/thread_pool.py index 0f3b31ffc..51473c509 100644 --- a/petastorm/workers_pool/thread_pool.py +++ b/petastorm/workers_pool/thread_pool.py @@ -29,6 +29,8 @@ # recheck if no more items are expected to be ventilated _VERIFY_END_OF_VENTILATION_PERIOD = 1 +logger = logging.getLogger(__name__) + class WorkerTerminationRequested(Exception): """This exception will be raised if a thread is being stopped while waiting to write to the results queue.""" @@ -88,7 +90,7 @@ def __init__(self, workers_count, results_queue_size=50, shuffle_rows=False, see :param workers_count: Number of threads :param profile: Whether to run a profiler on the threads """ - logger.debug('Initializing ThreadPool with workers_count: %s', workers_count) + logger.debug('DEBUG: Initializing ThreadPool with workers_count: %s', workers_count) self._seed = random.randint(0, 100000) self._shuffle_rows = shuffle_rows self._seed = seed @@ -119,7 +121,7 @@ class must implement :class:`.WorkerBase` protocol :class:`.WorkerBase` :return: ``None`` """ - logger.debug('Starting ThreadPool with worker_class: %s', worker_class) + logger.debug('DEBUG: Starting ThreadPool with worker_class: %s', worker_class) # Verify stop_event and raise exception if it's already set! if self._stop_event.is_set(): raise RuntimeError('ThreadPool({}) cannot be reused! stop_event set? {}' @@ -159,7 +161,7 @@ def make_publish_func(worker_id): def ventilate(self, *args, **kargs): """Sends a work item to a worker process. Will result in ``worker.process(...)`` call with arbitrary arguments. """ - logger.debug('Ventilating work item with args: %s, kargs: %s', args, kargs) + logger.debug('DEBUG: Ventilating work item with args: %s, kargs: %s', args, kargs) # Distribute work items in a round-robin manner across each worker ventilator queue current_worker_id = self._ventilated_items % self.workers_count self._ventilated_items += 1 diff --git a/petastorm/workers_pool/ventilator.py b/petastorm/workers_pool/ventilator.py index 431a15252..e05f93ca9 100644 --- a/petastorm/workers_pool/ventilator.py +++ b/petastorm/workers_pool/ventilator.py @@ -142,7 +142,7 @@ def reset(self): def _ventilate(self): # Randomize the item order before starting the ventilation if randomize_item_order is set - print('DEBUG: Items to ventilate before shuffle:', self._items_to_ventilate) + logger.debug('DEBUG: Items to ventilate before shuffle: %s', self._items_to_ventilate) if self._randomize_item_order: if self._random_seed is not None and self._random_seed != 0: # Deterministic randomization: use provided seed @@ -150,7 +150,7 @@ def _ventilate(self): else: # Non-deterministic randomization: use np.random self._items_to_ventilate = list(np.random.permutation(self._items_to_ventilate)) - print('DEBUG: Items to ventilate after shuffle:', self._items_to_ventilate) + logger.debug('DEBUG: Items to ventilate after shuffle: %s', self._items_to_ventilate) while True: # Stop condition is when no iterations are remaining or there are no items to ventilate From 30de3b53ab3a142ccfbd16b9da79d6ae6fc9a6d8 Mon Sep 17 00:00:00 2001 From: "arushi.arora" Date: Wed, 6 Aug 2025 21:35:45 +0000 Subject: [PATCH 21/25] [Revert] restrict test runs --- .github/workflows/unittest.yml | 18 +----------------- petastorm/workers_pool/thread_pool.py | 1 + 2 files changed, 2 insertions(+), 17 deletions(-) diff --git a/.github/workflows/unittest.yml b/.github/workflows/unittest.yml index 386f0ff76..178f58f0e 100644 --- a/.github/workflows/unittest.yml +++ b/.github/workflows/unittest.yml @@ -59,23 +59,7 @@ jobs: $RUN pylint --rcfile=.pylintrc petastorm examples -f parseable -r n $RUN ulimit -c unlimited -S $RUN bash -c "cd /petastorm/docs/autodoc && pwd && make html" - $RUN $PYTEST -m "forked" --forked \ - --ignore=examples/mnist/tests/test_pytorch_mnist.py \ - --ignore=petastorm/tests/test_pytorch_utils.py \ - --ignore=petastorm/tests/test_pytorch_dataloader.py \ - --ignore=petastorm/tests/test_tf_autograph.py \ - petastorm examples - $RUN $PYTEST -m "not forked" --cov-append \ - --ignore=examples/mnist/tests/test_pytorch_mnist.py \ - --ignore=petastorm/tests/test_pytorch_utils.py \ - --ignore=petastorm/tests/test_pytorch_dataloader.py \ - --ignore=petastorm/tests/test_tf_autograph.py \ - petastorm examples - $RUN $PYTEST --cov-append \ - examples/mnist/tests/test_pytorch_mnist.py \ - petastorm/tests/test_pytorch_dataloader.py \ - petastorm/tests/test_pytorch_utils.py - $RUN $PYTEST --cov-append petastorm/tests/test_tf_autograph.py + $RUN $PYTEST --cov-append petastorm/tests/test_tf_dataset.py draft_release: needs: unittest diff --git a/petastorm/workers_pool/thread_pool.py b/petastorm/workers_pool/thread_pool.py index 51473c509..6925055fb 100644 --- a/petastorm/workers_pool/thread_pool.py +++ b/petastorm/workers_pool/thread_pool.py @@ -22,6 +22,7 @@ from six.moves import queue from petastorm.workers_pool import EmptyResultError, VentilatedItemProcessedMessage +import logging # Defines how frequently will we check the stop event while waiting on a blocking queue IO_TIMEOUT_INTERVAL_S = 0.001 From 8ce4b042b65831d030ceb56b51363d02c57e9174 Mon Sep 17 00:00:00 2001 From: "arushi.arora" Date: Wed, 6 Aug 2025 21:44:28 +0000 Subject: [PATCH 22/25] [Revert] fix logger import --- petastorm/arrow_reader_worker.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/petastorm/arrow_reader_worker.py b/petastorm/arrow_reader_worker.py index f1c67c01e..9d4e7cdd5 100644 --- a/petastorm/arrow_reader_worker.py +++ b/petastorm/arrow_reader_worker.py @@ -17,9 +17,6 @@ import operator import logging -# Initialize logger -logger = logging.getLogger(__name__) - import numpy as np import pandas as pd import pyarrow as pa @@ -30,6 +27,9 @@ from petastorm.workers_pool import EmptyResultError from petastorm.workers_pool.worker_base import WorkerBase +# Initialize logger +logger = logging.getLogger(__name__) + class ArrowReaderWorkerResultsQueueReader(object): def __init__(self): From b214ca10e05a10ae8f7035b1c4556acf7fee1682 Mon Sep 17 00:00:00 2001 From: "arushi.arora" Date: Wed, 6 Aug 2025 22:18:55 +0000 Subject: [PATCH 23/25] Revert back to enable all tests --- .github/workflows/unittest.yml | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/.github/workflows/unittest.yml b/.github/workflows/unittest.yml index 178f58f0e..fbc380265 100644 --- a/.github/workflows/unittest.yml +++ b/.github/workflows/unittest.yml @@ -59,7 +59,23 @@ jobs: $RUN pylint --rcfile=.pylintrc petastorm examples -f parseable -r n $RUN ulimit -c unlimited -S $RUN bash -c "cd /petastorm/docs/autodoc && pwd && make html" - $RUN $PYTEST --cov-append petastorm/tests/test_tf_dataset.py + $RUN $PYTEST -m "forked" --forked -Y \ + --ignore=examples/mnist/tests/test_pytorch_mnist.py \ + --ignore=petastorm/tests/test_pytorch_utils.py \ + --ignore=petastorm/tests/test_pytorch_dataloader.py \ + --ignore=petastorm/tests/test_tf_autograph.py \ + petastorm examples + $RUN $PYTEST -m "not forked" -Y --cov-append \ + --ignore=examples/mnist/tests/test_pytorch_mnist.py \ + --ignore=petastorm/tests/test_pytorch_utils.py \ + --ignore=petastorm/tests/test_pytorch_dataloader.py \ + --ignore=petastorm/tests/test_tf_autograph.py \ + petastorm examples + $RUN $PYTEST --cov-append \ + examples/mnist/tests/test_pytorch_mnist.py \ + petastorm/tests/test_pytorch_dataloader.py \ + petastorm/tests/test_pytorch_utils.py + $RUN $PYTEST -Y --cov-append petastorm/tests/test_tf_autograph.py draft_release: needs: unittest From acbddb38661379ee554a760296e69482cc6459b5 Mon Sep 17 00:00:00 2001 From: "arushi.arora" Date: Wed, 6 Aug 2025 22:42:51 +0000 Subject: [PATCH 24/25] [Revert] change logs to print --- petastorm/arrow_reader_worker.py | 10 +++++----- petastorm/reader.py | 12 ++++++------ petastorm/workers_pool/thread_pool.py | 6 +++--- petastorm/workers_pool/ventilator.py | 4 ++-- 4 files changed, 16 insertions(+), 16 deletions(-) diff --git a/petastorm/arrow_reader_worker.py b/petastorm/arrow_reader_worker.py index 9d4e7cdd5..7a86773f1 100644 --- a/petastorm/arrow_reader_worker.py +++ b/petastorm/arrow_reader_worker.py @@ -96,7 +96,7 @@ def __init__(self, worker_id, publish_func, args): super(ArrowReaderWorker, self).__init__(worker_id, publish_func, args) # Add debug log in the constructor - logger.debug('DEBUG: Initializing ArrowReaderWorker with worker_id: %s', worker_id) + print('DEBUG: Initializing ArrowReaderWorker with worker_id: %s', worker_id) self._filesystem = args[0] self._dataset_path_or_paths = args[1] @@ -139,7 +139,7 @@ def process(self, piece_index, worker_predicate, shuffle_row_drop_partition): """ # Add debug log in the process method - logger.debug('DEBUG: Processing piece_index: %s', piece_index) + print('DEBUG: Processing piece_index: %s', piece_index) if not self._dataset: self._dataset = pq.ParquetDataset( @@ -148,7 +148,7 @@ def process(self, piece_index, worker_predicate, shuffle_row_drop_partition): validate_schema=False, filters=self._arrow_filters) # Add debug log after dataset is initialized - logger.debug('DEBUG: ParquetDataset initialized with path: %s', self._dataset_path_or_paths) + print('DEBUG: ParquetDataset initialized with path: %s', self._dataset_path_or_paths) piece = self._split_pieces[piece_index] @@ -178,14 +178,14 @@ def process(self, piece_index, worker_predicate, shuffle_row_drop_partition): piece.path, piece_index) # Add debug log for cache key - logger.debug('DEBUG: Cache key generated: %s', cache_key) + print('DEBUG: Cache key generated: %s', cache_key) all_cols = self._local_cache.get(cache_key, lambda: self._load_rows(parquet_file, piece, shuffle_row_drop_partition)) if all_cols: self.publish_func(all_cols) - logger.debug('DEBUG: Published columns for piece_index: %s', piece_index) + print('DEBUG: Published columns for piece_index: %s', piece_index) @staticmethod def _check_shape_and_ravel(x, field): diff --git a/petastorm/reader.py b/petastorm/reader.py index 24e6c40e4..c4a2de84b 100644 --- a/petastorm/reader.py +++ b/petastorm/reader.py @@ -401,7 +401,7 @@ def __init__(self, pyarrow_filesystem, dataset_path, schema_fields=None, These will be applied when loading the parquet file with PyArrow. More information here: https://arrow.apache.org/docs/python/generated/pyarrow.parquet.ParquetDataset.html """ - logger.debug('DEBUG: Initializing Reader with dataset_path: %s, num_epochs: %s', dataset_path, num_epochs) + print('DEBUG: Initializing Reader with dataset_path: %s, num_epochs: %s', dataset_path, num_epochs) self.num_epochs = num_epochs # 1. Open the parquet storage (dataset) @@ -439,7 +439,7 @@ def __init__(self, pyarrow_filesystem, dataset_path, schema_fields=None, raise NotImplementedError('Using timestamp_overlap=False is not implemented with' ' shuffle_options.shuffle_row_drop_partitions > 1') - logger.debug('DEBUG: Reader initialized with schema_fields: %s', schema_fields) + print('DEBUG: Reader initialized with schema_fields: %s', schema_fields) cache = cache or NullCache() @@ -487,7 +487,7 @@ def __init__(self, pyarrow_filesystem, dataset_path, schema_fields=None, self.ngram, row_groups, cache, transform_spec, self.schema, filters, shuffle_rows, seed), ventilator=self.ventilator) - logger.debug('DEBUG: Workers pool started') + print('DEBUG: Workers pool started') self.last_row_consumed = False self.stopped = False @@ -657,7 +657,7 @@ def _normalize_shuffle_options(shuffle_row_drop_partitions, dataset): def _create_ventilator(self, row_group_indexes, shuffle_row_groups, shuffle_row_drop_partitions, num_epochs, worker_predicate, max_ventilation_queue_size, seed): - logger.debug('DEBUG: Creating ventilator with row_group_indexes: %s', row_group_indexes) + print('DEBUG: Creating ventilator with row_group_indexes: %s', row_group_indexes) items_to_ventilate = [] for piece_index in row_group_indexes: for shuffle_row_drop_partition in range(shuffle_row_drop_partitions): @@ -675,12 +675,12 @@ def _create_ventilator(self, row_group_indexes, shuffle_row_groups, shuffle_row_ random_seed=seed) def stop(self): - logger.debug('DEBUG: Stopping Reader') + print('DEBUG: Stopping Reader') self._workers_pool.stop() self.stopped = True def join(self): - logger.debug('DEBUG: Joining Reader') + print('DEBUG: Joining Reader') self._workers_pool.join() @property diff --git a/petastorm/workers_pool/thread_pool.py b/petastorm/workers_pool/thread_pool.py index 6925055fb..01e941f74 100644 --- a/petastorm/workers_pool/thread_pool.py +++ b/petastorm/workers_pool/thread_pool.py @@ -91,7 +91,7 @@ def __init__(self, workers_count, results_queue_size=50, shuffle_rows=False, see :param workers_count: Number of threads :param profile: Whether to run a profiler on the threads """ - logger.debug('DEBUG: Initializing ThreadPool with workers_count: %s', workers_count) + print('DEBUG: Initializing ThreadPool with workers_count: %s', workers_count) self._seed = random.randint(0, 100000) self._shuffle_rows = shuffle_rows self._seed = seed @@ -122,7 +122,7 @@ class must implement :class:`.WorkerBase` protocol :class:`.WorkerBase` :return: ``None`` """ - logger.debug('DEBUG: Starting ThreadPool with worker_class: %s', worker_class) + print('DEBUG: Starting ThreadPool with worker_class: %s', worker_class) # Verify stop_event and raise exception if it's already set! if self._stop_event.is_set(): raise RuntimeError('ThreadPool({}) cannot be reused! stop_event set? {}' @@ -162,7 +162,7 @@ def make_publish_func(worker_id): def ventilate(self, *args, **kargs): """Sends a work item to a worker process. Will result in ``worker.process(...)`` call with arbitrary arguments. """ - logger.debug('DEBUG: Ventilating work item with args: %s, kargs: %s', args, kargs) + print('DEBUG: Ventilating work item with args: %s, kargs: %s', args, kargs) # Distribute work items in a round-robin manner across each worker ventilator queue current_worker_id = self._ventilated_items % self.workers_count self._ventilated_items += 1 diff --git a/petastorm/workers_pool/ventilator.py b/petastorm/workers_pool/ventilator.py index e05f93ca9..0a2a9aa03 100644 --- a/petastorm/workers_pool/ventilator.py +++ b/petastorm/workers_pool/ventilator.py @@ -142,7 +142,7 @@ def reset(self): def _ventilate(self): # Randomize the item order before starting the ventilation if randomize_item_order is set - logger.debug('DEBUG: Items to ventilate before shuffle: %s', self._items_to_ventilate) + print('DEBUG: Items to ventilate before shuffle: %s', self._items_to_ventilate) if self._randomize_item_order: if self._random_seed is not None and self._random_seed != 0: # Deterministic randomization: use provided seed @@ -150,7 +150,7 @@ def _ventilate(self): else: # Non-deterministic randomization: use np.random self._items_to_ventilate = list(np.random.permutation(self._items_to_ventilate)) - logger.debug('DEBUG: Items to ventilate after shuffle: %s', self._items_to_ventilate) + print('DEBUG: Items to ventilate after shuffle: %s', self._items_to_ventilate) while True: # Stop condition is when no iterations are remaining or there are no items to ventilate From 516aa098628a4fd8f3ab70bf810586f3f5cd09b0 Mon Sep 17 00:00:00 2001 From: "arushi.arora" Date: Wed, 6 Aug 2025 23:01:16 +0000 Subject: [PATCH 25/25] [Revert] Modify debug logs --- petastorm/arrow_reader_worker.py | 10 +++++----- petastorm/reader.py | 6 +++--- petastorm/workers_pool/thread_pool.py | 8 ++++---- petastorm/workers_pool/ventilator.py | 4 ++-- 4 files changed, 14 insertions(+), 14 deletions(-) diff --git a/petastorm/arrow_reader_worker.py b/petastorm/arrow_reader_worker.py index 7a86773f1..b7a3f090c 100644 --- a/petastorm/arrow_reader_worker.py +++ b/petastorm/arrow_reader_worker.py @@ -96,7 +96,7 @@ def __init__(self, worker_id, publish_func, args): super(ArrowReaderWorker, self).__init__(worker_id, publish_func, args) # Add debug log in the constructor - print('DEBUG: Initializing ArrowReaderWorker with worker_id: %s', worker_id) + print(f'DEBUG: Initializing ArrowReaderWorker with worker_id: {worker_id}') self._filesystem = args[0] self._dataset_path_or_paths = args[1] @@ -139,7 +139,7 @@ def process(self, piece_index, worker_predicate, shuffle_row_drop_partition): """ # Add debug log in the process method - print('DEBUG: Processing piece_index: %s', piece_index) + print(f'DEBUG: Processing piece_index: {piece_index}') if not self._dataset: self._dataset = pq.ParquetDataset( @@ -148,7 +148,7 @@ def process(self, piece_index, worker_predicate, shuffle_row_drop_partition): validate_schema=False, filters=self._arrow_filters) # Add debug log after dataset is initialized - print('DEBUG: ParquetDataset initialized with path: %s', self._dataset_path_or_paths) + print(f'DEBUG: ParquetDataset initialized with path: {self._dataset_path_or_paths}') piece = self._split_pieces[piece_index] @@ -178,14 +178,14 @@ def process(self, piece_index, worker_predicate, shuffle_row_drop_partition): piece.path, piece_index) # Add debug log for cache key - print('DEBUG: Cache key generated: %s', cache_key) + print(f'DEBUG: Cache key generated: {cache_key}') all_cols = self._local_cache.get(cache_key, lambda: self._load_rows(parquet_file, piece, shuffle_row_drop_partition)) if all_cols: self.publish_func(all_cols) - print('DEBUG: Published columns for piece_index: %s', piece_index) + print(f'DEBUG: Published columns for piece_index: {piece_index}') @staticmethod def _check_shape_and_ravel(x, field): diff --git a/petastorm/reader.py b/petastorm/reader.py index c4a2de84b..3b10625ba 100644 --- a/petastorm/reader.py +++ b/petastorm/reader.py @@ -401,7 +401,7 @@ def __init__(self, pyarrow_filesystem, dataset_path, schema_fields=None, These will be applied when loading the parquet file with PyArrow. More information here: https://arrow.apache.org/docs/python/generated/pyarrow.parquet.ParquetDataset.html """ - print('DEBUG: Initializing Reader with dataset_path: %s, num_epochs: %s', dataset_path, num_epochs) + print(f'DEBUG: Initializing Reader with dataset_path: {dataset_path}, num_epochs: {num_epochs}') self.num_epochs = num_epochs # 1. Open the parquet storage (dataset) @@ -439,7 +439,7 @@ def __init__(self, pyarrow_filesystem, dataset_path, schema_fields=None, raise NotImplementedError('Using timestamp_overlap=False is not implemented with' ' shuffle_options.shuffle_row_drop_partitions > 1') - print('DEBUG: Reader initialized with schema_fields: %s', schema_fields) + print(f'DEBUG: Reader initialized with schema_fields: {schema_fields}') cache = cache or NullCache() @@ -657,7 +657,7 @@ def _normalize_shuffle_options(shuffle_row_drop_partitions, dataset): def _create_ventilator(self, row_group_indexes, shuffle_row_groups, shuffle_row_drop_partitions, num_epochs, worker_predicate, max_ventilation_queue_size, seed): - print('DEBUG: Creating ventilator with row_group_indexes: %s', row_group_indexes) + print(f'DEBUG: Creating ventilator with row_group_indexes: {row_group_indexes}') items_to_ventilate = [] for piece_index in row_group_indexes: for shuffle_row_drop_partition in range(shuffle_row_drop_partitions): diff --git a/petastorm/workers_pool/thread_pool.py b/petastorm/workers_pool/thread_pool.py index 01e941f74..7ee1a8b91 100644 --- a/petastorm/workers_pool/thread_pool.py +++ b/petastorm/workers_pool/thread_pool.py @@ -91,7 +91,7 @@ def __init__(self, workers_count, results_queue_size=50, shuffle_rows=False, see :param workers_count: Number of threads :param profile: Whether to run a profiler on the threads """ - print('DEBUG: Initializing ThreadPool with workers_count: %s', workers_count) + print(f'DEBUG: Initializing ThreadPool with workers_count: {workers_count}') self._seed = random.randint(0, 100000) self._shuffle_rows = shuffle_rows self._seed = seed @@ -122,7 +122,7 @@ class must implement :class:`.WorkerBase` protocol :class:`.WorkerBase` :return: ``None`` """ - print('DEBUG: Starting ThreadPool with worker_class: %s', worker_class) + print(f'DEBUG: Starting ThreadPool with worker_class: {worker_class}') # Verify stop_event and raise exception if it's already set! if self._stop_event.is_set(): raise RuntimeError('ThreadPool({}) cannot be reused! stop_event set? {}' @@ -162,7 +162,7 @@ def make_publish_func(worker_id): def ventilate(self, *args, **kargs): """Sends a work item to a worker process. Will result in ``worker.process(...)`` call with arbitrary arguments. """ - print('DEBUG: Ventilating work item with args: %s, kargs: %s', args, kargs) + print(f'DEBUG: Ventilating work item with args: {args}, kargs: {kargs}') # Distribute work items in a round-robin manner across each worker ventilator queue current_worker_id = self._ventilated_items % self.workers_count self._ventilated_items += 1 @@ -210,7 +210,7 @@ def get_results(self): # Use blocking/strict round robin if shuffle_rows is disabled or the seed is set result = self._results_queues[self._get_results_worker_id].get( block=not use_non_blocking_get, timeout=_VERIFY_END_OF_VENTILATION_PERIOD) - print('DEBUG: Result from worker %s: %s' % (self._get_results_worker_id, result)) + print(f'DEBUG: Result from worker {self._get_results_worker_id}: {result}') # If the result is a VentilatedItemProcessedMessage, we need to increment the count of items # processed by the current worker if isinstance(result, VentilatedItemProcessedMessage): diff --git a/petastorm/workers_pool/ventilator.py b/petastorm/workers_pool/ventilator.py index 0a2a9aa03..d542b5868 100644 --- a/petastorm/workers_pool/ventilator.py +++ b/petastorm/workers_pool/ventilator.py @@ -142,7 +142,7 @@ def reset(self): def _ventilate(self): # Randomize the item order before starting the ventilation if randomize_item_order is set - print('DEBUG: Items to ventilate before shuffle: %s', self._items_to_ventilate) + print(f'DEBUG: Items to ventilate before shuffle: {self._items_to_ventilate}') if self._randomize_item_order: if self._random_seed is not None and self._random_seed != 0: # Deterministic randomization: use provided seed @@ -150,7 +150,7 @@ def _ventilate(self): else: # Non-deterministic randomization: use np.random self._items_to_ventilate = list(np.random.permutation(self._items_to_ventilate)) - print('DEBUG: Items to ventilate after shuffle: %s', self._items_to_ventilate) + print(f'DEBUG: Items to ventilate after shuffle: {self._items_to_ventilate}') while True: # Stop condition is when no iterations are remaining or there are no items to ventilate