Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BEAM-14546] Fix errant pass for empty collections in Count #17813

Merged
merged 6 commits into from
Jun 3, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 9 additions & 9 deletions sdks/go/pkg/beam/io/synthetic/step.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,9 @@ func init() {
func Step(s beam.Scope, cfg StepConfig, col beam.PCollection) beam.PCollection {
s = s.Scope("synthetic.Step")
if cfg.Splittable {
return beam.ParDo(s, &sdfStepFn{cfg: cfg}, col)
return beam.ParDo(s, &sdfStepFn{Cfg: cfg}, col)
}
return beam.ParDo(s, &stepFn{cfg: cfg}, col)
return beam.ParDo(s, &stepFn{Cfg: cfg}, col)
}

// stepFn is a DoFn implementing behavior for synthetic steps. For usage
Expand All @@ -60,7 +60,7 @@ func Step(s beam.Scope, cfg StepConfig, col beam.PCollection) beam.PCollection {
// The stepFn is expected to be initialized with a cfg and will follow that
// config to determine its behavior when emitting elements.
type stepFn struct {
cfg StepConfig
Cfg StepConfig
rng randWrapper
}

Expand All @@ -73,9 +73,9 @@ func (fn *stepFn) Setup() {
// outputs identical to that input based on the outputs per input configuration
// in StepConfig.
func (fn *stepFn) ProcessElement(key, val []byte, emit func([]byte, []byte)) {
filtered := fn.cfg.FilterRatio > 0 && fn.rng.Float64() < fn.cfg.FilterRatio
filtered := fn.Cfg.FilterRatio > 0 && fn.rng.Float64() < fn.Cfg.FilterRatio

for i := 0; i < fn.cfg.OutputPerInput; i++ {
for i := 0; i < fn.Cfg.OutputPerInput; i++ {
if !filtered {
emit(key, val)
}
Expand All @@ -88,7 +88,7 @@ func (fn *stepFn) ProcessElement(key, val []byte, emit func([]byte, []byte)) {
// The sdfStepFn is expected to be initialized with a cfg and will follow
// that config to determine its behavior when splitting and emitting elements.
type sdfStepFn struct {
cfg StepConfig
Cfg StepConfig
rng randWrapper
}

Expand All @@ -98,7 +98,7 @@ type sdfStepFn struct {
func (fn *sdfStepFn) CreateInitialRestriction(_, _ []byte) offsetrange.Restriction {
return offsetrange.Restriction{
Start: 0,
End: int64(fn.cfg.OutputPerInput),
End: int64(fn.Cfg.OutputPerInput),
}
}

Expand All @@ -107,7 +107,7 @@ func (fn *sdfStepFn) CreateInitialRestriction(_, _ []byte) offsetrange.Restricti
// method will contain at least one element, so the number of splits will not
// exceed the number of elements.
func (fn *sdfStepFn) SplitRestriction(_, _ []byte, rest offsetrange.Restriction) (splits []offsetrange.Restriction) {
return rest.EvenSplits(int64(fn.cfg.InitialSplits))
return rest.EvenSplits(int64(fn.Cfg.InitialSplits))
}

// RestrictionSize outputs the size of the restriction as the number of elements
Expand All @@ -130,7 +130,7 @@ func (fn *sdfStepFn) Setup() {
// ProcessElement takes an input and either filters it or produces a number of
// outputs identical to that input based on the restriction size.
func (fn *sdfStepFn) ProcessElement(rt *sdf.LockRTracker, key, val []byte, emit func([]byte, []byte)) {
filtered := fn.cfg.FilterRatio > 0 && fn.rng.Float64() < fn.cfg.FilterRatio
filtered := fn.Cfg.FilterRatio > 0 && fn.rng.Float64() < fn.Cfg.FilterRatio

for i := rt.GetRestriction().(offsetrange.Restriction).Start; rt.TryClaim(i); i++ {
if !filtered {
Expand Down
12 changes: 6 additions & 6 deletions sdks/go/pkg/beam/io/synthetic/step_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ func TestStepConfig_OutputPerInput(t *testing.T) {
cfg := DefaultStepConfig().OutputPerInput(test.outPer).Build()

// Non-splittable StepFn.
dfn := stepFn{cfg: cfg}
dfn := stepFn{Cfg: cfg}
var keys [][]byte
emitFn := func(key []byte, val []byte) {
keys = append(keys, key)
Expand All @@ -52,7 +52,7 @@ func TestStepConfig_OutputPerInput(t *testing.T) {

// SDF StepFn.
cfg = DefaultStepConfig().OutputPerInput(test.outPer).Splittable(true).Build()
sdf := sdfStepFn{cfg: cfg}
sdf := sdfStepFn{Cfg: cfg}
keys, _ = simulateSdfStepFn(t, &sdf)
if got := len(keys); got != test.outPer {
t.Errorf("sdfStepFn emitted wrong number of outputs: got: %v, want: %v",
Expand Down Expand Up @@ -97,7 +97,7 @@ func TestStepConfig_FilterRatio(t *testing.T) {

// Non-splittable StepFn.
cfg := DefaultStepConfig().FilterRatio(test.ratio).Build()
dfn := stepFn{cfg: cfg}
dfn := stepFn{Cfg: cfg}
dfn.Setup()
dfn.rng = &fakeRand{f64: test.rand}
dfn.ProcessElement(elm, elm, emitFn)
Expand All @@ -110,7 +110,7 @@ func TestStepConfig_FilterRatio(t *testing.T) {

// SDF StepFn.
cfg = DefaultStepConfig().FilterRatio(test.ratio).Splittable(true).Build()
sdf := sdfStepFn{cfg: cfg}
sdf := sdfStepFn{Cfg: cfg}
keys = nil
rest := sdf.CreateInitialRestriction(elm, elm)
splits := sdf.SplitRestriction(elm, elm, rest)
Expand Down Expand Up @@ -154,7 +154,7 @@ func TestStepConfig_InitialSplits(t *testing.T) {
Build()
elm := []byte{0, 0, 0, 0}

sdf := sdfStepFn{cfg: cfg}
sdf := sdfStepFn{Cfg: cfg}
rest := sdf.CreateInitialRestriction(elm, elm)
splits := sdf.SplitRestriction(elm, elm, rest)
if got := len(splits); got != test.want {
Expand Down Expand Up @@ -186,7 +186,7 @@ func TestStepConfig_InitialSplits(t *testing.T) {
InitialSplits(test.splits).
Build()

sdf := sdfStepFn{cfg: cfg}
sdf := sdfStepFn{Cfg: cfg}
keys, _ := simulateSdfStepFn(t, &sdf)
if got := len(keys); got != test.want {
t.Errorf("SourceFn emitted wrong number of outputs: got: %v, want: %v",
Expand Down
4 changes: 4 additions & 0 deletions sdks/go/pkg/beam/testing/passert/count.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@ func Count(s beam.Scope, col beam.PCollection, name string, count int) {
if typex.IsKV(col.Type()) {
col = beam.DropKey(s, col)
}

if count > 0 {
NonEmpty(s, col)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to add the same thing to Hash and Sum? I think an empty pcollection would silently pass for both of those as well

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sum did need it, confirmed via unit test. Will add some short validation for Hash

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a small "fail on empty" check for Hash.

}
counted := beam.Combine(s, &elmCountCombineFn{}, col)
beam.ParDo0(s, &errFn{Name: name, Count: count}, counted)
}
Expand Down
64 changes: 51 additions & 13 deletions sdks/go/pkg/beam/testing/passert/count_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,24 +22,62 @@ import (
"github.com/apache/beam/sdks/v2/go/pkg/beam/testing/ptest"
)

func TestCount_Good(t *testing.T) {
p, s := beam.NewPipelineWithRoot()
col := beam.Create(s, "a", "b", "c", "d", "e")
count := 5
func TestCount(t *testing.T) {
var tests = []struct {
name string
elements []string
count int
}{
{
"full",
[]string{"a", "b", "c", "d", "e"},
5,
},
{
"empty",
[]string{},
0,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
p, s := beam.NewPipelineWithRoot()
col := beam.CreateList(s, test.elements)

Count(s, col, "TestCount_Good", count)
if err := ptest.Run(p); err != nil {
t.Errorf("Pipeline failed: %v", err)
Count(s, col, test.name, test.count)
if err := ptest.Run(p); err != nil {
t.Errorf("Pipeline failed: %v", err)
}
})
}
}

func TestCount_Bad(t *testing.T) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we simplify and combine these tests into 1 by adding a expectErr variable to the tests struct?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could, although I don't see too much value in bundling them into one suite. The setup isn't particularly long or complicated to test the function, so deduplicating it doesn't add much value IMO. Totally cool doing it if you feel strongly though.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd argue its worth doing since more/duplicated code => more opportunities for bugs to slip in when updates are needed and more for a future developer (maybe us) to understand (code is a liability). I'm not going to block on it though, its not very important

p, s := beam.NewPipelineWithRoot()
col := beam.Create(s, "a", "b", "c", "d", "e")
count := 10
var tests = []struct {
name string
elements []string
count int
}{
{
"mismatch",
[]string{"a", "b", "c", "d", "e"},
10,
},
{
"empty pcollection",
[]string{},
5,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
p, s := beam.NewPipelineWithRoot()
col := beam.CreateList(s, test.elements)

Count(s, col, "TestCount_Bad", count)
if err := ptest.Run(p); err == nil {
t.Errorf("pipeline SUCCEEDED but should have failed")
Count(s, col, test.name, test.count)
if err := ptest.Run(p); err == nil {
t.Errorf("pipeline SUCCEEDED but should have failed")
}
})
}
}
1 change: 1 addition & 0 deletions sdks/go/pkg/beam/testing/passert/hash.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import (
func Hash(s beam.Scope, col beam.PCollection, name, hash string, size int) {
s = s.Scope(fmt.Sprintf("passert.Hash(%v)", name))

NonEmpty(s, col)
keyed := beam.AddFixedKey(s, col)
grouped := beam.GroupByKey(s, keyed)
beam.ParDo0(s, &hashFn{Name: name, Size: size, Hash: hash}, grouped)
Expand Down
33 changes: 33 additions & 0 deletions sdks/go/pkg/beam/testing/passert/hash_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
// Licensed to the Apache Software Foundation (ASF) under one or more
// contributor license agreements. See the NOTICE file distributed with
// this work for additional information regarding copyright ownership.
// The ASF licenses this file to You under the Apache License, Version 2.0
// (the "License"); you may not use this file except in compliance with
// the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package passert

import (
"testing"

"github.com/apache/beam/sdks/v2/go/pkg/beam"
"github.com/apache/beam/sdks/v2/go/pkg/beam/testing/ptest"
)

func TestHash_Bad(t *testing.T) {
p, s := beam.NewPipelineWithRoot()
col := beam.CreateList(s, []string{})

Hash(s, col, "empty collection", "", 0)
if err := ptest.Run(p); err == nil {
t.Errorf("pipeline SUCCEEDED but should have failed")
}
}
19 changes: 18 additions & 1 deletion sdks/go/pkg/beam/testing/passert/passert.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ import (
)

//go:generate go install github.com/apache/beam/sdks/v2/go/cmd/starcgen
//go:generate starcgen --package=passert --identifiers=diffFn,failFn,failIfBadEntries,failKVFn,failGBKFn,hashFn,sumFn,errFn,elmCountCombineFn
//go:generate starcgen --package=passert --identifiers=diffFn,failFn,failIfBadEntries,failKVFn,failGBKFn,hashFn,sumFn,errFn,elmCountCombineFn,nonEmptyFn
//go:generate go fmt

// Diff splits 2 incoming PCollections into 3: left only, both, right only. Duplicates are
Expand Down Expand Up @@ -179,3 +179,20 @@ type failGBKFn struct {
func (f *failGBKFn) ProcessElement(x beam.X, _ func(*beam.Y) bool) error {
return errors.Errorf(f.Format, fmt.Sprintf("(%v,*)", x))
}

type nonEmptyFn struct{}

func (n *nonEmptyFn) ProcessElement(_ []byte, iter func(*beam.Z) bool) error {
var val beam.Z
for iter(&val) {
return nil
}
return errors.New("PCollection is empty, want non-empty collection")
}

// NonEmpty asserts that the given PCollection has at least one element.
func NonEmpty(s beam.Scope, col beam.PCollection) beam.PCollection {
s = s.Scope("passert.NonEmpty")
beam.ParDo0(s, &nonEmptyFn{}, beam.Impulse(s), beam.SideInput{Input: col})
return col
}
56 changes: 56 additions & 0 deletions sdks/go/pkg/beam/testing/passert/passert.shims.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading