Skip to content

jaxlib pip package 0.1.58 regression: ImportError: cannot import name 'cusolver' from 'jaxlib' #5374

@nfelt

Description

@nfelt

The most recent jaxlib CPU release (0.1.58), new to PyPI as of earlier today, leads to an ImportError at import jax (using jax 0.2.7) in a fresh venv on my machine: ImportError: cannot import name 'cusolver' from 'jaxlib'. Reverting to jaxlib 0.1.57 eliminates the error.

Machine is macOS Cataline 10.15.7; virtualenv is python 3.8.6.

Repro script:

~$ virtualenv jaxtest
created virtual environment CPython3.8.6.final.0-64 in 142ms
  creator CPython3Posix(dest=~/jaxtest, clear=False, global=False)
  seeder FromAppData(download=False, pip=bundle, setuptools=bundle, wheel=bundle, via=copy, app_data_dir=~/.local/share/virtualenv)
    added seed packages: pip==20.1.1, pkg_resources==0.0.0, setuptools==44.0.0, wheel==0.34.2
  activators BashActivator,CShellActivator,FishActivator,PowerShellActivator,PythonActivator,XonshActivator
~$ source jaxtest/bin/activate
(jaxtest) ~$ pip install --upgrade pip
Collecting pip
  Downloading pip-20.3.3-py2.py3-none-any.whl (1.5 MB)
     |████████████████████████████████| 1.5 MB 15.2 MB/s
Installing collected packages: pip
  Attempting uninstall: pip
    Found existing installation: pip 20.1.1
    Uninstalling pip-20.1.1:
      Successfully uninstalled pip-20.1.1
Successfully installed pip-20.3.3
(jaxtest) ~$ pip install --upgrade jax jaxlib
Collecting jax
  Using cached jax-0.2.7-py3-none-any.whl
Collecting numpy>=1.12
  Downloading numpy-1.19.5-cp38-cp38-manylinux2010_x86_64.whl (14.9 MB)
     |████████████████████████████████| 14.9 MB 13.2 MB/s
Collecting jaxlib
  Using cached jaxlib-0.1.58-cp38-none-manylinux2010_x86_64.whl (34.1 MB)
Collecting absl-py
  Using cached absl_py-0.11.0-py3-none-any.whl (127 kB)
Collecting flatbuffers
  Using cached flatbuffers-1.12-py2.py3-none-any.whl (15 kB)
Collecting opt-einsum
  Using cached opt_einsum-3.3.0-py3-none-any.whl (65 kB)
Collecting scipy
  Using cached scipy-1.6.0-cp38-cp38-manylinux1_x86_64.whl (27.2 MB)
Collecting six
  Using cached six-1.15.0-py2.py3-none-any.whl (10 kB)
Installing collected packages: six, numpy, scipy, opt-einsum, flatbuffers, absl-py, jaxlib, jax
Successfully installed absl-py-0.11.0 flatbuffers-1.12 jax-0.2.7 jaxlib-0.1.58 numpy-1.19.5 opt-einsum-3.3.0 scipy-1.6.0 six-1.15.0
(jaxtest) ~$ python -c 'import jax'
Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "~/jaxtest/lib/python3.8/site-packages/jax/__init__.py", line 22, in <module>
    from .api import (
  File "~/jaxtest/lib/python3.8/site-packages/jax/api.py", line 37, in <module>
    from . import core
  File "~/jaxtest/lib/python3.8/site-packages/jax/core.py", line 31, in <module>
    from . import dtypes
  File "~/jaxtest/lib/python3.8/site-packages/jax/dtypes.py", line 31, in <module>
    from .lib import xla_client
  File "~/jaxtest/lib/python3.8/site-packages/jax/lib/__init__.py", line 60, in <module>
    from jaxlib import cusolver
ImportError: cannot import name 'cusolver' from 'jaxlib' (~/jaxtest/lib/python3.8/site-packages/jaxlib/__init__.py)
(jaxtest) ~$ pip install jaxlib==0.1.57
Collecting jaxlib==0.1.57
  Using cached jaxlib-0.1.57-cp38-none-manylinux2010_x86_64.whl (33.3 MB)
Requirement already satisfied: flatbuffers in ./jaxtest/lib/python3.8/site-packages (from jaxlib==0.1.57) (1.12)
Requirement already satisfied: scipy in ./jaxtest/lib/python3.8/site-packages (from jaxlib==0.1.57) (1.6.0)
Requirement already satisfied: numpy>=1.12 in ./jaxtest/lib/python3.8/site-packages (from jaxlib==0.1.57) (1.19.5)
Requirement already satisfied: absl-py in ./jaxtest/lib/python3.8/site-packages (from jaxlib==0.1.57) (0.11.0)
Requirement already satisfied: six in ./jaxtest/lib/python3.8/site-packages (from absl-py->jaxlib==0.1.57) (1.15.0)
Installing collected packages: jaxlib
  Attempting uninstall: jaxlib
    Found existing installation: jaxlib 0.1.58
    Uninstalling jaxlib-0.1.58:
      Successfully uninstalled jaxlib-0.1.58
Successfully installed jaxlib-0.1.57
(jaxtest) ~$ python -c 'import jax'
(jaxtest) ~$

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions