1414
1515import os
1616import six
17- import sys
18- import textwrap
1917import copy
2018import re
2119
@@ -50,7 +48,7 @@ def setup(**attr):
5048 Its usage is almost same as `setuptools.setup` except for `ext_modules`
5149 arguments. For compiling multi custom operators, all necessary source files
5250 can be include into just one Extension (CppExtension/CUDAExtension).
53- Moreover, only one `name` argument is required in `setup` and no need to spcific
51+ Moreover, only one `name` argument is required in `setup` and no need to specify
5452 `name` in Extension.
5553
5654 Example:
@@ -60,11 +58,11 @@ def setup(**attr):
6058 ext_modules=CUDAExtension(
6159 sources=['relu_op.cc', 'relu_op.cu'],
6260 include_dirs=[], # specific user-defined include dirs
63- extra_compile_args=[]) # specific user-defined compil arguments.
61+ extra_compile_args=[]) # specific user-defined compiler arguments.
6462 """
6563 cmdclass = attr .get ('cmdclass' , {})
6664 assert isinstance (cmdclass , dict )
67- # if not specific cmdclass in setup, add it automaticaly .
65+ # if not specific cmdclass in setup, add it automatically .
6866 if 'build_ext' not in cmdclass :
6967 cmdclass ['build_ext' ] = BuildExtension .with_options (
7068 no_python_abi_suffix = True )
@@ -81,18 +79,22 @@ def setup(**attr):
8179 sources=['relu_op.cc', 'relu_op.cu'])
8280
8381 # After running `python setup.py install`
84- from custom_module import relue
82+ from custom_module import relu
8583 """
8684 # name argument is required
8785 if 'name' not in attr :
8886 raise ValueError (error_msg )
8987
88+ assert not attr ['name' ].endswith ('module' ), \
89+ "Please don't use 'module' as suffix in `name` argument, "
90+ "it will be stripped in setuptools.bdist_egg and cause import error."
91+
9092 ext_modules = attr .get ('ext_modules' , [])
9193 if not isinstance (ext_modules , list ):
9294 ext_modules = [ext_modules ]
9395 assert len (
9496 ext_modules
95- ) == 1 , "Required only one Extension, but received {}. If you want to compile multi operators, you can include all necessary source files in one Extenion ." .format (
97+ ) == 1 , "Required only one Extension, but received {}. If you want to compile multi operators, you can include all necessary source files in one Extension ." .format (
9698 len (ext_modules ))
9799 # replace Extension.name with attr['name] to keep consistant with Package name.
98100 for ext_module in ext_modules :
@@ -233,12 +235,6 @@ def finalize_options(self):
233235
234236 def build_extensions (self ):
235237 self ._check_abi ()
236- for extension in self .extensions :
237- # check settings of compiler
238- if isinstance (extension .extra_compile_args , dict ):
239- for compiler in ['cxx' , 'nvcc' ]:
240- if compiler not in extension .extra_compile_args :
241- extension .extra_compile_args [compiler ] = []
242238
243239 # Consider .cu, .cu.cc as valid source extensions.
244240 self .compiler .src_extensions += ['.cu' , '.cu.cc' ]
@@ -248,8 +244,6 @@ def build_extensions(self):
248244 original_compile = self .compiler .compile
249245 original_spawn = self .compiler .spawn
250246 else :
251- # add determine compile flags
252- add_compile_flag (extension , '-std=c++11' )
253247 original_compile = self .compiler ._compile
254248
255249 def unix_custom_single_compiler (obj , src , ext , cc_args , extra_postargs ,
@@ -271,8 +265,8 @@ def unix_custom_single_compiler(obj, src, ext, cc_args, extra_postargs,
271265 # {'nvcc': {}, 'cxx: {}}
272266 if isinstance (cflags , dict ):
273267 cflags = cflags ['nvcc' ]
274- else :
275- cflags = prepare_unix_cudaflags (cflags )
268+
269+ cflags = prepare_unix_cudaflags (cflags )
276270 # cxx compile Cpp source
277271 elif isinstance (cflags , dict ):
278272 cflags = cflags ['cxx' ]
@@ -434,7 +428,7 @@ def _check_abi(self):
434428 compiler = os .environ .get ('CXX' , 'c++' )
435429
436430 check_abi_compatibility (compiler )
437- # Warn user if VC env is activated but `DISTUILS_USE_SDK ` is not set.
431+ # Warn user if VC env is activated but `DISTUTILS_USE_SDK ` is not set.
438432 if IS_WINDOWS and 'VSCMD_ARG_TGT_ARCH' in os .environ and 'DISTUTILS_USE_SDK' not in os .environ :
439433 msg = (
440434 'It seems that the VC environment is activated but DISTUTILS_USE_SDK is not set.'
@@ -444,7 +438,7 @@ def _check_abi(self):
444438
445439 def _record_op_info (self ):
446440 """
447- Record custum op inforomation .
441+ Record custom op information .
448442 """
449443 # parse shared library abs path
450444 outputs = self .get_outputs ()
@@ -535,7 +529,7 @@ def initialize_options(self):
535529
536530def load (name ,
537531 sources ,
538- extra_cflags = None ,
532+ extra_cxx_cflags = None ,
539533 extra_cuda_cflags = None ,
540534 extra_ldflags = None ,
541535 extra_include_paths = None ,
@@ -558,14 +552,14 @@ def load(name,
558552 Args:
559553 name(str): generated shared library file name.
560554 sources(list[str]): custom op source files name with .cc/.cu suffix.
561- extra_cflag (list[str]): additional flags used to compile CPP files. By default
555+ extra_cxx_cflags (list[str]): additional flags used to compile CPP files. By default
562556 all basic and framework related flags have been included.
563557 If your pre-insall Paddle supported MKLDNN, please add
564558 '-DPADDLE_WITH_MKLDNN'. Default None.
565- extra_cuda_cflags(list[str]): additonal flags used to compile CUDA files. See
559+ extra_cuda_cflags(list[str]): additional flags used to compile CUDA files. See
566560 https://docs.nvidia.com/cuda/cuda-compiler-driver-nvcc/index.html
567561 for details. Default None.
568- extra_ldflags(list[str]): additonal flags used to link shared library. See
562+ extra_ldflags(list[str]): additional flags used to link shared library. See
569563 https://gcc.gnu.org/onlinedocs/gcc/Link-Options.html for details.
570564 Default None.
571565 extra_include_paths(list[str]): additional include path used to search header files.
@@ -578,7 +572,7 @@ def load(name,
578572 verbose(bool): whether to verbose compiled log information
579573
580574 Returns:
581- custom api: A callable python function with same signature as CustomOp Kernel defination .
575+ custom api: A callable python function with same signature as CustomOp Kernel definition .
582576
583577 Example:
584578
@@ -603,18 +597,25 @@ def load(name,
603597 file_path = os .path .join (build_directory , "{}_setup.py" .format (name ))
604598 sources = [os .path .abspath (source ) for source in sources ]
605599
606- # TODO(Aurelius84): split cflags and cuda_flags
607- if extra_cflags is None : extra_cflags = []
600+ if extra_cxx_cflags is None : extra_cxx_cflags = []
608601 if extra_cuda_cflags is None : extra_cuda_cflags = []
609- compile_flags = extra_cflags + extra_cuda_cflags
610- log_v ("additonal compile_flags: [{}]" .format (' ' .join (compile_flags )),
611- verbose )
602+ assert isinstance (
603+ extra_cxx_cflags , list
604+ ), "Required type(extra_cxx_cflags) == list[str], but received {}" .format (
605+ extra_cxx_cflags )
606+ assert isinstance (
607+ extra_cuda_cflags , list
608+ ), "Required type(extra_cuda_cflags) == list[str], but received {}" .format (
609+ extra_cuda_cflags )
610+
611+ log_v ("additional extra_cxx_cflags: [{}], extra_cuda_cflags: [{}]" .format (
612+ ' ' .join (extra_cxx_cflags ), ' ' .join (extra_cuda_cflags )), verbose )
612613
613614 # write setup.py file and compile it
614615 build_base_dir = os .path .join (build_directory , name )
615616 _write_setup_file (name , sources , file_path , build_base_dir ,
616- extra_include_paths , compile_flags , extra_ldflags ,
617- verbose )
617+ extra_include_paths , extra_cxx_cflags , extra_cuda_cflags ,
618+ extra_ldflags , verbose )
618619 _jit_compile (file_path , interpreter , verbose )
619620
620621 # import as callable python api
0 commit comments