diff --git a/google/cloud/bigquery/client.py b/google/cloud/bigquery/client.py index 023346ffa..8cca2b24f 100644 --- a/google/cloud/bigquery/client.py +++ b/google/cloud/bigquery/client.py @@ -131,6 +131,8 @@ # https://github.com/googleapis/python-bigquery/issues/781#issuecomment-883497414 _PYARROW_BAD_VERSIONS = frozenset([packaging.version.Version("2.0.0")]) +TIMEOUT_HEADER = "X-Server-Timeout" + class Project(object): """Wrapper for resource describing a BigQuery project. @@ -742,16 +744,34 @@ def create_table( return self.get_table(table.reference, retry=retry) def _call_api( - self, retry, span_name=None, span_attributes=None, job_ref=None, **kwargs + self, + retry, + span_name=None, + span_attributes=None, + job_ref=None, + headers: Optional[Dict[str, str]] = None, + **kwargs, ): + timeout = kwargs.get("timeout") + if timeout is not None: + if headers is None: + headers = {} + headers[TIMEOUT_HEADER] = str(timeout) + + if headers: + kwargs["headers"] = headers + call = functools.partial(self._connection.api_request, **kwargs) + if retry: call = retry(call) + if span_name is not None: with create_span( name=span_name, attributes=span_attributes, client=self, job_ref=job_ref ): return call() + return call() def get_dataset( diff --git a/tests/unit/helpers.py b/tests/unit/helpers.py index 67aeaca35..1de3eae57 100644 --- a/tests/unit/helpers.py +++ b/tests/unit/helpers.py @@ -18,6 +18,38 @@ import pytest +def add_header_assertion_to_kwargs(kwargs): + timeout = kwargs.get("timeout") + if timeout is not None: + headers = kwargs.setdefault("headers", {}) + if headers is None: + kwargs["headers"] = headers = {} + headers[google.cloud.bigquery.client.TIMEOUT_HEADER] = str(kwargs["timeout"]) + + return kwargs + + +def add_header_assertion(mock, name): + """ + Modify assert_called_with-ish assertions to add timeout headers + + if there's a timeout + """ + orig = getattr(mock, name) + + def repl(*args, **kw): + return orig(*args, **add_header_assertion_to_kwargs(kw)) + + setattr(mock, name, repl) + + +def api_call(*args, **kw): + """ + Replacement for mock.call that adds a timeout header, if necessary + """ + return mock.call(*args, **add_header_assertion_to_kwargs(kw)) + + def make_connection(*responses): import google.cloud.bigquery._http import mock @@ -26,6 +58,9 @@ def make_connection(*responses): mock_conn = mock.create_autospec(google.cloud.bigquery._http.Connection) mock_conn.user_agent = "testing 1.2.3" mock_conn.api_request.side_effect = list(responses) + [NotFound("miss")] + for name in "assert_called_with", "assert_called_once_with": + add_header_assertion(mock_conn.api_request, name) + mock_conn.API_BASE_URL = "https://bigquery.googleapis.com" mock_conn.get_api_base_url_for_mtls = mock.Mock(return_value=mock_conn.API_BASE_URL) return mock_conn diff --git a/tests/unit/job/helpers.py b/tests/unit/job/helpers.py index c792214e7..ca115c93e 100644 --- a/tests/unit/job/helpers.py +++ b/tests/unit/job/helpers.py @@ -17,6 +17,8 @@ import mock from google.api_core import exceptions +from ..helpers import make_connection as _make_connection + def _make_credentials(): import google.auth.credentials @@ -35,15 +37,6 @@ def _make_client(project="test-project", connection=None): return client -def _make_connection(*responses): - import google.cloud.bigquery._http - from google.cloud.exceptions import NotFound - - mock_conn = mock.create_autospec(google.cloud.bigquery._http.Connection) - mock_conn.api_request.side_effect = list(responses) + [NotFound("miss")] - return mock_conn - - def _make_retriable_exception(): return exceptions.TooManyRequests( "retriable exception", errors=[{"reason": "rateLimitExceeded"}] diff --git a/tests/unit/job/test_base.py b/tests/unit/job/test_base.py index c3f7854e3..c96dfed1a 100644 --- a/tests/unit/job/test_base.py +++ b/tests/unit/job/test_base.py @@ -21,6 +21,8 @@ import mock import pytest +from ..helpers import api_call + from .helpers import _make_client from .helpers import _make_connection from .helpers import _make_retriable_exception @@ -824,8 +826,8 @@ def test_cancel_w_custom_retry(self): self.assertEqual( fake_api_request.call_args_list, [ - mock.call(method="POST", path=api_path, query_params={}, timeout=7.5), - mock.call( + api_call(method="POST", path=api_path, query_params={}, timeout=7.5), + api_call( method="POST", path=api_path, query_params={}, timeout=7.5 ), # was retried once ], @@ -941,13 +943,13 @@ def test_result_default_wo_state(self): self.assertIs(job.result(), job) - begin_call = mock.call( + begin_call = api_call( method="POST", path=f"/projects/{self.PROJECT}/jobs", data={"jobReference": {"jobId": self.JOB_ID, "projectId": self.PROJECT}}, timeout=None, ) - reload_call = mock.call( + reload_call = api_call( method="GET", path=f"/projects/{self.PROJECT}/jobs/{self.JOB_ID}", query_params={"location": "US"}, @@ -985,7 +987,7 @@ def test_result_w_retry_wo_state(self): ) self.assertIs(job.result(retry=custom_retry), job) - begin_call = mock.call( + begin_call = api_call( method="POST", path=f"/projects/{self.PROJECT}/jobs", data={ @@ -997,7 +999,7 @@ def test_result_w_retry_wo_state(self): }, timeout=None, ) - reload_call = mock.call( + reload_call = api_call( method="GET", path=f"/projects/{self.PROJECT}/jobs/{self.JOB_ID}", query_params={"location": "EU"}, diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index e9204f1de..76b9c323f 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -56,6 +56,7 @@ import google.cloud._helpers from google.cloud import bigquery_v2 from google.cloud.bigquery.dataset import DatasetReference +from google.cloud.bigquery.client import TIMEOUT_HEADER from google.cloud.bigquery.retry import DEFAULT_TIMEOUT try: @@ -63,7 +64,7 @@ except (ImportError, AttributeError): # pragma: NO COVER bigquery_storage = None from test_utils.imports import maybe_fail_import -from tests.unit.helpers import make_connection +from tests.unit.helpers import api_call, make_connection PANDAS_MINIUM_VERSION = pkg_resources.parse_version("1.0.0") @@ -469,8 +470,8 @@ def test_get_service_account_email_w_custom_retry(self): self.assertEqual( fake_api_request.call_args_list, [ - mock.call(method="GET", path=api_path, timeout=7.5), - mock.call(method="GET", path=api_path, timeout=7.5), # was retried once + api_call(method="GET", path=api_path, timeout=7.5), + api_call(method="GET", path=api_path, timeout=7.5), # was retried once ], ) @@ -846,12 +847,13 @@ def test_create_routine_w_conflict_exists_ok(self): self.assertEqual(actual_routine.routine_id, "minimal_routine") conn.api_request.assert_has_calls( [ - mock.call( + api_call( method="POST", path=path, data=resource, timeout=DEFAULT_TIMEOUT, ), - mock.call( + api_call( method="GET", - path="/projects/test-routine-project/datasets/test_routines/routines/minimal_routine", + path="/projects/test-routine-project/datasets/" + "test_routines/routines/minimal_routine", timeout=DEFAULT_TIMEOUT, ), ] @@ -1313,7 +1315,7 @@ def test_create_table_alreadyexists_w_exists_ok_true(self): conn.api_request.assert_has_calls( [ - mock.call( + api_call( method="POST", path=post_path, data={ @@ -1326,7 +1328,7 @@ def test_create_table_alreadyexists_w_exists_ok_true(self): }, timeout=DEFAULT_TIMEOUT, ), - mock.call(method="GET", path=get_path, timeout=DEFAULT_TIMEOUT), + api_call(method="GET", path=get_path, timeout=DEFAULT_TIMEOUT), ] ) @@ -1506,6 +1508,7 @@ def test_get_table_sets_user_agent(self): "X-Goog-API-Client": expected_user_agent, "Accept-Encoding": "gzip", "User-Agent": expected_user_agent, + TIMEOUT_HEADER: str(DEFAULT_TIMEOUT), }, data=mock.ANY, timeout=DEFAULT_TIMEOUT, @@ -2855,7 +2858,7 @@ def test_create_job_query_config_w_rateLimitExceeded_error(self): self.assertEqual(len(fake_api_request.call_args_list), 2) # was retried once self.assertEqual( fake_api_request.call_args_list[1], - mock.call( + api_call( method="POST", path="/projects/PROJECT/jobs", data=data_without_destination, @@ -5373,7 +5376,7 @@ def test_insert_rows_from_dataframe(self): for call, expected_data in itertools.zip_longest( actual_calls, EXPECTED_SENT_DATA ): - expected_call = mock.call( + expected_call = api_call( method="POST", path=API_PATH, data=expected_data, timeout=7.5 ) assert call == expected_call @@ -5441,7 +5444,7 @@ def test_insert_rows_from_dataframe_nan(self): for call, expected_data in itertools.zip_longest( actual_calls, EXPECTED_SENT_DATA ): - expected_call = mock.call( + expected_call = api_call( method="POST", path=API_PATH, data=expected_data, timeout=7.5 ) assert call == expected_call @@ -5488,7 +5491,7 @@ def test_insert_rows_from_dataframe_many_columns(self): } ] } - expected_call = mock.call( + expected_call = api_call( method="POST", path=API_PATH, data=EXPECTED_SENT_DATA, @@ -5544,7 +5547,7 @@ def test_insert_rows_from_dataframe_w_explicit_none_insert_ids(self): actual_calls = conn.api_request.call_args_list assert len(actual_calls) == 1 - assert actual_calls[0] == mock.call( + assert actual_calls[0] == api_call( method="POST", path=API_PATH, data=EXPECTED_SENT_DATA, @@ -5964,7 +5967,7 @@ def test_list_rows_w_start_index_w_page_size(self): conn.api_request.assert_has_calls( [ - mock.call( + api_call( method="GET", path="/%s" % PATH, query_params={ @@ -5974,7 +5977,7 @@ def test_list_rows_w_start_index_w_page_size(self): }, timeout=DEFAULT_TIMEOUT, ), - mock.call( + api_call( method="GET", path="/%s" % PATH, query_params={ diff --git a/tests/unit/test_create_dataset.py b/tests/unit/test_create_dataset.py index 67b21225d..95b945c02 100644 --- a/tests/unit/test_create_dataset.py +++ b/tests/unit/test_create_dataset.py @@ -13,7 +13,7 @@ # limitations under the License. from google.cloud.bigquery.dataset import Dataset, DatasetReference -from .helpers import make_connection, dataset_polymorphic, make_client +from .helpers import api_call, dataset_polymorphic, make_client, make_connection import google.cloud.bigquery.dataset from google.cloud.bigquery.retry import DEFAULT_TIMEOUT import mock @@ -349,7 +349,7 @@ def test_create_dataset_alreadyexists_w_exists_ok_true(PROJECT, DS_ID, LOCATION) conn.api_request.assert_has_calls( [ - mock.call( + api_call( method="POST", path=post_path, data={ @@ -359,6 +359,6 @@ def test_create_dataset_alreadyexists_w_exists_ok_true(PROJECT, DS_ID, LOCATION) }, timeout=DEFAULT_TIMEOUT, ), - mock.call(method="GET", path=get_path, timeout=DEFAULT_TIMEOUT), + api_call(method="GET", path=get_path, timeout=DEFAULT_TIMEOUT), ] ) diff --git a/tests/unit/test_magics.py b/tests/unit/test_magics.py index 36cbf4993..ba993780d 100644 --- a/tests/unit/test_magics.py +++ b/tests/unit/test_magics.py @@ -33,7 +33,7 @@ from google.cloud.bigquery import table from google.cloud.bigquery.magics import magics from google.cloud.bigquery.retry import DEFAULT_TIMEOUT -from tests.unit.helpers import make_connection +from tests.unit.helpers import api_call, make_connection from test_utils.imports import maybe_fail_import @@ -182,17 +182,17 @@ def test_context_with_default_connection(): # Check that query actually starts the job. conn.assert_called() list_rows.assert_called() - begin_call = mock.call( + begin_call = api_call( method="POST", path="/projects/project-from-env/jobs", data=mock.ANY, timeout=DEFAULT_TIMEOUT, ) - query_results_call = mock.call( + query_results_call = api_call( method="GET", path=f"/projects/{PROJECT_ID}/queries/{JOB_ID}", query_params=mock.ANY, - timeout=mock.ANY, + timeout=120, ) default_conn.api_request.assert_has_calls([begin_call, query_results_call]) @@ -246,17 +246,17 @@ def test_context_with_custom_connection(): list_rows.assert_called() default_conn.api_request.assert_not_called() - begin_call = mock.call( + begin_call = api_call( method="POST", path="/projects/project-from-env/jobs", data=mock.ANY, timeout=DEFAULT_TIMEOUT, ) - query_results_call = mock.call( + query_results_call = api_call( method="GET", path=f"/projects/{PROJECT_ID}/queries/{JOB_ID}", query_params=mock.ANY, - timeout=mock.ANY, + timeout=120, ) context_conn.api_request.assert_has_calls([begin_call, query_results_call])