diff --git a/dask_sql/physical/rel/logical/limit.py b/dask_sql/physical/rel/logical/limit.py index 58cd68fe8..fb7c3a6dd 100644 --- a/dask_sql/physical/rel/logical/limit.py +++ b/dask_sql/physical/rel/logical/limit.py @@ -5,7 +5,6 @@ from dask_sql.datacontainer import DataContainer from dask_sql.physical.rel.base import BaseRelPlugin from dask_sql.physical.rex import RexConverter -from dask_sql.physical.utils.map import map_on_partition_index if TYPE_CHECKING: import dask_sql @@ -38,25 +37,18 @@ def convert( if offset: end += offset - df = self._apply_offset(df, offset, end) + df = self._apply_limit(df, offset, end) cc = self.fix_column_to_row_type(cc, rel.getRowType()) # No column type has changed, so no need to cast again return DataContainer(df, cc) - def _apply_offset(self, df: dd.DataFrame, offset: int, end: int) -> dd.DataFrame: + def _apply_limit(self, df: dd.DataFrame, offset: int, end: int) -> dd.DataFrame: """ Limit the dataframe to the window [offset, end]. - That is unfortunately, not so simple as we do not know how many - items we have in each partition. We have therefore no other way than to - calculate (!!!) the sizes of each partition. - - After that, we can create a new dataframe from the old - dataframe by calculating for each partition if and how much - it should be used. - We do this via generating our own dask computation graph as - we need to pass the partition number to the selection - function, which is not possible with normal "map_partitions". + + Unfortunately, Dask does not currently support row selection through `iloc`, so this must be done using a custom partition function. + However, it is sometimes possible to compute this window using `head` when an `offset` is not specified. """ if not offset: # We do a (hopefully) very quick check: if the first partition @@ -65,23 +57,19 @@ def _apply_offset(self, df: dd.DataFrame, offset: int, end: int) -> dd.DataFrame if first_partition_length >= end: return df.head(end, compute=False) - # First, we need to find out which partitions we want to use. - # Therefore we count the total number of entries + # compute the size of each partition + # TODO: compute `cumsum` here when dask#9067 is resolved partition_borders = df.map_partitions(lambda x: len(x)) - # Now we let each of the partitions figure out, how much it needs to return - # using these partition borders - # For this, we generate out own dask computation graph (as it does not really - # fit well with one of the already present methods). - - # (a) we define a method to be calculated on each partition - # This method returns the part of the partition, which falls between [offset, fetch] - # Please note that the dask object "partition_borders", will be turned into - # its pandas representation at this point and we can calculate the cumsum - # (which is not possible on the dask object). Recalculating it should not cost - # us much, as we assume the number of partitions is rather small. - def select_from_to(df, partition_index, partition_borders): + def limit_partition_func(df, partition_borders, partition_info=None): + """Limit the partition to values contained within the specified window, returning an empty dataframe if there are none""" + + # TODO: remove the `cumsum` call here when dask#9067 is resolved partition_borders = partition_borders.cumsum().to_dict() + partition_index = ( + partition_info["number"] if partition_info is not None else 0 + ) + this_partition_border_left = ( partition_borders[partition_index - 1] if partition_index > 0 else 0 ) @@ -101,8 +89,7 @@ def select_from_to(df, partition_index, partition_borders): return df.iloc[from_index:to_index] - # (b) Now we just need to apply the function on every partition - # We do this via the delayed interface, which seems the easiest one. - return map_on_partition_index( - df, select_from_to, partition_borders, meta=df._meta + return df.map_partitions( + limit_partition_func, + partition_borders=partition_borders, ) diff --git a/dask_sql/physical/utils/map.py b/dask_sql/physical/utils/map.py deleted file mode 100644 index 791342ccc..000000000 --- a/dask_sql/physical/utils/map.py +++ /dev/null @@ -1,17 +0,0 @@ -from typing import Any, Callable - -import dask -import dask.dataframe as dd - - -def map_on_partition_index( - df: dd.DataFrame, f: Callable, *args: Any, **kwargs: Any -) -> dd.DataFrame: - meta = kwargs.pop("meta", None) - return dd.from_delayed( - [ - dask.delayed(f)(partition, partition_number, *args, **kwargs) - for partition_number, partition in enumerate(df.partitions) - ], - meta=meta, - )