Skip to content

Commit 5ede7b7

Browse files
committed
1. Build on Windows
2. Fix OverflowError When calling `key = random.PRNGKey(0)` OverflowError: Python int too large to convert to C long for casting value 4294967295 (0xFFFFFFFF) from python int to int32. 3. fix file path in regex of errors_test 4. handle ValueError of os.path.commonpath
1 parent c2b01f1 commit 5ede7b7

File tree

11 files changed

+250
-106
lines changed

11 files changed

+250
-106
lines changed

.bazelversion

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
3.1.0

WORKSPACE

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,10 @@ http_archive(
2828
# and update the sha256 with the result.
2929
http_archive(
3030
name = "org_tensorflow",
31-
sha256 = "579a74ad171d8da7b7193ff863f28482c2e6050c4090650b001fb80bbc46bb0f",
32-
strip_prefix = "tensorflow-04f25b55e27be95ec340f414c2a1cabe16be5c2a",
31+
sha256 = "f59fc70b349373267a928f37d3856984781214c07ca53ef075c48389778d17cf",
32+
strip_prefix = "tensorflow-b7060f1c189a3b2d97c8eca13d5a13acc5553403",
3333
urls = [
34-
"https://github.com/tensorflow/tensorflow/archive/04f25b55e27be95ec340f414c2a1cabe16be5c2a.tar.gz",
34+
"https://github.com/tensorflow/tensorflow/archive/b7060f1c189a3b2d97c8eca13d5a13acc5553403.tar.gz",
3535
],
3636
)
3737

build/BUILD.bazel

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,25 +15,27 @@
1515
# JAX is Autograd and XLA
1616

1717
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
18+
load("@org_tensorflow//tensorflow:tensorflow.bzl", "if_not_windows")
1819

1920
licenses(["notice"]) # Apache 2
2021

2122
package(default_visibility = ["//visibility:public"])
2223

23-
sh_binary(
24+
py_binary(
2425
name = "install_xla_in_source_tree",
25-
srcs = ["install_xla_in_source_tree.sh"],
26+
srcs = ["install_xla_in_source_tree.py"],
2627
data = [
2728
"@org_tensorflow//tensorflow/compiler/xla/python:xla_client",
28-
"@org_tensorflow//tensorflow/compiler/xla/python/tpu_driver/client:py_tpu_client",
2929
"//jaxlib",
3030
"//jaxlib:lapack.so",
3131
"//jaxlib:_pocketfft.so",
3232
"//jaxlib:pocketfft_flatbuffers_py",
33-
] + if_cuda([
33+
] + if_not_windows([
34+
"@org_tensorflow//tensorflow/compiler/xla/python/tpu_driver/client:py_tpu_client",
35+
]) + if_cuda([
3436
"//jaxlib:cublas_kernels",
3537
"//jaxlib:cusolver_kernels",
3638
"//jaxlib:cuda_prng_kernels",
3739
]),
38-
deps = ["@bazel_tools//tools/bash/runfiles"],
40+
deps = ["@bazel_tools//tools/python/runfiles"],
3941
)

build/build.py

Lines changed: 77 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,10 @@
4343
# pylint: enable=g-import-not-at-top
4444

4545

46+
def is_windows():
47+
return sys.platform.startswith("win32")
48+
49+
4650
def shell(cmd):
4751
output = subprocess.check_output(cmd)
4852
return output.decode("UTF-8").strip()
@@ -52,7 +56,8 @@ def shell(cmd):
5256

5357
def get_python_bin_path(python_bin_path_flag):
5458
"""Returns the path to the Python interpreter to use."""
55-
return python_bin_path_flag or sys.executable
59+
path = python_bin_path_flag or sys.executable
60+
return path.replace(os.sep, "/")
5661

5762

5863
def get_python_version(python_bin_path):
@@ -189,7 +194,8 @@ def check_bazel_version(bazel_path, min_version, max_version):
189194
build --repo_env TF_NEED_CUDA="{tf_need_cuda}"
190195
build --action_env TF_CUDA_COMPUTE_CAPABILITIES="{cuda_compute_capabilities}"
191196
build --distinct_host_configuration=false
192-
build --copt=-Wno-sign-compare
197+
build:linux --copt=-Wno-sign-compare
198+
build:macos --copt=-Wno-sign-compare
193199
build -c opt
194200
build:opt --copt=-march=native
195201
build:opt --host_copt=-march=native
@@ -205,9 +211,12 @@ def check_bazel_version(bazel_path, min_version, max_version):
205211
build --define open_source_build=true
206212
207213
# Disable enabled-by-default TensorFlow features that we don't care about.
208-
build --define=no_aws_support=true
209-
build --define=no_gcp_support=true
210-
build --define=no_hdfs_support=true
214+
build:linux --define=no_aws_support=true
215+
build:macos --define=no_aws_support=true
216+
build:linux --define=no_gcp_support=true
217+
build:macos --define=no_gcp_support=true
218+
build:linux --define=no_hdfs_support=true
219+
build:macos --define=no_hdfs_support=true
211220
build --define=no_kafka_support=true
212221
build --define=no_ignite_support=true
213222
build --define=grpc_no_ares=true
@@ -218,16 +227,49 @@ def check_bazel_version(bazel_path, min_version, max_version):
218227
build --spawn_strategy=standalone
219228
build --strategy=Genrule=standalone
220229
221-
build --cxxopt=-std=c++14
222-
build --host_cxxopt=-std=c++14
230+
build --enable_platform_specific_config
231+
232+
# Tensorflow uses M_* math constants that only get defined by MSVC headers if
233+
# _USE_MATH_DEFINES is defined.
234+
build:windows --copt=/D_USE_MATH_DEFINES
235+
build:windows --host_copt=/D_USE_MATH_DEFINES
236+
237+
# Make sure to include as little of windows.h as possible
238+
build:windows --copt=-DWIN32_LEAN_AND_MEAN
239+
build:windows --host_copt=-DWIN32_LEAN_AND_MEAN
240+
build:windows --copt=-DNOGDI
241+
build:windows --host_copt=-DNOGDI
242+
243+
# https://devblogs.microsoft.com/cppblog/announcing-full-support-for-a-c-c-conformant-preprocessor-in-msvc/
244+
# otherwise, there will be some compiling error due to preprocessing.
245+
build:windows --copt=/Zc:preprocessor
246+
247+
build:linux --cxxopt=-std=c++14
248+
build:linux --host_cxxopt=-std=c++14
249+
250+
build:macos --cxxopt=-std=c++14
251+
build:macos --host_cxxopt=-std=c++14
252+
253+
build:windows --cxxopt=/std:c++14
254+
build:windows --host_cxxopt=/std:c++14
255+
256+
# Generate PDB files, to generate useful PDBs, in opt compilation_mode
257+
# --copt /Z7 is needed.
258+
build:windows --linkopt=/DEBUG
259+
build:windows --host_linkopt=/DEBUG
260+
build:windows --linkopt=/OPT:REF
261+
build:windows --host_linkopt=/OPT:REF
262+
build:windows --linkopt=/OPT:ICF
263+
build:windows --host_linkopt=/OPT:ICF
223264
224265
# Suppress all warning messages.
225266
build:short_logs --output_filter=DONT_MATCH_ANYTHING
226267
"""
227268

228269

229270

230-
def write_bazelrc(cuda_toolkit_path=None, cudnn_install_path=None, **kwargs):
271+
def write_bazelrc(cuda_toolkit_path=None, cudnn_install_path=None,
272+
cuda_version=None, cudnn_version=None, **kwargs):
231273
with open("../.bazelrc", "w") as f:
232274
f.write(BAZELRC_TEMPLATE.format(**kwargs))
233275
if cuda_toolkit_path:
@@ -236,7 +278,12 @@ def write_bazelrc(cuda_toolkit_path=None, cudnn_install_path=None, **kwargs):
236278
if cudnn_install_path:
237279
f.write("build --action_env CUDNN_INSTALL_PATH=\"{cudnn_install_path}\"\n"
238280
.format(cudnn_install_path=cudnn_install_path))
239-
281+
if cuda_version:
282+
f.write("build --action_env TF_CUDA_VERSION=\"{cuda_version}\"\n"
283+
.format(cuda_version=cuda_version))
284+
if cudnn_version:
285+
f.write("build --action_env TF_CUDNN_VERSION=\"{cudnn_version}\"\n"
286+
.format(cudnn_version=cudnn_version))
240287

241288
BANNER = r"""
242289
_ _ __ __
@@ -317,6 +364,14 @@ def main():
317364
"--cudnn_path",
318365
default=None,
319366
help="Path to CUDNN libraries.")
367+
parser.add_argument(
368+
"--cuda_version",
369+
default=None,
370+
help="CUDA toolkit version, e.g., 11.1")
371+
parser.add_argument(
372+
"--cudnn_version",
373+
default=None,
374+
help="CUDNN version, e.g., 8")
320375
parser.add_argument(
321376
"--cuda_compute_capabilities",
322377
default="3.5,5.2,6.0,6.1,7.0",
@@ -331,6 +386,12 @@ def main():
331386
help="Additional options to pass to bazel.")
332387
args = parser.parse_args()
333388

389+
if is_windows() and args.enable_cuda:
390+
if args.cuda_version is None:
391+
parser.error("--cuda_version is needed for Windows CUDA build.")
392+
if args.cudnn_version is None:
393+
parser.error("--cudnn_version is needed for Windows CUDA build.")
394+
334395
print(BANNER)
335396
os.chdir(os.path.dirname(__file__ or args.prog) or '.')
336397

@@ -357,12 +418,18 @@ def main():
357418
if cudnn_install_path:
358419
print("CUDNN library path: {}".format(cudnn_install_path))
359420
print("CUDA compute capabilities: {}".format(args.cuda_compute_capabilities))
421+
if args.cuda_version:
422+
print("CUDA version: {}".format(args.cuda_version))
423+
if args.cudnn_version:
424+
print("CUDNN version: {}".format(args.cudnn_version))
360425
write_bazelrc(
361426
python_bin_path=python_bin_path,
362427
tf_need_cuda=1 if args.enable_cuda else 0,
363428
cuda_toolkit_path=cuda_toolkit_path,
364429
cudnn_install_path=cudnn_install_path,
365-
cuda_compute_capabilities=args.cuda_compute_capabilities)
430+
cuda_compute_capabilities=args.cuda_compute_capabilities,
431+
cuda_version=args.cuda_version,
432+
cudnn_version=args.cudnn_version)
366433

367434
print("\nBuilding XLA and installing it in the jaxlib source tree...")
368435
config_args = args.bazel_options
Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
# Copyright 2020 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
import os
17+
import sys
18+
import shutil
19+
import argparse
20+
21+
from bazel_tools.tools.python.runfiles import runfiles
22+
23+
parser = argparse.ArgumentParser()
24+
parser.add_argument("target")
25+
args = parser.parse_args()
26+
27+
r = runfiles.Create()
28+
29+
30+
def _is_windows():
31+
return sys.platform.startswith("win32")
32+
33+
34+
def _copy_so(src_file, dst_dir):
35+
src_filename = os.path.basename(src_file)
36+
if _is_windows() and src_filename.endswith(".so"):
37+
dst_filename = src_filename[:-3] + ".pyd"
38+
else:
39+
dst_filename = src_filename
40+
dst_file = os.path.join(dst_dir, dst_filename)
41+
shutil.copyfile(src_file, dst_file)
42+
43+
44+
def _copy_normal(src_file, dst_dir):
45+
src_filename = os.path.basename(src_file)
46+
dst_file = os.path.join(dst_dir, src_filename)
47+
shutil.copyfile(src_file, dst_file)
48+
49+
50+
def copy(src_file, dst_dir=os.path.join(args.target, "jaxlib")):
51+
if src_file.endswith(".so"):
52+
_copy_so(src_file, dst_dir)
53+
else:
54+
_copy_normal(src_file, dst_dir)
55+
56+
57+
def patch_copy_xla_client_py(dst_dir=os.path.join(args.target, "jaxlib")):
58+
with open(r.Rlocation("org_tensorflow/tensorflow/compiler/xla/python/xla_client.py")) as f:
59+
src = f.read()
60+
src = src.replace("from tensorflow.compiler.xla.python import xla_extension as _xla",
61+
"from . import xla_extension as _xla")
62+
src = src.replace("from tensorflow.compiler.xla.python.xla_extension import ops",
63+
"from .xla_extension import ops")
64+
with open(os.path.join(dst_dir, "xla_client.py"), "w") as f:
65+
f.write(src)
66+
67+
68+
def patch_copy_tpu_client_py(dst_dir=os.path.join(args.target, "jaxlib")):
69+
with open(r.Rlocation("org_tensorflow/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.py")) as f:
70+
src = f.read()
71+
src = src.replace("from tensorflow.compiler.xla.python import xla_extension as _xla",
72+
"from . import xla_extension as _xla")
73+
src = src.replace("from tensorflow.compiler.xla.python import xla_client",
74+
"from . import xla_client")
75+
src = src.replace(
76+
"from tensorflow.compiler.xla.python.tpu_driver.client import tpu_client_extension as _tpu_client",
77+
"from . import tpu_client_extension as _tpu_client")
78+
with open(os.path.join(dst_dir, "tpu_client.py"), "w") as f:
79+
f.write(src)
80+
81+
82+
copy(r.Rlocation("__main__/jaxlib/lapack.so"))
83+
copy(r.Rlocation("__main__/jaxlib/_pocketfft.so"))
84+
copy(r.Rlocation("__main__/jaxlib/pocketfft_flatbuffers_py_generated.py"))
85+
copy(r.Rlocation("__main__/jaxlib/pocketfft.py"))
86+
if r.Rlocation("__main__/jaxlib/cusolver_kernels.so") is not None:
87+
copy(r.Rlocation("__main__/jaxlib/cusolver_kernels.so"))
88+
copy(r.Rlocation("__main__/jaxlib/cublas_kernels.so"))
89+
copy(r.Rlocation("__main__/jaxlib/cusolver_kernels.so"))
90+
copy(r.Rlocation("__main__/jaxlib/cuda_prng_kernels.so"))
91+
if r.Rlocation("__main__/jaxlib/cusolver_kernels.pyd") is not None:
92+
copy(r.Rlocation("__main__/jaxlib/cusolver_kernels.pyd"))
93+
copy(r.Rlocation("__main__/jaxlib/cublas_kernels.pyd"))
94+
copy(r.Rlocation("__main__/jaxlib/cusolver_kernels.pyd"))
95+
copy(r.Rlocation("__main__/jaxlib/cuda_prng_kernels.pyd"))
96+
copy(r.Rlocation("__main__/jaxlib/version.py"))
97+
copy(r.Rlocation("__main__/jaxlib/cusolver.py"))
98+
copy(r.Rlocation("__main__/jaxlib/cuda_prng.py"))
99+
100+
if _is_windows():
101+
copy(r.Rlocation("org_tensorflow/tensorflow/compiler/xla/python/xla_extension.pyd"))
102+
else:
103+
copy(r.Rlocation("org_tensorflow/tensorflow/compiler/xla/python/xla_extension.so"))
104+
patch_copy_xla_client_py()
105+
106+
if not _is_windows():
107+
copy(r.Rlocation("org_tensorflow/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client_extension.so"))
108+
patch_copy_tpu_client_py()

0 commit comments

Comments
 (0)