diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/providers/BigQueryStorageWriteApiSchemaTransformProvider.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/providers/BigQueryStorageWriteApiSchemaTransformProvider.java index 5f7851bba519..22e8abca35bc 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/providers/BigQueryStorageWriteApiSchemaTransformProvider.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/providers/BigQueryStorageWriteApiSchemaTransformProvider.java @@ -37,6 +37,8 @@ import org.apache.beam.sdk.io.gcp.bigquery.BigQueryUtils; import org.apache.beam.sdk.io.gcp.bigquery.WriteResult; import org.apache.beam.sdk.io.gcp.bigquery.providers.BigQueryStorageWriteApiSchemaTransformProvider.BigQueryStorageWriteApiSchemaTransformConfiguration; +import org.apache.beam.sdk.metrics.Counter; +import org.apache.beam.sdk.metrics.Metrics; import org.apache.beam.sdk.schemas.AutoValueSchema; import org.apache.beam.sdk.schemas.Schema; import org.apache.beam.sdk.schemas.Schema.Field; @@ -45,8 +47,10 @@ import org.apache.beam.sdk.schemas.transforms.SchemaTransform; import org.apache.beam.sdk.schemas.transforms.SchemaTransformProvider; import org.apache.beam.sdk.schemas.transforms.TypedSchemaTransformProvider; +import org.apache.beam.sdk.transforms.DoFn; import org.apache.beam.sdk.transforms.MapElements; import org.apache.beam.sdk.transforms.PTransform; +import org.apache.beam.sdk.transforms.ParDo; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PCollection.IsBounded; import org.apache.beam.sdk.values.PCollectionRowTuple; @@ -222,6 +226,30 @@ public void setBigQueryServices(BigQueryServices testBigQueryServices) { this.testBigQueryServices = testBigQueryServices; } + // A generic counter for PCollection of Row. Will be initialized with the given + // name argument. Performs element-wise counter of the input PCollection. + private static class ElementCounterFn extends DoFn { + private Counter bqGenericElementCounter; + private Long elementsInBundle = 0L; + + ElementCounterFn(String name) { + this.bqGenericElementCounter = + Metrics.counter(BigQueryStorageWriteApiPCollectionRowTupleTransform.class, name); + } + + @ProcessElement + public void process(ProcessContext c) { + this.elementsInBundle += 1; + c.output(c.element()); + } + + @FinishBundle + public void finish(FinishBundleContext c) { + this.bqGenericElementCounter.inc(this.elementsInBundle); + this.elementsInBundle = 0L; + } + } + @Override public PCollectionRowTuple expand(PCollectionRowTuple input) { // Check that the input exists @@ -241,7 +269,12 @@ public PCollectionRowTuple expand(PCollectionRowTuple input) { : Duration.standardSeconds(triggeringFrequency)); } - WriteResult result = inputRows.apply(write); + Schema inputSchema = inputRows.getSchema(); + WriteResult result = + inputRows + .apply("element-count", ParDo.of(new ElementCounterFn("element-counter"))) + .setRowSchema(inputSchema) + .apply(write); Schema errorSchema = Schema.of( @@ -263,7 +296,12 @@ public PCollectionRowTuple expand(PCollectionRowTuple input) { .build())) .setRowSchema(errorSchema); - return PCollectionRowTuple.of(OUTPUT_ERRORS_TAG, errorRows); + PCollection errorOutput = + errorRows + .apply("error-count", ParDo.of(new ElementCounterFn("error-counter"))) + .setRowSchema(errorSchema); + + return PCollectionRowTuple.of(OUTPUT_ERRORS_TAG, errorOutput); } BigQueryIO.Write createStorageWriteApiTransform() { diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/providers/BigQueryStorageWriteApiSchemaTransformProviderTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/providers/BigQueryStorageWriteApiSchemaTransformProviderTest.java index c8e733c8458f..c0a6df5f125d 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/providers/BigQueryStorageWriteApiSchemaTransformProviderTest.java +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/providers/BigQueryStorageWriteApiSchemaTransformProviderTest.java @@ -21,24 +21,36 @@ import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertThrows; +import com.google.api.services.bigquery.model.TableRow; +import java.io.Serializable; import java.time.LocalDateTime; import java.util.Arrays; import java.util.List; +import java.util.function.Function; +import org.apache.beam.sdk.PipelineResult; import org.apache.beam.sdk.io.gcp.bigquery.BigQueryHelpers; +import org.apache.beam.sdk.io.gcp.bigquery.BigQueryUtils; import org.apache.beam.sdk.io.gcp.bigquery.providers.BigQueryStorageWriteApiSchemaTransformProvider.BigQueryStorageWriteApiPCollectionRowTupleTransform; import org.apache.beam.sdk.io.gcp.bigquery.providers.BigQueryStorageWriteApiSchemaTransformProvider.BigQueryStorageWriteApiSchemaTransformConfiguration; import org.apache.beam.sdk.io.gcp.testing.FakeBigQueryServices; import org.apache.beam.sdk.io.gcp.testing.FakeDatasetService; import org.apache.beam.sdk.io.gcp.testing.FakeJobService; +import org.apache.beam.sdk.metrics.MetricNameFilter; +import org.apache.beam.sdk.metrics.MetricQueryResults; +import org.apache.beam.sdk.metrics.MetricResult; +import org.apache.beam.sdk.metrics.MetricResults; +import org.apache.beam.sdk.metrics.MetricsFilter; import org.apache.beam.sdk.schemas.Schema; import org.apache.beam.sdk.schemas.Schema.Field; import org.apache.beam.sdk.schemas.Schema.FieldType; import org.apache.beam.sdk.schemas.logicaltypes.SqlTypes; import org.apache.beam.sdk.testing.TestPipeline; import org.apache.beam.sdk.transforms.Create; +import org.apache.beam.sdk.transforms.MapElements; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PCollectionRowTuple; import org.apache.beam.sdk.values.Row; +import org.apache.beam.sdk.values.TypeDescriptor; import org.junit.Before; import org.junit.Rule; import org.junit.Test; @@ -60,24 +72,6 @@ public class BigQueryStorageWriteApiSchemaTransformProviderTest { Field.of("number", FieldType.INT64), Field.of("dt", FieldType.logicalType(SqlTypes.DATETIME))); - private static final List ROWS = - Arrays.asList( - Row.withSchema(SCHEMA) - .withFieldValue("name", "a") - .withFieldValue("number", 1L) - .withFieldValue("dt", LocalDateTime.parse("2000-01-01T00:00:00")) - .build(), - Row.withSchema(SCHEMA) - .withFieldValue("name", "b") - .withFieldValue("number", 2L) - .withFieldValue("dt", LocalDateTime.parse("2000-01-02T00:00:00")) - .build(), - Row.withSchema(SCHEMA) - .withFieldValue("name", "c") - .withFieldValue("number", 3L) - .withFieldValue("dt", LocalDateTime.parse("2000-01-03T00:00:00")) - .build()); - @Rule public final transient TestPipeline p = TestPipeline.create(); @Before @@ -115,10 +109,28 @@ public PCollectionRowTuple runWithConfig( (BigQueryStorageWriteApiPCollectionRowTupleTransform) provider.from(config).buildTransform(); + List testRows = + Arrays.asList( + Row.withSchema(SCHEMA) + .withFieldValue("name", "a") + .withFieldValue("number", 1L) + .withFieldValue("dt", LocalDateTime.parse("2000-01-01T00:00:00")) + .build(), + Row.withSchema(SCHEMA) + .withFieldValue("name", "b") + .withFieldValue("number", 2L) + .withFieldValue("dt", LocalDateTime.parse("2000-01-02T00:00:00")) + .build(), + Row.withSchema(SCHEMA) + .withFieldValue("name", "c") + .withFieldValue("number", 3L) + .withFieldValue("dt", LocalDateTime.parse("2000-01-03T00:00:00")) + .build()); + writeRowTupleTransform.setBigQueryServices(fakeBigQueryServices); String tag = provider.inputCollectionNames().get(0); - PCollection rows = p.apply(Create.of(ROWS).withRowSchema(SCHEMA)); + PCollection rows = p.apply(Create.of(testRows).withRowSchema(SCHEMA)); PCollectionRowTuple input = PCollectionRowTuple.of(tag, rows); PCollectionRowTuple result = input.apply(writeRowTupleTransform); @@ -135,8 +147,125 @@ public void testSimpleWrite() throws Exception { runWithConfig(config); p.run().waitUntilFinish(); + assertNotNull(fakeDatasetService.getTable(BigQueryHelpers.parseTableSpec(tableSpec))); + assertEquals(3, fakeDatasetService.getAllRows("project", "dataset", "simple_write").size()); + } + + @Test + public void testInputElementCount() throws Exception { + String tableSpec = "project:dataset.input_count"; + BigQueryStorageWriteApiSchemaTransformConfiguration config = + BigQueryStorageWriteApiSchemaTransformConfiguration.builder().setTable(tableSpec).build(); + + runWithConfig(config); + PipelineResult result = p.run(); + + MetricResults metrics = result.metrics(); + MetricQueryResults metricResults = + metrics.queryMetrics( + MetricsFilter.builder() + .addNameFilter( + MetricNameFilter.named( + BigQueryStorageWriteApiPCollectionRowTupleTransform.class, + "element-counter")) + .build()); + + Iterable> counters = metricResults.getCounters(); + if (!counters.iterator().hasNext()) { + throw new RuntimeException("no counters available for the input element count"); + } + + Long expectedCount = 3L; + for (MetricResult count : counters) { + assertEquals(expectedCount, count.getAttempted()); + } + } + + public PCollectionRowTuple runWithError( + BigQueryStorageWriteApiSchemaTransformConfiguration config) { + BigQueryStorageWriteApiSchemaTransformProvider provider = + new BigQueryStorageWriteApiSchemaTransformProvider(); + + BigQueryStorageWriteApiPCollectionRowTupleTransform writeRowTupleTransform = + (BigQueryStorageWriteApiPCollectionRowTupleTransform) + provider.from(config).buildTransform(); + + Function shouldFailRow = + (Function & Serializable) tr -> tr.get("name").equals("a"); + fakeDatasetService.setShouldFailRow(shouldFailRow); + + TableRow row1 = + new TableRow() + .set("name", "a") + .set("number", 1L) + .set("dt", LocalDateTime.parse("2000-01-01T00:00:00")); + TableRow row2 = + new TableRow() + .set("name", "b") + .set("number", 2L) + .set("dt", LocalDateTime.parse("2000-01-02T00:00:00")); + TableRow row3 = + new TableRow() + .set("name", "c") + .set("number", 3L) + .set("dt", LocalDateTime.parse("2000-01-03T00:00:00")); + + writeRowTupleTransform.setBigQueryServices(fakeBigQueryServices); + String tag = provider.inputCollectionNames().get(0); + + PCollection rows = + p.apply(Create.of(row1, row2, row3)) + .apply( + MapElements.into(TypeDescriptor.of(Row.class)) + .via((tableRow) -> BigQueryUtils.toBeamRow(SCHEMA, tableRow))) + .setRowSchema(SCHEMA); + + PCollectionRowTuple input = PCollectionRowTuple.of(tag, rows); + PCollectionRowTuple result = input.apply(writeRowTupleTransform); + + return result; + } + + @Test + public void testSimpleWriteWithFailure() throws Exception { + String tableSpec = "project:dataset.simple_write_with_failure"; + BigQueryStorageWriteApiSchemaTransformConfiguration config = + BigQueryStorageWriteApiSchemaTransformConfiguration.builder().setTable(tableSpec).build(); + + runWithError(config); + p.run().waitUntilFinish(); + assertNotNull(fakeDatasetService.getTable(BigQueryHelpers.parseTableSpec(tableSpec))); assertEquals( - ROWS.size(), fakeDatasetService.getAllRows("project", "dataset", "simple_write").size()); + 2, fakeDatasetService.getAllRows("project", "dataset", "simple_write_with_failure").size()); + } + + @Test + public void testErrorCount() throws Exception { + String tableSpec = "project:dataset.error_count"; + BigQueryStorageWriteApiSchemaTransformConfiguration config = + BigQueryStorageWriteApiSchemaTransformConfiguration.builder().setTable(tableSpec).build(); + + runWithError(config); + PipelineResult result = p.run(); + + MetricResults metrics = result.metrics(); + MetricQueryResults metricResults = + metrics.queryMetrics( + MetricsFilter.builder() + .addNameFilter( + MetricNameFilter.named( + BigQueryStorageWriteApiPCollectionRowTupleTransform.class, "error-counter")) + .build()); + + Iterable> counters = metricResults.getCounters(); + if (!counters.iterator().hasNext()) { + throw new RuntimeException("no counters available "); + } + + Long expectedCount = 1L; + for (MetricResult count : counters) { + assertEquals(expectedCount, count.getAttempted()); + } } }