diff --git a/dask_bigquery/core.py b/dask_bigquery/core.py index 5b39b53..2eab9de 100644 --- a/dask_bigquery/core.py +++ b/dask_bigquery/core.py @@ -22,7 +22,7 @@ @contextmanager -def bigquery_clients(project_id): +def bigquery_clients(project_id, credentials: dict = None): """This context manager is a temporary solution until there is an upstream solution to handle this. See googleapis/google-cloud-python#9457 @@ -35,7 +35,15 @@ def bigquery_clients(project_id): user_agent=f"dask-bigquery/{dask_bigquery.__version__}" ) - with bigquery.Client(project_id, client_info=bq_client_info) as bq_client: + # Google library client needs an instance of google.auth.credentials.Credentials + if isinstance(credentials, dict): + credentials = service_account.Credentials.from_service_account_info( + info=credentials + ) + + with bigquery.Client( + project_id, credentials=credentials, client_info=bq_client_info + ) as bq_client: bq_storage_client = bigquery_storage.BigQueryReadClient( credentials=bq_client._credentials, client_info=bqstorage_client_info, @@ -88,6 +96,7 @@ def bigquery_read( project_id: str, read_kwargs: dict, arrow_options: dict, + credentials: dict = None, ) -> pd.DataFrame: """Read a single batch of rows via BQ Storage API, in Arrow binary format. @@ -108,7 +117,7 @@ def bigquery_read( NOTE: Please set if reading from Storage API without any `row_restriction`. https://cloud.google.com/bigquery/docs/reference/storage/rpc/google.cloud.bigquery.storage.v1beta1#stream """ - with bigquery_clients(project_id) as (_, bqs_client): + with bigquery_clients(project_id, credentials=credentials) as (_, bqs_client): session = bqs_client.create_read_session(make_create_read_session_request()) schema = pyarrow.ipc.read_schema( pyarrow.py_buffer(session.arrow_schema.serialized_schema) @@ -132,6 +141,7 @@ def read_gbq( max_stream_count: int = 0, read_kwargs: dict = None, arrow_options: dict = None, + credentials: dict = None, ): """Read table as dask dataframe using BigQuery Storage API via Arrow format. Partitions will be approximately balanced according to BigQuery stream allocation logic. @@ -157,6 +167,9 @@ def read_gbq( kwargs to pass to record_batch.to_pandas() when converting from pyarrow to pandas. See https://arrow.apache.org/docs/python/generated/pyarrow.RecordBatch.html#pyarrow.RecordBatch.to_pandas for possible values + credentials : dict, optional + Credentials for accessing Google APIs. Use this parameter to override + default credentials. The dict should contain service account credentials in JSON format. Returns ------- @@ -164,7 +177,10 @@ def read_gbq( """ read_kwargs = read_kwargs or {} arrow_options = arrow_options or {} - with bigquery_clients(project_id) as (bq_client, bqs_client): + with bigquery_clients(project_id, credentials=credentials) as ( + bq_client, + bqs_client, + ): table_ref = bq_client.get_table(f"{dataset_id}.{table_id}") if table_ref.table_type == "VIEW": raise TypeError("Table type VIEW not supported") @@ -209,6 +225,7 @@ def make_create_read_session_request(): project_id=project_id, read_kwargs=read_kwargs, arrow_options=arrow_options, + credentials=credentials, ), label=label, ) diff --git a/dask_bigquery/tests/test_core.py b/dask_bigquery/tests/test_core.py index 67eec62..83b9416 100644 --- a/dask_bigquery/tests/test_core.py +++ b/dask_bigquery/tests/test_core.py @@ -336,6 +336,35 @@ def test_read_columns(df, table, client): assert list(ddf.columns) == columns +@pytest.mark.parametrize("dataset_fixture", ["write_dataset", "write_existing_dataset"]) +def test_read_gbq_credentials(df, dataset_fixture, request, monkeypatch): + dataset = request.getfixturevalue(dataset_fixture) + credentials, project_id, dataset_id, table_id = dataset + ddf = dd.from_pandas(df, npartitions=2) + + monkeypatch.delenv("GOOGLE_DEFAULT_CREDENTIALS", raising=False) + # with explicit credentials + result = to_gbq( + ddf, + project_id=project_id, + dataset_id=dataset_id, + table_id=table_id or "table_to_write", + credentials=credentials, + ) + assert result.state == "DONE" + + # with explicit credentials + ddf = read_gbq( + project_id=project_id, + dataset_id=dataset_id, + table_id=table_id or "table_to_write", + credentials=credentials, + ) + + assert list(ddf.columns) == ["name", "number", "timestamp", "idx"] + assert assert_eq(ddf.set_index("idx"), df.set_index("idx")) + + def test_max_streams(df, table, client): project_id, dataset_id, table_id = table ddf = read_gbq(