Skip to content

Commit 69fcbf8

Browse files
authored
【PIR API adaptor No.250】 python/paddle/geometric/sampling/neighbors.py (#58783)
1 parent d4a612b commit 69fcbf8

File tree

2 files changed

+4
-2
lines changed

2 files changed

+4
-2
lines changed

python/paddle/geometric/sampling/neighbors.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from paddle import _C_ops, _legacy_C_ops
1616
from paddle.base.data_feeder import check_variable_and_dtype
1717
from paddle.base.layer_helper import LayerHelper
18-
from paddle.framework import in_dynamic_mode
18+
from paddle.framework import in_dynamic_mode, in_dynamic_or_pir_mode
1919

2020
__all__ = []
2121

@@ -251,7 +251,7 @@ def weighted_sample_neighbors(
251251
"`eids` should not be None if `return_eids` is True."
252252
)
253253

254-
if in_dynamic_mode():
254+
if in_dynamic_or_pir_mode():
255255
(
256256
out_neighbors,
257257
out_count,

test/legacy_test/test_weighted_sample_neighbors.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import numpy as np
1818

1919
import paddle
20+
from paddle.pir_utils import test_with_pir_api
2021

2122

2223
class TestWeightedSampleNeighbors(unittest.TestCase):
@@ -80,6 +81,7 @@ def test_sample_result(self):
8081
)
8182
self.assertTrue(np.sum(in_neighbors) == in_neighbors.shape[0])
8283

84+
@test_with_pir_api
8385
def test_sample_result_static(self):
8486
paddle.enable_static()
8587
with paddle.static.program_guard(paddle.static.Program()):

0 commit comments

Comments
 (0)