diff --git a/sdks/go/pkg/beam/io/synthetic/step.go b/sdks/go/pkg/beam/io/synthetic/step.go index d800f5a054d38..d654a2bbe34f2 100644 --- a/sdks/go/pkg/beam/io/synthetic/step.go +++ b/sdks/go/pkg/beam/io/synthetic/step.go @@ -49,9 +49,9 @@ func init() { func Step(s beam.Scope, cfg StepConfig, col beam.PCollection) beam.PCollection { s = s.Scope("synthetic.Step") if cfg.Splittable { - return beam.ParDo(s, &sdfStepFn{cfg: cfg}, col) + return beam.ParDo(s, &sdfStepFn{Cfg: cfg}, col) } - return beam.ParDo(s, &stepFn{cfg: cfg}, col) + return beam.ParDo(s, &stepFn{Cfg: cfg}, col) } // stepFn is a DoFn implementing behavior for synthetic steps. For usage @@ -60,7 +60,7 @@ func Step(s beam.Scope, cfg StepConfig, col beam.PCollection) beam.PCollection { // The stepFn is expected to be initialized with a cfg and will follow that // config to determine its behavior when emitting elements. type stepFn struct { - cfg StepConfig + Cfg StepConfig rng randWrapper } @@ -73,9 +73,9 @@ func (fn *stepFn) Setup() { // outputs identical to that input based on the outputs per input configuration // in StepConfig. func (fn *stepFn) ProcessElement(key, val []byte, emit func([]byte, []byte)) { - filtered := fn.cfg.FilterRatio > 0 && fn.rng.Float64() < fn.cfg.FilterRatio + filtered := fn.Cfg.FilterRatio > 0 && fn.rng.Float64() < fn.Cfg.FilterRatio - for i := 0; i < fn.cfg.OutputPerInput; i++ { + for i := 0; i < fn.Cfg.OutputPerInput; i++ { if !filtered { emit(key, val) } @@ -88,7 +88,7 @@ func (fn *stepFn) ProcessElement(key, val []byte, emit func([]byte, []byte)) { // The sdfStepFn is expected to be initialized with a cfg and will follow // that config to determine its behavior when splitting and emitting elements. type sdfStepFn struct { - cfg StepConfig + Cfg StepConfig rng randWrapper } @@ -98,7 +98,7 @@ type sdfStepFn struct { func (fn *sdfStepFn) CreateInitialRestriction(_, _ []byte) offsetrange.Restriction { return offsetrange.Restriction{ Start: 0, - End: int64(fn.cfg.OutputPerInput), + End: int64(fn.Cfg.OutputPerInput), } } @@ -107,7 +107,7 @@ func (fn *sdfStepFn) CreateInitialRestriction(_, _ []byte) offsetrange.Restricti // method will contain at least one element, so the number of splits will not // exceed the number of elements. func (fn *sdfStepFn) SplitRestriction(_, _ []byte, rest offsetrange.Restriction) (splits []offsetrange.Restriction) { - return rest.EvenSplits(int64(fn.cfg.InitialSplits)) + return rest.EvenSplits(int64(fn.Cfg.InitialSplits)) } // RestrictionSize outputs the size of the restriction as the number of elements @@ -130,7 +130,7 @@ func (fn *sdfStepFn) Setup() { // ProcessElement takes an input and either filters it or produces a number of // outputs identical to that input based on the restriction size. func (fn *sdfStepFn) ProcessElement(rt *sdf.LockRTracker, key, val []byte, emit func([]byte, []byte)) { - filtered := fn.cfg.FilterRatio > 0 && fn.rng.Float64() < fn.cfg.FilterRatio + filtered := fn.Cfg.FilterRatio > 0 && fn.rng.Float64() < fn.Cfg.FilterRatio for i := rt.GetRestriction().(offsetrange.Restriction).Start; rt.TryClaim(i); i++ { if !filtered { diff --git a/sdks/go/pkg/beam/io/synthetic/step_test.go b/sdks/go/pkg/beam/io/synthetic/step_test.go index 809f084002012..f4393481567a8 100644 --- a/sdks/go/pkg/beam/io/synthetic/step_test.go +++ b/sdks/go/pkg/beam/io/synthetic/step_test.go @@ -37,7 +37,7 @@ func TestStepConfig_OutputPerInput(t *testing.T) { cfg := DefaultStepConfig().OutputPerInput(test.outPer).Build() // Non-splittable StepFn. - dfn := stepFn{cfg: cfg} + dfn := stepFn{Cfg: cfg} var keys [][]byte emitFn := func(key []byte, val []byte) { keys = append(keys, key) @@ -52,7 +52,7 @@ func TestStepConfig_OutputPerInput(t *testing.T) { // SDF StepFn. cfg = DefaultStepConfig().OutputPerInput(test.outPer).Splittable(true).Build() - sdf := sdfStepFn{cfg: cfg} + sdf := sdfStepFn{Cfg: cfg} keys, _ = simulateSdfStepFn(t, &sdf) if got := len(keys); got != test.outPer { t.Errorf("sdfStepFn emitted wrong number of outputs: got: %v, want: %v", @@ -97,7 +97,7 @@ func TestStepConfig_FilterRatio(t *testing.T) { // Non-splittable StepFn. cfg := DefaultStepConfig().FilterRatio(test.ratio).Build() - dfn := stepFn{cfg: cfg} + dfn := stepFn{Cfg: cfg} dfn.Setup() dfn.rng = &fakeRand{f64: test.rand} dfn.ProcessElement(elm, elm, emitFn) @@ -110,7 +110,7 @@ func TestStepConfig_FilterRatio(t *testing.T) { // SDF StepFn. cfg = DefaultStepConfig().FilterRatio(test.ratio).Splittable(true).Build() - sdf := sdfStepFn{cfg: cfg} + sdf := sdfStepFn{Cfg: cfg} keys = nil rest := sdf.CreateInitialRestriction(elm, elm) splits := sdf.SplitRestriction(elm, elm, rest) @@ -154,7 +154,7 @@ func TestStepConfig_InitialSplits(t *testing.T) { Build() elm := []byte{0, 0, 0, 0} - sdf := sdfStepFn{cfg: cfg} + sdf := sdfStepFn{Cfg: cfg} rest := sdf.CreateInitialRestriction(elm, elm) splits := sdf.SplitRestriction(elm, elm, rest) if got := len(splits); got != test.want { @@ -186,7 +186,7 @@ func TestStepConfig_InitialSplits(t *testing.T) { InitialSplits(test.splits). Build() - sdf := sdfStepFn{cfg: cfg} + sdf := sdfStepFn{Cfg: cfg} keys, _ := simulateSdfStepFn(t, &sdf) if got := len(keys); got != test.want { t.Errorf("SourceFn emitted wrong number of outputs: got: %v, want: %v", diff --git a/sdks/go/pkg/beam/testing/passert/count.go b/sdks/go/pkg/beam/testing/passert/count.go index 6bc5ad89e87a4..12e8981d04b6c 100644 --- a/sdks/go/pkg/beam/testing/passert/count.go +++ b/sdks/go/pkg/beam/testing/passert/count.go @@ -30,6 +30,10 @@ func Count(s beam.Scope, col beam.PCollection, name string, count int) { if typex.IsKV(col.Type()) { col = beam.DropKey(s, col) } + + if count > 0 { + NonEmpty(s, col) + } counted := beam.Combine(s, &elmCountCombineFn{}, col) beam.ParDo0(s, &errFn{Name: name, Count: count}, counted) } diff --git a/sdks/go/pkg/beam/testing/passert/count_test.go b/sdks/go/pkg/beam/testing/passert/count_test.go index 36f6c2aeeebe2..c34294998509a 100644 --- a/sdks/go/pkg/beam/testing/passert/count_test.go +++ b/sdks/go/pkg/beam/testing/passert/count_test.go @@ -22,24 +22,62 @@ import ( "github.com/apache/beam/sdks/v2/go/pkg/beam/testing/ptest" ) -func TestCount_Good(t *testing.T) { - p, s := beam.NewPipelineWithRoot() - col := beam.Create(s, "a", "b", "c", "d", "e") - count := 5 +func TestCount(t *testing.T) { + var tests = []struct { + name string + elements []string + count int + }{ + { + "full", + []string{"a", "b", "c", "d", "e"}, + 5, + }, + { + "empty", + []string{}, + 0, + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + p, s := beam.NewPipelineWithRoot() + col := beam.CreateList(s, test.elements) - Count(s, col, "TestCount_Good", count) - if err := ptest.Run(p); err != nil { - t.Errorf("Pipeline failed: %v", err) + Count(s, col, test.name, test.count) + if err := ptest.Run(p); err != nil { + t.Errorf("Pipeline failed: %v", err) + } + }) } } func TestCount_Bad(t *testing.T) { - p, s := beam.NewPipelineWithRoot() - col := beam.Create(s, "a", "b", "c", "d", "e") - count := 10 + var tests = []struct { + name string + elements []string + count int + }{ + { + "mismatch", + []string{"a", "b", "c", "d", "e"}, + 10, + }, + { + "empty pcollection", + []string{}, + 5, + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + p, s := beam.NewPipelineWithRoot() + col := beam.CreateList(s, test.elements) - Count(s, col, "TestCount_Bad", count) - if err := ptest.Run(p); err == nil { - t.Errorf("pipeline SUCCEEDED but should have failed") + Count(s, col, test.name, test.count) + if err := ptest.Run(p); err == nil { + t.Errorf("pipeline SUCCEEDED but should have failed") + } + }) } } diff --git a/sdks/go/pkg/beam/testing/passert/hash.go b/sdks/go/pkg/beam/testing/passert/hash.go index 678c86dc59c0f..5a562a3356e10 100644 --- a/sdks/go/pkg/beam/testing/passert/hash.go +++ b/sdks/go/pkg/beam/testing/passert/hash.go @@ -31,6 +31,7 @@ import ( func Hash(s beam.Scope, col beam.PCollection, name, hash string, size int) { s = s.Scope(fmt.Sprintf("passert.Hash(%v)", name)) + NonEmpty(s, col) keyed := beam.AddFixedKey(s, col) grouped := beam.GroupByKey(s, keyed) beam.ParDo0(s, &hashFn{Name: name, Size: size, Hash: hash}, grouped) diff --git a/sdks/go/pkg/beam/testing/passert/hash_test.go b/sdks/go/pkg/beam/testing/passert/hash_test.go new file mode 100644 index 0000000000000..177fd37e3ad5f --- /dev/null +++ b/sdks/go/pkg/beam/testing/passert/hash_test.go @@ -0,0 +1,33 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package passert + +import ( + "testing" + + "github.com/apache/beam/sdks/v2/go/pkg/beam" + "github.com/apache/beam/sdks/v2/go/pkg/beam/testing/ptest" +) + +func TestHash_Bad(t *testing.T) { + p, s := beam.NewPipelineWithRoot() + col := beam.CreateList(s, []string{}) + + Hash(s, col, "empty collection", "", 0) + if err := ptest.Run(p); err == nil { + t.Errorf("pipeline SUCCEEDED but should have failed") + } +} diff --git a/sdks/go/pkg/beam/testing/passert/passert.go b/sdks/go/pkg/beam/testing/passert/passert.go index 6530c67e2d6ab..c3dceaadb763d 100644 --- a/sdks/go/pkg/beam/testing/passert/passert.go +++ b/sdks/go/pkg/beam/testing/passert/passert.go @@ -29,7 +29,7 @@ import ( ) //go:generate go install github.com/apache/beam/sdks/v2/go/cmd/starcgen -//go:generate starcgen --package=passert --identifiers=diffFn,failFn,failIfBadEntries,failKVFn,failGBKFn,hashFn,sumFn,errFn,elmCountCombineFn +//go:generate starcgen --package=passert --identifiers=diffFn,failFn,failIfBadEntries,failKVFn,failGBKFn,hashFn,sumFn,errFn,elmCountCombineFn,nonEmptyFn //go:generate go fmt // Diff splits 2 incoming PCollections into 3: left only, both, right only. Duplicates are @@ -179,3 +179,20 @@ type failGBKFn struct { func (f *failGBKFn) ProcessElement(x beam.X, _ func(*beam.Y) bool) error { return errors.Errorf(f.Format, fmt.Sprintf("(%v,*)", x)) } + +type nonEmptyFn struct{} + +func (n *nonEmptyFn) ProcessElement(_ []byte, iter func(*beam.Z) bool) error { + var val beam.Z + for iter(&val) { + return nil + } + return errors.New("PCollection is empty, want non-empty collection") +} + +// NonEmpty asserts that the given PCollection has at least one element. +func NonEmpty(s beam.Scope, col beam.PCollection) beam.PCollection { + s = s.Scope("passert.NonEmpty") + beam.ParDo0(s, &nonEmptyFn{}, beam.Impulse(s), beam.SideInput{Input: col}) + return col +} diff --git a/sdks/go/pkg/beam/testing/passert/passert.shims.go b/sdks/go/pkg/beam/testing/passert/passert.shims.go index 795268d275b45..c9650d61280b0 100644 --- a/sdks/go/pkg/beam/testing/passert/passert.shims.go +++ b/sdks/go/pkg/beam/testing/passert/passert.shims.go @@ -49,6 +49,8 @@ func init() { schema.RegisterType(reflect.TypeOf((*failKVFn)(nil)).Elem()) runtime.RegisterType(reflect.TypeOf((*hashFn)(nil)).Elem()) schema.RegisterType(reflect.TypeOf((*hashFn)(nil)).Elem()) + runtime.RegisterType(reflect.TypeOf((*nonEmptyFn)(nil)).Elem()) + schema.RegisterType(reflect.TypeOf((*nonEmptyFn)(nil)).Elem()) runtime.RegisterType(reflect.TypeOf((*sumFn)(nil)).Elem()) schema.RegisterType(reflect.TypeOf((*sumFn)(nil)).Elem()) reflectx.RegisterStructWrapper(reflect.TypeOf((*diffFn)(nil)).Elem(), wrapMakerDiffFn) @@ -58,6 +60,7 @@ func init() { reflectx.RegisterStructWrapper(reflect.TypeOf((*failGBKFn)(nil)).Elem(), wrapMakerFailGBKFn) reflectx.RegisterStructWrapper(reflect.TypeOf((*failKVFn)(nil)).Elem(), wrapMakerFailKVFn) reflectx.RegisterStructWrapper(reflect.TypeOf((*hashFn)(nil)).Elem(), wrapMakerHashFn) + reflectx.RegisterStructWrapper(reflect.TypeOf((*nonEmptyFn)(nil)).Elem(), wrapMakerNonEmptyFn) reflectx.RegisterStructWrapper(reflect.TypeOf((*sumFn)(nil)).Elem(), wrapMakerSumFn) reflectx.RegisterFunc(reflect.TypeOf((*func(int, int) int)(nil)).Elem(), funcMakerIntIntГInt) reflectx.RegisterFunc(reflect.TypeOf((*func(int, func(*int) bool) error)(nil)).Elem(), funcMakerIntIterIntГError) @@ -67,6 +70,7 @@ func init() { reflectx.RegisterFunc(reflect.TypeOf((*func(int) int)(nil)).Elem(), funcMakerIntГInt) reflectx.RegisterFunc(reflect.TypeOf((*func([]byte, func(*typex.T) bool, func(*typex.T) bool, func(t typex.T), func(t typex.T), func(t typex.T)) error)(nil)).Elem(), funcMakerSliceOfByteIterTypex۰TIterTypex۰TEmitTypex۰TEmitTypex۰TEmitTypex۰TГError) reflectx.RegisterFunc(reflect.TypeOf((*func([]byte, func(*typex.T) bool, func(*typex.T) bool, func(*typex.T) bool) error)(nil)).Elem(), funcMakerSliceOfByteIterTypex۰TIterTypex۰TIterTypex۰TГError) + reflectx.RegisterFunc(reflect.TypeOf((*func([]byte, func(*typex.Z) bool) error)(nil)).Elem(), funcMakerSliceOfByteIterTypex۰ZГError) reflectx.RegisterFunc(reflect.TypeOf((*func(typex.X, func(*typex.Y) bool) error)(nil)).Elem(), funcMakerTypex۰XIterTypex۰YГError) reflectx.RegisterFunc(reflect.TypeOf((*func(typex.X, typex.Y) error)(nil)).Elem(), funcMakerTypex۰XTypex۰YГError) reflectx.RegisterFunc(reflect.TypeOf((*func(typex.X) error)(nil)).Elem(), funcMakerTypex۰XГError) @@ -76,6 +80,7 @@ func init() { exec.RegisterInput(reflect.TypeOf((*func(*string) bool)(nil)).Elem(), iterMakerString) exec.RegisterInput(reflect.TypeOf((*func(*typex.T) bool)(nil)).Elem(), iterMakerTypex۰T) exec.RegisterInput(reflect.TypeOf((*func(*typex.Y) bool)(nil)).Elem(), iterMakerTypex۰Y) + exec.RegisterInput(reflect.TypeOf((*func(*typex.Z) bool)(nil)).Elem(), iterMakerTypex۰Z) } func wrapMakerDiffFn(fn interface{}) map[string]reflectx.Func { @@ -132,6 +137,13 @@ func wrapMakerHashFn(fn interface{}) map[string]reflectx.Func { } } +func wrapMakerNonEmptyFn(fn interface{}) map[string]reflectx.Func { + dfn := fn.(*nonEmptyFn) + return map[string]reflectx.Func{ + "ProcessElement": reflectx.MakeFunc(func(a0 []byte, a1 func(*typex.Z) bool) error { return dfn.ProcessElement(a0, a1) }), + } +} + func wrapMakerSumFn(fn interface{}) map[string]reflectx.Func { dfn := fn.(*sumFn) return map[string]reflectx.Func{ @@ -347,6 +359,32 @@ func (c *callerSliceOfByteIterTypex۰TIterTypex۰TIterTypex۰TГError) Call4x1(a return c.fn(arg0.([]byte), arg1.(func(*typex.T) bool), arg2.(func(*typex.T) bool), arg3.(func(*typex.T) bool)) } +type callerSliceOfByteIterTypex۰ZГError struct { + fn func([]byte, func(*typex.Z) bool) error +} + +func funcMakerSliceOfByteIterTypex۰ZГError(fn interface{}) reflectx.Func { + f := fn.(func([]byte, func(*typex.Z) bool) error) + return &callerSliceOfByteIterTypex۰ZГError{fn: f} +} + +func (c *callerSliceOfByteIterTypex۰ZГError) Name() string { + return reflectx.FunctionName(c.fn) +} + +func (c *callerSliceOfByteIterTypex۰ZГError) Type() reflect.Type { + return reflect.TypeOf(c.fn) +} + +func (c *callerSliceOfByteIterTypex۰ZГError) Call(args []interface{}) []interface{} { + out0 := c.fn(args[0].([]byte), args[1].(func(*typex.Z) bool)) + return []interface{}{out0} +} + +func (c *callerSliceOfByteIterTypex۰ZГError) Call2x1(arg0, arg1 interface{}) interface{} { + return c.fn(arg0.([]byte), arg1.(func(*typex.Z) bool)) +} + type callerTypex۰XIterTypex۰YГError struct { fn func(typex.X, func(*typex.Y) bool) error } @@ -594,4 +632,22 @@ func (v *iterNative) readTypex۰Y(value *typex.Y) bool { return true } +func iterMakerTypex۰Z(s exec.ReStream) exec.ReusableInput { + ret := &iterNative{s: s} + ret.fn = ret.readTypex۰Z + return ret +} + +func (v *iterNative) readTypex۰Z(value *typex.Z) bool { + elm, err := v.cur.Read() + if err != nil { + if err == io.EOF { + return false + } + panic(fmt.Sprintf("broken stream: %v", err)) + } + *value = elm.Elm.(typex.Z) + return true +} + // DO NOT MODIFY: GENERATED CODE diff --git a/sdks/go/pkg/beam/testing/passert/passert_test.go b/sdks/go/pkg/beam/testing/passert/passert_test.go index 221b2f399c6c1..9524bc868ebb9 100644 --- a/sdks/go/pkg/beam/testing/passert/passert_test.go +++ b/sdks/go/pkg/beam/testing/passert/passert_test.go @@ -118,3 +118,21 @@ func TestEmpty_bad(t *testing.T) { t.Errorf("Pipeline failed but did not produce the expected error, got %v", err) } } + +func TestNonEmpty(t *testing.T) { + p, s := beam.NewPipelineWithRoot() + col := beam.CreateList(s, []string{"a", "b", "c"}) + NonEmpty(s, col) + if err := ptest.Run(p); err != nil { + t.Errorf("Pipeline failed: %v", err) + } +} + +func TestNonEmpty_Bad(t *testing.T) { + p, s := beam.NewPipelineWithRoot() + col := beam.CreateList(s, []string{}) + NonEmpty(s, col) + if err := ptest.Run(p); err == nil { + t.Error("Pipeline succeeded when it should have failed.") + } +} diff --git a/sdks/go/pkg/beam/testing/passert/sum.go b/sdks/go/pkg/beam/testing/passert/sum.go index 5e89ba5ed4c54..8a48c9f4dc838 100644 --- a/sdks/go/pkg/beam/testing/passert/sum.go +++ b/sdks/go/pkg/beam/testing/passert/sum.go @@ -27,7 +27,9 @@ import ( // that avoids a lot of machinery for testing. func Sum(s beam.Scope, col beam.PCollection, name string, size, value int) { s = s.Scope(fmt.Sprintf("passert.Sum(%v)", name)) - + if size > 0 { + NonEmpty(s, col) + } keyed := beam.AddFixedKey(s, col) grouped := beam.GroupByKey(s, keyed) beam.ParDo0(s, &sumFn{Name: name, Size: size, Sum: value}, grouped) diff --git a/sdks/go/pkg/beam/testing/passert/sum_test.go b/sdks/go/pkg/beam/testing/passert/sum_test.go index 1b6f391642a98..e63616dccfe29 100644 --- a/sdks/go/pkg/beam/testing/passert/sum_test.go +++ b/sdks/go/pkg/beam/testing/passert/sum_test.go @@ -87,6 +87,13 @@ func TestSum_bad(t *testing.T) { 16, []string{"{15, size: 5}", "want {16, size:5}"}, }, + { + "empty", + []int{}, + 1, + 1, + []string{"PCollection is empty, want non-empty collection"}, + }, } for _, tc := range tests { p, s := beam.NewPipelineWithRoot()