diff --git a/easybuild/easyconfigs/j/jax/jax-0.3.23-foss-2021b-CUDA-11.4.1.eb b/easybuild/easyconfigs/j/jax/jax-0.3.23-foss-2021b-CUDA-11.4.1.eb new file mode 100644 index 00000000000..cc8ee4c2d2e --- /dev/null +++ b/easybuild/easyconfigs/j/jax/jax-0.3.23-foss-2021b-CUDA-11.4.1.eb @@ -0,0 +1,123 @@ +# This file is an EasyBuild reciPY as per https://github.com/easybuilders/easybuild +# Author: Denis Kristak +# Updated by: Alex Domingo (Vrije Universiteit Brussel) +# Downgrade to foss/2021b: T. Hoffmann (EMBL) +easyblock = 'PythonBundle' + +name = 'jax' +version = '0.3.23' +versionsuffix = '-CUDA-%(cudaver)s' + +homepage = 'https://pypi.python.org/pypi/jax' +description = """Composable transformations of Python+NumPy programs: +differentiate, vectorize, JIT to GPU/TPU, and more""" + +toolchain = {'name': 'foss', 'version': '2021b'} + +builddependencies = [ + ('Bazel', '3.7.2'), + ('pytest-xdist', '2.5.0'), + # git 2.x required to fetch repository 'io_bazel_rules_docker' + ('git', '2.33.1', '-nodocs'), + ('matplotlib', '3.5.2'), # required by some jax tests; also loads Pillow/9.1.1 +] + +dependencies = [ + ('CUDA', '11.4.1', '', SYSTEM), + ('cuDNN', '8.2.2.26', versionsuffix, SYSTEM), + ('NCCL', '2.10.3', versionsuffix), + ('Python', '3.9.6'), + ('SciPy-bundle', '2021.10'), + ('flatbuffers-python', '2.0'), +] + +# downloading TensorFlow tarball to avoid that Bazel downloads it during the build +# note: this *must* be the exact same commit as used in WORKSPACE +local_tf_commit = 'cb946f223b9b3fa04efdbb7a0e6a9dabb22a7057' +local_tf_dir = 'tensorflow-%s' % local_tf_commit +local_tf_builddir = '%(builddir)s/' + local_tf_dir + +# replace remote TensorFlow repository with the local one from EB +local_jax_prebuildopts = "sed -i -f jaxlib_local-tensorflow-repo.sed WORKSPACE && " +local_jax_prebuildopts += "sed -i 's|EB_TF_REPOPATH|%s|' WORKSPACE && " % local_tf_builddir + +use_pip = True + +default_easyblock = 'PythonPackage' +default_component_specs = { + 'sources': [SOURCE_TAR_GZ], + 'source_urls': [PYPI_SOURCE], + 'start_dir': '%(name)s-%(version)s', + 'use_pip': True, + 'sanity_pip_check': True, + 'download_dep_fail': True, +} + +components = [ + ('absl-py', '1.3.0', { + 'options': {'modulename': 'absl'}, + 'checksums': ['463c38a08d2e4cef6c498b76ba5bd4858e4c6ef51da1a5a1f27139a022e20248'], + }), + ('jaxlib', '0.3.22', { + 'sources': [ + '%(name)s-v%(version)s.tar.gz', + { + 'download_filename': '%s.tar.gz' % local_tf_commit, + 'filename': 'tensorflow-%s.tar.gz' % local_tf_commit, + } + ], + 'source_urls': [ + 'https://github.com/google/jax/archive/', + 'https://github.com/tensorflow/tensorflow/archive/' + ], + 'patches': [ + ('jaxlib_local-tensorflow-repo.sed', '.'), + ('TensorFlow-2.7.0_cuda-noncanonical-include-paths.patch', '../' + local_tf_dir), + ], + 'checksums': [ + # jaxlib-v0.3.22.tar.gz + '680a6f5265ba26d5515617a95ae47244005366f879a5c321782fde60f34e6d0d', + # tensorflow-cb946f223b9b3fa04efdbb7a0e6a9dabb22a7057.tar.gz + '9a7a7a87356bdeef5874fae135de380466482b593469035be3609a9cd2c153c4', + # jaxlib_local-tensorflow-repo.sed + 'abb5c3b97f4e317bce9f22ed3eeea3b9715365818d8b50720d937e2d41d5c4e5', + # TensorFlow-2.7.0_cuda-noncanonical-include-paths.patch + '0a759010c253d49755955cd5f028e75de4a4c447dcc8f5a0d9f47cce6881a9db', + ], + 'start_dir': 'jax-jaxlib-v%(version)s', + 'prebuildopts': local_jax_prebuildopts, + }), +] + +exts_list = [ + ('opt_einsum', '3.3.0', { + 'checksums': ['59f6475f77bbc37dcf7cd748519c0ec60722e91e63ca114e68821c0c54a46549'], + }), + ('etils', '0.8.0', { + 'checksums': ['d1d5af7bd9c784a273c4e1eccfaa8feaca5e0481a08717b5313fa231da22a903'], + }), + (name, version, { + 'patches': [ + 'jax-0.3.9_relax-test-tolerance.patch', + 'jax-0.3.23_correctly-skip-from_dlpack-tests.patch', + 'jax-0.3.23_relax-testPoly5-tolerance.patch', + ], + 'runtest': "NVIDIA_TF32_OVERRIDE=0 CUDA_VISIBLE_DEVICES=0 XLA_PYTHON_CLIENT_ALLOCATOR=platform " + + "JAX_ENABLE_X64=true pytest -vv tests", + 'source_tmpl': '%(name)s-v%(version)s.tar.gz', + 'source_urls': ['https://github.com/google/jax/archive/'], + 'checksums': [ + {'jax-v0.3.23.tar.gz': 'fa8c68a82fa2fcf3d272bf239c77e7028bb6077466a53349ce85f6e85ed623db'}, + {'jax-0.3.9_relax-test-tolerance.patch': + '3da3c8b4d9ff3449b51a4f39d6bbadd348ea3bd4ca493a6f1292743f86fa7b3d'}, + {'jax-0.3.23_correctly-skip-from_dlpack-tests.patch': + 'a69ce7280ca8bb42e671217f00d9147f8c64b4b7ba65dea7f05f2c6de757b279'}, + {'jax-0.3.23_relax-testPoly5-tolerance.patch': + 'be64bf36dde4884a97b6c8bb22c6b14ab5b24033cd40bfe7ce18363c55c30e87'}, + ], + }), +] + +sanity_pip_check = True + +moduleclass = 'tools'