1+ import sys
2+
13import torch
24
5+ from ignite .engine import Engine , State
36from ignite .contrib .metrics import GpuMemory
47
58import pytest
69
710
811@pytest .fixture
912def no_site_packages ():
10- try :
11- import pynvml
12- except ImportError :
13- yield "no_site_packages"
14- return
13+ import pynvml
1514 import sys
1615 assert 'pynvml' in sys .modules
1716 pynvml_module = sys .modules ['pynvml' ]
@@ -23,23 +22,42 @@ def no_site_packages():
2322 sys .modules ['pynvml' ] = pynvml_module
2423
2524
25+ @pytest .mark .skipif (sys .version [0 ] == "2" , reason = "No pynvml for python 2.7" )
2626def test_no_pynvml_package (no_site_packages ):
2727
2828 with pytest .raises (RuntimeError , match = "This contrib module requires pynvml to be installed." ):
2929 GpuMemory ()
3030
3131
32- @pytest .mark .skipif (torch .cuda .is_available (), reason = "Skip if has GPU " )
32+ @pytest .mark .skipif (sys . version [ 0 ] == "2" or torch .cuda .is_available (), reason = "No pynvml for python 2.7 " )
3333def test_no_gpu ():
3434
3535 with pytest .raises (RuntimeError , match = "This contrib module requires available GPU" ):
3636 GpuMemory ()
3737
3838
39- @pytest .mark .skipif (torch .cuda .is_available (), reason = "Skip if has GPU" )
39+ @pytest .mark .skipif (sys .version [0 ] == "2" or not (torch .cuda .is_available ()),
40+ reason = "No pynvml for python 2.7 and no GPU" )
4041def test_gpu_mem_consumption ():
4142
4243 gpu_mem = GpuMemory ()
4344
45+ t = torch .rand (4 , 10 , 100 , 100 )
4446 data = gpu_mem .compute ()
4547 assert len (data ) > 0
48+ assert "fb_memory_usage" in data [0 ]
49+ report = data [0 ]['fb_memory_usage' ]
50+ assert 'used' in report and 'total' in report
51+ assert report ['total' ] > 0.0
52+ assert report ['used' ] > t .shape [0 ] * t .shape [1 ] * t .shape [2 ] * t .shape [3 ] / 1024.0 / 1024.0
53+
54+ # with Engine
55+ engine = Engine (lambda engine , batch : 0.0 )
56+ engine .state = State (metrics = {})
57+
58+ gpu_mem .completed (engine , name = 'gpu mem' , local_rank = 0 )
59+
60+ assert 'gpu mem' in engine .state .metrics
61+ assert isinstance (engine .state .metrics ['gpu mem' ], str )
62+ assert "{}" .format (int (report ['used' ])) in engine .state .metrics ['gpu mem' ]
63+ assert "{}" .format (int (report ['total' ])) in engine .state .metrics ['gpu mem' ]
0 commit comments