11//! [`Uint`] square root operations.
22
33use super :: Uint ;
4- use crate :: CtChoice ;
54use subtle:: { ConstantTimeEq , CtOption } ;
65
76impl < const LIMBS : usize > Uint < LIMBS > {
8- /// See [`Self::sqrt_vartime`].
9- #[ deprecated(
10- since = "0.5.3" ,
11- note = "This functionality will be moved to `sqrt_vartime` in a future release."
12- ) ]
7+ /// Computes √(`self`) in constant time.
8+ /// Based on Brent & Zimmermann, Modern Computer Arithmetic, v0.5.9, Algorithm 1.13
9+ ///
10+ /// Callers can check if `self` is a square by squaring the result
1311 pub const fn sqrt ( & self ) -> Self {
14- self . sqrt_vartime ( )
12+ let max_bits = ( self . bits ( ) + 1 ) >> 1 ;
13+ let cap = Self :: ONE . shl ( max_bits) ;
14+ let mut guess = cap; // ≥ √(`self`)
15+ let mut xn = {
16+ let q = self . wrapping_div ( & guess) ;
17+ let t = guess. wrapping_add ( & q) ;
18+ t. shr_vartime ( 1 )
19+ } ;
20+
21+ // Repeat enough times to guarantee result has stabilized.
22+ // See Hast, "Note on computation of integer square roots" for a proof of this bound.
23+ // https://github.com/RustCrypto/crypto-bigint/files/12600669/ct_sqrt.pdf
24+ let mut i = 0 ;
25+ while i < Self :: LOG2_BITS {
26+ guess = xn;
27+ xn = {
28+ let ( q, _, is_some) = self . const_div_rem ( & guess) ;
29+ let q = Self :: ct_select ( & Self :: ZERO , & q, is_some) ;
30+ let t = guess. wrapping_add ( & q) ;
31+ t. shr_vartime ( 1 )
32+ } ;
33+ i += 1 ;
34+ }
35+
36+ // at least one of `guess` and `xn` is now equal to √(`self`), so return the minimum
37+ Self :: ct_select ( & guess, & xn, Uint :: ct_gt ( & guess, & xn) )
1538 }
1639
1740 /// Computes √(`self`)
@@ -23,62 +46,49 @@ impl<const LIMBS: usize> Uint<LIMBS> {
2346 let cap = Self :: ONE . shl_vartime ( max_bits) ;
2447 let mut guess = cap; // ≥ √(`self`)
2548 let mut xn = {
26- let q = self . wrapping_div ( & guess) ;
49+ let q = self . wrapping_div_vartime ( & guess) ;
2750 let t = guess. wrapping_add ( & q) ;
2851 t. shr_vartime ( 1 )
2952 } ;
30-
31- // If guess increased, the initial guess was low.
32- // Repeat until reverse course.
33- while Uint :: ct_lt ( & guess, & xn) . is_true_vartime ( ) {
34- // Sometimes an increase is too far, especially with large
35- // powers, and then takes a long time to walk back. The upper
36- // bound is based on bit size, so saturate on that.
37- let le = CtChoice :: from_u32_le ( xn. bits_vartime ( ) , max_bits) ;
38- guess = Self :: ct_select ( & cap, & xn, le) ;
39- xn = {
40- let q = self . wrapping_div ( & guess) ;
41- let t = guess. wrapping_add ( & q) ;
42- t. shr_vartime ( 1 )
43- } ;
44- }
53+ // Note, xn <= guess at this point.
4554
4655 // Repeat while guess decreases.
47- while Uint :: ct_gt ( & guess , & xn) . is_true_vartime ( ) && xn. ct_is_nonzero ( ) . is_true_vartime ( ) {
56+ while guess . cmp_vartime ( & xn) . is_gt ( ) && ! xn. cmp_vartime ( & Self :: ZERO ) . is_eq ( ) {
4857 guess = xn;
4958 xn = {
50- let q = self . wrapping_div ( & guess) ;
59+ let q = self . wrapping_div_vartime ( & guess) ;
5160 let t = guess. wrapping_add ( & q) ;
5261 t. shr_vartime ( 1 )
5362 } ;
5463 }
5564
56- Self :: ct_select ( & Self :: ZERO , & guess, self . ct_is_nonzero ( ) )
65+ if self . ct_is_nonzero ( ) . is_true_vartime ( ) {
66+ guess
67+ } else {
68+ Self :: ZERO
69+ }
5770 }
5871
59- /// See [`Self::wrapping_sqrt_vartime`].
60- #[ deprecated(
61- since = "0.5.3" ,
62- note = "This functionality will be moved to `wrapping_sqrt_vartime` in a future release."
63- ) ]
72+ /// Wrapped sqrt is just normal √(`self`)
73+ /// There’s no way wrapping could ever happen.
74+ /// This function exists so that all operations are accounted for in the wrapping operations.
6475 pub const fn wrapping_sqrt ( & self ) -> Self {
65- self . wrapping_sqrt_vartime ( )
76+ self . sqrt ( )
6677 }
6778
6879 /// Wrapped sqrt is just normal √(`self`)
6980 /// There’s no way wrapping could ever happen.
70- /// This function exists, so that all operations are accounted for in the wrapping operations.
81+ /// This function exists so that all operations are accounted for in the wrapping operations.
7182 pub const fn wrapping_sqrt_vartime ( & self ) -> Self {
7283 self . sqrt_vartime ( )
7384 }
7485
75- /// See [`Self::checked_sqrt_vartime`].
76- #[ deprecated(
77- since = "0.5.3" ,
78- note = "This functionality will be moved to `checked_sqrt_vartime` in a future release."
79- ) ]
86+ /// Perform checked sqrt, returning a [`CtOption`] which `is_some`
87+ /// only if the √(`self`)² == self
8088 pub fn checked_sqrt ( & self ) -> CtOption < Self > {
81- self . checked_sqrt_vartime ( )
89+ let r = self . sqrt ( ) ;
90+ let s = r. wrapping_mul ( & r) ;
91+ CtOption :: new ( r, ConstantTimeEq :: ct_eq ( self , & s) )
8292 }
8393
8494 /// Perform checked sqrt, returning a [`CtOption`] which `is_some`
@@ -92,7 +102,7 @@ impl<const LIMBS: usize> Uint<LIMBS> {
92102
93103#[ cfg( test) ]
94104mod tests {
95- use crate :: { Limb , U256 } ;
105+ use crate :: { Limb , U192 , U256 } ;
96106
97107 #[ cfg( feature = "rand" ) ]
98108 use {
@@ -103,13 +113,35 @@ mod tests {
103113
104114 #[ test]
105115 fn edge ( ) {
116+ assert_eq ! ( U256 :: ZERO . sqrt( ) , U256 :: ZERO ) ;
117+ assert_eq ! ( U256 :: ONE . sqrt( ) , U256 :: ONE ) ;
118+ let mut half = U256 :: ZERO ;
119+ for i in 0 ..half. limbs . len ( ) / 2 {
120+ half. limbs [ i] = Limb :: MAX ;
121+ }
122+ assert_eq ! ( U256 :: MAX . sqrt( ) , half) ;
123+
124+ assert_eq ! (
125+ U192 :: from_be_hex( "055fa39422bd9f281762946e056535badbf8a6864d45fa3d" ) . sqrt( ) ,
126+ U192 :: from_be_hex( "0000000000000000000000002516f0832a538b2d98869e21" )
127+ ) ;
128+
129+ assert_eq ! (
130+ U256 :: from_be_hex( "4bb750738e25a8f82940737d94a48a91f8cd918a3679ff90c1a631f2bd6c3597" )
131+ . sqrt( ) ,
132+ U256 :: from_be_hex( "000000000000000000000000000000008b3956339e8315cff66eb6107b610075" )
133+ ) ;
134+ }
135+
136+ #[ test]
137+ fn edge_vartime ( ) {
106138 assert_eq ! ( U256 :: ZERO . sqrt_vartime( ) , U256 :: ZERO ) ;
107139 assert_eq ! ( U256 :: ONE . sqrt_vartime( ) , U256 :: ONE ) ;
108140 let mut half = U256 :: ZERO ;
109141 for i in 0 ..half. limbs . len ( ) / 2 {
110142 half. limbs [ i] = Limb :: MAX ;
111143 }
112- assert_eq ! ( U256 :: MAX . sqrt_vartime( ) , half, ) ;
144+ assert_eq ! ( U256 :: MAX . sqrt_vartime( ) , half) ;
113145 }
114146
115147 #[ test]
@@ -131,13 +163,28 @@ mod tests {
131163 for ( a, e) in & tests {
132164 let l = U256 :: from ( * a) ;
133165 let r = U256 :: from ( * e) ;
166+ assert_eq ! ( l. sqrt( ) , r) ;
134167 assert_eq ! ( l. sqrt_vartime( ) , r) ;
168+ assert_eq ! ( l. checked_sqrt( ) . is_some( ) . unwrap_u8( ) , 1u8 ) ;
135169 assert_eq ! ( l. checked_sqrt_vartime( ) . is_some( ) . unwrap_u8( ) , 1u8 ) ;
136170 }
137171 }
138172
139173 #[ test]
140174 fn nonsquares ( ) {
175+ assert_eq ! ( U256 :: from( 2u8 ) . sqrt( ) , U256 :: from( 1u8 ) ) ;
176+ assert_eq ! ( U256 :: from( 2u8 ) . checked_sqrt( ) . is_some( ) . unwrap_u8( ) , 0 ) ;
177+ assert_eq ! ( U256 :: from( 3u8 ) . sqrt( ) , U256 :: from( 1u8 ) ) ;
178+ assert_eq ! ( U256 :: from( 3u8 ) . checked_sqrt( ) . is_some( ) . unwrap_u8( ) , 0 ) ;
179+ assert_eq ! ( U256 :: from( 5u8 ) . sqrt( ) , U256 :: from( 2u8 ) ) ;
180+ assert_eq ! ( U256 :: from( 6u8 ) . sqrt( ) , U256 :: from( 2u8 ) ) ;
181+ assert_eq ! ( U256 :: from( 7u8 ) . sqrt( ) , U256 :: from( 2u8 ) ) ;
182+ assert_eq ! ( U256 :: from( 8u8 ) . sqrt( ) , U256 :: from( 2u8 ) ) ;
183+ assert_eq ! ( U256 :: from( 10u8 ) . sqrt( ) , U256 :: from( 3u8 ) ) ;
184+ }
185+
186+ #[ test]
187+ fn nonsquares_vartime ( ) {
141188 assert_eq ! ( U256 :: from( 2u8 ) . sqrt_vartime( ) , U256 :: from( 1u8 ) ) ;
142189 assert_eq ! (
143190 U256 :: from( 2u8 ) . checked_sqrt_vartime( ) . is_some( ) . unwrap_u8( ) ,
@@ -163,14 +210,17 @@ mod tests {
163210 let t = rng. next_u32 ( ) as u64 ;
164211 let s = U256 :: from ( t) ;
165212 let s2 = s. checked_mul ( & s) . unwrap ( ) ;
213+ assert_eq ! ( s2. sqrt( ) , s) ;
166214 assert_eq ! ( s2. sqrt_vartime( ) , s) ;
215+ assert_eq ! ( s2. checked_sqrt( ) . is_some( ) . unwrap_u8( ) , 1 ) ;
167216 assert_eq ! ( s2. checked_sqrt_vartime( ) . is_some( ) . unwrap_u8( ) , 1 ) ;
168217 }
169218
170219 for _ in 0 ..50 {
171220 let s = U256 :: random ( & mut rng) ;
172221 let mut s2 = U512 :: ZERO ;
173222 s2. limbs [ ..s. limbs . len ( ) ] . copy_from_slice ( & s. limbs ) ;
223+ assert_eq ! ( s. square( ) . sqrt( ) , s2) ;
174224 assert_eq ! ( s. square( ) . sqrt_vartime( ) , s2) ;
175225 }
176226 }
0 commit comments