diff --git a/bigquery/tests/test_async_query.py b/bigquery/tests/test_async_query.py index d41c724da0f..52498d34544 100644 --- a/bigquery/tests/test_async_query.py +++ b/bigquery/tests/test_async_query.py @@ -12,11 +12,12 @@ # limitations under the License. # import json +import os import unittest -from bigquery.samples.async_query import run -from tests import CloudBaseTest - +from bigquery.samples.async_query import run, main +from tests import CloudBaseTest, mock_raw_input, BUCKET_NAME_ENV, \ + PROJECT_ID_ENV class TestAsyncQuery(CloudBaseTest): @@ -29,5 +30,16 @@ def test_async_query(self): self.assertIsNotNone(json.loads(result)) +class TestAsyncRunner(CloudBaseTest): + + def test_async_query_runner(self): + test_bucket_name = os.environ.get(BUCKET_NAME_ENV) + test_project_id = os.environ.get(PROJECT_ID_ENV) + answers = [test_bucket_name, test_project_id, 'n', + '1', '1'] + with mock_raw_input(answers): + main() + + if __name__ == '__main__': unittest.main() diff --git a/tests/__init__.py b/tests/__init__.py index 9f87cf1d518..e006de633f0 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -18,7 +18,7 @@ import json import os import unittest - +import __builtin__ BUCKET_NAME_ENV = 'TEST_BUCKET_NAME' PROJECT_ID_ENV = 'TEST_PROJECT_ID' @@ -26,6 +26,25 @@ os.path.abspath(os.path.dirname(__file__)), 'resources') +class mock_raw_input(object): + + def __init__(self, list_): + self.i = 0 + self.list_ = list_ + + def get_next_value(self, question): + ret = self.list_[self.i] + self.i += 1 + return ret + + def __enter__(self): + self.raw_input_cache = __builtin__.raw_input + __builtin__.raw_input = self.get_next_value + + def __exit__(self, exc_type, exc_value, traceback): + __builtin__.raw_input = self.raw_input_cache + + class CloudBaseTest(unittest.TestCase): def setUp(self):