diff --git a/numba_cuda/numba/cuda/tests/nrt/test_nrt.py b/numba_cuda/numba/cuda/tests/nrt/test_nrt.py index 38d67f7ef..a621fe625 100644 --- a/numba_cuda/numba/cuda/tests/nrt/test_nrt.py +++ b/numba_cuda/numba/cuda/tests/nrt/test_nrt.py @@ -226,6 +226,37 @@ def test_nrt_explicit_stats_query_raises_exception_when_disabled(self): stats_func() self.assertIn("NRT stats are disabled.", str(raises.exception)) + def test_read_one_stat(self): + @cuda.jit + def foo(): + tmp = np.ones(3) + arr = np.arange(5 * tmp[0]) # noqa: F841 + return None + + with ( + override_config("CUDA_ENABLE_NRT", True), + override_config("CUDA_NRT_STATS", True), + ): + # Switch on stats + rtsys.memsys_enable_stats() + + # Launch the kernel a couple of times to increase stats + foo[1, 1]() + foo[1, 1]() + + # Get stats struct and individual stats + stats = rtsys.get_allocation_stats() + stats_alloc = rtsys.memsys_get_stats_alloc() + stats_mi_alloc = rtsys.memsys_get_stats_mi_alloc() + stats_free = rtsys.memsys_get_stats_free() + stats_mi_free = rtsys.memsys_get_stats_mi_free() + + # Check individual stats match stats struct + self.assertEqual(stats.alloc, stats_alloc) + self.assertEqual(stats.mi_alloc, stats_mi_alloc) + self.assertEqual(stats.free, stats_free) + self.assertEqual(stats.mi_free, stats_mi_free) + if __name__ == "__main__": unittest.main()