-
Notifications
You must be signed in to change notification settings - Fork 6.7k
forward can't run parallelly using multi-gpus when custom operator using numpy #8884
Description
Description
I found after #6928 when you using numpy in custom operator, multi-gpus forwards is not running parallelly. But before #6928, multi-gpus forwards can run parallelly. We can use mxnet's image-classification example to reproduce it by replacing Softmax operator using custom's softmax.
Environment info
----------Python Info----------
('Version :', '2.7.6')
('Compiler :', 'GCC 4.8.4')
('Build :', ('default', 'Oct 26 2016 20:30:19'))
('Arch :', ('64bit', 'ELF'))
------------Pip Info-----------
('Version :', '9.0.1')
('Directory :', '/usr/local/lib/python2.7/dist-packages/pip-9.0.1-py2.7.egg/pip')
----------MXNet Info-----------
('Version :', '0.10.0')
('Directory :', '/data/home/xizhou/incubator-mxnet/python/mxnet')
Traceback (most recent call last):
File "diagnose_new.py", line 171, in <module>
check_mxnet()
File "diagnose_new.py", line 113, in check_mxnet
except FileNotFoundError:
NameError: global name 'FileNotFoundError' is not defined
----------System Info----------
('Platform :', 'Linux-3.13.0-132-generic-x86_64-with-Ubuntu-14.04-trusty')
('system :', 'Linux')
('node :', 'msravcg10')
('release :', '3.13.0-132-generic')
('version :', '#181-Ubuntu SMP Wed Sep 13 13:25:03 UTC 2017')
----------Hardware Info----------
('machine :', 'x86_64')
('processor :', 'x86_64')
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Byte Order: Little Endian
CPU(s): 32
On-line CPU(s) list: 0-31
Thread(s) per core: 2
Core(s) per socket: 8
Socket(s): 2
NUMA node(s): 2
Vendor ID: GenuineIntel
CPU family: 6
Model: 62
Stepping: 4
CPU MHz: 2600.000
BogoMIPS: 5188.77
Virtualization: VT-x
L1d cache: 32K
L1i cache: 32K
L2 cache: 256K
L3 cache: 20480K
NUMA node0 CPU(s): 0-7,16-23
NUMA node1 CPU(s): 8-15,24-31
----------Network Test----------
Setting timeout: 10
Timing for MXNet: https://github.com/apache/incubator-mxnet, DNS: 0.0017 sec, LOAD: 1.1562 sec.
Timing for PYPI: https://pypi.python.org/pypi/pip, DNS: 0.0012 sec, LOAD: 0.4335 sec.
Timing for FashionMNIST: https://apache-mxnet.s3-accelerate.dualstack.amazonaws.com/gluon/dataset/fashion-mnist/train-labels-idx1-ubyte.gz, DNS: 0.2028 sec, LOAD: 0.9514 sec.
Timing for Conda: https://repo.continuum.io/pkgs/free/, DNS: 0.0678 sec, LOAD: 0.4102 sec.
Timing for Gluon Tutorial(en): http://gluon.mxnet.io, DNS: 0.0684 sec, LOAD: 0.2063 sec.
Error open Gluon Tutorial(cn): https://zh.gluon.ai, <urlopen error [Errno 1] _ssl.c:510: error:14077410:SSL routines:SSL23_GET_SERVER_HELLO:sslv3 alert handshake failure>, DNS finished in 0.0689778327942 sec.
Package used (Python/R/Scala/Julia):
Python
Build info (Required if built from source)
Compiler (gcc/clang/mingw/visual studio):
gcc
MXNet commit hash:
ed19095
Minimum reproducible example
The custom's softmax operator, just in order to reproduce this issue, so I did no implement the backward.
class Softmax(mx.operator.CustomOp):
def forward(self, is_train, req, in_data, out_data, aux):
self.assign(out_data[0], req[0], mx.nd.softmax(mx.nd.array(in_data[0].asnumpy(), ctx=in_data[0].context)))
def backward(self, req, out_grad, in_data, out_data, in_grad, aux):
self.assign(in_grad[0], req[0], 0)
@mx.operator.register("softmax")
class SoftmaxProp(mx.operator.CustomOpProp):
def __init__(self):
super(SoftmaxProp, self).__init__(need_top_grad=False)
def list_arguments(self):
return ['data', 'label']
def list_outputs(self):
return ['output']
def infer_shape(self, in_shape):
data_shape = in_shape[0]
label_shape = (in_shape[0][0],)
output_shape = in_shape[0]
return [data_shape, label_shape], [output_shape], []
def infer_type(self, in_type):
return in_type, [in_type[0]], []
def create_operator(self, ctx, shapes, dtypes):
return Softmax()
Steps to reproduce
- run original train_imagenet.py as baseline
python train_imagenet.py --benchmark 1 --batch-size 128 --gpus 0,1,2,3
Training speed is:
INFO:root:Epoch[0] Batch [20] Speed: 217.27 samples/sec accuracy=0.113467
INFO:root:Epoch[0] Batch [40] Speed: 217.81 samples/sec accuracy=1.000000
- run train_imagenet.py using custom softmax which using asnumpy function in the custom's operator.
python train_imagenet.py --benchmark 1 --batch-size 128 --gpus 0,1,2,3
Training speed is:
INFO:root:Epoch[0] Batch [20] Speed: 114.91 samples/sec accuracy=0.000000
INFO:root:Epoch[0] Batch [40] Speed: 113.70 samples/sec accuracy=0.000000
What have you tried to solve it?
I have used the mxnet build-in profiler to find more detail about the execution time
The original's version:
using custom softmax's version:
it can see, that when using custom operator, the forward procedures on multi-gpus are running sequentially not parallelly.
I have also tried mxnet's version before #6928, using custom softmax operator or not, the speed is almost the same.
original training speed using mxnet before #6928
INFO:root:Epoch[0] Batch [20] Speed: 217.54 samples/sec accuracy=0.232515
INFO:root:Epoch[0] Batch [40] Speed: 214.66 samples/sec accuracy=1.000000
using custom softmax using mxnet before #6928
INFO:root:Epoch[0] Batch [20] Speed: 217.28 samples/sec accuracy=0.000000
INFO:root:Epoch[0] Batch [40] Speed: 213.57 samples/sec accuracy=0.000000