Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Add credentials parameter to read_gbq #78

Merged
merged 3 commits into from
Sep 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 21 additions & 4 deletions dask_bigquery/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@


@contextmanager
def bigquery_clients(project_id):
def bigquery_clients(project_id, credentials: dict = None):
"""This context manager is a temporary solution until there is an
upstream solution to handle this.
See googleapis/google-cloud-python#9457
Expand All @@ -35,7 +35,15 @@ def bigquery_clients(project_id):
user_agent=f"dask-bigquery/{dask_bigquery.__version__}"
)

with bigquery.Client(project_id, client_info=bq_client_info) as bq_client:
# Google library client needs an instance of google.auth.credentials.Credentials
if isinstance(credentials, dict):
credentials = service_account.Credentials.from_service_account_info(
info=credentials
)

with bigquery.Client(
project_id, credentials=credentials, client_info=bq_client_info
) as bq_client:
bq_storage_client = bigquery_storage.BigQueryReadClient(
credentials=bq_client._credentials,
client_info=bqstorage_client_info,
Expand Down Expand Up @@ -88,6 +96,7 @@ def bigquery_read(
project_id: str,
read_kwargs: dict,
arrow_options: dict,
credentials: dict = None,
) -> pd.DataFrame:
"""Read a single batch of rows via BQ Storage API, in Arrow binary format.

Expand All @@ -108,7 +117,7 @@ def bigquery_read(
NOTE: Please set if reading from Storage API without any `row_restriction`.
https://cloud.google.com/bigquery/docs/reference/storage/rpc/google.cloud.bigquery.storage.v1beta1#stream
"""
with bigquery_clients(project_id) as (_, bqs_client):
with bigquery_clients(project_id, credentials=credentials) as (_, bqs_client):
session = bqs_client.create_read_session(make_create_read_session_request())
schema = pyarrow.ipc.read_schema(
pyarrow.py_buffer(session.arrow_schema.serialized_schema)
Expand All @@ -132,6 +141,7 @@ def read_gbq(
max_stream_count: int = 0,
read_kwargs: dict = None,
arrow_options: dict = None,
credentials: dict = None,
):
"""Read table as dask dataframe using BigQuery Storage API via Arrow format.
Partitions will be approximately balanced according to BigQuery stream allocation logic.
Expand All @@ -157,14 +167,20 @@ def read_gbq(
kwargs to pass to record_batch.to_pandas() when converting from pyarrow to pandas. See
https://arrow.apache.org/docs/python/generated/pyarrow.RecordBatch.html#pyarrow.RecordBatch.to_pandas
for possible values
credentials : dict, optional
Credentials for accessing Google APIs. Use this parameter to override
default credentials. The dict should contain service account credentials in JSON format.

Returns
-------
Dask DataFrame
"""
read_kwargs = read_kwargs or {}
arrow_options = arrow_options or {}
with bigquery_clients(project_id) as (bq_client, bqs_client):
with bigquery_clients(project_id, credentials=credentials) as (
bq_client,
bqs_client,
):
table_ref = bq_client.get_table(f"{dataset_id}.{table_id}")
if table_ref.table_type == "VIEW":
raise TypeError("Table type VIEW not supported")
Expand Down Expand Up @@ -209,6 +225,7 @@ def make_create_read_session_request():
project_id=project_id,
read_kwargs=read_kwargs,
arrow_options=arrow_options,
credentials=credentials,
),
label=label,
)
Expand Down
29 changes: 29 additions & 0 deletions dask_bigquery/tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,35 @@ def test_read_columns(df, table, client):
assert list(ddf.columns) == columns


@pytest.mark.parametrize("dataset_fixture", ["write_dataset", "write_existing_dataset"])
def test_read_gbq_credentials(df, dataset_fixture, request, monkeypatch):
dataset = request.getfixturevalue(dataset_fixture)
credentials, project_id, dataset_id, table_id = dataset
ddf = dd.from_pandas(df, npartitions=2)

monkeypatch.delenv("GOOGLE_DEFAULT_CREDENTIALS", raising=False)
# with explicit credentials
result = to_gbq(
ddf,
project_id=project_id,
dataset_id=dataset_id,
table_id=table_id or "table_to_write",
credentials=credentials,
)
assert result.state == "DONE"

# with explicit credentials
ddf = read_gbq(
project_id=project_id,
dataset_id=dataset_id,
table_id=table_id or "table_to_write",
credentials=credentials,
)

assert list(ddf.columns) == ["name", "number", "timestamp", "idx"]
assert assert_eq(ddf.set_index("idx"), df.set_index("idx"))


def test_max_streams(df, table, client):
project_id, dataset_id, table_id = table
ddf = read_gbq(
Expand Down