@@ -133,5 +133,107 @@ def test_niter_range():
133133 self .assertRaises (ValueError , test_niter_range )
134134
135135
136+ class TestStaticPcaLowrankAPI (unittest .TestCase ):
137+ def transpose (self , x ):
138+ shape = x .shape
139+ perm = list (range (0 , len (shape )))
140+ perm = perm [:- 2 ] + [perm [- 1 ]] + [perm [- 2 ]]
141+ return paddle .transpose (x , perm )
142+
143+ def random_matrix (self , rows , columns , * batch_dims , ** kwargs ):
144+ dtype = kwargs .get ('dtype' , 'float64' )
145+
146+ x = paddle .randn (batch_dims + (rows , columns ), dtype = dtype )
147+ u , _ , vh = paddle .linalg .svd (x , full_matrices = False )
148+ k = min (rows , columns )
149+ s = paddle .linspace (1 / (k + 1 ), 1 , k , dtype = dtype )
150+ return (u * s .unsqueeze (- 2 )) @ vh
151+
152+ def random_lowrank_matrix (self , rank , rows , columns , * batch_dims , ** kwargs ):
153+ B = self .random_matrix (rows , rank , * batch_dims , ** kwargs )
154+ C = self .random_matrix (rank , columns , * batch_dims , ** kwargs )
155+ return B .matmul (C )
156+
157+ def run_subtest (
158+ self , guess_rank , actual_rank , matrix_size , batches , pca , ** options
159+ ):
160+ main = paddle .static .Program ()
161+ startup = paddle .static .Program ()
162+ with paddle .static .program_guard (main , startup ):
163+ if isinstance (matrix_size , int ):
164+ rows = columns = matrix_size
165+ else :
166+ rows , columns = matrix_size
167+ a_input = self .random_lowrank_matrix (
168+ actual_rank , rows , columns , * batches
169+ )
170+ a = a_input
171+
172+ u , s , v = pca (a_input , q = guess_rank , ** options )
173+
174+ self .assertEqual (s .shape [- 1 ], guess_rank )
175+ self .assertEqual (u .shape [- 2 ], rows )
176+ self .assertEqual (u .shape [- 1 ], guess_rank )
177+ self .assertEqual (v .shape [- 1 ], guess_rank )
178+ self .assertEqual (v .shape [- 2 ], columns )
179+
180+ A1 = u .matmul (paddle .diag_embed (s )).matmul (self .transpose (v ))
181+ ones_m1 = paddle .ones (batches + (rows , 1 ), dtype = a .dtype )
182+ c = a .sum (axis = - 2 ) / rows
183+ c = c .reshape (batches + (1 , columns ))
184+ A2 = a - ones_m1 .matmul (c )
185+ detect_rank = (s .abs () > 1e-5 ).sum (axis = - 1 )
186+ left1 = actual_rank * paddle .ones (batches , dtype = paddle .int64 )
187+ S = paddle .linalg .svd (A2 , full_matrices = False )[1 ]
188+ left2 = s [..., :actual_rank ]
189+ right = S [..., :actual_rank ]
190+
191+ exe = paddle .static .Executor ()
192+ exe .run (startup )
193+ A1 , A2 , left1 , detect_rank , left2 , right = exe .run (
194+ main ,
195+ feed = {},
196+ fetch_list = [A1 , A2 , left1 , detect_rank , left2 , right ],
197+ )
198+
199+ np .testing .assert_allclose (A1 , A2 , atol = 1e-5 )
200+ if not left1 .shape :
201+ np .testing .assert_allclose (int (left1 ), int (detect_rank ))
202+ else :
203+ np .testing .assert_allclose (left1 , detect_rank )
204+ np .testing .assert_allclose (left2 , right )
205+
206+ def test_forward (self ):
207+ with paddle .pir_utils .IrGuard ():
208+ pca_lowrank = paddle .linalg .pca_lowrank
209+ all_batches = [(), (1 ,), (3 ,), (2 , 3 )]
210+ for actual_rank , size in [
211+ (2 , (17 , 4 )),
212+ (2 , (100 , 4 )),
213+ (6 , (100 , 40 )),
214+ ]:
215+ for batches in all_batches :
216+ for guess_rank in [
217+ actual_rank ,
218+ actual_rank + 2 ,
219+ actual_rank + 6 ,
220+ ]:
221+ if guess_rank <= min (* size ):
222+ self .run_subtest (
223+ guess_rank ,
224+ actual_rank ,
225+ size ,
226+ batches ,
227+ pca_lowrank ,
228+ )
229+ self .run_subtest (
230+ guess_rank ,
231+ actual_rank ,
232+ size [::- 1 ],
233+ batches ,
234+ pca_lowrank ,
235+ )
236+
237+
136238if __name__ == "__main__" :
137239 unittest .main ()
0 commit comments