diff --git a/packages/google-cloud-bigquery-storage/samples/pyarrow/append_rows_with_arrow.py b/packages/google-cloud-bigquery-storage/samples/pyarrow/append_rows_with_arrow.py index 1d4adad52f01..cac46f98fc15 100644 --- a/packages/google-cloud-bigquery-storage/samples/pyarrow/append_rows_with_arrow.py +++ b/packages/google-cloud-bigquery-storage/samples/pyarrow/append_rows_with_arrow.py @@ -16,14 +16,14 @@ import datetime import decimal -from google.cloud import bigquery from google.cloud.bigquery import enums -from google.cloud.bigquery_storage_v1 import types as gapic_types -from google.cloud.bigquery_storage_v1.writer import AppendRowsStream import pandas as pd - import pyarrow as pa +from google.cloud import bigquery +from google.cloud.bigquery_storage_v1 import types as gapic_types +from google.cloud.bigquery_storage_v1.writer import AppendRowsStream + TABLE_LENGTH = 100_000 BQ_SCHEMA = [ @@ -100,7 +100,10 @@ def make_table(project_id, dataset_id, bq_client): def create_stream(bqstorage_write_client, table): - stream_name = f"projects/{table.project}/datasets/{table.dataset_id}/tables/{table.table_id}/_default" + stream_name = ( + f"projects/{table.project}/datasets/{table.dataset_id}/" + f"tables/{table.table_id}/_default" + ) request_template = gapic_types.AppendRowsRequest() request_template.write_stream = stream_name @@ -160,18 +163,64 @@ def generate_pyarrow_table(num_rows=TABLE_LENGTH): def generate_write_requests(pyarrow_table): - # Determine max_chunksize of the record batches. Because max size of - # AppendRowsRequest is 10 MB, we need to split the table if it's too big. - # See: https://cloud.google.com/bigquery/docs/reference/storage/rpc/google.cloud.bigquery.storage.v1#appendrowsrequest - max_request_bytes = 10 * 2**20 # 10 MB - chunk_num = int(pyarrow_table.nbytes / max_request_bytes) + 1 - chunk_size = int(pyarrow_table.num_rows / chunk_num) - - # Construct request(s). - for batch in pyarrow_table.to_batches(max_chunksize=chunk_size): + # Maximum size for a single AppendRowsRequest is 10 MB. + # To be safe, we'll aim for a soft limit of 7 MB. + max_request_bytes = 7 * 1024 * 1024 # 7 MB + + def _create_request(batches): + """Helper to create an AppendRowsRequest from a list of batches.""" + combined_table = pa.Table.from_batches(batches) request = gapic_types.AppendRowsRequest() - request.arrow_rows.rows.serialized_record_batch = batch.serialize().to_pybytes() - yield request + request.arrow_rows.rows.serialized_record_batch = ( + combined_table.combine_chunks().to_batches()[0].serialize().to_pybytes() + ) + return request + + batches = pyarrow_table.to_batches() + + current_batches = [] + current_size = 0 + + while batches: + batch = batches.pop() + batch_size = batch.nbytes + + if current_size + batch_size > max_request_bytes: + if batch.num_rows > 1: + # Split the batch into 2 sub batches with identical chunksizes + mid = batch.num_rows // 2 + batch_left = batch.slice(offset=0, length=mid) + batch_right = batch.slice(offset=mid) + + # Append the new batches into the stack and continue poping. + batches.append(batch_right) + batches.append(batch_left) + continue + + # If the batch is single row and still larger than max_request_size + else: + # If current batches is empty, throw error + if len(current_batches) == 0: + raise ValueError( + f"A single PyArrow batch of one row is larger than the maximum request size " + f"(batch size: {batch_size} > max request size: {max_request_bytes}). Cannot proceed." + ) + # Otherwise, generate the request, reset current_size and current_batches + else: + yield _create_request(current_batches) + + current_batches = [] + current_size = 0 + batches.append(batch) + + # Otherwise, add the batch into current_batches + else: + current_batches.append(batch) + current_size += batch_size + + # Flush remaining batches + if current_batches: + yield _create_request(current_batches) def verify_result(client, table, futures): @@ -181,14 +230,13 @@ def verify_result(client, table, futures): assert bq_table.schema == BQ_SCHEMA # Verify table size. - query = client.query(f"SELECT COUNT(1) FROM `{bq_table}`;") + query = client.query(f"SELECT DISTINCT int64_col FROM `{bq_table}`;") query_result = query.result().to_dataframe() - # There might be extra rows due to retries. - assert query_result.iloc[0, 0] >= TABLE_LENGTH + assert len(query_result) == TABLE_LENGTH # Verify that table was split into multiple requests. - assert len(futures) == 2 + assert len(futures) == 3 def main(project_id, dataset): diff --git a/packages/google-cloud-bigquery-storage/samples/pyarrow/test_generate_write_requests.py b/packages/google-cloud-bigquery-storage/samples/pyarrow/test_generate_write_requests.py new file mode 100644 index 000000000000..85b070ba2361 --- /dev/null +++ b/packages/google-cloud-bigquery-storage/samples/pyarrow/test_generate_write_requests.py @@ -0,0 +1,82 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time + +import pyarrow as pa +import pytest + +from . import append_rows_with_arrow + + +def create_table_with_batches(num_batches, rows_per_batch): + # Generate a small table to get a valid batch + small_table = append_rows_with_arrow.generate_pyarrow_table(rows_per_batch) + # Ensure we get exactly one batch for the small table + batches = small_table.to_batches() + assert len(batches) == 1 + batch = batches[0] + + # Replicate the batch + all_batches = [batch] * num_batches + return pa.Table.from_batches(all_batches) + + +# Test generate_write_requests with different numbers of batches in the input table. +# The total rows in the generated table is constantly 1000000. +@pytest.mark.parametrize( + "num_batches, rows_per_batch", + [ + (1, 1000000), + (10, 100000), + (100, 10000), + (1000, 1000), + (10000, 100), + (100000, 10), + (1000000, 1), + ], +) +def test_generate_write_requests_varying_batches(num_batches, rows_per_batch): + """Test generate_write_requests with different numbers of batches in the input table.""" + # Create a table that returns `num_batches` when to_batches() is called. + table = create_table_with_batches(num_batches, rows_per_batch) + + # Verify our setup is correct + assert len(table.to_batches()) == num_batches + + # Generate requests + start_time = time.perf_counter() + requests = list(append_rows_with_arrow.generate_write_requests(table)) + end_time = time.perf_counter() + print( + f"\nTime used to generate requests for {num_batches} batches: {end_time - start_time:.4f} seconds" + ) + + # We expect the requests to be aggregated until 7MB. + # Since the row number is constant, the number of requests should be deterministic. + assert len(requests) == 26 + + # Verify total rows in requests matches total rows in table + total_rows_processed = 0 + for request in requests: + # Deserialize the batch from the request to count rows + serialized_batch = request.arrow_rows.rows.serialized_record_batch + # We need a schema to read the batch. The schema is PYARROW_SCHEMA. + batch = pa.ipc.read_record_batch( + serialized_batch, append_rows_with_arrow.PYARROW_SCHEMA + ) + total_rows_processed += batch.num_rows + + expected_rows = num_batches * rows_per_batch + assert total_rows_processed == expected_rows