diff --git a/bigquery/storage_client.go b/bigquery/storage_client.go index 12b7702250f7..f63f649ceba0 100644 --- a/bigquery/storage_client.go +++ b/bigquery/storage_client.go @@ -95,7 +95,7 @@ func (c *readClient) close() error { } // sessionForTable establishes a new session to fetch from a table using the Storage API -func (c *readClient) sessionForTable(ctx context.Context, table *Table, ordered bool) (*readSession, error) { +func (c *readClient) sessionForTable(ctx context.Context, table *Table, rsProjectID string, ordered bool) (*readSession, error) { tableID, err := table.Identifier(StorageAPIResourceID) if err != nil { return nil, err @@ -111,6 +111,7 @@ func (c *readClient) sessionForTable(ctx context.Context, table *Table, ordered ctx: ctx, table: table, tableID: tableID, + projectID: rsProjectID, settings: settings, readRowsFunc: c.rawClient.ReadRows, createReadSessionFunc: c.rawClient.CreateReadSession, @@ -122,9 +123,10 @@ func (c *readClient) sessionForTable(ctx context.Context, table *Table, ordered type readSession struct { settings readClientSettings - ctx context.Context - table *Table - tableID string + ctx context.Context + table *Table + tableID string + projectID string bqSession *storagepb.ReadSession @@ -141,7 +143,7 @@ func (rs *readSession) start() error { preferredMinStreamCount = int32(rs.settings.maxWorkerCount) } createReadSessionRequest := &storagepb.CreateReadSessionRequest{ - Parent: fmt.Sprintf("projects/%s", rs.table.ProjectID), + Parent: fmt.Sprintf("projects/%s", rs.projectID), ReadSession: &storagepb.ReadSession{ Table: rs.tableID, DataFormat: storagepb.DataFormat_ARROW, diff --git a/bigquery/storage_integration_test.go b/bigquery/storage_integration_test.go index cbc9b5afd51b..492561c8a488 100644 --- a/bigquery/storage_integration_test.go +++ b/bigquery/storage_integration_test.go @@ -18,6 +18,7 @@ import ( "context" "errors" "fmt" + "strings" "testing" "time" @@ -87,6 +88,32 @@ func TestIntegration_StorageReadEmptyResultSet(t *testing.T) { } } +func TestIntegration_StorageReadClientProject(t *testing.T) { + if client == nil { + t.Skip("Integration tests skipped") + } + ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second) + defer cancel() + + table := storageOptimizedClient.Dataset("usa_names").Table("usa_1910_current") + table.ProjectID = "bigquery-public-data" + + it := table.Read(ctx) + _, err := countIteratorRows(it) + if err != nil { + t.Fatal(err) + } + if !it.IsAccelerated() { + t.Fatal("expected storage api to be used") + } + + session := it.arrowIterator.(*storageArrowIterator).rs + expectedPrefix := fmt.Sprintf("projects/%s", storageOptimizedClient.projectID) + if !strings.HasPrefix(session.bqSession.Name, expectedPrefix) { + t.Fatalf("expected read session to have prefix %q: but found %s:", expectedPrefix, session.bqSession.Name) + } +} + func TestIntegration_StorageReadFromSources(t *testing.T) { if client == nil { t.Skip("Integration tests skipped") diff --git a/bigquery/storage_iterator.go b/bigquery/storage_iterator.go index a7a8cdd38ce9..1ed9917ddfb9 100644 --- a/bigquery/storage_iterator.go +++ b/bigquery/storage_iterator.go @@ -47,12 +47,12 @@ type storageArrowIterator struct { var _ ArrowIterator = &storageArrowIterator{} -func newStorageRowIteratorFromTable(ctx context.Context, table *Table, ordered bool) (*RowIterator, error) { +func newStorageRowIteratorFromTable(ctx context.Context, table *Table, rsProjectID string, ordered bool) (*RowIterator, error) { md, err := table.Metadata(ctx) if err != nil { return nil, err } - rs, err := table.c.rc.sessionForTable(ctx, table, ordered) + rs, err := table.c.rc.sessionForTable(ctx, table, rsProjectID, ordered) if err != nil { return nil, err } @@ -95,7 +95,7 @@ func newStorageRowIteratorFromJob(ctx context.Context, j *Job) (*RowIterator, er return newStorageRowIteratorFromJob(ctx, lastJob) } ordered := query.HasOrderedResults(qcfg.Q) - return newStorageRowIteratorFromTable(ctx, qcfg.Dst, ordered) + return newStorageRowIteratorFromTable(ctx, qcfg.Dst, job.projectID, ordered) } func resolveLastChildSelectJob(ctx context.Context, job *Job) (*Job, error) { diff --git a/bigquery/table.go b/bigquery/table.go index 25068e558682..944a836d8dd6 100644 --- a/bigquery/table.go +++ b/bigquery/table.go @@ -974,7 +974,7 @@ func (t *Table) Read(ctx context.Context) *RowIterator { func (t *Table) read(ctx context.Context, pf pageFetcher) *RowIterator { if t.c.isStorageReadAvailable() { - it, err := newStorageRowIteratorFromTable(ctx, t, false) + it, err := newStorageRowIteratorFromTable(ctx, t, t.c.projectID, false) if err == nil { return it }