Skip to content

[Tracker] Improve cuml.accel proxy layer implementation #6502

@csadorf

Description

@csadorf

Current Implementation Limitations

The current proxy layer implementation (estimator_proxy.py) has architectural limitations that affect maintainability and extensibility:

1. Inheritance vs. Composition

The current design inherits from accelerated components rather than using composition:

class ProxyEstimator(class_b):  # Inherits from accelerated class
    def __init__(self, *args, **kwargs):
        self._cpu_model_class = original_class_a  # Stores original as reference

This creates tight coupling where:

  • Accelerated components must handle translation logic
  • Changes to either component affect the other
  • Adding new acceleration paths requires modifying existing code

2. Early Parameter Translation

Parameter translation occurs at construction time rather than operation time:

def __init__(self, *args, **kwargs):
    translated_kwargs, self._gpuaccel = self._hyperparam_translator(**kwargs)
    super().__init__(*args, **translated_kwargs)

Issues:

  • Parameters are translated before knowing runtime context
  • Multiple translations may occur (original → accelerated → cpu)

3. Dispatch Limitations

Current dispatching is limited by:

  • Limited support for branching at runtime
  • Tight coupling with parameter translation

Requirements

P0 (Critical)

  1. API Compatibility

    • Proxied objects must match original API behavior
    • Public interface (non _-prefixed) must be identical
    • Must preserve original error handling and warnings
  2. Context-Aware Dispatch

    • Dispatch decision should consider the full context at operation time (fit(), predict(), etc.)
    • Consider full context:
      • Input data properties (type, shape, memory location)
      • Available hardware
      • Runtime state
      • Hyperparameters

P1 (Important)

  1. Type System Integration

    # Should work correctly:
    isinstance(proxy_estimator, OriginalEstimator)  # True
  2. Meta-Estimator Support

    • Support nested estimators
    • Preserve acceleration capabilities in pipelines
    • Handle parameter grids correctly

P2 (Nice-to-have)

  1. Serialization
    # Should work with original library:
    import pickle
    loaded = pickle.loads(pickle.dumps(proxy_estimator))

Metadata

Metadata

Assignees

Labels

Tech DebtIssues related to debtcuml-accelIssues related to cuml.accelimprovementImprovement / enhancement to an existing function

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions