diff --git a/tests/torchtune/_cli/test_download.py b/tests/torchtune/_cli/test_download.py index 5dbd695226..8a6d6ba0ab 100644 --- a/tests/torchtune/_cli/test_download.py +++ b/tests/torchtune/_cli/test_download.py @@ -65,3 +65,44 @@ def test_download_calls_snapshot(self, capsys, monkeypatch, snapshot_download): # Make sure it was called twice assert snapshot_download.call_count == 3 + + # GatedRepoError without --hf-token (expect prompt for token) + def test_gated_repo_error_no_token(self, capsys, monkeypatch, snapshot_download): + model = "meta-llama/Llama-2-7b" + testargs = f"tune download {model}".split() + monkeypatch.setattr(sys, "argv", testargs) + + # Expect GatedRepoError without --hf-token provided + with pytest.raises(SystemExit, match="2"): + runpy.run_path(TUNE_PATH, run_name="__main__") + + out_err = capsys.readouterr() + # Check that error message prompts for --hf-token + assert ( + "It looks like you are trying to access a gated repository." in out_err.err + ) + assert ( + "Please ensure you have access to the repository and have provided the proper Hugging Face API token" + in out_err.err + ) + + # GatedRepoError with --hf-token (should not ask for token) + def test_gated_repo_error_with_token(self, capsys, monkeypatch, snapshot_download): + model = "meta-llama/Llama-2-7b" + testargs = f"tune download {model} --hf-token valid_token".split() + monkeypatch.setattr(sys, "argv", testargs) + + # Expect GatedRepoError with --hf-token provided + with pytest.raises(SystemExit, match="2"): + runpy.run_path(TUNE_PATH, run_name="__main__") + + out_err = capsys.readouterr() + # Check that error message does not prompt for --hf-token again + assert ( + "It looks like you are trying to access a gated repository." in out_err.err + ) + assert "Please ensure you have access to the repository." in out_err.err + assert ( + "Please ensure you have access to the repository and have provided the proper Hugging Face API token" + not in out_err.err + ) diff --git a/torchtune/_cli/download.py b/torchtune/_cli/download.py index b35b81cca0..82b4935c01 100644 --- a/torchtune/_cli/download.py +++ b/torchtune/_cli/download.py @@ -131,12 +131,18 @@ def _download_cmd(self, args: argparse.Namespace) -> None: token=args.hf_token, ) except GatedRepoError: - self._parser.error( - "It looks like you are trying to access a gated repository. Please ensure you " - "have access to the repository and have provided the proper Hugging Face API token " - "using the option `--hf-token` or by running `huggingface-cli login`." - "You can find your token by visiting https://huggingface.co/settings/tokens" - ) + if args.hf_token: + self._parser.error( + "It looks like you are trying to access a gated repository. Please ensure you " + "have access to the repository." + ) + else: + self._parser.error( + "It looks like you are trying to access a gated repository. Please ensure you " + "have access to the repository and have provided the proper Hugging Face API token " + "using the option `--hf-token` or by running `huggingface-cli login`." + "You can find your token by visiting https://huggingface.co/settings/tokens" + ) except RepositoryNotFoundError: self._parser.error( f"Repository '{args.repo_id}' not found on the Hugging Face Hub."