Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 24 additions & 1 deletion storage/google/cloud/storage/acl.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,7 @@ class ACL(object):
# as properties).
reload_path = None
save_path = None
user_project = None

def __init__(self):
self.entities = {}
Expand Down Expand Up @@ -405,10 +406,18 @@ def reload(self, client=None):
"""
path = self.reload_path
client = self._require_client(client)
query_params = {}

if self.user_project is not None:
query_params['userProject'] = self.user_project

self.entities.clear()

found = client._connection.api_request(method='GET', path=path)
found = client._connection.api_request(
method='GET',
path=path,
query_params=query_params,
)
self.loaded = True
for entry in found.get('items', ()):
self.add_entity(self.entity_from_dict(entry))
Expand All @@ -435,8 +444,12 @@ def _save(self, acl, predefined, client):
acl = []
query_params[self._PREDEFINED_QUERY_PARAM] = predefined

if self.user_project is not None:
query_params['userProject'] = self.user_project

path = self.save_path
client = self._require_client(client)

result = client._connection.api_request(
method='PATCH',
path=path,
Expand Down Expand Up @@ -532,6 +545,11 @@ def save_path(self):
"""Compute the path for PATCH API requests for this ACL."""
return self.bucket.path

@property
def user_project(self):
"""Compute the user project charged for API requests for this ACL."""
return self.bucket.user_project


class DefaultObjectACL(BucketACL):
"""A class representing the default object ACL for a bucket."""
Expand Down Expand Up @@ -565,3 +583,8 @@ def reload_path(self):
def save_path(self):
"""Compute the path for PATCH API requests for this ACL."""
return self.blob.path

@property
def user_project(self):
"""Compute the user project charged for API requests for this ACL."""
return self.blob.user_project
135 changes: 100 additions & 35 deletions storage/tests/unit/test_acl.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,8 +532,11 @@ def test_reload_missing(self):
self.assertEqual(list(acl), [])
kw = connection._requested
self.assertEqual(len(kw), 1)
self.assertEqual(kw[0]['method'], 'GET')
self.assertEqual(kw[0]['path'], '/testing/acl')
self.assertEqual(kw[0], {
'method': 'GET',
'path': '/testing/acl',
'query_params': {},
})

def test_reload_empty_result_clears_local(self):
ROLE = 'role'
Expand All @@ -543,29 +546,41 @@ def test_reload_empty_result_clears_local(self):
acl.reload_path = '/testing/acl'
acl.loaded = True
acl.entity('allUsers', ROLE)

acl.reload(client=client)

self.assertTrue(acl.loaded)
self.assertEqual(list(acl), [])
kw = connection._requested
self.assertEqual(len(kw), 1)
self.assertEqual(kw[0]['method'], 'GET')
self.assertEqual(kw[0]['path'], '/testing/acl')
self.assertEqual(kw[0], {
'method': 'GET',
'path': '/testing/acl',
'query_params': {},
})

def test_reload_nonempty_result(self):
def test_reload_nonempty_result_w_user_project(self):
ROLE = 'role'
USER_PROJECT = 'user-project-123'
connection = _Connection(
{'items': [{'entity': 'allUsers', 'role': ROLE}]})
client = _Client(connection)
acl = self._make_one()
acl.reload_path = '/testing/acl'
acl.loaded = True
acl.user_project = USER_PROJECT

acl.reload(client=client)

self.assertTrue(acl.loaded)
self.assertEqual(list(acl), [{'entity': 'allUsers', 'role': ROLE}])
kw = connection._requested
self.assertEqual(len(kw), 1)
self.assertEqual(kw[0]['method'], 'GET')
self.assertEqual(kw[0]['path'], '/testing/acl')
self.assertEqual(kw[0], {
'method': 'GET',
'path': '/testing/acl',
'query_params': {'userProject': USER_PROJECT},
})

def test_save_none_set_none_passed(self):
connection = _Connection()
Expand Down Expand Up @@ -606,30 +621,43 @@ def test_save_no_acl(self):
self.assertEqual(len(kw), 1)
self.assertEqual(kw[0]['method'], 'PATCH')
self.assertEqual(kw[0]['path'], '/testing')
self.assertEqual(kw[0]['data'], {'acl': AFTER})
self.assertEqual(kw[0]['query_params'], {'projection': 'full'})

def test_save_w_acl(self):
self.assertEqual(kw[0], {
'method': 'PATCH',
'path': '/testing',
'query_params': {'projection': 'full'},
'data': {'acl': AFTER},
})

def test_save_w_acl_w_user_project(self):
ROLE1 = 'role1'
ROLE2 = 'role2'
STICKY = {'entity': 'allUsers', 'role': ROLE2}
USER_PROJECT = 'user-project-123'
new_acl = [{'entity': 'allUsers', 'role': ROLE1}]
connection = _Connection({'acl': [STICKY] + new_acl})
client = _Client(connection)
acl = self._make_one()
acl.save_path = '/testing'
acl.loaded = True
acl.user_project = USER_PROJECT

acl.save(new_acl, client=client)

entries = list(acl)
self.assertEqual(len(entries), 2)
self.assertTrue(STICKY in entries)
self.assertTrue(new_acl[0] in entries)
kw = connection._requested
self.assertEqual(len(kw), 1)
self.assertEqual(kw[0]['method'], 'PATCH')
self.assertEqual(kw[0]['path'], '/testing')
self.assertEqual(kw[0]['data'], {'acl': new_acl})
self.assertEqual(kw[0]['query_params'], {'projection': 'full'})
self.assertEqual(kw[0], {
'method': 'PATCH',
'path': '/testing',
'query_params': {
'projection': 'full',
'userProject': USER_PROJECT,
},
'data': {'acl': new_acl},
})

def test_save_prefefined_invalid(self):
connection = _Connection()
Expand All @@ -652,11 +680,15 @@ def test_save_predefined_valid(self):
self.assertEqual(len(entries), 0)
kw = connection._requested
self.assertEqual(len(kw), 1)
self.assertEqual(kw[0]['method'], 'PATCH')
self.assertEqual(kw[0]['path'], '/testing')
self.assertEqual(kw[0]['data'], {'acl': []})
self.assertEqual(kw[0]['query_params'],
{'projection': 'full', 'predefinedAcl': PREDEFINED})
self.assertEqual(kw[0], {
'method': 'PATCH',
'path': '/testing',
'query_params': {
'projection': 'full',
'predefinedAcl': PREDEFINED,
},
'data': {'acl': []},
})

def test_save_predefined_w_XML_alias(self):
PREDEFINED_XML = 'project-private'
Expand All @@ -671,12 +703,15 @@ def test_save_predefined_w_XML_alias(self):
self.assertEqual(len(entries), 0)
kw = connection._requested
self.assertEqual(len(kw), 1)
self.assertEqual(kw[0]['method'], 'PATCH')
self.assertEqual(kw[0]['path'], '/testing')
self.assertEqual(kw[0]['data'], {'acl': []})
self.assertEqual(kw[0]['query_params'],
{'projection': 'full',
'predefinedAcl': PREDEFINED_JSON})
self.assertEqual(kw[0], {
'method': 'PATCH',
'path': '/testing',
'query_params': {
'projection': 'full',
'predefinedAcl': PREDEFINED_JSON,
},
'data': {'acl': []},
})

def test_save_predefined_valid_w_alternate_query_param(self):
# Cover case where subclass overrides _PREDEFINED_QUERY_PARAM
Expand All @@ -692,11 +727,15 @@ def test_save_predefined_valid_w_alternate_query_param(self):
self.assertEqual(len(entries), 0)
kw = connection._requested
self.assertEqual(len(kw), 1)
self.assertEqual(kw[0]['method'], 'PATCH')
self.assertEqual(kw[0]['path'], '/testing')
self.assertEqual(kw[0]['data'], {'acl': []})
self.assertEqual(kw[0]['query_params'],
{'projection': 'full', 'alternate': PREDEFINED})
self.assertEqual(kw[0], {
'method': 'PATCH',
'path': '/testing',
'query_params': {
'projection': 'full',
'alternate': PREDEFINED,
},
'data': {'acl': []},
})

def test_clear(self):
ROLE1 = 'role1'
Expand All @@ -712,10 +751,12 @@ def test_clear(self):
self.assertEqual(list(acl), [STICKY])
kw = connection._requested
self.assertEqual(len(kw), 1)
self.assertEqual(kw[0]['method'], 'PATCH')
self.assertEqual(kw[0]['path'], '/testing')
self.assertEqual(kw[0]['data'], {'acl': []})
self.assertEqual(kw[0]['query_params'], {'projection': 'full'})
self.assertEqual(kw[0], {
'method': 'PATCH',
'path': '/testing',
'query_params': {'projection': 'full'},
'data': {'acl': []},
})


class Test_BucketACL(unittest.TestCase):
Expand All @@ -739,6 +780,15 @@ def test_ctor(self):
self.assertEqual(acl.reload_path, '/b/%s/acl' % NAME)
self.assertEqual(acl.save_path, '/b/%s' % NAME)

def test_user_project(self):
NAME = 'name'
USER_PROJECT = 'user-project-123'
bucket = _Bucket(NAME)
acl = self._make_one(bucket)
self.assertIsNone(acl.user_project)
bucket.user_project = USER_PROJECT
self.assertEqual(acl.user_project, USER_PROJECT)


class Test_DefaultObjectACL(unittest.TestCase):

Expand Down Expand Up @@ -785,9 +835,22 @@ def test_ctor(self):
self.assertEqual(acl.reload_path, '/b/%s/o/%s/acl' % (NAME, BLOB_NAME))
self.assertEqual(acl.save_path, '/b/%s/o/%s' % (NAME, BLOB_NAME))

def test_user_project(self):
NAME = 'name'
BLOB_NAME = 'blob-name'
USER_PROJECT = 'user-project-123'
bucket = _Bucket(NAME)
blob = _Blob(bucket, BLOB_NAME)
acl = self._make_one(blob)
self.assertIsNone(acl.user_project)
blob.user_project = USER_PROJECT
self.assertEqual(acl.user_project, USER_PROJECT)


class _Blob(object):

user_project = None

def __init__(self, bucket, blob):
self.bucket = bucket
self.blob = blob
Expand All @@ -799,6 +862,8 @@ def path(self):

class _Bucket(object):

user_project = None

def __init__(self, name):
self.name = name

Expand Down