4343# pylint: enable=g-import-not-at-top
4444
4545
46+ def is_windows ():
47+ return sys .platform .startswith ("win32" )
48+
49+
4650def shell (cmd ):
4751 output = subprocess .check_output (cmd )
4852 return output .decode ("UTF-8" ).strip ()
@@ -52,7 +56,8 @@ def shell(cmd):
5256
5357def get_python_bin_path (python_bin_path_flag ):
5458 """Returns the path to the Python interpreter to use."""
55- return python_bin_path_flag or sys .executable
59+ path = python_bin_path_flag or sys .executable
60+ return path .replace (os .sep , "/" )
5661
5762
5863def get_python_version (python_bin_path ):
@@ -189,7 +194,8 @@ def check_bazel_version(bazel_path, min_version, max_version):
189194build --repo_env TF_NEED_CUDA="{tf_need_cuda}"
190195build --action_env TF_CUDA_COMPUTE_CAPABILITIES="{cuda_compute_capabilities}"
191196build --distinct_host_configuration=false
192- build --copt=-Wno-sign-compare
197+ build:linux --copt=-Wno-sign-compare
198+ build:macos --copt=-Wno-sign-compare
193199build -c opt
194200build:opt --copt=-march=native
195201build:opt --host_copt=-march=native
@@ -205,9 +211,12 @@ def check_bazel_version(bazel_path, min_version, max_version):
205211build --define open_source_build=true
206212
207213# Disable enabled-by-default TensorFlow features that we don't care about.
208- build --define=no_aws_support=true
209- build --define=no_gcp_support=true
210- build --define=no_hdfs_support=true
214+ build:linux --define=no_aws_support=true
215+ build:macos --define=no_aws_support=true
216+ build:linux --define=no_gcp_support=true
217+ build:macos --define=no_gcp_support=true
218+ build:linux --define=no_hdfs_support=true
219+ build:macos --define=no_hdfs_support=true
211220build --define=no_kafka_support=true
212221build --define=no_ignite_support=true
213222build --define=grpc_no_ares=true
@@ -218,16 +227,49 @@ def check_bazel_version(bazel_path, min_version, max_version):
218227build --spawn_strategy=standalone
219228build --strategy=Genrule=standalone
220229
221- build --cxxopt=-std=c++14
222- build --host_cxxopt=-std=c++14
230+ build --enable_platform_specific_config
231+
232+ # Tensorflow uses M_* math constants that only get defined by MSVC headers if
233+ # _USE_MATH_DEFINES is defined.
234+ build:windows --copt=/D_USE_MATH_DEFINES
235+ build:windows --host_copt=/D_USE_MATH_DEFINES
236+
237+ # Make sure to include as little of windows.h as possible
238+ build:windows --copt=-DWIN32_LEAN_AND_MEAN
239+ build:windows --host_copt=-DWIN32_LEAN_AND_MEAN
240+ build:windows --copt=-DNOGDI
241+ build:windows --host_copt=-DNOGDI
242+
243+ # https://devblogs.microsoft.com/cppblog/announcing-full-support-for-a-c-c-conformant-preprocessor-in-msvc/
244+ # otherwise, there will be some compiling error due to preprocessing.
245+ build:windows --copt=/Zc:preprocessor
246+
247+ build:linux --cxxopt=-std=c++14
248+ build:linux --host_cxxopt=-std=c++14
249+
250+ build:macos --cxxopt=-std=c++14
251+ build:macos --host_cxxopt=-std=c++14
252+
253+ build:windows --cxxopt=/std:c++14
254+ build:windows --host_cxxopt=/std:c++14
255+
256+ # Generate PDB files, to generate useful PDBs, in opt compilation_mode
257+ # --copt /Z7 is needed.
258+ build:windows --linkopt=/DEBUG
259+ build:windows --host_linkopt=/DEBUG
260+ build:windows --linkopt=/OPT:REF
261+ build:windows --host_linkopt=/OPT:REF
262+ build:windows --linkopt=/OPT:ICF
263+ build:windows --host_linkopt=/OPT:ICF
223264
224265# Suppress all warning messages.
225266build:short_logs --output_filter=DONT_MATCH_ANYTHING
226267"""
227268
228269
229270
230- def write_bazelrc (cuda_toolkit_path = None , cudnn_install_path = None , ** kwargs ):
271+ def write_bazelrc (cuda_toolkit_path = None , cudnn_install_path = None ,
272+ cuda_version = None , cudnn_version = None , ** kwargs ):
231273 with open ("../.bazelrc" , "w" ) as f :
232274 f .write (BAZELRC_TEMPLATE .format (** kwargs ))
233275 if cuda_toolkit_path :
@@ -236,7 +278,12 @@ def write_bazelrc(cuda_toolkit_path=None, cudnn_install_path=None, **kwargs):
236278 if cudnn_install_path :
237279 f .write ("build --action_env CUDNN_INSTALL_PATH=\" {cudnn_install_path}\" \n "
238280 .format (cudnn_install_path = cudnn_install_path ))
239-
281+ if cuda_version :
282+ f .write ("build --action_env TF_CUDA_VERSION=\" {cuda_version}\" \n "
283+ .format (cuda_version = cuda_version ))
284+ if cudnn_version :
285+ f .write ("build --action_env TF_CUDNN_VERSION=\" {cudnn_version}\" \n "
286+ .format (cudnn_version = cudnn_version ))
240287
241288BANNER = r"""
242289 _ _ __ __
@@ -317,6 +364,14 @@ def main():
317364 "--cudnn_path" ,
318365 default = None ,
319366 help = "Path to CUDNN libraries." )
367+ parser .add_argument (
368+ "--cuda_version" ,
369+ default = None ,
370+ help = "CUDA toolkit version, e.g., 11.1" )
371+ parser .add_argument (
372+ "--cudnn_version" ,
373+ default = None ,
374+ help = "CUDNN version, e.g., 8" )
320375 parser .add_argument (
321376 "--cuda_compute_capabilities" ,
322377 default = "3.5,5.2,6.0,6.1,7.0" ,
@@ -331,6 +386,12 @@ def main():
331386 help = "Additional options to pass to bazel." )
332387 args = parser .parse_args ()
333388
389+ if is_windows () and args .enable_cuda :
390+ if args .cuda_version is None :
391+ parser .error ("--cuda_version is needed for Windows CUDA build." )
392+ if args .cudnn_version is None :
393+ parser .error ("--cudnn_version is needed for Windows CUDA build." )
394+
334395 print (BANNER )
335396 os .chdir (os .path .dirname (__file__ or args .prog ) or '.' )
336397
@@ -357,12 +418,18 @@ def main():
357418 if cudnn_install_path :
358419 print ("CUDNN library path: {}" .format (cudnn_install_path ))
359420 print ("CUDA compute capabilities: {}" .format (args .cuda_compute_capabilities ))
421+ if args .cuda_version :
422+ print ("CUDA version: {}" .format (args .cuda_version ))
423+ if args .cudnn_version :
424+ print ("CUDNN version: {}" .format (args .cudnn_version ))
360425 write_bazelrc (
361426 python_bin_path = python_bin_path ,
362427 tf_need_cuda = 1 if args .enable_cuda else 0 ,
363428 cuda_toolkit_path = cuda_toolkit_path ,
364429 cudnn_install_path = cudnn_install_path ,
365- cuda_compute_capabilities = args .cuda_compute_capabilities )
430+ cuda_compute_capabilities = args .cuda_compute_capabilities ,
431+ cuda_version = args .cuda_version ,
432+ cudnn_version = args .cudnn_version )
366433
367434 print ("\n Building XLA and installing it in the jaxlib source tree..." )
368435 config_args = args .bazel_options
0 commit comments