Skip to content

Commit 36aeeca

Browse files
committed
Updated test_gpu_memory.py
1 parent 3f23271 commit 36aeeca

File tree

1 file changed

+25
-7
lines changed

1 file changed

+25
-7
lines changed
Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,16 @@
1+
import sys
2+
13
import torch
24

5+
from ignite.engine import Engine, State
36
from ignite.contrib.metrics import GpuMemory
47

58
import pytest
69

710

811
@pytest.fixture
912
def no_site_packages():
10-
try:
11-
import pynvml
12-
except ImportError:
13-
yield "no_site_packages"
14-
return
13+
import pynvml
1514
import sys
1615
assert 'pynvml' in sys.modules
1716
pynvml_module = sys.modules['pynvml']
@@ -23,23 +22,42 @@ def no_site_packages():
2322
sys.modules['pynvml'] = pynvml_module
2423

2524

25+
@pytest.mark.skipif(sys.version[0] == "2", reason="No pynvml for python 2.7")
2626
def test_no_pynvml_package(no_site_packages):
2727

2828
with pytest.raises(RuntimeError, match="This contrib module requires pynvml to be installed."):
2929
GpuMemory()
3030

3131

32-
@pytest.mark.skipif(torch.cuda.is_available(), reason="Skip if has GPU")
32+
@pytest.mark.skipif(sys.version[0] == "2" or torch.cuda.is_available(), reason="No pynvml for python 2.7")
3333
def test_no_gpu():
3434

3535
with pytest.raises(RuntimeError, match="This contrib module requires available GPU"):
3636
GpuMemory()
3737

3838

39-
@pytest.mark.skipif(torch.cuda.is_available(), reason="Skip if has GPU")
39+
@pytest.mark.skipif(sys.version[0] == "2" or not (torch.cuda.is_available()),
40+
reason="No pynvml for python 2.7 and no GPU")
4041
def test_gpu_mem_consumption():
4142

4243
gpu_mem = GpuMemory()
4344

45+
t = torch.rand(4, 10, 100, 100)
4446
data = gpu_mem.compute()
4547
assert len(data) > 0
48+
assert "fb_memory_usage" in data[0]
49+
report = data[0]['fb_memory_usage']
50+
assert 'used' in report and 'total' in report
51+
assert report['total'] > 0.0
52+
assert report['used'] > t.shape[0] * t.shape[1] * t.shape[2] * t.shape[3] / 1024.0 / 1024.0
53+
54+
# with Engine
55+
engine = Engine(lambda engine, batch: 0.0)
56+
engine.state = State(metrics={})
57+
58+
gpu_mem.completed(engine, name='gpu mem', local_rank=0)
59+
60+
assert 'gpu mem' in engine.state.metrics
61+
assert isinstance(engine.state.metrics['gpu mem'], str)
62+
assert "{}".format(int(report['used'])) in engine.state.metrics['gpu mem']
63+
assert "{}".format(int(report['total'])) in engine.state.metrics['gpu mem']

0 commit comments

Comments
 (0)