@@ -134,8 +134,9 @@ def __init__(
134134 self .value = self .value .unsqueeze (0 ).repeat (self .batch_size , 1 , 1 )
135135
136136 self .value = self .value .to_sparse ()
137- assert not getattr (self , 'enforce_polarity' , False ), \
138- "enforce_polarity isn't supported for sparse tensors"
137+ assert not getattr (
138+ self , "enforce_polarity" , False
139+ ), "enforce_polarity isn't supported for sparse tensors"
139140
140141 @abstractmethod
141142 def reset_state_variables (self ) -> None :
@@ -179,9 +180,15 @@ def prime_feature(self, connection, device, **kwargs) -> None:
179180 # Check if values/norms are the correct shape
180181 if isinstance (self .value , torch .Tensor ):
181182 if self .sparse :
182- assert tuple (self .value .shape [1 :]) == (connection .source .n , connection .target .n )
183+ assert tuple (self .value .shape [1 :]) == (
184+ connection .source .n ,
185+ connection .target .n ,
186+ )
183187 else :
184- assert tuple (self .value .shape ) == (connection .source .n , connection .target .n )
188+ assert tuple (self .value .shape ) == (
189+ connection .source .n ,
190+ connection .target .n ,
191+ )
185192
186193 if self .norm is not None and isinstance (self .norm , torch .Tensor ):
187194 assert self .norm .shape [0 ] == connection .target .n
@@ -325,10 +332,12 @@ def assert_feature_in_range(self):
325332
326333 def assert_valid_shape (self , source_shape , target_shape , f ):
327334 # Multidimensional feat
328- if (not self .sparse and len (f .shape ) > 1 ) or (self .sparse and len (f .shape [1 :]) > 1 ):
335+ if (not self .sparse and len (f .shape ) > 1 ) or (
336+ self .sparse and len (f .shape [1 :]) > 1
337+ ):
329338 if self .sparse :
330339 f_shape = f .shape [1 :]
331- expected = (' batch_size' , source_shape , target_shape )
340+ expected = (" batch_size" , source_shape , target_shape )
332341 else :
333342 f_shape = f .shape
334343 expected = (source_shape , target_shape )
@@ -352,7 +361,7 @@ def __init__(
352361 decay : float = 0.0 ,
353362 parent_feature = None ,
354363 sparse : Optional [bool ] = False ,
355- batch_size : int = 1
364+ batch_size : int = 1 ,
356365 ) -> None :
357366 # language=rst
358367 """
@@ -386,19 +395,15 @@ def __init__(
386395 decay = decay ,
387396 parent_feature = parent_feature ,
388397 sparse = sparse ,
389- batch_size = batch_size
398+ batch_size = batch_size ,
390399 )
391400
392401 def sparse_bernoulli (self ):
393402 values = torch .bernoulli (self .value .values ())
394403 mask = values != 0
395404 indices = self .value .indices ()[:, mask ]
396405 non_zero = values [mask ]
397- return torch .sparse_coo_tensor (
398- indices ,
399- non_zero ,
400- self .value .size ()
401- )
406+ return torch .sparse_coo_tensor (indices , non_zero , self .value .size ())
402407
403408 def compute (self , conn_spikes ) -> Union [torch .Tensor , float , int ]:
404409 if self .sparse :
@@ -448,7 +453,7 @@ def __init__(
448453 name : str ,
449454 value : Union [torch .Tensor , float , int ] = None ,
450455 sparse : Optional [bool ] = False ,
451- batch_size : int = 1
456+ batch_size : int = 1 ,
452457 ) -> None :
453458 # language=rst
454459 """
@@ -472,12 +477,7 @@ def __init__(
472477 # Send boolean to tensor (priming wont work if it's not a tensor)
473478 value = torch .tensor (value )
474479
475- super ().__init__ (
476- name = name ,
477- value = value ,
478- sparse = sparse ,
479- batch_size = batch_size
480- )
480+ super ().__init__ (name = name , value = value , sparse = sparse , batch_size = batch_size )
481481
482482 def compute (self , conn_spikes ) -> torch .Tensor :
483483 return conn_spikes * self .value
@@ -561,7 +561,7 @@ def __init__(
561561 enforce_polarity : Optional [bool ] = False ,
562562 decay : float = 0.0 ,
563563 sparse : Optional [bool ] = False ,
564- batch_size : int = 1
564+ batch_size : int = 1 ,
565565 ) -> None :
566566 # language=rst
567567 """
@@ -596,7 +596,7 @@ def __init__(
596596 reduction = reduction ,
597597 decay = decay ,
598598 sparse = sparse ,
599- batch_size = batch_size
599+ batch_size = batch_size ,
600600 )
601601
602602 def reset_state_variables (self ) -> None :
@@ -651,7 +651,7 @@ def __init__(
651651 range : Optional [Sequence [float ]] = None ,
652652 norm : Optional [Union [torch .Tensor , float , int ]] = None ,
653653 sparse : Optional [bool ] = False ,
654- batch_size : int = 1
654+ batch_size : int = 1 ,
655655 ) -> None :
656656 # language=rst
657657 """
@@ -671,7 +671,7 @@ def __init__(
671671 range = [- torch .inf , + torch .inf ] if range is None else range ,
672672 norm = norm ,
673673 sparse = sparse ,
674- batch_size = batch_size
674+ batch_size = batch_size ,
675675 )
676676
677677 def reset_state_variables (self ) -> None :
@@ -697,7 +697,7 @@ def __init__(
697697 value : Union [torch .Tensor , float , int ] = None ,
698698 range : Optional [Sequence [float ]] = None ,
699699 sparse : Optional [bool ] = False ,
700- batch_size : int = 1
700+ batch_size : int = 1 ,
701701 ) -> None :
702702 # language=rst
703703 """
@@ -708,7 +708,9 @@ def __init__(
708708 :param batch_size: Mini-batch size.
709709 """
710710
711- super ().__init__ (name = name , value = value , range = range , sparse = sparse , batch_size = batch_size )
711+ super ().__init__ (
712+ name = name , value = value , range = range , sparse = sparse , batch_size = batch_size
713+ )
712714
713715 def reset_state_variables (self ) -> None :
714716 pass
@@ -738,7 +740,7 @@ def __init__(
738740 degrade_function : callable = None ,
739741 parent_feature : Optional [AbstractFeature ] = None ,
740742 sparse : Optional [bool ] = False ,
741- batch_size : int = 1
743+ batch_size : int = 1 ,
742744 ) -> None :
743745 # language=rst
744746 """
@@ -754,7 +756,13 @@ def __init__(
754756 """
755757
756758 # Note: parent_feature will override value. See abstract constructor
757- super ().__init__ (name = name , value = value , parent_feature = parent_feature , sparse = sparse , batch_size = batch_size )
759+ super ().__init__ (
760+ name = name ,
761+ value = value ,
762+ parent_feature = parent_feature ,
763+ sparse = sparse ,
764+ batch_size = batch_size ,
765+ )
758766
759767 self .degrade_function = degrade_function
760768
@@ -774,7 +782,7 @@ def __init__(
774782 const_update_rate : float = 0.1 ,
775783 const_decay : float = 0.001 ,
776784 sparse : Optional [bool ] = False ,
777- batch_size : int = 1
785+ batch_size : int = 1 ,
778786 ) -> None :
779787 # language=rst
780788 """
@@ -833,7 +841,9 @@ def compute(self, conn_spikes) -> Union[torch.Tensor, float, int]:
833841 flat_conn_spikes = conn_spikes .to_dense ().flatten ()
834842 else :
835843 flat_conn_spikes = conn_spikes .flatten ()
836- self .spike_buffer [:, self .counter % self .spike_buffer .shape [1 ]] = flat_conn_spikes
844+ self .spike_buffer [:, self .counter % self .spike_buffer .shape [1 ]] = (
845+ flat_conn_spikes
846+ )
837847 self .counter += 1
838848
839849 # Update the masks
@@ -872,7 +882,7 @@ def __init__(
872882 const_update_rate : float = 0.1 ,
873883 const_decay : float = 0.01 ,
874884 sparse : Optional [bool ] = False ,
875- batch_size : int = 1
885+ batch_size : int = 1 ,
876886 ) -> None :
877887 # language=rst
878888 """
@@ -931,7 +941,9 @@ def compute(self, conn_spikes) -> Union[torch.Tensor, float, int]:
931941 flat_conn_spikes = conn_spikes .to_dense ().flatten ()
932942 else :
933943 flat_conn_spikes = conn_spikes .flatten ()
934- self .spike_buffer [:, self .counter % self .spike_buffer .shape [1 ]] = flat_conn_spikes
944+ self .spike_buffer [:, self .counter % self .spike_buffer .shape [1 ]] = (
945+ flat_conn_spikes
946+ )
935947 self .counter += 1
936948
937949 # Update the masks
0 commit comments