Skip to content

Commit 2dfa0f7

Browse files
authored
【PIR API adaptor No.289】Migrate pca_lowrank to pir (#60320)
1 parent aec353c commit 2dfa0f7

2 files changed

Lines changed: 103 additions & 1 deletion

File tree

python/paddle/tensor/random.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -429,7 +429,7 @@ def gaussian(shape, mean=0.0, std=1.0, seed=0, dtype=None, name=None):
429429
op_type_for_check, supported_dtypes, dtype
430430
)
431431
)
432-
if not isinstance(dtype, core.VarDesc.VarType):
432+
if not isinstance(dtype, (core.VarDesc.VarType, core.DataType)):
433433
dtype = convert_np_dtype_to_dtype_(dtype)
434434

435435
if in_dynamic_or_pir_mode():

test/legacy_test/test_pca_lowrank.py

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,5 +133,107 @@ def test_niter_range():
133133
self.assertRaises(ValueError, test_niter_range)
134134

135135

136+
class TestStaticPcaLowrankAPI(unittest.TestCase):
137+
def transpose(self, x):
138+
shape = x.shape
139+
perm = list(range(0, len(shape)))
140+
perm = perm[:-2] + [perm[-1]] + [perm[-2]]
141+
return paddle.transpose(x, perm)
142+
143+
def random_matrix(self, rows, columns, *batch_dims, **kwargs):
144+
dtype = kwargs.get('dtype', 'float64')
145+
146+
x = paddle.randn(batch_dims + (rows, columns), dtype=dtype)
147+
u, _, vh = paddle.linalg.svd(x, full_matrices=False)
148+
k = min(rows, columns)
149+
s = paddle.linspace(1 / (k + 1), 1, k, dtype=dtype)
150+
return (u * s.unsqueeze(-2)) @ vh
151+
152+
def random_lowrank_matrix(self, rank, rows, columns, *batch_dims, **kwargs):
153+
B = self.random_matrix(rows, rank, *batch_dims, **kwargs)
154+
C = self.random_matrix(rank, columns, *batch_dims, **kwargs)
155+
return B.matmul(C)
156+
157+
def run_subtest(
158+
self, guess_rank, actual_rank, matrix_size, batches, pca, **options
159+
):
160+
main = paddle.static.Program()
161+
startup = paddle.static.Program()
162+
with paddle.static.program_guard(main, startup):
163+
if isinstance(matrix_size, int):
164+
rows = columns = matrix_size
165+
else:
166+
rows, columns = matrix_size
167+
a_input = self.random_lowrank_matrix(
168+
actual_rank, rows, columns, *batches
169+
)
170+
a = a_input
171+
172+
u, s, v = pca(a_input, q=guess_rank, **options)
173+
174+
self.assertEqual(s.shape[-1], guess_rank)
175+
self.assertEqual(u.shape[-2], rows)
176+
self.assertEqual(u.shape[-1], guess_rank)
177+
self.assertEqual(v.shape[-1], guess_rank)
178+
self.assertEqual(v.shape[-2], columns)
179+
180+
A1 = u.matmul(paddle.diag_embed(s)).matmul(self.transpose(v))
181+
ones_m1 = paddle.ones(batches + (rows, 1), dtype=a.dtype)
182+
c = a.sum(axis=-2) / rows
183+
c = c.reshape(batches + (1, columns))
184+
A2 = a - ones_m1.matmul(c)
185+
detect_rank = (s.abs() > 1e-5).sum(axis=-1)
186+
left1 = actual_rank * paddle.ones(batches, dtype=paddle.int64)
187+
S = paddle.linalg.svd(A2, full_matrices=False)[1]
188+
left2 = s[..., :actual_rank]
189+
right = S[..., :actual_rank]
190+
191+
exe = paddle.static.Executor()
192+
exe.run(startup)
193+
A1, A2, left1, detect_rank, left2, right = exe.run(
194+
main,
195+
feed={},
196+
fetch_list=[A1, A2, left1, detect_rank, left2, right],
197+
)
198+
199+
np.testing.assert_allclose(A1, A2, atol=1e-5)
200+
if not left1.shape:
201+
np.testing.assert_allclose(int(left1), int(detect_rank))
202+
else:
203+
np.testing.assert_allclose(left1, detect_rank)
204+
np.testing.assert_allclose(left2, right)
205+
206+
def test_forward(self):
207+
with paddle.pir_utils.IrGuard():
208+
pca_lowrank = paddle.linalg.pca_lowrank
209+
all_batches = [(), (1,), (3,), (2, 3)]
210+
for actual_rank, size in [
211+
(2, (17, 4)),
212+
(2, (100, 4)),
213+
(6, (100, 40)),
214+
]:
215+
for batches in all_batches:
216+
for guess_rank in [
217+
actual_rank,
218+
actual_rank + 2,
219+
actual_rank + 6,
220+
]:
221+
if guess_rank <= min(*size):
222+
self.run_subtest(
223+
guess_rank,
224+
actual_rank,
225+
size,
226+
batches,
227+
pca_lowrank,
228+
)
229+
self.run_subtest(
230+
guess_rank,
231+
actual_rank,
232+
size[::-1],
233+
batches,
234+
pca_lowrank,
235+
)
236+
237+
136238
if __name__ == "__main__":
137239
unittest.main()

0 commit comments

Comments
 (0)