Skip to content

Commit d4ffeb3

Browse files
Aurelius84chenwhql
authored andcommitted
[CustomOp] Support to specific extra_cflags and exctra_cuda_flags independently (PaddlePaddle#31059)
* split cxx/nvcc compile flags * enhance input argument check * rename extra_cflags into extrac_cxx_flags * add name checking in setup * fix test_dispatch failed * fix word typo and rm usless import statement * refine import statement * fix unittest failed * fix cuda flags error
1 parent e2c1330 commit d4ffeb3

File tree

6 files changed

+60
-49
lines changed

6 files changed

+60
-49
lines changed

python/paddle/fluid/tests/custom_op/test_custom_relu_op_jit.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,8 @@
4040
'custom_relu_op.cc', 'custom_relu_op.cu', 'custom_relu_op_dup.cc'
4141
],
4242
extra_include_paths=paddle_includes, # add for Coverage CI
43-
extra_cflags=extra_compile_args, # add for Coverage CI
43+
extra_cxx_cflags=extra_compile_args, # add for Coverage CI
44+
extra_cuda_cflags=extra_compile_args, # add for Coverage CI
4445
verbose=True)
4546

4647

python/paddle/fluid/tests/custom_op/test_dispatch_jit.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,8 @@
3131
name='dispatch_op',
3232
sources=['dispatch_test_op.cc'],
3333
extra_include_paths=paddle_includes, # add for Coverage CI
34-
extra_cflags=extra_compile_args, # add for Coverage CI
34+
extra_cxx_cflags=extra_compile_args,
35+
extra_cuda_cflags=extra_compile_args, # add for Coverage CI
3536
verbose=True)
3637

3738

python/paddle/fluid/tests/custom_op/test_jit_load.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,8 @@
2929
sources=['relu_op.cc', 'relu_op.cu', 'relu_op3.cc', 'relu_op3.cu'],
3030
interpreter='python', # add for unittest
3131
extra_include_paths=paddle_includes, # add for Coverage CI
32-
extra_cflags=extra_compile_args, # add for Coverage CI
32+
extra_cxx_cflags=extra_compile_args, # add for Coverage CI,
33+
extra_cuda_cflags=extra_compile_args, # add for split cpp/cuda flags
3334
verbose=True # add for unittest
3435
)
3536

python/paddle/fluid/tests/custom_op/test_multi_out_jit.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,8 @@
3535
name='multi_out_jit',
3636
sources=['multi_out_test_op.cc'],
3737
extra_include_paths=paddle_includes, # add for Coverage CI
38-
extra_cflags=extra_compile_args, # add for Coverage CI
38+
extra_cxx_cflags=extra_compile_args, # add for Coverage CI
39+
extra_cuda_cflags=extra_compile_args, # add for Coverage CI
3940
verbose=True)
4041

4142

python/paddle/utils/cpp_extension/cpp_extension.py

Lines changed: 32 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,6 @@
1414

1515
import os
1616
import six
17-
import sys
18-
import textwrap
1917
import copy
2018
import re
2119

@@ -50,7 +48,7 @@ def setup(**attr):
5048
Its usage is almost same as `setuptools.setup` except for `ext_modules`
5149
arguments. For compiling multi custom operators, all necessary source files
5250
can be include into just one Extension (CppExtension/CUDAExtension).
53-
Moreover, only one `name` argument is required in `setup` and no need to spcific
51+
Moreover, only one `name` argument is required in `setup` and no need to specify
5452
`name` in Extension.
5553
5654
Example:
@@ -60,11 +58,11 @@ def setup(**attr):
6058
ext_modules=CUDAExtension(
6159
sources=['relu_op.cc', 'relu_op.cu'],
6260
include_dirs=[], # specific user-defined include dirs
63-
extra_compile_args=[]) # specific user-defined compil arguments.
61+
extra_compile_args=[]) # specific user-defined compiler arguments.
6462
"""
6563
cmdclass = attr.get('cmdclass', {})
6664
assert isinstance(cmdclass, dict)
67-
# if not specific cmdclass in setup, add it automaticaly.
65+
# if not specific cmdclass in setup, add it automatically.
6866
if 'build_ext' not in cmdclass:
6967
cmdclass['build_ext'] = BuildExtension.with_options(
7068
no_python_abi_suffix=True)
@@ -81,18 +79,22 @@ def setup(**attr):
8179
sources=['relu_op.cc', 'relu_op.cu'])
8280
8381
# After running `python setup.py install`
84-
from custom_module import relue
82+
from custom_module import relu
8583
"""
8684
# name argument is required
8785
if 'name' not in attr:
8886
raise ValueError(error_msg)
8987

88+
assert not attr['name'].endswith('module'), \
89+
"Please don't use 'module' as suffix in `name` argument, "
90+
"it will be stripped in setuptools.bdist_egg and cause import error."
91+
9092
ext_modules = attr.get('ext_modules', [])
9193
if not isinstance(ext_modules, list):
9294
ext_modules = [ext_modules]
9395
assert len(
9496
ext_modules
95-
) == 1, "Required only one Extension, but received {}. If you want to compile multi operators, you can include all necessary source files in one Extenion.".format(
97+
) == 1, "Required only one Extension, but received {}. If you want to compile multi operators, you can include all necessary source files in one Extension.".format(
9698
len(ext_modules))
9799
# replace Extension.name with attr['name] to keep consistant with Package name.
98100
for ext_module in ext_modules:
@@ -233,12 +235,6 @@ def finalize_options(self):
233235

234236
def build_extensions(self):
235237
self._check_abi()
236-
for extension in self.extensions:
237-
# check settings of compiler
238-
if isinstance(extension.extra_compile_args, dict):
239-
for compiler in ['cxx', 'nvcc']:
240-
if compiler not in extension.extra_compile_args:
241-
extension.extra_compile_args[compiler] = []
242238

243239
# Consider .cu, .cu.cc as valid source extensions.
244240
self.compiler.src_extensions += ['.cu', '.cu.cc']
@@ -248,8 +244,6 @@ def build_extensions(self):
248244
original_compile = self.compiler.compile
249245
original_spawn = self.compiler.spawn
250246
else:
251-
# add determine compile flags
252-
add_compile_flag(extension, '-std=c++11')
253247
original_compile = self.compiler._compile
254248

