diff --git a/bindsnet/pipeline/action.py b/bindsnet/pipeline/action.py index 5d45c45b..fa5bfd84 100644 --- a/bindsnet/pipeline/action.py +++ b/bindsnet/pipeline/action.py @@ -43,7 +43,8 @@ def select_multinomial(pipeline: EnvironmentPipeline, **kwargs) -> int: [ spikes[(i * pop_size) : (i * pop_size) + pop_size].sum() for i in range(action_space.n) - ] + ], + device=spikes.device, ) action = torch.multinomial((pop_spikes.float() / _sum).view(-1), 1)[0].item()