Skip to content
Draft
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,5 @@ scripts-local/
# pixi environments
.pixi
*.egg-info
.idea/
*.toml
23 changes: 23 additions & 0 deletions Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
FROM rayproject/ray:nightly-py312-gpu

USER root

# Install uv
RUN curl -LsSf https://astral.sh/uv/install.sh | sh
ENV PATH="/root/.local/bin:${PATH}"

# Create writable directory
WORKDIR /app

# Copy project
COPY . .

# Create venv + install deps
RUN uv venv --python 3.12 --clear && \
. .venv/bin/activate && \
uv sync --extra cuda126

ENV VIRTUAL_ENV=/app/.venv
ENV PATH="/app/.venv/bin:${PATH}"

CMD ["ray", "start", "--head"]
28 changes: 28 additions & 0 deletions README_pixi_shell.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# Pixi shell instructions

If you are unable to install the nvidia related dependencies directly, you will be
using a pixi shell. This requires a few extra steps before you can run the darts pipeline

First, run this command

`pixi shell -e cuda128`

`conda install -c nvidia cuda-toolkit=12 -y
`

then

`uv sync --extra cuda128 --extra torchdeps --extra cuda12deps`

This will install the dependencies.

Once you run those, you will activate the environment using this command:

`source .venv/bin/activate`

commands are then run like this. Note that there is no `uv run` at the start.

`darts inference sentinel2-ray --aoi-shapefile
/taiga/toddn/rts-files/tiles_nwt_2010_2016_small.geojson
--start-date 2024-07 --max-cloud-cover 100 --max-snow-cover 100
--end-date 2024-09 --verbose`
58 changes: 51 additions & 7 deletions darts-utils/src/darts_utils/cuda.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
"""Utility functions around cuda, e.g. memory management."""

import gc
import logging
from typing import Literal

import xarray as xr
from xrspatial.utils import has_cuda_and_cupy
from typing import Any

Check failure on line 9 in darts-utils/src/darts_utils/cuda.py

View workflow job for this annotation

GitHub Actions / lint

Ruff (I001)

darts-utils/src/darts_utils/cuda.py:3:1: I001 Import block is un-sorted or un-formatted

logger = logging.getLogger(__name__.replace("darts_", "darts."))

Expand Down Expand Up @@ -54,19 +55,62 @@
return tile


def move_to_host(tile: xr.Dataset) -> xr.Dataset:
"""Move a dataset from GPU to CPU.
def move_to_host(
tile: xr.Dataset | xr.DataArray | Any,
) -> xr.Dataset | xr.DataArray | Any:
"""Ensure data are moved from GPU (CuPy) memory to CPU (NumPy) memory.

This function converts CuPy-backed arrays inside an xarray Dataset or DataArray
into NumPy arrays, ensuring full CPU compatibility for serialization or
downstream processing (e.g., Ray pipelines).

Handles the following cases:
1. **Raw CuPy array** → returns NumPy array via `cp.asnumpy`.
2. **xarray.DataArray** backed by CuPy → returns a new DataArray
with its data copied to NumPy.
3. **xarray.Dataset** with CuPy-backed variables → returns a new Dataset
where each variable is NumPy-backed.

If the input is already CPU-backed or CuPy is unavailable, it is returned unchanged.

Args:
tile (xr.Dataset): The xarray dataset to move.
tile: The data object to move. Can be:
- `cupy.ndarray`
- `xarray.DataArray`
- `xarray.Dataset`

Returns:
xr.Dataset: _description_
The same type of object, but backed by NumPy arrays on CPU.

Raises:

Check failure on line 85 in darts-utils/src/darts_utils/cuda.py

View workflow job for this annotation

GitHub Actions / lint

Ruff (D413)

