5858
5959 # install test packages
6060 conda install -y pytest
61- if [ "${{ inputs.gpu }}" = "ON" ]; then
61+ if [ "${{ inputs.rocm }}" = "ON" ]; then
62+ : # skip torch install via conda, we need to install via pip to get
63+ # ROCm-enabled version until it's supported in conda by PyTorch
64+ elif [ "${{ inputs.gpu }}" = "ON" ]; then
6265 conda install -y -q pytorch pytorch-cuda=12.4 -c pytorch -c nvidia/label/cuda-12.4.0
6366 else
6467 conda install -y -q pytorch -c pytorch
@@ -138,14 +141,19 @@ runs:
138141 working-directory : build/faiss/python
139142 run : |
140143 $CONDA/bin/python setup.py install
144+ - name : ROCm - install ROCm-enabled torch via pip
145+ if : inputs.rocm == 'ON'
146+ shell : bash
147+ run : |
148+ pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm6.1
141149 - name : Python tests (CPU only)
142150 if : inputs.gpu == 'OFF'
143151 shell : bash
144152 run : |
145153 pytest --junitxml=test-results/pytest/results.xml tests/test_*.py
146154 pytest --junitxml=test-results/pytest/results-torch.xml tests/torch_*.py
147155 - name : Python tests (CPU + GPU)
148- if : inputs.gpu == 'ON' && inputs.rocm == 'OFF'
156+ if : inputs.gpu == 'ON'
149157 shell : bash
150158 run : |
151159 pytest --junitxml=test-results/pytest/results.xml tests/test_*.py
@@ -160,7 +168,6 @@ runs:
160168 FAISS_DISABLE_CPU_FEATURES=AVX2 LD_DEBUG=libs $CONDA/bin/python -c "import faiss" 2>&1 | grep faiss.so
161169 LD_DEBUG=libs $CONDA/bin/python -c "import faiss" 2>&1 | grep faiss_avx2.so
162170 - name : Upload test results
163- if : inputs.rocm == 'OFF'
164171 uses : actions/upload-artifact@v4
165172 with :
166173 name : test-results-arch=${{ runner.arch }}-opt=${{ inputs.opt_level }}-gpu=${{ inputs.gpu }}-raft=${{ inputs.raft }}-rocm=${{ inputs.rocm }}
0 commit comments