255249
def unix_custom_single_compiler(obj, src, ext, cc_args, extra_postargs,
@@ -271,8 +265,8 @@ def unix_custom_single_compiler(obj, src, ext, cc_args, extra_postargs,
271265
# {'nvcc': {}, 'cxx: {}}
272266
if isinstance(cflags, dict):
273267
cflags = cflags['nvcc']
274-
else:
275-
cflags = prepare_unix_cudaflags(cflags)
268+
269+
cflags = prepare_unix_cudaflags(cflags)
276270
# cxx compile Cpp source
277271
elif isinstance(cflags, dict):
278272
cflags = cflags['cxx']
@@ -434,7 +428,7 @@ def _check_abi(self):
434428
compiler = os.environ.get('CXX', 'c++')
435429

436430
check_abi_compatibility(compiler)
437-
# Warn user if VC env is activated but `DISTUILS_USE_SDK` is not set.
431+
# Warn user if VC env is activated but `DISTUTILS_USE_SDK` is not set.
438432
if IS_WINDOWS and 'VSCMD_ARG_TGT_ARCH' in os.environ and 'DISTUTILS_USE_SDK' not in os.environ:
439433
msg = (
440434
'It seems that the VC environment is activated but DISTUTILS_USE_SDK is not set.'
@@ -444,7 +438,7 @@ def _check_abi(self):
444438

445439
def _record_op_info(self):
446440
"""
447-
Record custum op inforomation.
441+
Record custom op information.
448442
"""
449443
# parse shared library abs path
450444
outputs = self.get_outputs()
@@ -535,7 +529,7 @@ def initialize_options(self):
535529

536530
def load(name,
537531
sources,
538-
extra_cflags=None,
532+
extra_cxx_cflags=None,
539533
extra_cuda_cflags=None,
540534
extra_ldflags=None,
541535
extra_include_paths=None,
@@ -558,14 +552,14 @@ def load(name,
558552
Args:
559553
name(str): generated shared library file name.
560554
sources(list[str]): custom op source files name with .cc/.cu suffix.
561-
extra_cflag(list[str]): additional flags used to compile CPP files. By default
555+
extra_cxx_cflags(list[str]): additional flags used to compile CPP files. By default
562556
all basic and framework related flags have been included.
563557
If your pre-insall Paddle supported MKLDNN, please add
564558
'-DPADDLE_WITH_MKLDNN'. Default None.
565-
extra_cuda_cflags(list[str]): additonal flags used to compile CUDA files. See
559+
extra_cuda_cflags(list[str]): additional flags used to compile CUDA files. See
566560
https://docs.nvidia.com/cuda/cuda-compiler-driver-nvcc/index.html
567561
for details. Default None.
568-
extra_ldflags(list[str]): additonal flags used to link shared library. See
562+
extra_ldflags(list[str]): additional flags used to link shared library. See
569563
https://gcc.gnu.org/onlinedocs/gcc/Link-Options.html for details.
570564
Default None.
571565
extra_include_paths(list[str]): additional include path used to search header files.
@@ -578,7 +572,7 @@ def load(name,
578572
verbose(bool): whether to verbose compiled log information
579573
580574
Returns:
581-
custom api: A callable python function with same signature as CustomOp Kernel defination.
575+
custom api: A callable python function with same signature as CustomOp Kernel definition.
582576
583577
Example:
584578
@@ -603,18 +597,25 @@ def load(name,
603597
file_path = os.path.join(build_directory, "{}_setup.py".format(name))
604598
sources = [os.path.abspath(source) for source in sources]
605599

606-
# TODO(Aurelius84): split cflags and cuda_flags
607-
if extra_cflags is None: extra_cflags = []
600+
if extra_cxx_cflags is None: extra_cxx_cflags = []
608601
if extra_cuda_cflags is None: extra_cuda_cflags = []
609-
compile_flags = extra_cflags + extra_cuda_cflags
610-
log_v("additonal compile_flags: [{}]".format(' '.join(compile_flags)),
611-
verbose)
602+
assert isinstance(
603+
extra_cxx_cflags, list
604+
), "Required type(extra_cxx_cflags) == list[str], but received {}".format(
605+
extra_cxx_cflags)
606+
assert isinstance(
607+
extra_cuda_cflags, list
608+
), "Required type(extra_cuda_cflags) == list[str], but received {}".format(
609+
extra_cuda_cflags)
610+
611+
log_v("additional extra_cxx_cflags: [{}], extra_cuda_cflags: [{}]".format(
612+
' '.join(extra_cxx_cflags), ' '.join(extra_cuda_cflags)), verbose)
612613

613614
# write setup.py file and compile it
614615
build_base_dir = os.path.join(build_directory, name)
615616
_write_setup_file(name, sources, file_path, build_base_dir,
616-
extra_include_paths, compile_flags, extra_ldflags,
617-
verbose)
617+
extra_include_paths, extra_cxx_cflags, extra_cuda_cflags,
618+
extra_ldflags, verbose)
618619
_jit_compile(file_path, interpreter, verbose)
619620

620621
# import as callable python api

python/paddle/utils/cpp_extension/extension_utils.py

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
import re
1717
import six
1818
import sys
19-
import copy
2019
import glob
2120
import logging
2221
import collections
@@ -271,6 +270,13 @@ def normalize_extension_kwargs(kwargs, use_cuda=False):
271270
library_dirs.extend(find_paddle_libraries(use_cuda))
272271
kwargs['library_dirs'] = library_dirs
273272

273+
# append compile flags and check settings of compiler
274+
extra_compile_args = kwargs.get('extra_compile_args', [])
275+
if isinstance(extra_compile_args, dict):
276+
for compiler in ['cxx', 'nvcc']:
277+
if compiler not in extra_compile_args:
278+
extra_compile_args[compiler] = []
279+
274280
if IS_WINDOWS:
275281
# TODO(zhouwei): may append compile flags in future
276282
pass
@@ -282,9 +288,7 @@ def normalize_extension_kwargs(kwargs, use_cuda=False):
282288
kwargs['extra_link_args'] = extra_link_args
283289
else:
284290
# append compile flags
285-
extra_compile_args = kwargs.get('extra_compile_args', [])
286-
extra_compile_args.extend(['-g', '-w']) # diable warnings
287-
kwargs['extra_compile_args'] = extra_compile_args
291+
add_compile_flag(extra_compile_args, ['-g', '-w']) # disable warnings
288292

289293
# append link flags
290294
extra_link_args = kwargs.get('extra_link_args', [])
@@ -302,6 +306,8 @@ def normalize_extension_kwargs(kwargs, use_cuda=False):
302306
runtime_library_dirs.extend(find_paddle_libraries(use_cuda))
303307
kwargs['runtime_library_dirs'] = runtime_library_dirs
304308

309+
kwargs['extra_compile_args'] = extra_compile_args
310+
305311
kwargs['language'] = 'c++'
306312
return kwargs
307313

@@ -407,15 +413,13 @@ def find_paddle_libraries(use_cuda=False):
407413
return paddle_lib_dirs
408414

409415

410-
def add_compile_flag(extension, flag):
411-
extra_compile_args = copy.deepcopy(extension.extra_compile_args)
416+
def add_compile_flag(extra_compile_args, flags):
417+
assert isinstance(flags, list)
412418
if isinstance(extra_compile_args, dict):
413419
for args in extra_compile_args.values():
414-
args.append(flag)
420+
args.extend(flags)
415421
else:
416-
extra_compile_args.append(flag)
417-
418-
extension.extra_compile_args = extra_compile_args
422+
extra_compile_args.extend(flags)
419423

420424

421425
def is_cuda_file(path):
@@ -520,7 +524,7 @@ def _custom_api_content(op_name):
520524
def {op_name}({inputs}):
521525
helper = LayerHelper("{op_name}", **locals())
522526
523-
# prepare inputs and output
527+
# prepare inputs and outputs
524528
ins = {ins}
525529
outs = {{}}
526530
out_names = {out_names}
@@ -585,7 +589,8 @@ def _write_setup_file(name,
585589
file_path,
586590
build_dir,
587591
include_dirs,
588-
compile_flags,
592+
extra_cxx_cflags,
593+
extra_cuda_cflags,
589594
link_args,
590595
verbose=False):
591596
"""
@@ -605,7 +610,7 @@ def _write_setup_file(name,
605610
{prefix}Extension(
606611
sources={sources},
607612
include_dirs={include_dirs},
608-
extra_compile_args={extra_compile_args},
613+
extra_compile_args={{'cxx':{extra_cxx_cflags}, 'nvcc':{extra_cuda_cflags}}},
609614
extra_link_args={extra_link_args})],
610615
cmdclass={{"build_ext" : BuildExtension.with_options(
611616
output_dir=r'{build_dir}',
@@ -622,7 +627,8 @@ def _write_setup_file(name,
622627
prefix='CUDA' if with_cuda else 'Cpp',
623628
sources=list2str(sources),
624629
include_dirs=list2str(include_dirs),
625-
extra_compile_args=list2str(compile_flags),
630+
extra_cxx_cflags=list2str(extra_cxx_cflags),
631+
extra_cuda_cflags=list2str(extra_cuda_cflags),
626632
extra_link_args=list2str(link_args),
627633
build_dir=build_dir,
628634
use_new_method=use_new_custom_op_load_method())

0 commit comments

Comments
 (0)