Skip to content

Commit 65b9499

Browse files
Fix push_to_hub by not calling create_branch if branch exists (#7069)
* Fix push_to_hub by not calling create_branch if branch exists * Fix push_to_hub by not calling create_branch if branch exists * Reword comment * Fix push_to_hub by not calling create_branch if PR ref * Update test
1 parent 27ea8e8 commit 65b9499

File tree

3 files changed

+6
-5
lines changed

3 files changed

+6
-5
lines changed

src/datasets/arrow_dataset.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5677,7 +5677,8 @@ def push_to_hub(
56775677
)
56785678
repo_id = repo_url.repo_id
56795679

5680-
if revision is not None:
5680+
if revision is not None and not revision.startswith("refs/pr/"):
5681+
# We do not call create_branch for a PR reference: 400 Bad Request
56815682
api.create_branch(repo_id, branch=revision, token=token, repo_type="dataset", exist_ok=True)
56825683

56835684
if not data_dir:

src/datasets/dataset_dict.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1708,7 +1708,8 @@ def push_to_hub(
17081708
)
17091709
repo_id = repo_url.repo_id
17101710

1711-
if revision is not None:
1711+
if revision is not None and not revision.startswith("refs/pr/"):
1712+
# We do not call create_branch for a PR reference: 400 Bad Request
17121713
api.create_branch(repo_id, branch=revision, token=token, repo_type="dataset", exist_ok=True)
17131714

17141715
if not data_dir:

tests/test_hub.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -66,9 +66,8 @@ def test_convert_to_parquet(temporary_repo, hf_api, hf_token, ci_hub_config, ci_
6666
_ = convert_to_parquet(repo_id, token=hf_token, trust_remote_code=True)
6767
# mock_create_branch
6868
assert mock_create_branch.called
69-
assert mock_create_branch.call_count == 2
70-
for call_args, expected_branch in zip(mock_create_branch.call_args_list, ["refs/pr/1", "script"]):
71-
assert call_args.kwargs.get("branch") == expected_branch
69+
assert mock_create_branch.call_count == 1
70+
assert mock_create_branch.call_args.kwargs.get("branch") == "script"
7271
# mock_create_commit
7372
assert mock_create_commit.called
7473
assert mock_create_commit.call_count == 2

0 commit comments

Comments
 (0)