Skip to content

Commit 09db77b

Browse files
committed
polish code
1 parent a3beefe commit 09db77b

File tree

1 file changed

+9
-2
lines changed

1 file changed

+9
-2
lines changed

python/paddle/utils/cpp_extension/cpp_extension.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -355,6 +355,8 @@ def __init__(self, *args, **kwargs):
355355
super(BuildExtension, self).__init__(*args, **kwargs)
356356
self.no_python_abi_suffix = kwargs.get("no_python_abi_suffix", True)
357357
self.output_dir = kwargs.get("output_dir", None)
358+
# whether containing cuda source file in Extensions
359+
self.contain_cuda_file = False
358360

359361
def initialize_options(self):
360362
super(BuildExtension, self).initialize_options()
@@ -432,8 +434,8 @@ def unix_custom_single_compiler(obj, src, ext, cc_args, extra_postargs,
432434
# shared library have same ABI suffix with core_(no)avx.so.
433435
# See https://stackoverflow.com/questions/34571583/understanding-gcc-5s-glibcxx-use-cxx11-abi-or-the-new-abi
434436
add_compile_flag(['-D_GLIBCXX_USE_CXX11_ABI=1'], cflags)
435-
436-
if not is_cuda_file(src):
437+
# Append this macor only when jointly compiling .cc with .cu
438+
if not is_cuda_file(src) and self.contain_cuda_file:
437439
cflags.append('-DPADDLE_WITH_CUDA')
438440

439441
add_std_without_repeat(
@@ -509,6 +511,9 @@ def win_custom_spawn(cmd):
509511
elif isinstance(self.cflags, list):
510512
cflags = MSVC_COMPILE_FLAGS + self.cflags
511513
cmd += cflags
514+
# Append this macor only when jointly compiling .cc with .cu
515+
if not is_cuda_file(src) and self.contain_cuda_file:
516+
cmd.append('-DPADDLE_WITH_CUDA')
512517

513518
return original_spawn(cmd)
514519

@@ -636,6 +641,8 @@ def _record_op_info(self):
636641

637642
for i, extension in enumerate(self.extensions):
638643
sources = [os.path.abspath(s) for s in extension.sources]
644+
if not self.contain_cuda_file:
645+
self.contain_cuda_file = any([is_cuda_file(s) for s in sources])
639646
op_names = parse_op_name_from(sources)
640647

641648
for op_name in op_names:

0 commit comments

Comments
 (0)