diff --git a/dask_bigquery/core.py b/dask_bigquery/core.py index 1fd396d..b596a7a 100644 --- a/dask_bigquery/core.py +++ b/dask_bigquery/core.py @@ -88,7 +88,8 @@ def read_gbq( project_id: str, dataset_id: str, table_id: str, - row_filter="", + row_filter: str = "", + columns: list[str] = None, read_kwargs: dict = None, ): """Read table as dask dataframe using BigQuery Storage API via Arrow format. @@ -104,6 +105,8 @@ def read_gbq( BigQuery table within dataset row_filter: str SQL text filtering statement to pass to `row_restriction` + columns: list[str] + list of columns to load from the table read_kwargs: dict kwargs to pass to read_rows() @@ -124,7 +127,7 @@ def make_create_read_session_request(row_filter=""): read_session=bigquery_storage.types.ReadSession( data_format=bigquery_storage.types.DataFormat.ARROW, read_options=bigquery_storage.types.ReadSession.TableReadOptions( - row_restriction=row_filter, + row_restriction=row_filter, selected_fields=columns ), table=table_ref.to_bqstorage(), ), diff --git a/dask_bigquery/tests/test_core.py b/dask_bigquery/tests/test_core.py index b8ecf90..17415e8 100644 --- a/dask_bigquery/tests/test_core.py +++ b/dask_bigquery/tests/test_core.py @@ -82,3 +82,17 @@ def test_read_kwargs(dataset, client): with pytest.raises(Exception, match="504 Deadline Exceeded"): ddf.compute() + + +def test_read_columns(df, dataset, client): + project_id, dataset_id, table_id = dataset + assert df.shape[1] > 1, "Test data should have multiple columns" + + columns = ["name"] + ddf = read_gbq( + project_id=project_id, + dataset_id=dataset_id, + table_id=table_id, + columns=columns, + ) + assert list(ddf.columns) == columns