@@ -516,6 +516,117 @@ private void Update(int j, TFloat origVal)
516516 }
517517 }
518518
519+ [ BestFriend ]
520+ internal static partial class MedianAggregatorUtils
521+ {
522+ /// <summary>
523+ /// Based on the algorithm on GeeksForGeeks https://www.geeksforgeeks.org/median-of-stream-of-integers-running-integers/.
524+ /// </summary>
525+ /// <param name="num">The new number to account for in our median calculation.</param>
526+ /// <param name="median">The current median.</param>
527+ /// <param name="belowMedianHeap">The MaxHeap that has all the numbers below the median.</param>
528+ /// <param name="aboveMedianHeap">The MinHeap that has all the numbers above the median.</param>
529+ [ BestFriend ]
530+ internal static void GetMedianSoFar ( in double num , ref double median , ref MaxHeap < double > belowMedianHeap , ref MinHeap < double > aboveMedianHeap )
531+ {
532+ int comparison = belowMedianHeap . Count ( ) . CompareTo ( aboveMedianHeap . Count ( ) ) ;
533+
534+ if ( comparison < 0 )
535+ { // More elements in aboveMedianHeap than belowMedianHeap.
536+ if ( num < median )
537+ { // Current element belongs in the belowMedianHeap.
538+ // Insert new number into belowMedianHeap
539+ belowMedianHeap . Add ( num ) ;
540+
541+ }
542+ else
543+ { // Current element belongs in aboveMedianHeap.
544+ // Need to move one to belowMedianHeap to keep heeps balanced.
545+ belowMedianHeap . Add ( aboveMedianHeap . Pop ( ) ) ;
546+
547+ aboveMedianHeap . Add ( num ) ;
548+ }
549+
550+ // Both heaps are balanced so median is the average of the 2 heaps.
551+ median = ( aboveMedianHeap . Peek ( ) + belowMedianHeap . Peek ( ) ) / 2 ;
552+
553+ }
554+ else if ( comparison == 0 )
555+ { // Both heaps have the same number of elements. Simple put the number where it belongs.
556+ if ( num < median )
557+ { // Current element belongs in the belowMedianHeap.
558+ belowMedianHeap . Add ( num ) ;
559+
560+ // Now we have an odd number of items, median is the new root of the belowMedianHeap
561+ median = belowMedianHeap . Peek ( ) ;
562+
563+ }
564+ else
565+ { // Current element belongs in above median heap.
566+ aboveMedianHeap . Add ( num ) ;
567+
568+ // Now we have an odd number of items, median is the new root of the aboveMedianHeap
569+ median = aboveMedianHeap . Peek ( ) ;
570+ }
571+
572+ }
573+ else
574+ { // More elements in belowMedianHeap than aboveMedianHeap.
575+ if ( num < median )
576+ { // Current element belongs in the belowMedianHeap.
577+ // Need to move one to aboveMedianHeap to keep heeps balanced.
578+ aboveMedianHeap . Add ( belowMedianHeap . Pop ( ) ) ;
579+
580+ // Insert new number into belowMedianHeap
581+ belowMedianHeap . Add ( num ) ;
582+
583+ }
584+ else
585+ { // Current element belongs in aboveMedianHeap.
586+ aboveMedianHeap . Add ( num ) ;
587+ }
588+
589+ // Both heaps are balanced so median is the average of the 2 heaps.
590+ median = ( aboveMedianHeap . Peek ( ) + belowMedianHeap . Peek ( ) ) / 2 ;
591+ }
592+ }
593+ }
594+
595+ /// <summary>
596+ /// Base class for tracking median values for a single valued column.
597+ /// It tracks median values of non-sparse values (vCount).
598+ /// NaNs are ignored when updating min and max.
599+ /// </summary>
600+ internal sealed class MedianDblAggregator : IColumnAggregator < double >
601+ {
602+ private MedianAggregatorUtils . MaxHeap < double > _belowMedianHeap ;
603+ private MedianAggregatorUtils . MinHeap < double > _aboveMedianHeap ;
604+ private double _median ;
605+
606+ public MedianDblAggregator ( int contatinerStartingSize = 1000 )
607+ {
608+ Contracts . Check ( contatinerStartingSize > 0 ) ;
609+ _belowMedianHeap = new MedianAggregatorUtils . MaxHeap < double > ( contatinerStartingSize ) ;
610+ _aboveMedianHeap = new MedianAggregatorUtils . MinHeap < double > ( contatinerStartingSize ) ;
611+ _median = default ;
612+ }
613+
614+ public double Median
615+ {
616+ get { return _median ; }
617+ }
618+
619+ public void ProcessValue ( in double value )
620+ {
621+ MedianAggregatorUtils . GetMedianSoFar ( value , ref _median , ref _belowMedianHeap , ref _aboveMedianHeap ) ;
622+ }
623+
624+ public void Finish ( )
625+ {
626+ // Finish is a no-op because we are updating the median continually as we go
627+ }
628+ }
629+
519630 internal sealed partial class NormalizeTransform
520631 {
521632 internal abstract partial class AffineColumnFunction
@@ -1912,6 +2023,144 @@ public static IColumnFunctionBuilder Create(NormalizingEstimator.SupervisedBinni
19122023 return new SupervisedBinVecColumnFunctionBuilder ( host , lim , fix , numBins , column . MininimumBinSize , valueColumnId , labelColumnId , dataRow ) ;
19132024 }
19142025 }
2026+
2027+ public sealed class RobustScalerOneColumnFunctionBuilder : OneColumnFunctionBuilderBase < double >
2028+ {
2029+ private readonly MinMaxDblAggregator _minMaxAggregator ;
2030+ private readonly MedianDblAggregator _medianAggregator ;
2031+ private readonly bool _centerData ;
2032+ private readonly uint _quantileMin ;
2033+ private readonly uint _quantileMax ;
2034+ private VBuffer < double > _buffer ;
2035+
2036+ private RobustScalerOneColumnFunctionBuilder ( IHost host , long lim , bool centerData , uint quantileMin , uint quantileMax , ValueGetter < double > getSrc )
2037+ : base ( host , lim , getSrc )
2038+ {
2039+ // Using the MinMax aggregator since that is what needs to be found here as well.
2040+ // The difference is how the min/max are used.
2041+ _minMaxAggregator = new MinMaxDblAggregator ( 1 ) ;
2042+ _medianAggregator = new MedianDblAggregator ( ) ;
2043+ _buffer = new VBuffer < double > ( 1 , new double [ 1 ] ) ;
2044+ _centerData = centerData ;
2045+ _quantileMin = quantileMin ;
2046+ _quantileMax = quantileMax ;
2047+ }
2048+
2049+ protected override bool ProcessValue ( in double val )
2050+ {
2051+ if ( ! base . ProcessValue ( in val ) )
2052+ return false ;
2053+ VBufferEditor . CreateFromBuffer ( ref _buffer ) . Values [ 0 ] = val ;
2054+ _minMaxAggregator . ProcessValue ( in _buffer ) ;
2055+ _medianAggregator . ProcessValue ( in val ) ;
2056+ return true ;
2057+ }
2058+
2059+ public static IColumnFunctionBuilder Create ( NormalizingEstimator . RobustScalingColumnOptions column , IHost host , DataViewType srcType ,
2060+ bool centerData , uint quantileMin , uint quantileMax , ValueGetter < double > getter )
2061+ {
2062+ host . CheckUserArg ( column . MaximumExampleCount > 1 , nameof ( column . MaximumExampleCount ) , "Must be greater than 1" ) ;
2063+ return new RobustScalerOneColumnFunctionBuilder ( host , column . MaximumExampleCount , centerData , quantileMin , quantileMax , getter ) ;
2064+ }
2065+
2066+ public override IColumnFunction CreateColumnFunction ( )
2067+ {
2068+ _minMaxAggregator . Finish ( ) ;
2069+ _medianAggregator . Finish ( ) ;
2070+
2071+ double median = _medianAggregator . Median ;
2072+ double range = _minMaxAggregator . Max [ 0 ] - _minMaxAggregator . Min [ 0 ] ;
2073+ // Divide the range by 100 because we need to make the number, i.e. 75, into a decimal, .75
2074+ double quantileRange = ( _quantileMax - _quantileMin ) / 100f ;
2075+ double scale = 1 / ( range * quantileRange ) ;
2076+
2077+ if ( _centerData )
2078+ return AffineColumnFunction . Create ( Host , scale , median ) ;
2079+ else
2080+ return AffineColumnFunction . Create ( Host , scale , 0 ) ;
2081+ }
2082+ }
2083+
2084+ public sealed class RobustScalerVecFunctionBuilder : OneColumnFunctionBuilderBase < VBuffer < double > >
2085+ {
2086+ private readonly MinMaxDblAggregator _minMaxAggregator ;
2087+ private readonly MedianDblAggregator [ ] _medianAggregators ;
2088+ private readonly bool _centerData ;
2089+ private readonly uint _quantileMin ;
2090+ private readonly uint _quantileMax ;
2091+
2092+ private RobustScalerVecFunctionBuilder ( IHost host , long lim , int vectorSize , bool centerData , uint quantileMin , uint quantileMax , ValueGetter < VBuffer < double > > getSrc )
2093+ : base ( host , lim , getSrc )
2094+ {
2095+ // Using the MinMax aggregator since that is what needs to be found here as well.
2096+ // The difference is how the min/max are used.
2097+ _minMaxAggregator = new MinMaxDblAggregator ( vectorSize ) ;
2098+
2099+ // If we aren't centering data dont need the median.
2100+ _medianAggregators = new MedianDblAggregator [ vectorSize ] ;
2101+
2102+ for ( int i = 0 ; i < vectorSize ; i ++ )
2103+ {
2104+ _medianAggregators [ i ] = new MedianDblAggregator ( ) ;
2105+ }
2106+
2107+ _centerData = centerData ;
2108+ _quantileMin = quantileMin ;
2109+ _quantileMax = quantileMax ;
2110+ }
2111+
2112+ protected override bool ProcessValue ( in VBuffer < double > val )
2113+ {
2114+ if ( ! base . ProcessValue ( in val ) )
2115+ return false ;
2116+ _minMaxAggregator . ProcessValue ( in val ) ;
2117+
2118+ // Have to calculate the median per slot
2119+ var span = val . GetValues ( ) ;
2120+ for ( int i = 0 ; i < _medianAggregators . Length ; i ++ )
2121+ {
2122+ _medianAggregators [ i ] . ProcessValue ( span [ i ] ) ;
2123+ }
2124+
2125+ return true ;
2126+ }
2127+
2128+ public static IColumnFunctionBuilder Create ( NormalizingEstimator . RobustScalingColumnOptions column , IHost host , VectorDataViewType srcType ,
2129+ bool centerData , uint quantileMin , uint quantileMax , ValueGetter < VBuffer < double > > getter )
2130+ {
2131+ host . CheckUserArg ( column . MaximumExampleCount > 1 , nameof ( column . MaximumExampleCount ) , "Must be greater than 1" ) ;
2132+ var vectorSize = srcType . Size ;
2133+ return new RobustScalerVecFunctionBuilder ( host , column . MaximumExampleCount , vectorSize , centerData , quantileMin , quantileMax , getter ) ;
2134+ }
2135+
2136+ public override IColumnFunction CreateColumnFunction ( )
2137+ {
2138+ _minMaxAggregator . Finish ( ) ;
2139+
2140+ double [ ] scale = new double [ _medianAggregators . Length ] ;
2141+ double [ ] median = new double [ _medianAggregators . Length ] ;
2142+
2143+ // Have to calculate the median per slot
2144+ for ( int i = 0 ; i < _medianAggregators . Length ; i ++ )
2145+ {
2146+ _medianAggregators [ i ] . Finish ( ) ;
2147+ median [ i ] = _medianAggregators [ i ] . Median ;
2148+
2149+ double range = _minMaxAggregator . Max [ i ] - _minMaxAggregator . Min [ i ] ;
2150+
2151+ // Divide the range by 100 because we need to make the number, i.e. 75, into a decimal, .75
2152+ double quantileRange = ( _quantileMax - _quantileMin ) / 100f ;
2153+ scale [ i ] = 1 / ( range * quantileRange ) ;
2154+
2155+ }
2156+
2157+ if ( _centerData )
2158+ return AffineColumnFunction . Create ( Host , scale , median , null ) ;
2159+ else
2160+ return AffineColumnFunction . Create ( Host , scale , null , null ) ;
2161+
2162+ }
2163+ }
19152164 }
19162165 }
19172166}
0 commit comments