Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@

import numpy as np

# Import CPU's kernel generation function
from aeon.transformations.collection.convolution_based._rocket import (
_generate_kernels,
)
from aeon.transformations.collection.convolution_based.rocketGPU.base import (
BaseROCKETGPU,
)
Expand All @@ -28,19 +32,17 @@ class ROCKETGPU(BaseROCKETGPU):
----------
n_kernels : int, default=10000
Number of random convolutional filters.
kernel_size : list, default = None
The list of possible kernel sizes, default is [7, 9, 11].
padding : list, default = None
The list of possible tensorflow padding, default is ["SAME", "VALID"].
use_dilation : bool, default = True
Whether or not to use dilation in convolution operations.
bias_range : Tuple, default = None
The min and max value of bias values, default is (-1.0, 1.0).
batch_size : int, default = 64
The batch to parallelize over GPU.
random_state : None or int, optional, default = None
Seed for random number generation.

Notes
-----
This GPU implementation uses the CPU's kernel generation logic
(from `_rocket._generate_kernels`) to ensure exact kernel parity
when using the same random seed.
Comment on lines +40 to +44
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would a user need to know this? I can see noting a difference in results but this seems a bit unnecessary.


References
----------
.. [1] Tan, Chang Wei and Dempster, Angus and Bergmeir, Christoph
Expand All @@ -54,62 +56,18 @@ class ROCKETGPU(BaseROCKETGPU):
def __init__(
self,
n_kernels=10000,
kernel_size=None,
padding=None,
use_dilation=True,
bias_range=None,
batch_size=64,
random_state=None,
):
super().__init__(n_kernels)

self.n_kernels = n_kernels
self.kernel_size = kernel_size
self.padding = padding
self.use_dilation = use_dilation
self.bias_range = bias_range
self.batch_size = batch_size
self.random_state = random_state

def _define_parameters(self):
"""Define the parameters of ROCKET."""
rng = np.random.default_rng(self.random_state)

self._list_of_kernels = []
self._list_of_dilations = []
self._list_of_paddings = []
self._list_of_biases = []

for _ in range(self.n_kernels):
_kernel_size = rng.choice(self._kernel_size, size=1)[0]
_convolution_kernel = rng.normal(size=(_kernel_size, self.n_channels, 1))
_convolution_kernel = _convolution_kernel - _convolution_kernel.mean(
axis=0, keepdims=True
)

if self.use_dilation:
_dilation_rate = 2 ** rng.uniform(
0, np.log2((self.input_length - 1) / (_kernel_size - 1))
)
else:
_dilation_rate = 1

_padding = rng.choice(self._padding, size=1)[0]
assert _padding in ["SAME", "VALID"]

_bias = rng.uniform(self._bias_range[0], self._bias_range[1])

self._list_of_kernels.append(_convolution_kernel)
self._list_of_dilations.append(_dilation_rate)
self._list_of_paddings.append(_padding)
self._list_of_biases.append(_bias)

def _fit(self, X, y=None):
"""Generate random kernels adjusted to time series shape.

Infers time series length and number of channels from input numpy array,
and generates random kernels.

Parameters
----------
X : 3D np.ndarray of shape = (n_cases, n_channels, n_timepoints)
Expand All @@ -123,13 +81,91 @@ def _fit(self, X, y=None):
self.input_length = X.shape[2]
self.n_channels = X.shape[1]

self._kernel_size = [7, 9, 11] if self.kernel_size is None else self.kernel_size
self._padding = ["VALID", "SAME"] if self.padding is None else self.padding
self._bias_range = (-1.0, 1.0) if self.bias_range is None else self.bias_range
self.kernels = _generate_kernels(
n_timepoints=self.input_length,
n_kernels=self.n_kernels,
n_channels=self.n_channels,
seed=self.random_state,
)
self._convert_cpu_kernels_to_gpu_format()
return self

def _convert_cpu_kernels_to_gpu_format(self):
"""Convert CPU's kernel format to GPU's TensorFlow-compatible format.

CPU kernels are stored compactly as:
(weights,lengths,biases,dilations,paddings,num_channel_indices,channel_indices)

