|
| 1 | +# -*- coding: utf-8 -*- |
| 2 | +import warnings |
| 3 | + |
| 4 | +import torch |
| 5 | + |
| 6 | +from ignite.metrics import Metric |
| 7 | +from ignite.engine import Events |
| 8 | + |
| 9 | + |
| 10 | +class GpuInfo(Metric): |
| 11 | + """GPU information: a) used / max memory, b) gpu utilization values as Metric. |
| 12 | +
|
| 13 | + Examples: |
| 14 | +
|
| 15 | + .. code-block:: python |
| 16 | +
|
| 17 | + # Default GPU measurement |
| 18 | + GpuInfo().attach(trainer) # default metric names are 'gpu info:N memory', 'gpu info:N util' |
| 19 | + ProgressBar(persist=True).attach(trainer, metric_names=['gpu info:0 memory', 'gpu info:0 util']) |
| 20 | +
|
| 21 | + # Progress bar will looks like |
| 22 | + # Epoch [2/50]: [64/128] 50%|█████ , gpu memory=1120 / 11176 MiB [06:17<12:34] |
| 23 | +
|
| 24 | + """ |
| 25 | + |
| 26 | + def __init__(self): |
| 27 | + try: |
| 28 | + import pynvml |
| 29 | + except ImportError: |
| 30 | + raise RuntimeError("This contrib module requires pynvml to be installed. " |
| 31 | + "Please install it with command: \n pip install pynvml") |
| 32 | + # Let's check available devices |
| 33 | + if not torch.cuda.is_available(): |
| 34 | + raise RuntimeError("This contrib module requires available GPU") |
| 35 | + |
| 36 | + from pynvml.smi import nvidia_smi |
| 37 | + # Let it fail if no libnvidia drivers or NMVL library found |
| 38 | + self.nvsmi = nvidia_smi.getInstance() |
| 39 | + super(GpuInfo, self).__init__() |
| 40 | + |
| 41 | + def reset(self): |
| 42 | + pass |
| 43 | + |
| 44 | + def update(self, output): |
| 45 | + pass |
| 46 | + |
| 47 | + def compute(self): |
| 48 | + data = self.nvsmi.DeviceQuery('memory.used, memory.total, utilization.gpu') |
| 49 | + if len(data) == 0 or ('gpu' not in data): |
| 50 | + warnings.warn("No GPU information available") |
| 51 | + return [] |
| 52 | + return data['gpu'] |
| 53 | + |
| 54 | + def completed(self, engine, name): |
| 55 | + data = self.compute() |
| 56 | + if len(data) < 1: |
| 57 | + warnings.warn("No GPU information available") |
| 58 | + return |
| 59 | + |
| 60 | + for i, data_by_rank in enumerate(data): |
| 61 | + mem_name = "{}:{} memory".format(name, i) |
| 62 | + |
| 63 | + if 'fb_memory_usage' not in data_by_rank: |
| 64 | + warnings.warn("No GPU memory usage information available in {}".format(data_by_rank)) |
| 65 | + continue |
| 66 | + mem_report = data_by_rank['fb_memory_usage'] |
| 67 | + if not ('used' in mem_report and 'total' in mem_report): |
| 68 | + warnings.warn("GPU memory usage information does not provide used/total " |
| 69 | + "memory consumption information in {}".format(mem_report)) |
| 70 | + continue |
| 71 | + |
| 72 | + engine.state.metrics[mem_name] = "{}/{} MiB".format(int(mem_report['used']), int(mem_report['total'])) |
| 73 | + |
| 74 | + util_name = "{}:{} util".format(name, i) |
| 75 | + if 'utilization' not in data_by_rank: |
| 76 | + warnings.warn("No GPU utilization information available in {}".format(data_by_rank)) |
| 77 | + continue |
| 78 | + util_report = data_by_rank['utilization'] |
| 79 | + if not ('gpu_util' in util_report): |
| 80 | + warnings.warn("GPU utilization information does not provide 'gpu_util' information in " |
| 81 | + "{}".format(util_report)) |
| 82 | + continue |
| 83 | + |
| 84 | + engine.state.metrics[util_name] = "{:02d}%".format(int(util_report['gpu_util'])) |
| 85 | + |
| 86 | + def attach(self, engine, name="gpu info", event_name=Events.ITERATION_COMPLETED): |
| 87 | + engine.add_event_handler(event_name, self.completed, name) |
0 commit comments