Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 2 additions & 5 deletions src/datasets/arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4320,17 +4320,14 @@ def _push_parquet_shards_to_hub(
f"The identifier should be in the format <repo_id> or <namespace>/<repo_id>. It is {identifier}, "
"which doesn't conform to either format."
)
elif len(identifier) == 2:
organization_or_username, dataset_name = identifier
elif len(identifier) == 1:
dataset_name = identifier[0]
organization_or_username = api.whoami(token)["name"]
repo_id = f"{organization_or_username}/{dataset_name}"

create_repo(
hf_api=api,
name=dataset_name,
organization=organization_or_username,
api,
repo_id,
token=token,
repo_type="dataset",
private=private,
Expand Down
20 changes: 8 additions & 12 deletions src/datasets/utils/_hf_hub_fixes.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,8 @@

def create_repo(
hf_api: HfApi,
name: str,
repo_id: str,
token: Optional[str] = None,
organization: Optional[str] = None,
private: Optional[bool] = None,
repo_type: Optional[str] = None,
exist_ok: Optional[bool] = False,
Expand All @@ -22,10 +21,8 @@ def create_repo(

Args:
hf_api (`huggingface_hub.HfApi`): Hub client
name (`str`): name of the repository (without the namespace)
repo_id (`str`): A namespace (user or an organization) and a repo name separated by a `/`.
token (`str`, *optional*): user or organization token. Defaults to None.
organization (`str`, *optional*): namespace for the repository: the username or organization name.
By default it uses the namespace associated to the token used.
private (`bool`, *optional*):
Whether the model repo should be private.
repo_type (`str`, *optional*):
Expand All @@ -42,6 +39,7 @@ def create_repo(
`str`: URL to the newly created repo.
"""
if version.parse(huggingface_hub.__version__) < version.parse("0.5.0"):
organization, name = repo_id.split("/")
return hf_api.create_repo(
name=name,
organization=organization,
Expand All @@ -53,7 +51,7 @@ def create_repo(
)
else: # the `organization` parameter is deprecated in huggingface_hub>=0.5.0
return hf_api.create_repo(
repo_id=f"{organization}/{name}",
repo_id=repo_id,
token=token,
private=private,
repo_type=repo_type,
Expand All @@ -64,9 +62,8 @@ def create_repo(

def delete_repo(
hf_api: HfApi,
name: str,
repo_id: str,
token: Optional[str] = None,
organization: Optional[str] = None,
repo_type: Optional[str] = None,
) -> str:
"""
Expand All @@ -75,10 +72,8 @@ def delete_repo(

Args:
hf_api (`huggingface_hub.HfApi`): Hub client
name (`str`): name of the repository (without the namespace)
repo_id (`str`): A namespace (user or an organization) and a repo name separated by a `/`.
token (`str`, *optional*): user or organization token. Defaults to None.
organization (`str`, *optional*): namespace for the repository: the username or organization name.
By default it uses the namespace associated to the token used.
repo_type (`str`, *optional*):
Set to `"dataset"` or `"space"` if uploading to a dataset or
space, `None` or `"model"` if uploading to a model. Default is
Expand All @@ -88,6 +83,7 @@ def delete_repo(
`str`: URL to the newly created repo.
"""
if version.parse(huggingface_hub.__version__) < version.parse("0.5.0"):
organization, name = repo_id.split("/")
return hf_api.delete_repo(
name=name,
organization=organization,
Expand All @@ -96,7 +92,7 @@ def delete_repo(
)
else: # the `organization` parameter is deprecated in huggingface_hub>=0.5.0
return hf_api.delete_repo(
repo_id=f"{organization}/{name}",
repo_id=repo_id,
token=token,
repo_type=repo_type,
)
Expand Down
15 changes: 7 additions & 8 deletions tests/fixtures/hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,7 @@ def hf_token(hf_api: HfApi):
@pytest.fixture
def cleanup_repo(hf_api):
def _cleanup_repo(repo_id):
organization, name = repo_id.split("/")
delete_repo(hf_api=hf_api, name=name, organization=organization, token=CI_HUB_USER_TOKEN, repo_type="dataset")
delete_repo(hf_api, repo_id, token=CI_HUB_USER_TOKEN, repo_type="dataset")

return _cleanup_repo

Expand All @@ -81,8 +80,8 @@ def _temporary_repo(repo_id):
@pytest.fixture(scope="session")
def hf_private_dataset_repo_txt_data_(hf_api: HfApi, hf_token, text_file):
repo_name = f"repo_txt_data-{int(time.time() * 10e3)}"
create_repo(hf_api, repo_name, token=hf_token, organization=CI_HUB_USER, repo_type="dataset", private=True)
repo_id = f"{CI_HUB_USER}/{repo_name}"
create_repo(hf_api, repo_id, token=hf_token, repo_type="dataset", private=True)
hf_api.upload_file(
token=hf_token,
path_or_fileobj=str(text_file),
Expand All @@ -92,7 +91,7 @@ def hf_private_dataset_repo_txt_data_(hf_api: HfApi, hf_token, text_file):
)
yield repo_id
try:
delete_repo(hf_api, repo_name, token=hf_token, organization=CI_HUB_USER, repo_type="dataset")
delete_repo(hf_api, repo_id, token=hf_token, repo_type="dataset")
except (requests.exceptions.HTTPError, ValueError): # catch http error and token invalid error
pass

Expand All @@ -107,8 +106,8 @@ def hf_private_dataset_repo_txt_data(hf_private_dataset_repo_txt_data_):
@pytest.fixture(scope="session")
def hf_private_dataset_repo_zipped_txt_data_(hf_api: HfApi, hf_token, zip_csv_with_dir_path):
repo_name = f"repo_zipped_txt_data-{int(time.time() * 10e3)}"
create_repo(hf_api, repo_name, token=hf_token, organization=CI_HUB_USER, repo_type="dataset", private=True)
repo_id = f"{CI_HUB_USER}/{repo_name}"
create_repo(hf_api, repo_id, token=hf_token, repo_type="dataset", private=True)
hf_api.upload_file(
token=hf_token,
path_or_fileobj=str(zip_csv_with_dir_path),
Expand All @@ -118,7 +117,7 @@ def hf_private_dataset_repo_zipped_txt_data_(hf_api: HfApi, hf_token, zip_csv_wi
)
yield repo_id
try:
delete_repo(hf_api, repo_name, token=hf_token, organization=CI_HUB_USER, repo_type="dataset")
delete_repo(hf_api, repo_id, token=hf_token, repo_type="dataset")
except (requests.exceptions.HTTPError, ValueError): # catch http error and token invalid error
pass

Expand All @@ -133,8 +132,8 @@ def hf_private_dataset_repo_zipped_txt_data(hf_private_dataset_repo_zipped_txt_d
@pytest.fixture(scope="session")
def hf_private_dataset_repo_zipped_img_data_(hf_api: HfApi, hf_token, zip_image_path):
repo_name = f"repo_zipped_img_data-{int(time.time() * 10e3)}"
create_repo(hf_api, repo_name, token=hf_token, organization=CI_HUB_USER, repo_type="dataset", private=True)
repo_id = f"{CI_HUB_USER}/{repo_name}"
create_repo(hf_api, repo_id, token=hf_token, repo_type="dataset", private=True)
hf_api.upload_file(
token=hf_token,
path_or_fileobj=str(zip_image_path),
Expand All @@ -144,7 +143,7 @@ def hf_private_dataset_repo_zipped_img_data_(hf_api: HfApi, hf_token, zip_image_
)
yield repo_id
try:
delete_repo(hf_api, repo_name, token=hf_token, organization=CI_HUB_USER, repo_type="dataset")
delete_repo(hf_api, repo_id, token=hf_token, repo_type="dataset")
except (requests.exceptions.HTTPError, ValueError): # catch http error and token invalid error
pass

Expand Down