5454 GroupKey = Any
5555 GroupIndex = Union [int , slice , list [int ]]
5656 T_GroupIndices = list [GroupIndex ]
57- T_FactorizeOut = tuple [
58- DataArray , T_GroupIndices , Union [pd .Index , "_DummyGroup" ], pd .Index , DataArray
59- ]
6057
6158
6259def check_reduce_dims (reduce_dims , dimensions ):
@@ -92,7 +89,7 @@ def _maybe_squeeze_indices(
9289
9390def unique_value_groups (
9491 ar , sort : bool = True
95- ) -> tuple [np .ndarray | pd .Index , T_GroupIndices , np .ndarray ]:
92+ ) -> tuple [np .ndarray | pd .Index , np .ndarray ]:
9693 """Group an array by its unique values.
9794
9895 Parameters
@@ -113,11 +110,11 @@ def unique_value_groups(
113110 inverse , values = pd .factorize (ar , sort = sort )
114111 if isinstance (values , pd .MultiIndex ):
115112 values .names = ar .names
116- groups = _codes_to_groups (inverse , len (values ))
117- return values , groups , inverse
113+ return values , inverse
118114
119115
120- def _codes_to_groups (inverse : np .ndarray , N : int ) -> T_GroupIndices :
116+ def _codes_to_group_indices (inverse : np .ndarray , N : int ) -> T_GroupIndices :
117+ assert inverse .ndim == 1
121118 groups : T_GroupIndices = [[] for _ in range (N )]
122119 for n , g in enumerate (inverse ):
123120 if g >= 0 :
@@ -341,16 +338,35 @@ def _apply_loffset(
341338
342339
343340@dataclass
344- class ResolvedGrouper :
341+ class EncodedGroups :
342+ """
343+ Parameters
344+ ----------
345+ codes:
346+ full_index:
347+ group_indices: optional,
348+ Inferred if not provided.
349+ unique_coord:
350+ Inferred if not provided
351+ """
352+
353+ codes : DataArray
354+ full_index : pd .Index
355+ group_indices : T_GroupIndices | None = field (default = None )
356+ unique_coord : IndexVariable | _DummyGroup | None = field (default = None )
357+
358+
359+ @dataclass
360+ class ResolvedGrouper (Generic [T_Xarray ]):
345361 grouper : Grouper
346362 group : T_Group
347363 obj : T_Xarray
348364
349- # Defined by factorize:
365+ # returned by factorize:
350366 codes : DataArray = field (init = False )
367+ full_index : pd .Index = field (init = False )
351368 group_indices : T_GroupIndices = field (init = False )
352369 unique_coord : IndexVariable | _DummyGroup = field (init = False )
353- full_index : pd .Index = field (init = False )
354370
355371 # _ensure_1d:
356372 group1d : T_Group = field (init = False )
@@ -394,20 +410,29 @@ def dims(self):
394410 return self .group1d .dims
395411
396412 def factorize (self ) -> None :
397- # This design makes it clear to mypy that
398- # codes, group_indices, unique_coord, and full_index
399- # are set by the factorize method on the derived class.
400- (
401- self .codes ,
402- self .group_indices ,
403- self .unique_coord ,
404- self .full_index ,
405- ) = self .grouper .factorize (self .group1d )
413+ encoded = self .grouper .factorize (self .group1d )
414+
415+ self .codes = encoded .codes
416+ self .full_index = encoded .full_index
417+
418+ if encoded .group_indices is not None :
419+ self .group_indices = encoded .group_indices
420+ else :
421+ self .group_indices = [
422+ g
423+ for g in _codes_to_group_indices (self .codes .data , len (self .full_index ))
424+ if g
425+ ]
426+ if encoded .unique_coord is None :
427+ # TODO
428+ raise NotImplementedError
429+ else :
430+ self .unique_coord = encoded .unique_coord
406431
407432
408433class Grouper (ABC ):
409434 @abstractmethod
410- def factorize (self , group ) -> T_FactorizeOut :
435+ def factorize (self , group : T_Group ) -> EncodedGroups :
411436 pass
412437
413438
@@ -437,34 +462,33 @@ def can_squeeze(self) -> bool:
437462 is_dimension = self .group .dims == (self .group .name ,)
438463 return is_dimension and self .is_unique_and_monotonic
439464
440- def factorize (self , group1d ) -> T_FactorizeOut :
465+ def factorize (self , group1d ) -> EncodedGroups :
441466 self .group = group1d
442467
443468 if self .can_squeeze :
444469 return self ._factorize_dummy ()
445470 else :
446471 return self ._factorize_unique ()
447472
448- def _factorize_unique (self ) -> T_FactorizeOut :
473+ def _factorize_unique (self ) -> EncodedGroups :
449474 # look through group to find the unique values
450475 sort = not isinstance (self .group_as_index , pd .MultiIndex )
451- unique_values , group_indices , codes_ = unique_value_groups (
452- self .group_as_index , sort = sort
453- )
454- if len (group_indices ) == 0 :
476+ unique_values , codes_ = unique_value_groups (self .group_as_index , sort = sort )
477+ if (codes_ == - 1 ).all ():
455478 raise ValueError (
456479 "Failed to group data. Are you grouping by a variable that is all NaN?"
457480 )
458481 codes = self .group .copy (data = codes_ )
459- group_indices = group_indices
460482 unique_coord = IndexVariable (
461483 self .group .name , unique_values , attrs = self .group .attrs
462484 )
463485 full_index = unique_coord
464486
465- return codes , group_indices , unique_coord , full_index
487+ return EncodedGroups (
488+ codes = codes , full_index = full_index , unique_coord = unique_coord
489+ )
466490
467- def _factorize_dummy (self ) -> T_FactorizeOut :
491+ def _factorize_dummy (self ) -> EncodedGroups :
468492 size = self .group .size
469493 # no need to factorize
470494 # use slices to do views instead of fancy indexing
@@ -479,8 +503,12 @@ def _factorize_dummy(self) -> T_FactorizeOut:
479503 full_index = IndexVariable (
480504 self .group .name , unique_coord .values , self .group .attrs
481505 )
482-
483- return codes , group_indices , unique_coord , full_index
506+ return EncodedGroups (
507+ codes = codes ,
508+ group_indices = group_indices ,
509+ full_index = full_index ,
510+ unique_coord = unique_coord ,
511+ )
484512
485513
486514@dataclass
@@ -494,7 +522,7 @@ def __post_init__(self) -> None:
494522 if duck_array_ops .isnull (self .bins ).all ():
495523 raise ValueError ("All bin edges are NaN." )
496524
497- def factorize (self , group ) -> T_FactorizeOut :
525+ def factorize (self , group ) -> EncodedGroups :
498526 from xarray .core .dataarray import DataArray
499527
500528 data = group .data
@@ -512,11 +540,7 @@ def factorize(self, group) -> T_FactorizeOut:
512540 full_index = binned .categories
513541 uniques = np .sort (pd .unique (binned_codes ))
514542 unique_values = full_index [uniques [uniques != - 1 ]]
515- group_indices = [
516- g for g in _codes_to_groups (binned_codes , len (full_index )) if g
517- ]
518-
519- if len (group_indices ) == 0 :
543+ if (binned_codes == - 1 ).all ():
520544 raise ValueError (
521545 f"None of the data falls within bins with edges { self .bins !r} "
522546 )
@@ -525,7 +549,9 @@ def factorize(self, group) -> T_FactorizeOut:
525549 binned_codes , getattr (group , "coords" , None ), name = new_dim_name
526550 )
527551 unique_coord = IndexVariable (new_dim_name , pd .Index (unique_values ), group .attrs )
528- return codes , group_indices , unique_coord , full_index
552+ return EncodedGroups (
553+ codes = codes , full_index = full_index , unique_coord = unique_coord
554+ )
529555
530556
531557@dataclass
@@ -628,7 +654,7 @@ def first_items(self) -> tuple[pd.Series, np.ndarray]:
628654 _apply_loffset (self .loffset , first_items )
629655 return first_items , codes
630656
631- def factorize (self , group ) -> T_FactorizeOut :
657+ def factorize (self , group ) -> EncodedGroups :
632658 self ._init_properties (group )
633659 full_index , first_items , codes_ = self ._get_index_and_items ()
634660 sbins = first_items .values .astype (np .int64 )
@@ -640,7 +666,12 @@ def factorize(self, group) -> T_FactorizeOut:
640666 unique_coord = IndexVariable (group .name , first_items .index , group .attrs )
641667 codes = group .copy (data = codes_ )
642668
643- return codes , group_indices , unique_coord , full_index
669+ return EncodedGroups (
670+ codes = codes ,
671+ group_indices = group_indices ,
672+ full_index = full_index ,
673+ unique_coord = unique_coord ,
674+ )
644675
645676
646677def _validate_groupby_squeeze (squeeze : bool | None ) -> None :
0 commit comments