diff --git a/src/datasets/search.py b/src/datasets/search.py index f9b54e89b35..08b77df48ce 100644 --- a/src/datasets/search.py +++ b/src/datasets/search.py @@ -319,18 +319,11 @@ def save(self, file: Union[str, PurePath]): """Serialize the FaissIndex on disk""" import faiss # noqa: F811 - if ( - hasattr(self.faiss_index, "device") - and self.faiss_index.device is not None - and self.faiss_index.device > -1 - ) or ( - hasattr(self.faiss_index, "getDevice") - and self.faiss_index.getDevice() is not None - and self.faiss_index.getDevice() > -1 - ): + if self.device is not None and self.device > -1: index = faiss.index_gpu_to_cpu(self.faiss_index) else: index = self.faiss_index + faiss.write_index(index, str(file)) @classmethod