77from __future__ import annotations
88
99import datetime
10+ import functools
1011import itertools
12+ import operator
1113from abc import ABC , abstractmethod
1214from collections import defaultdict
1315from collections .abc import Mapping , Sequence
@@ -670,7 +672,12 @@ class SeasonResampler(Resampler):
670672
671673 def __post_init__ (self ):
672674 self .season_inds = season_to_month_tuple (self .seasons )
673- self .season_tuples = dict (zip (self .seasons , self .season_inds , strict = False ))
675+ all_inds = functools .reduce (operator .add , self .season_inds )
676+ if len (all_inds ) > len (set (all_inds )):
677+ raise ValueError (
678+ f"Overlapping seasons are not allowed. Received { self .seasons !r} "
679+ )
680+ self .season_tuples = dict (zip (self .seasons , self .season_inds , strict = True ))
674681
675682 def factorize (self , group ):
676683 if group .ndim != 1 :
@@ -696,12 +703,22 @@ def factorize(self, group):
696703 season_label [month .isin (season_ind )] = season_str
697704 if "DJ" in season_str :
698705 after_dec = season_ind [season_str .index ("D" ) + 1 :]
706+ # important this is assuming non-overlapping seasons
699707 year [month .isin (after_dec )] -= 1
700708
709+ # Allow users to skip one or more months?
710+ # present_seasons is a mask that is True for months that are requestsed in the output
711+ present_seasons = season_label != ""
712+ if present_seasons .all ():
713+ present_seasons = slice (None )
701714 frame = pd .DataFrame (
702- data = {"index" : np .arange (group .size ), "month" : month },
715+ data = {
716+ "index" : np .arange (group [present_seasons ].size ),
717+ "month" : month [present_seasons ],
718+ },
703719 index = pd .MultiIndex .from_arrays (
704- [year .data , season_label ], names = ["year" , "season" ]
720+ [year .data [present_seasons ], season_label [present_seasons ]],
721+ names = ["year" , "season" ],
705722 ),
706723 )
707724
@@ -727,19 +744,19 @@ def factorize(self, group):
727744
728745 sbins = first_items .values .astype (int )
729746 group_indices = [
730- slice (i , j ) for i , j in zip (sbins [:- 1 ], sbins [1 :], strict = False )
747+ slice (i , j ) for i , j in zip (sbins [:- 1 ], sbins [1 :], strict = True )
731748 ]
732749 group_indices += [slice (sbins [- 1 ], None )]
733750
734751 # Make sure the first and last timestamps
735752 # are for the correct months,if not we have incomplete seasons
736753 unique_codes = np .arange (len (unique_coord ))
737754 if self .drop_incomplete :
738- for idx , slicer in zip ([0 , - 1 ], (slice (1 , None ), slice (- 1 )), strict = False ):
755+ for idx , slicer in zip ([0 , - 1 ], (slice (1 , None ), slice (- 1 )), strict = True ):
739756 stamp_year , stamp_season = frame .index [idx ]
740757 code = seasons .index (stamp_season )
741758 stamp_month = season_inds [code ][idx ]
742- if stamp_month != month [idx ].item ():
759+ if stamp_month != month [present_seasons ][ idx ].item ():
743760 # we have an incomplete season!
744761 group_indices = group_indices [slicer ]
745762 unique_coord = unique_coord [slicer ]
@@ -769,7 +786,9 @@ def factorize(self, group):
769786 if not full_index .equals (unique_coord ):
770787 raise ValueError ("Are there seasons missing in the middle of the dataset?" )
771788
772- codes = group .copy (data = np .repeat (unique_codes , counts ), deep = False )
789+ final_codes = np .full (group .data .size , - 1 )
790+ final_codes [present_seasons ] = np .repeat (unique_codes , counts )
791+ codes = group .copy (data = final_codes , deep = False )
773792 unique_coord_var = Variable (group .name , unique_coord , group .attrs )
774793
775794 return EncodedGroups (
0 commit comments