diff --git a/pygmt/src/which.py b/pygmt/src/which.py index db1eb00c667..9b28990271a 100644 --- a/pygmt/src/which.py +++ b/pygmt/src/which.py @@ -2,11 +2,18 @@ which - Find the full path to specified files. """ from pygmt.clib import Session -from pygmt.helpers import GMTTempFile, build_arg_string, fmt_docstring, use_alias +from pygmt.helpers import ( + GMTTempFile, + build_arg_string, + fmt_docstring, + kwargs_to_strings, + use_alias, +) @fmt_docstring @use_alias(G="download", V="verbose") +@kwargs_to_strings(fname="sequence_space") def which(fname, **kwargs): """ Find the full path to specified files. @@ -27,8 +34,8 @@ def which(fname, **kwargs): Parameters ---------- - fname : str - The file name that you want to check. + fname : str or list + One or more file names of any data type (grids, tables, etc.). download : bool or str If the file is downloadable and not found, we will try to download the it. Use True or 'l' (default) to download to the current directory. Use @@ -38,8 +45,8 @@ def which(fname, **kwargs): Returns ------- - path : str - The path of the file, depending on the options used. + path : str or list + The path(s) to the file(s), depending on the options used. Raises ------ @@ -52,5 +59,6 @@ def which(fname, **kwargs): lib.call_module("which", arg_str) path = tmpfile.read().strip() if not path: - raise FileNotFoundError("File '{}' not found.".format(fname)) - return path + _fname = fname.replace(" ", "', '") + raise FileNotFoundError(f"File(s) '{_fname}' not found.") + return path.split("\n") if "\n" in path else path diff --git a/pygmt/tests/test_which.py b/pygmt/tests/test_which.py index caac58fe36d..bb9ae086dfb 100644 --- a/pygmt/tests/test_which.py +++ b/pygmt/tests/test_which.py @@ -10,7 +10,7 @@ def test_which(): """ - Make sure which returns file paths for @files correctly without errors. + Make sure `which` returns file paths for @files correctly without errors. """ for fname in ["tut_quakes.ngdc", "tut_bathy.nc"]: cached_file = which(f"@{fname}", download="c") @@ -18,10 +18,23 @@ def test_which(): assert os.path.basename(cached_file) == fname +def test_which_multiple(): + """ + Make sure `which` returns file paths for multiple @files correctly. + """ + filenames = ["ridge.txt", "tut_ship.xyz"] + cached_files = which(fname=[f"@{fname}" for fname in filenames], download="c") + for cached_file in cached_files: + assert os.path.exists(cached_file) + assert os.path.basename(cached_file) in filenames + + def test_which_fails(): """ - which should fail with a FileNotFoundError. + Make sure `which` will fail with a FileNotFoundError. """ bogus_file = unique_name() with pytest.raises(FileNotFoundError): which(bogus_file) + with pytest.raises(FileNotFoundError): + which(fname=[f"{bogus_file}.nc", f"{bogus_file}.txt"])