diff --git a/sdks/go/examples/snippets/04transforms.go b/sdks/go/examples/snippets/04transforms.go index e0ff23351135..bb3abe3cf83a 100644 --- a/sdks/go/examples/snippets/04transforms.go +++ b/sdks/go/examples/snippets/04transforms.go @@ -65,17 +65,17 @@ func applyWordLen(s beam.Scope, words beam.PCollection) beam.PCollection { return wordLengths } +// [START model_pardo_apply_anon] + +func wordLengths(word string) int { return len(word) } +func init() { register.Function1x1(wordLengths) } + func applyWordLenAnon(s beam.Scope, words beam.PCollection) beam.PCollection { - // [START model_pardo_apply_anon] - // Apply an anonymous function as a DoFn PCollection words. - // Save the result as the PCollection wordLengths. - wordLengths := beam.ParDo(s, func(word string) int { - return len(word) - }, words) - // [END model_pardo_apply_anon] - return wordLengths + return beam.ParDo(s, wordLengths, words) } +// [END model_pardo_apply_anon] + func applyGbk(s beam.Scope, input []stringPair) beam.PCollection { // [START groupbykey] // CreateAndSplit creates and returns a PCollection with @@ -345,26 +345,29 @@ func globallyAverage(s beam.Scope, ints beam.PCollection) beam.PCollection { return average } +// [START combine_global_with_default] + +func returnSideOrDefault(d float64, iter func(*float64) bool) float64 { + var c float64 + if iter(&c) { + // Side input has a value, so return it. + return c + } + // Otherwise, return the default + return d +} +func init() { register.Function2x1(returnSideOrDefault) } + func globallyAverageWithDefault(s beam.Scope, ints beam.PCollection) beam.PCollection { - // [START combine_global_with_default] // Setting combine defaults has requires no helper function in the Go SDK. average := beam.Combine(s, &averageFn{}, ints) // To add a default value: defaultValue := beam.Create(s, float64(0)) - avgWithDefault := beam.ParDo(s, func(d float64, iter func(*float64) bool) float64 { - var c float64 - if iter(&c) { - // Side input has a value, so return it. - return c - } - // Otherwise, return the default - return d - }, defaultValue, beam.SideInput{Input: average}) - // [END combine_global_with_default] - return avgWithDefault + return beam.ParDo(s, returnSideOrDefault, defaultValue, beam.SideInput{Input: average}) } +// [END combine_global_with_default] func perKeyAverage(s beam.Scope, playerAccuracies beam.PCollection) beam.PCollection { // [START combine_per_key] avgAccuracyPerPlayer := stats.MeanPerKey(s, playerAccuracies) diff --git a/sdks/go/examples/snippets/04transforms_test.go b/sdks/go/examples/snippets/04transforms_test.go index 8d888e028562..509da6d5065a 100644 --- a/sdks/go/examples/snippets/04transforms_test.go +++ b/sdks/go/examples/snippets/04transforms_test.go @@ -19,6 +19,7 @@ import ( "testing" "github.com/apache/beam/sdks/v2/go/pkg/beam" + "github.com/apache/beam/sdks/v2/go/pkg/beam/register" "github.com/apache/beam/sdks/v2/go/pkg/beam/testing/passert" "github.com/apache/beam/sdks/v2/go/pkg/beam/testing/ptest" ) @@ -205,6 +206,14 @@ func TestSideInputs(t *testing.T) { ptest.RunAndValidate(t, p) } +func emitOnTestKey(k string, v int, emit func(int)) { + if k == "test" { + emit(v) + } +} + +func init() { register.Function3x0(emitOnTestKey) } + func TestComposite(t *testing.T) { p, s, lines := ptest.CreateList([]string{ "this test dataset has the word test", @@ -215,11 +224,7 @@ func TestComposite(t *testing.T) { // A Composite PTransform function is called like any other function. wordCounts := CountWords(s, lines) // returns a PCollection> // [END countwords_composite_call] - testCount := beam.ParDo(s, func(k string, v int, emit func(int)) { - if k == "test" { - emit(v) - } - }, wordCounts) + testCount := beam.ParDo(s, emitOnTestKey, wordCounts) passert.Equals(s, testCount, 4) ptest.RunAndValidate(t, p) } diff --git a/sdks/go/pkg/beam/testing/ptest/ptest.go b/sdks/go/pkg/beam/testing/ptest/ptest.go index 8e8412fa3d88..0bca8c48ceb6 100644 --- a/sdks/go/pkg/beam/testing/ptest/ptest.go +++ b/sdks/go/pkg/beam/testing/ptest/ptest.go @@ -60,27 +60,31 @@ func CreateList2(a, b any) (*beam.Pipeline, beam.Scope, beam.PCollection, beam.P return p, s, beam.CreateList(s, a), beam.CreateList(s, b) } +const ( + defaultRunner = "prism" +) + // Runner is a flag that sets which runner pipelines under test will use. // // The test file must have a TestMain that calls Main or MainWithDefault // to function. var ( - Runner = runners.Runner - defaultRunner = "prism" - mainCalled = false + Runner = runners.Runner + defaultRunnerOverride = defaultRunner + mainCalled = false ) func getRunner() string { r := *Runner if r == "" { - r = defaultRunner + r = defaultRunnerOverride } return r } // DefaultRunner returns the default runner name for the test file. func DefaultRunner() string { - return defaultRunner + return defaultRunnerOverride } // MainCalled returns true iff Main or MainRet has been called. @@ -133,7 +137,7 @@ func BuildAndRun(t *testing.T, build func(s beam.Scope)) beam.PipelineResult { // ptest.Main(m) // } func Main(m *testing.M) { - MainWithDefault(m, "direct") + MainWithDefault(m, defaultRunner) } // MainWithDefault is an implementation of testing's TestMain to permit testing @@ -141,7 +145,7 @@ func Main(m *testing.M) { // runner to use. func MainWithDefault(m *testing.M, runner string) { mainCalled = true - defaultRunner = runner + defaultRunnerOverride = runner if !flag.Parsed() { flag.Parse() } @@ -149,7 +153,7 @@ func MainWithDefault(m *testing.M, runner string) { os.Exit(m.Run()) } -// MainRet is equivelant to Main, but returns an exit code to pass to os.Exit(). +// MainRet is equivalent to Main, but returns an exit code to pass to os.Exit(). // // Example: // @@ -157,14 +161,14 @@ func MainWithDefault(m *testing.M, runner string) { // os.Exit(ptest.Main(m)) // } func MainRet(m *testing.M) int { - return MainRetWithDefault(m, "direct") + return MainRetWithDefault(m, defaultRunner) } -// MainRetWithDefault is equivelant to MainWithDefault but returns an exit code +// MainRetWithDefault is equivalent to MainWithDefault but returns an exit code // to pass to os.Exit(). func MainRetWithDefault(m *testing.M, runner string) int { mainCalled = true - defaultRunner = runner + defaultRunnerOverride = runner if !flag.Parsed() { flag.Parse() } diff --git a/sdks/go/pkg/beam/transforms/filter/filter_test.go b/sdks/go/pkg/beam/transforms/filter/filter_test.go index 96b4cbe12d79..ffa138e099a6 100644 --- a/sdks/go/pkg/beam/transforms/filter/filter_test.go +++ b/sdks/go/pkg/beam/transforms/filter/filter_test.go @@ -48,17 +48,17 @@ func TestInclude(t *testing.T) { }{ { []int{1, 2, 3}, - func(a int) bool { return true }, + alwaysTrue, []int{1, 2, 3}, }, { []int{1, 2, 3}, - func(a int) bool { return a == 1 }, + isOne, []int{1}, }, { []int{1, 2, 3}, - func(a int) bool { return a > 1 }, + greaterThanOne, []int{2, 3}, }, } @@ -81,17 +81,17 @@ func TestExclude(t *testing.T) { }{ { []int{1, 2, 3}, - func(a int) bool { return false }, + alwaysFalse, []int{1, 2, 3}, }, { []int{1, 2, 3}, - func(a int) bool { return a == 1 }, + isOne, []int{2, 3}, }, { []int{1, 2, 3}, - func(a int) bool { return a > 1 }, + greaterThanOne, []int{1}, }, }