1+ /* * 快速取模运算类(MontgomeryModInt32 蒙哥马利模乘)
2+ * 2023-08-11: https://ac.nowcoder.com/acm/contest/view-submission?submissionId=63381475&returnHomeType=1&uid=815516497
3+ * * 感谢菜菜园子群友提供
4+ **/
5+ template <std::uint32_t P> struct MontgomeryModInt32 {
6+ public:
7+ using i32 = std::int32_t ;
8+ using u32 = std::uint32_t ;
9+ using i64 = std::int64_t ;
10+ using u64 = std::uint64_t ;
11+
12+ private:
13+ u32 v;
14+
15+ static constexpr u32 get_r () {
16+ u32 iv = P;
17+
18+ for (u32 i = 0 ; i != 4 ; ++i)
19+ iv *= 2U - P * iv;
20+
21+ return -iv;
22+ }
23+
24+ static constexpr u32 r = get_r(), r2 = -u64 (P) % P;
25+
26+ static_assert ((P & 1 ) == 1 );
27+ static_assert (-r * P == 1 );
28+ static_assert (P < (1 << 30 ));
29+
30+ public:
31+ static constexpr u32 pow_mod (u32 x, u64 y) {
32+ if ((y %= P - 1 ) < 0 )
33+ y += P - 1 ;
34+
35+ u32 res = 1 ;
36+
37+ for (; y != 0 ; y >>= 1 , x = u64 (x) * x % P)
38+ if (y & 1 )
39+ res = u64 (res) * x % P;
40+
41+ return res;
42+ }
43+
44+ static constexpr u32 get_pr () {
45+ u32 tmp[32 ] = {}, cnt = 0 ;
46+ const u64 phi = P - 1 ;
47+ u64 m = phi;
48+
49+ for (u64 i = 2 ; i * i <= m; ++i) {
50+ if (m % i == 0 ) {
51+ tmp[cnt++] = i;
52+
53+ while (m % i == 0 )
54+ m /= i;
55+ }
56+ }
57+
58+ if (m > 1 )
59+ tmp[cnt++] = m;
60+
61+ for (u64 res = 2 ; res <= phi; ++res) {
62+ bool flag = true ;
63+
64+ for (u32 i = 0 ; i != cnt && flag; ++i)
65+ flag &= pow_mod (res, phi / tmp[i]) != 1 ;
66+
67+ if (flag)
68+ return res;
69+ }
70+
71+ return 0 ;
72+ }
73+
74+ MontgomeryModInt32 () = default ;
75+ ~MontgomeryModInt32 () = default ;
76+ constexpr MontgomeryModInt32 (u32 v) : v(reduce(u64 (v) * r2)) {}
77+ constexpr MontgomeryModInt32 (const MontgomeryModInt32 &rhs) : v(rhs.v) {}
78+ static constexpr u32 reduce (u64 x) {
79+ return x + (u64 (u32 (x) * r) * P) >> 32 ;
80+ }
81+ static constexpr u32 norm (u32 x) {
82+ return x - (P & -(x >= P));
83+ }
84+ constexpr u32 get () const {
85+ u32 res = reduce (v) - P;
86+ return res + (P & -(res >> 31 ));
87+ }
88+ explicit constexpr operator u32 () const {
89+ return get ();
90+ }
91+ explicit constexpr operator i32 () const {
92+ return i32 (get ());
93+ }
94+ constexpr MontgomeryModInt32 &operator =(const MontgomeryModInt32 &rhs) {
95+ return v = rhs.v , *this ;
96+ }
97+ constexpr MontgomeryModInt32 operator -() const {
98+ MontgomeryModInt32 res;
99+ return res.v = (P << 1 & -(v != 0 )) - v, res;
100+ }
101+ constexpr MontgomeryModInt32 inv () const {
102+ return pow (-1 );
103+ }
104+ constexpr MontgomeryModInt32 &operator +=(const MontgomeryModInt32 &rhs) {
105+ return v += rhs.v - (P << 1 ), v += P << 1 & -(v >> 31 ), *this ;
106+ }
107+ constexpr MontgomeryModInt32 &operator -=(const MontgomeryModInt32 &rhs) {
108+ return v -= rhs.v , v += P << 1 & -(v >> 31 ), *this ;
109+ }
110+ constexpr MontgomeryModInt32 &operator *=(const MontgomeryModInt32 &rhs) {
111+ return v = reduce (u64 (v) * rhs.v ), *this ;
112+ }
113+ constexpr MontgomeryModInt32 &operator /=(const MontgomeryModInt32 &rhs) {
114+ return this ->operator *=(rhs.inv ());
115+ }
116+ friend MontgomeryModInt32 operator +(const MontgomeryModInt32 &lhs,
117+ const MontgomeryModInt32 &rhs) {
118+ return MontgomeryModInt32 (lhs) += rhs;
119+ }
120+ friend MontgomeryModInt32 operator -(const MontgomeryModInt32 &lhs,
121+ const MontgomeryModInt32 &rhs) {
122+ return MontgomeryModInt32 (lhs) -= rhs;
123+ }
124+ friend MontgomeryModInt32 operator *(const MontgomeryModInt32 &lhs,
125+ const MontgomeryModInt32 &rhs) {
126+ return MontgomeryModInt32 (lhs) *= rhs;
127+ }
128+ friend MontgomeryModInt32 operator /(const MontgomeryModInt32 &lhs,
129+ const MontgomeryModInt32 &rhs) {
130+ return MontgomeryModInt32 (lhs) /= rhs;
131+ }
132+ friend bool operator ==(const MontgomeryModInt32 &lhs, const MontgomeryModInt32 &rhs) {
133+ return norm (lhs.v ) == norm (rhs.v );
134+ }
135+ friend bool operator !=(const MontgomeryModInt32 &lhs, const MontgomeryModInt32 &rhs) {
136+ return norm (lhs.v ) != norm (rhs.v );
137+ }
138+ friend std::istream &operator >>(std::istream &is, MontgomeryModInt32 &rhs) {
139+ return is >> rhs.v , rhs.v = reduce (u64 (rhs.v ) * r2), is;
140+ }
141+ friend std::ostream &operator <<(std::ostream &os, const MontgomeryModInt32 &rhs) {
142+ return os << rhs.get ();
143+ }
144+ constexpr MontgomeryModInt32 pow (i64 y) const {
145+ if ((y %= P - 1 ) < 0 )
146+ y += P - 1 ; // phi(P) = P - 1, assume P is a prime number
147+
148+ MontgomeryModInt32 res (1 ), x (*this );
149+
150+ for (; y != 0 ; y >>= 1 , x *= x)
151+ if (y & 1 )
152+ res *= x;
153+
154+ return res;
155+ }
156+ };
157+
158+ template <std::uint32_t P> MontgomeryModInt32<P> sqrt (const MontgomeryModInt32<P> &x) {
159+ using value_type = MontgomeryModInt32<P>;
160+ static constexpr value_type negtive_one (P - 1 ), ZERO (0 );
161+
162+ if (x == ZERO || x.pow (P - 1 >> 1 ) == negtive_one)
163+ return ZERO;
164+
165+ if ((P & 3 ) == 3 )
166+ return x.pow (P + 1 >> 2 );
167+
168+ static value_type w2, ax;
169+ ax = x;
170+ static std::random_device rd;
171+ static std::mt19937 gen (rd ());
172+ static std::uniform_int_distribution<std::uint32_t > dis (1 , P - 1 );
173+ const value_type four (value_type (4 ) * x);
174+ static value_type t;
175+
176+ do
177+ t = value_type (dis (gen)), w2 = t * t - four;
178+
179+ while (w2.pow (P - 1 >> 1 ) != negtive_one);
180+
181+ struct Field_P2 { // (A + Bx)(C+Dx)=(AC-BDa)+(AD+BC+BDt)x
182+ public:
183+ value_type a, b;
184+ Field_P2 (const value_type &a, const value_type &b) : a(a), b(b) {}
185+ ~Field_P2 () = default ;
186+ Field_P2 &operator *=(const Field_P2 &rhs) {
187+ value_type tmp1 (b * rhs.b ), tmp2 (a * rhs.a - tmp1 * ax),
188+ tmp3 (a * rhs.b + b * rhs.a + tmp1 * t);
189+ return a = tmp2, b = tmp3, *this ;
190+ }
191+ Field_P2 pow (std::uint64_t y) const {
192+ Field_P2 res (value_type (1 ), ZERO), x (*this );
193+
194+ for (; y != 0 ; y >>= 1 , x *= x)
195+ if (y & 1 )
196+ res *= x;
197+
198+ return res;
199+ }
200+ } res (ZERO, value_type (1 ));
201+ return res.pow (P + 1 >> 1 ).a ;
202+ }
203+
204+ std::uint64_t get_len (std::uint64_t n) { // if n=0, boom
205+ return --n, n |= n >> 1 , n |= n >> 2 , n |= n >> 4 , n |= n >> 8 , n |= n >> 16 , n |= n >> 32 , ++n;
206+ }
0 commit comments