GPU needs:
- _list_of_kernels: List of (kernel_length, n_channels, 1) arrays
- _list_of_dilations: List of int dilation rates
- _list_of_paddings: List of "SAME" or "VALID" strings
- _list_of_biases: List of float bias values

The key conversion is handling CPU's selective channel indexing
by creating dense kernels with zero weights for unused channels.
"""
(
weights,
lengths,
biases,
dilations,
paddings,
num_channel_indices,
channel_indices,
) = self.kernels

self._list_of_kernels = []
self._list_of_dilations = []
self._list_of_paddings = []
self._list_of_biases = []

weight_idx = 0
channel_idx = 0

for i in range(self.n_kernels):
kernel_length = lengths[i]
n_kernel_channels = num_channel_indices[i]

# Extract this kernel's sparse weights from CPU format
n_weights = kernel_length * n_kernel_channels
sparse_weights = weights[weight_idx : weight_idx + n_weights]
sparse_weights = sparse_weights.reshape((n_kernel_channels, kernel_length))

# Get which channels this kernel operates on
selected_channels = channel_indices[
channel_idx : channel_idx + n_kernel_channels
]

# Create dense kernel tensor: (kernel_length, n_channels, 1)
# Unused channels have zero weights (no contribution to convolution)
dense_kernel = np.zeros(
(kernel_length, self.n_channels, 1), dtype=np.float32
)

# Place sparse weights in the corresponding channel positions
# Preserving the exact channel order from CPU
for c_idx, channel in enumerate(selected_channels):
dense_kernel[:, channel, 0] = sparse_weights[c_idx, :]

self._list_of_kernels.append(dense_kernel)

# Convert numeric padding to TensorFlow categorical padding
# CPU: 0 or (length-1)*dilation//2
# GPU: "VALID" or "SAME"
if paddings[i] == 0:
self._list_of_paddings.append("VALID")
else:
# Non-zero padding -> use SAME to approximate symmetric padding
self._list_of_paddings.append("SAME")

assert self._bias_range[0] <= self._bias_range[1]
# Convert dilation and bias to Python scalar types
self._list_of_dilations.append(int(dilations[i]))
self._list_of_biases.append(float(biases[i]))

self._define_parameters()
# Advance indices for next kernel
weight_idx += n_weights
channel_idx += n_kernel_channels

def _generate_batch_indices(self, n):
"""Generate the list of batches.
Expand Down Expand Up @@ -182,7 +218,8 @@ def _transform(self, X, y=None):

tf.random.set_seed(self.random_state)

X = X.transpose(0, 2, 1)
# Transpose and convert to float32 for TensorFlow compatibility
X = X.transpose(0, 2, 1).astype(np.float32)
Comment on lines +221 to +222
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why?


batch_indices_list = self._generate_batch_indices(n=len(X))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -135,10 +135,13 @@ def test_base_rocketGPU_multivariate():
not _check_soft_dependencies("tensorflow", severity="none"),
reason="skip test if required soft dependency not available",
)
@pytest.mark.xfail(reason="Random numbers in Rocket and ROCKETGPU differ.")
@pytest.mark.parametrize("n_channels", [1, 3])
def test_rocket_cpu_gpu(n_channels):
"""Test consistency between CPU and GPU versions of ROCKET."""
"""Test feature parity between CPU and GPU versions of ROCKET.

GPU uses CPU's kernel generation to ensure identical kernels.
Feature outputs match within 1e-4 precision.
"""
random_state = 42
X, _ = make_example_3d_numpy(n_channels=n_channels, random_state=random_state)

Expand All @@ -152,4 +155,5 @@ def test_rocket_cpu_gpu(n_channels):

X_transform_cpu = rocket_cpu.transform(X)
X_transform_gpu = rocket_gpu.transform(X)
assert_array_almost_equal(X_transform_cpu, X_transform_gpu, decimal=8)
# Set decimal threshold here
assert_array_almost_equal(X_transform_cpu, X_transform_gpu, decimal=4)
Comment on lines 140 to +159
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why the changes here? Not against the decimal changes but interested in hearing why. Docs changes seem unnecessary.