diff --git a/tests/mockserver_tests/test_basics.py b/tests/mockserver_tests/test_basics.py index 3706552d31..9db84b117f 100644 --- a/tests/mockserver_tests/test_basics.py +++ b/tests/mockserver_tests/test_basics.py @@ -93,6 +93,23 @@ def test_dbapi_partitioned_dml(self): TransactionOptions(dict(partitioned_dml={})), begin_request.options ) + def test_batch_create_sessions_unavailable(self): + add_select1_result() + add_error(SpannerServicer.BatchCreateSessions.__name__, unavailable_status()) + with self.database.snapshot() as snapshot: + results = snapshot.execute_sql("select 1") + result_list = [] + for row in results: + result_list.append(row) + self.assertEqual(1, row[0]) + self.assertEqual(1, len(result_list)) + requests = self.spanner_service.requests + self.assertEqual(3, len(requests), msg=requests) + # The BatchCreateSessions call should be retried. + self.assertTrue(isinstance(requests[0], BatchCreateSessionsRequest)) + self.assertTrue(isinstance(requests[1], BatchCreateSessionsRequest)) + self.assertTrue(isinstance(requests[2], ExecuteSqlRequest)) + def test_execute_streaming_sql_unavailable(self): add_select1_result() # Add an UNAVAILABLE error that is returned the first time the