darts-utils/src/darts_utils/cuda.py:85:5: D413 Missing blank line after last section ("Raises")
AttributeError: Only if unexpected object types or data attributes are missing.
"""

Check failure on line 87 in darts-utils/src/darts_utils/cuda.py

View workflow job for this annotation

GitHub Actions / lint

Ruff (DOC502)

darts-utils/src/darts_utils/cuda.py:61:5: DOC502 Raised exception is not explicitly raised: `AttributeError`
if tile.cupy.is_cupy:
tile = tile.cupy.as_numpy()
free_cupy()
if has_cuda_and_cupy():
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This might be problematic and causing the backpressure.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was running into this issue
File "/Users/cwang138/Documents/Projects/PDG/darts-nextgen/darts-utils/src/darts_utils/cuda.py", line 67, in move_to_host
if tile.cupy.is_cupy:
^^^^^^^^^
File "/Users/cwang138/Documents/Projects/PDG/darts-nextgen/.venv/lib/python3.12/site-packages/xarray/core/common.py", line 306, in getattr
raise AttributeError(
AttributeError: 'Dataset' object has no attribute 'cupy'. Did you mean: 'copy'?
hence I added different safeguard

Tobias Hölzer
Nov 13th at 4:50 PM
Ah, maybe try to add "import cupy_xarray " in the function before calling tile.cupy...
🆗
1

4:50
https://github.com/xarray-contrib/cupy-xarray
xarray-contrib/cupy-xarray
Interface for using cupy in xarray, providing convenience accessors.
Website
https://cupy-xarray.readthedocs.io/
Stars
85
Added by GitHub
4:51
Seems like ray doesn't import this properly 😅

try:
# Case 1: raw CuPy array
if isinstance(tile, cp.ndarray):
return cp.asnumpy(tile)

# Case 2 & 3: DataArray or Dataset backed by CuPy
if isinstance(tile, xr.DataArray):
data = tile.data
if hasattr(data, "__cuda_array_interface__"):
return tile.copy(data=cp.asnumpy(data))
return tile

# Case 3: Dataset containing CuPy-backed DataArrays
if isinstance(tile, xr.Dataset):
vars_cpu = {}
for name, da in tile.data_vars.items():
data = da.data
if hasattr(data, "__cuda_array_interface__"):
data = cp.asnumpy(data)
vars_cpu[name] = (da.dims, data, da.attrs)
return xr.Dataset(vars_cpu, attrs=tile.attrs)

except AttributeError:
# Dataset doesn't have cupy attribute, already on CPU
pass
return tile


Expand Down
26 changes: 19 additions & 7 deletions darts/src/darts/pipelines/_ray_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from darts_ensemble import EnsembleV1
from darts_export import export_tile
from darts_postprocessing import prepare_export
from darts_preprocessing import preprocess_legacy_fast
from darts_preprocessing import preprocess_v2

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -41,6 +41,8 @@
tilekey: Any # The key to identify the tile, e.g. a path or a tile id
outpath: str # The path to the output directory
tile_id: str # The id of the tile, e.g. the name of the file or the tile id
metadata: dict | None
debug_data: bool | None


# @ray.remote(num_cpus=1, num_gpus=1)
Expand All @@ -60,7 +62,7 @@
write_model_outputs: bool,
) -> RayDataDict:
tile = row["tile"].dataset
tile = self.ensemble(
tile = self.ensemble.segment_tile(
tile,
patch_size=patch_size,
overlap=overlap,
Expand All @@ -69,6 +71,7 @@
keep_inputs=write_model_outputs,
)
row["tile"] = RayDataset(tile)
logger.info("Ensemble done", extra={"tile_id": row["tile_id"]})
return row


Expand All @@ -81,17 +84,20 @@
arcticdem_resolution: int,
buffer: int,
tcvis_dir: Path,
offline: bool = False,
) -> RayDataDict:
tile = row["tile"].dataset
arcticdem = load_arcticdem(
tile.odc.geobox,
data_dir=arcticdem_dir,
resolution=arcticdem_resolution,
buffer=buffer,
offline=offline,
)
tcvis = load_tcvis(tile.odc.geobox, tcvis_dir)
tcvis = load_tcvis(tile.odc.geobox, tcvis_dir, offline=offline)
row["adem"] = RayDataset(arcticdem)
row["tcvis"] = RayDataset(tcvis)
logger.info("Aux data loaded", extra={"tile_id": row["tile_id"]})
return row


Expand All @@ -103,9 +109,9 @@
device: int | Literal["cuda", "cpu"],
):
tile = row["tile"].dataset
arcticdem = row["adem"].dataset
tcvis = row["tcvis"].dataset
tile = preprocess_legacy_fast(
arcticdem = row["adem"].dataset if row["adem"] is not None else None
tcvis = row["tcvis"].dataset if row["tcvis"] is not None else None
tile = preprocess_v2(
tile,
arcticdem,
tcvis,
Expand All @@ -116,6 +122,7 @@
row["tile"] = RayDataset(tile)
row["adem"] = None
row["tcvis"] = None
logger.info("Preprocess done", extra={"tile_id": row["tile_id"]})
return row


Expand All @@ -129,19 +136,21 @@
models: dict[str, Any],
write_model_outputs: bool,
device: int | Literal["cuda", "cpu"],
edge_erosion_size: int | None = None,
):
tile = row["tile"].dataset
tile = prepare_export(
tile,
bin_threshold=binarization_threshold,
mask_erosion_size=mask_erosion_size,
# TODO: edge_erosion_size
edge_erosion_size=edge_erosion_size,
min_object_size=min_object_size,
quality_level=quality_level,
ensemble_subsets=models.keys() if write_model_outputs else [],
device=device,
)
row["tile"] = RayDataset(tile)
logger.info("Export done", extra={"tile_id": row["tile_id"]})
return row


Expand All @@ -159,6 +168,8 @@
outpath,
bands=export_bands,
ensemble_subsets=models.keys() if write_model_outputs else [],
metadata=row.get("metadata") or {},
debug=row.get("debug_data") or False,
)
del row["tile"]

Expand All @@ -170,3 +181,4 @@
"tile_id": tile_id,
"outpath": str(outpath),
}

Loading
Loading