Skip to content

[BUG] Training cuML single GPU models on dask dataframe objects uses client instead of worker #4406

@VibhuJawa

Description

@VibhuJawa

Describe the bug

With PR we enabed training single GPU cuML models using Dask DataFrames and Series but we use compute there which brings data to the client.

This causes the following issues:

  1. Clusters where we use rmm-pool on workers not leaving enough memory on workers causing OOM
  2. Adds overhead of bringing data to the client.

Steps/Code to reproduce bug

Example with dask-cudf

from cuml.linear_model import LinearRegression
from dask.datasets import timeseries
import cudf

from dask_cuda import LocalCUDACluster
from dask.distributed import Client

cluster = LocalCUDACluster(CUDA_VISIBLE_DEVICES='0',rmm_pool_size='14.5 GB')
client = Client(cluster)


df = timeseries(start = '2000-01-01',end ='2002-10-28', freq="1s")[['x','y']]
df = df.map_partitions(cudf.from_pandas).reset_index(drop=True)

model = LinearRegression()
model = model.fit(df,df['x']*df['y'])

Trace:

Details
distributed.preloading - INFO - Import preload module: dask_cuda.initialize
distributed.utils - ERROR - std::bad_alloc: CUDA error at: /datasets/vjawa/miniconda3/envs/rapids-21.12/include/rmm/mr/device/cuda_memory_resource.hpp
Traceback (most recent call last):
  File "/datasets/vjawa/miniconda3/envs/rapids-21.12/lib/python3.8/site-packages/distributed/utils.py", line 653, in log_errors
    yield
  File "/datasets/vjawa/miniconda3/envs/rapids-21.12/lib/python3.8/site-packages/cudf/comm/serialize.py", line 27, in dask_deserialize_cudf_object
    return Serializable.host_deserialize(header, frames)
  File "/datasets/vjawa/miniconda3/envs/rapids-21.12/lib/python3.8/site-packages/cudf/core/abc.py", line 186, in host_deserialize
    frames = [
  File "/datasets/vjawa/miniconda3/envs/rapids-21.12/lib/python3.8/site-packages/cudf/core/abc.py", line 187, in <listcomp>
    rmm.DeviceBuffer.to_device(f) if c else f
  File "rmm/_lib/device_buffer.pyx", line 146, in rmm._lib.device_buffer.DeviceBuffer.to_device
  File "rmm/_lib/device_buffer.pyx", line 335, in rmm._lib.device_buffer.to_device
  File "rmm/_lib/device_buffer.pyx", line 86, in rmm._lib.device_buffer.DeviceBuffer.__cinit__
MemoryError: std::bad_alloc: CUDA error at: /datasets/vjawa/miniconda3/envs/rapids-21.12/include/rmm/mr/device/cuda_memory_resource.hpp
distributed.protocol.core - CRITICAL - Failed to deserialize
Traceback (most recent call last):
  File "/datasets/vjawa/miniconda3/envs/rapids-21.12/lib/python3.8/site-packages/distributed/protocol/core.py", line 111, in loads
    return msgpack.loads(
  File "msgpack/_unpacker.pyx", line 195, in msgpack._cmsgpack.unpackb
  File "/datasets/vjawa/miniconda3/envs/rapids-21.12/lib/python3.8/site-packages/distributed/protocol/core.py", line 103, in _decode_default
    return merge_and_deserialize(
  File "/datasets/vjawa/miniconda3/envs/rapids-21.12/lib/python3.8/site-packages/distributed/protocol/serialize.py", line 488, in merge_and_deserialize
    return deserialize(header, merged_frames, deserializers=deserializers)
  File "/datasets/vjawa/miniconda3/envs/rapids-21.12/lib/python3.8/site-packages/distributed/protocol/serialize.py", line 417, in deserialize
    return loads(header, frames)
  File "/datasets/vjawa/miniconda3/envs/rapids-21.12/lib/python3.8/site-packages/distributed/protocol/serialize.py", line 57, in dask_loads
    return loads(header["sub-header"], frames)
  File "/datasets/vjawa/miniconda3/envs/rapids-21.12/lib/python3.8/site-packages/cudf/comm/serialize.py", line 27, in dask_deserialize_cudf_object
    return Serializable.host_deserialize(header, frames)
  File "/datasets/vjawa/miniconda3/envs/rapids-21.12/lib/python3.8/site-packages/cudf/core/abc.py", line 186, in host_deserialize
    frames = [
  File "/datasets/vjawa/miniconda3/envs/rapids-21.12/lib/python3.8/site-packages/cudf/core/abc.py", line 187, in <listcomp>
    rmm.DeviceBuffer.to_device(f) if c else f
  File "rmm/_lib/device_buffer.pyx", line 146, in rmm._lib.device_buffer.DeviceBuffer.to_device
  File "rmm/_lib/device_buffer.pyx", line 335, in rmm._lib.device_buffer.to_device
  File "rmm/_lib/device_buffer.pyx", line 86, in rmm._lib.device_buffer.DeviceBuffer.__cinit__
MemoryError: std::bad_alloc: CUDA error at: /datasets/vjawa/miniconda3/envs/rapids-21.12/include/rmm/mr/device/cuda_memory_resource.hpp
---------------------------------------------------------------------------
MemoryError                               Traceback (most recent call last)
/tmp/ipykernel_1570/91331363.py in <module>
     14 
     15 model = LinearRegression()
---> 16 model = model.fit(df,df['x']*df['y'])

/datasets/vjawa/miniconda3/envs/rapids-21.12/lib/python3.8/site-packages/cuml/internals/api_decorators.py in inner_with_setters(*args, **kwargs)
    407                                 target_val=target_val)
    408 
--> 409                 return func(*args, **kwargs)
    410 
    411         @wraps(func)

cuml/linear_model/linear_regression.pyx in cuml.linear_model.linear_regression.LinearRegression.fit()

/datasets/vjawa/miniconda3/envs/rapids-21.12/lib/python3.8/contextlib.py in inner(*args, **kwds)
     73         def inner(*args, **kwds):
     74             with self._recreate_cm():
---> 75                 return func(*args, **kwds)
     76         return inner
     77 

/datasets/vjawa/miniconda3/envs/rapids-21.12/lib/python3.8/site-packages/cuml/internals/api_decorators.py in inner(*args, **kwargs)
    358         def inner(*args, **kwargs):
    359             with self._recreate_cm(func, args):
--> 360                 return func(*args, **kwargs)
    361 
    362         return inner

/datasets/vjawa/miniconda3/envs/rapids-21.12/lib/python3.8/site-packages/cuml/common/input_utils.py in input_to_cuml_array(X, order, deepcopy, check_dtype, convert_to_dtype, safe_dtype_conversion, check_cols, check_rows, fail_on_order, force_contiguous)
    318     if isinstance(X, (dask_cudf.core.Series, dask_cudf.core.DataFrame)):
    319         # TODO: Warn, but not when using dask_sql
--> 320         X = X.compute()
    321 
    322     if (isinstance(X, cudf.Series)):

/datasets/vjawa/miniconda3/envs/rapids-21.12/lib/python3.8/site-packages/dask/base.py in compute(self, **kwargs)
    286         dask.base.compute
    287         """
--> 288         (result,) = compute(self, traverse=False, **kwargs)
    289         return result
    290 

/datasets/vjawa/miniconda3/envs/rapids-21.12/lib/python3.8/site-packages/dask/base.py in compute(traverse, optimize_graph, scheduler, get, *args, **kwargs)
    569         postcomputes.append(x.__dask_postcompute__())
    570 
--> 571     results = schedule(dsk, keys, **kwargs)
    572     return repack([f(r, *a) for r, (f, a) in zip(results, postcomputes)])
    573 

/datasets/vjawa/miniconda3/envs/rapids-21.12/lib/python3.8/site-packages/distributed/client.py in get(self, dsk, keys, workers, allow_other_workers, resources, sync, asynchronous, direct, retries, priority, fifo_timeout, actors, **kwargs)
   2723                     should_rejoin = False
   2724             try:
-> 2725                 results = self.gather(packed, asynchronous=asynchronous, direct=direct)
   2726             finally:
   2727                 for f in futures.values():

/datasets/vjawa/miniconda3/envs/rapids-21.12/lib/python3.8/site-packages/distributed/client.py in gather(self, futures, errors, direct, asynchronous)
   1978             else:
   1979                 local_worker = None
-> 1980             return self.sync(
   1981                 self._gather,
   1982                 futures,

/datasets/vjawa/miniconda3/envs/rapids-21.12/lib/python3.8/site-packages/distributed/client.py in sync(self, func, asynchronous, callback_timeout, *args, **kwargs)
    866             return future
    867         else:
--> 868             return sync(
    869                 self.loop, func, *args, callback_timeout=callback_timeout, **kwargs
    870             )

/datasets/vjawa/miniconda3/envs/rapids-21.12/lib/python3.8/site-packages/distributed/utils.py in sync(loop, func, callback_timeout, *args, **kwargs)
    330     if error[0]:
    331         typ, exc, tb = error[0]
--> 332         raise exc.with_traceback(tb)
    333     else:
    334         return result[0]

/datasets/vjawa/miniconda3/envs/rapids-21.12/lib/python3.8/site-packages/distributed/utils.py in f()
    313             if callback_timeout is not None:
    314                 future = asyncio.wait_for(future, callback_timeout)
--> 315             result[0] = yield future
    316         except Exception:
    317             error[0] = sys.exc_info()

/datasets/vjawa/miniconda3/envs/rapids-21.12/lib/python3.8/site-packages/tornado/gen.py in run(self)
    760 
    761                     try:
--> 762                         value = future.result()
    763                     except Exception:
    764                         exc_info = sys.exc_info()

/datasets/vjawa/miniconda3/envs/rapids-21.12/lib/python3.8/site-packages/distributed/client.py in _gather(self, futures, errors, direct, local_worker)
   1872                 else:
   1873                     self._gather_future = future
-> 1874                 response = await future
   1875 
   1876             if response["status"] == "error":

/datasets/vjawa/miniconda3/envs/rapids-21.12/lib/python3.8/site-packages/distributed/client.py in _gather_remote(self, direct, local_worker)
   1923 
   1924             else:  # ask scheduler to gather data for us
-> 1925                 response = await retry_operation(self.scheduler.gather, keys=keys)
   1926 
   1927         return response

/datasets/vjawa/miniconda3/envs/rapids-21.12/lib/python3.8/site-packages/distributed/utils_comm.py in retry_operation(coro, operation, *args, **kwargs)
    383         dask.config.get("distributed.comm.retry.delay.max"), default="s"
    384     )
--> 385     return await retry(
    386         partial(coro, *args, **kwargs),
    387         count=retry_count,

/datasets/vjawa/miniconda3/envs/rapids-21.12/lib/python3.8/site-packages/distributed/utils_comm.py in retry(coro, count, delay_min, delay_max, jitter_fraction, retry_on_exceptions, operation)
    368                 delay *= 1 + random.random() * jitter_fraction
    369             await asyncio.sleep(delay)
--> 370     return await coro()
    371 
    372 

/datasets/vjawa/miniconda3/envs/rapids-21.12/lib/python3.8/site-packages/distributed/core.py in send_recv_from_rpc(**kwargs)
    893             name, comm.name = comm.name, "ConnectionPool." + key
    894             try:
--> 895                 result = await send_recv(comm=comm, op=key, **kwargs)
    896             finally:
    897                 self.pool.reuse(self.addr, comm)

/datasets/vjawa/miniconda3/envs/rapids-21.12/lib/python3.8/site-packages/distributed/core.py in send_recv(comm, reply, serializers, deserializers, **kwargs)
    670         await comm.write(msg, serializers=serializers, on_error="raise")
    671         if reply:
--> 672             response = await comm.read(deserializers=deserializers)
    673         else:
    674             response = None

/datasets/vjawa/miniconda3/envs/rapids-21.12/lib/python3.8/site-packages/distributed/comm/tcp.py in read(self, deserializers)
    231                 frames = unpack_frames(frames)
    232 
--> 233                 msg = await from_frames(
    234                     frames,
    235                     deserialize=self.deserialize,

/datasets/vjawa/miniconda3/envs/rapids-21.12/lib/python3.8/site-packages/distributed/comm/utils.py in from_frames(frames, deserialize, deserializers, allow_offload)
     74         size = sum(map(nbytes, frames))
     75     if allow_offload and deserialize and OFFLOAD_THRESHOLD and size > OFFLOAD_THRESHOLD:
---> 76         res = await offload(_from_frames)
     77     else:
     78         res = _from_frames()

/datasets/vjawa/miniconda3/envs/rapids-21.12/lib/python3.8/site-packages/distributed/utils.py in offload(fn, *args, **kwargs)
   1330     # Retain context vars while deserializing; see https://bugs.python.org/issue34014
   1331     context = contextvars.copy_context()
-> 1332     return await loop.run_in_executor(
   1333         _offload_executor, lambda: context.run(fn, *args, **kwargs)
   1334     )

/datasets/vjawa/miniconda3/envs/rapids-21.12/lib/python3.8/concurrent/futures/thread.py in run(self)
     55 
     56         try:
---> 57             result = self.fn(*self.args, **self.kwargs)
     58         except BaseException as exc:
     59             self.future.set_exception(exc)

/datasets/vjawa/miniconda3/envs/rapids-21.12/lib/python3.8/site-packages/distributed/utils.py in <lambda>()
   1331     context = contextvars.copy_context()
   1332     return await loop.run_in_executor(
-> 1333         _offload_executor, lambda: context.run(fn, *args, **kwargs)
   1334     )
   1335 

/datasets/vjawa/miniconda3/envs/rapids-21.12/lib/python3.8/site-packages/distributed/comm/utils.py in _from_frames()
     59     def _from_frames():
     60         try:
---> 61             return protocol.loads(
     62                 frames, deserialize=deserialize, deserializers=deserializers
     63             )

/datasets/vjawa/miniconda3/envs/rapids-21.12/lib/python3.8/site-packages/distributed/protocol/core.py in loads(frames, deserialize, deserializers)
    109                 return msgpack_decode_default(obj)
    110 
--> 111         return msgpack.loads(
    112             frames[0], object_hook=_decode_default, use_list=False, **msgpack_opts
    113         )

msgpack/_unpacker.pyx in msgpack._cmsgpack.unpackb()

/datasets/vjawa/miniconda3/envs/rapids-21.12/lib/python3.8/site-packages/distributed/protocol/core.py in _decode_default(obj)
    101                     if "compression" in sub_header:
    102                         sub_frames = decompress(sub_header, sub_frames)
--> 103                     return merge_and_deserialize(
    104                         sub_header, sub_frames, deserializers=deserializers
    105                     )

/datasets/vjawa/miniconda3/envs/rapids-21.12/lib/python3.8/site-packages/distributed/protocol/serialize.py in merge_and_deserialize(header, frames, deserializers)
    486             merged_frames.append(merged)
    487 
--> 488     return deserialize(header, merged_frames, deserializers=deserializers)
    489 
    490 

/datasets/vjawa/miniconda3/envs/rapids-21.12/lib/python3.8/site-packages/distributed/protocol/serialize.py in deserialize(header, frames, deserializers)
    415         )
    416     dumps, loads, wants_context = families[name]
--> 417     return loads(header, frames)
    418 
    419 

/datasets/vjawa/miniconda3/envs/rapids-21.12/lib/python3.8/site-packages/distributed/protocol/serialize.py in dask_loads(header, frames)
     55     typ = pickle.loads(header["type-serialized"])
     56     loads = dask_deserialize.dispatch(typ)
---> 57     return loads(header["sub-header"], frames)
     58 
     59 

/datasets/vjawa/miniconda3/envs/rapids-21.12/lib/python3.8/site-packages/cudf/comm/serialize.py in dask_deserialize_cudf_object(header, frames)
     25     def dask_deserialize_cudf_object(header, frames):
     26         with log_errors():
---> 27             return Serializable.host_deserialize(header, frames)
     28 
     29 

/datasets/vjawa/miniconda3/envs/rapids-21.12/lib/python3.8/site-packages/cudf/core/abc.py in host_deserialize(cls, header, frames)
    184         :meta private:
    185         """
--> 186         frames = [
    187             rmm.DeviceBuffer.to_device(f) if c else f
    188             for c, f in zip(header["is-cuda"], map(memoryview, frames))

/datasets/vjawa/miniconda3/envs/rapids-21.12/lib/python3.8/site-packages/cudf/core/abc.py in <listcomp>(.0)
    185         """
    186         frames = [
--> 187             rmm.DeviceBuffer.to_device(f) if c else f
    188             for c, f in zip(header["is-cuda"], map(memoryview, frames))
    189         ]

rmm/_lib/device_buffer.pyx in rmm._lib.device_buffer.DeviceBuffer.to_device()

rmm/_lib/device_buffer.pyx in rmm._lib.device_buffer.to_device()

rmm/_lib/device_buffer.pyx in rmm._lib.device_buffer.DeviceBuffer.__cinit__()

MemoryError: std::bad_alloc: CUDA error at: /datasets/vjawa/miniconda3/envs/rapids-21.12/include/rmm/mr/device/cuda_memory_resource.hpp

Example with dask-sql

from dask_cuda import LocalCUDACluster
from dask.distributed import Client
from dask.datasets import timeseries
from dask_sql import Context
import cudf

cluster = LocalCUDACluster(CUDA_VISIBLE_DEVICES='0',rmm_pool_size='14.5 GB')
client = Client(cluster)

c = Context()

df = timeseries(start = '2000-01-01',end ='2002-10-28', freq="1s")[['x','y']]
df = df.map_partitions(cudf.from_pandas).reset_index(drop=True)
c.create_table("timeseries", input_table=df)

model_query = """
    CREATE OR REPLACE MODEL my_model WITH (
        model_class = 'cuml.linear_model.LinearRegression',
        wrap_predict = True,
        target_column = 'target'
    ) AS (
        SELECT x, y, x*y AS target
        FROM timeseries
    )
    """
c.sql(model_query)

Expected Behaviour:

I expect this to succeed like if we were to do this with cuDF dataframes.

from dask.datasets import timeseries
import cudf
from cuml.linear_model import LinearRegression


df = timeseries(start = '2000-01-01',end ='2002-10-28', freq="1s")[['x','y']]
df = df.map_partitions(cudf.from_pandas).reset_index(drop=True).compute()

model = LinearRegression()
model = model.fit(df,df['x']*df['y'])

CC: @dantegd , @ChrisJar

** Expected Solution **

Unsure where we should push a fix for this.

For the dask-sql case it might be a better to fix it in dask-sql and train there via a map_partitions call directly and just error/warn if stand alone dask-cuDF.

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions