diff --git a/sdks/go/pkg/beam/runners/prism/internal/engine/data.go b/sdks/go/pkg/beam/runners/prism/internal/engine/data.go index 6fc192ac83be6..6679f484aa2d3 100644 --- a/sdks/go/pkg/beam/runners/prism/internal/engine/data.go +++ b/sdks/go/pkg/beam/runners/prism/internal/engine/data.go @@ -15,10 +15,30 @@ package engine +import ( + "bytes" + "fmt" + + "github.com/apache/beam/sdks/v2/go/pkg/beam/core/graph/coder" + "github.com/apache/beam/sdks/v2/go/pkg/beam/core/graph/window" + "github.com/apache/beam/sdks/v2/go/pkg/beam/core/runtime/exec" + "github.com/apache/beam/sdks/v2/go/pkg/beam/core/typex" + "golang.org/x/exp/slog" +) + +// StateData is a "union" between Bag state and MultiMap state to increase common code. +type StateData struct { + Bag [][]byte + Multimap map[string][][]byte +} + // TentativeData is where data for in progress bundles is put // until the bundle executes successfully. type TentativeData struct { Raw map[string][][]byte + + // state is a map from transformID + UserStateID, to window, to userKey, to datavalues. + state map[LinkID]map[typex.Window]map[string]StateData } // WriteData adds data to a given global collectionID. @@ -28,3 +48,159 @@ func (d *TentativeData) WriteData(colID string, data []byte) { } d.Raw[colID] = append(d.Raw[colID], data) } + +func (d *TentativeData) toWindow(wKey []byte) typex.Window { + if len(wKey) == 0 { + return window.GlobalWindow{} + } + // TODO: Custom Window handling. + w, err := exec.MakeWindowDecoder(coder.NewIntervalWindow()).DecodeSingle(bytes.NewBuffer(wKey)) + if err != nil { + panic(fmt.Sprintf("error decoding append bag user state window key %v: %v", wKey, err)) + } + return w +} + +// GetBagState retrieves available state from the tentative bundle data. +// The stateID has the Transform and Local fields populated, for the Transform and UserStateID respectively. +func (d *TentativeData) GetBagState(stateID LinkID, wKey, uKey []byte) [][]byte { + winMap := d.state[stateID] + w := d.toWindow(wKey) + data := winMap[w][string(uKey)] + slog.Debug("State() Bag.Get", slog.Any("StateID", stateID), slog.Any("UserKey", uKey), slog.Any("Window", w), slog.Any("Data", data)) + return data.Bag +} + +func (d *TentativeData) appendState(stateID LinkID, wKey []byte) map[string]StateData { + if d.state == nil { + d.state = map[LinkID]map[typex.Window]map[string]StateData{} + } + winMap, ok := d.state[stateID] + if !ok { + winMap = map[typex.Window]map[string]StateData{} + d.state[stateID] = winMap + } + w := d.toWindow(wKey) + kmap, ok := winMap[w] + if !ok { + kmap = map[string]StateData{} + winMap[w] = kmap + } + return kmap +} + +// AppendBagState appends the incoming data to the existing tentative data bundle. +// +// The stateID has the Transform and Local fields populated, for the Transform and UserStateID respectively. +func (d *TentativeData) AppendBagState(stateID LinkID, wKey, uKey, data []byte) { + kmap := d.appendState(stateID, wKey) + kmap[string(uKey)] = StateData{Bag: append(kmap[string(uKey)].Bag, data)} + slog.Debug("State() Bag.Append", slog.Any("StateID", stateID), slog.Any("UserKey", uKey), slog.Any("Window", wKey), slog.Any("NewData", data)) +} + +func (d *TentativeData) clearState(stateID LinkID, wKey []byte) map[string]StateData { + if d.state == nil { + return nil + } + winMap, ok := d.state[stateID] + if !ok { + return nil + } + w := d.toWindow(wKey) + return winMap[w] +} + +// ClearBagState clears any tentative data for the state. Since state data is only initialized if any exists, +// Clear takes the approach to not create state that doesn't already exist. Existing state is zeroed +// to allow that to be committed post bundle commpletion. +// +// The stateID has the Transform and Local fields populated, for the Transform and UserStateID respectively. +func (d *TentativeData) ClearBagState(stateID LinkID, wKey, uKey []byte) { + kmap := d.clearState(stateID, wKey) + if kmap == nil { + return + } + // Zero the current entry to clear. + // Delete makes it difficult to delete the persisted stage state for the key. + kmap[string(uKey)] = StateData{} + slog.Debug("State() Bag.Clear", slog.Any("StateID", stateID), slog.Any("UserKey", uKey), slog.Any("WindowKey", wKey)) +} + +// GetMultimapState retrieves available state from the tentative bundle data. +// The stateID has the Transform and Local fields populated, for the Transform and UserStateID respectively. +func (d *TentativeData) GetMultimapState(stateID LinkID, wKey, uKey, mapKey []byte) [][]byte { + winMap := d.state[stateID] + w := d.toWindow(wKey) + data := winMap[w][string(uKey)].Multimap[string(mapKey)] + slog.Debug("State() Multimap.Get", slog.Any("StateID", stateID), slog.Any("UserKey", uKey), slog.Any("Window", w), slog.Any("Data", data)) + return data +} + +// AppendMultimapState appends the incoming data to the existing tentative data bundle. +// +// The stateID has the Transform and Local fields populated, for the Transform and UserStateID respectively. +func (d *TentativeData) AppendMultimapState(stateID LinkID, wKey, uKey, mapKey, data []byte) { + kmap := d.appendState(stateID, wKey) + stateData, ok := kmap[string(uKey)] + if !ok || stateData.Multimap == nil { // Incase of All Key Clear tombstones, we may have a nil map. + stateData = StateData{Multimap: map[string][][]byte{}} + kmap[string(uKey)] = stateData + } + stateData.Multimap[string(mapKey)] = append(stateData.Multimap[string(mapKey)], data) + // The Multimap field is aliased to the instance we stored in kmap, + // so we don't need to re-assign back to kmap after appending the data to mapKey. + slog.Debug("State() Multimap.Append", slog.Any("StateID", stateID), slog.Any("UserKey", uKey), slog.Any("MapKey", mapKey), slog.Any("Window", wKey), slog.Any("NewData", data)) +} + +// ClearMultimapState clears any tentative data for the state. Since state data is only initialized if any exists, +// Clear takes the approach to not create state that doesn't already exist. Existing state is zeroed +// to allow that to be committed post bundle commpletion. +// +// The stateID has the Transform and Local fields populated, for the Transform and UserStateID respectively. +func (d *TentativeData) ClearMultimapState(stateID LinkID, wKey, uKey, mapKey []byte) { + kmap := d.clearState(stateID, wKey) + if kmap == nil { + return + } + // Nil the current entry to clear. + // Delete makes it difficult to delete the persisted stage state for the key. + userMap, ok := kmap[string(uKey)] + if !ok || userMap.Multimap == nil { + return + } + userMap.Multimap[string(mapKey)] = nil + // The Multimap field is aliased to the instance we stored in kmap, + // so we don't need to re-assign back to kmap after clearing the data from mapKey. + slog.Debug("State() Multimap.Clear", slog.Any("StateID", stateID), slog.Any("UserKey", uKey), slog.Any("Window", wKey)) +} + +// GetMultimapKeysState retrieves all available user map keys. +// +// The stateID has the Transform and Local fields populated, for the Transform and UserStateID respectively. +func (d *TentativeData) GetMultimapKeysState(stateID LinkID, wKey, uKey []byte) [][]byte { + winMap := d.state[stateID] + w := d.toWindow(wKey) + userMap := winMap[w][string(uKey)] + var keys [][]byte + for k := range userMap.Multimap { + keys = append(keys, []byte(k)) + } + slog.Debug("State() MultimapKeys.Get", slog.Any("StateID", stateID), slog.Any("UserKey", uKey), slog.Any("Window", w), slog.Any("Keys", keys)) + return keys +} + +// ClearMultimapKeysState clears tentative data for all user map keys. Since state data is only initialized if any exists, +// Clear takes the approach to not create state that doesn't already exist. Existing state is zeroed +// to allow that to be committed post bundle commpletion. +// +// The stateID has the Transform and Local fields populated, for the Transform and UserStateID respectively. +func (d *TentativeData) ClearMultimapKeysState(stateID LinkID, wKey, uKey []byte) { + kmap := d.clearState(stateID, wKey) + if kmap == nil { + return + } + // Zero the current entry to clear. + // Delete makes it difficult to delete the persisted stage state for the key. + kmap[string(uKey)] = StateData{} + slog.Debug("State() MultimapKeys.Clear", slog.Any("StateID", stateID), slog.Any("UserKey", uKey), slog.Any("WindowKey", wKey)) +} diff --git a/sdks/go/pkg/beam/runners/prism/internal/engine/elementmanager.go b/sdks/go/pkg/beam/runners/prism/internal/engine/elementmanager.go index 656525c670475..6cb5523541863 100644 --- a/sdks/go/pkg/beam/runners/prism/internal/engine/elementmanager.go +++ b/sdks/go/pkg/beam/runners/prism/internal/engine/elementmanager.go @@ -23,6 +23,7 @@ import ( "context" "fmt" "io" + "strings" "sync" "sync/atomic" @@ -39,6 +40,7 @@ type element struct { pane typex.PaneInfo elmBytes []byte + keyBytes []byte } type elements struct { @@ -51,6 +53,7 @@ type PColInfo struct { WDec exec.WindowDecoder WEnc exec.WindowEncoder EDec func(io.Reader) []byte + KeyDec func(io.Reader) []byte } // ToData recodes the elements with their approprate windowed value header. @@ -182,6 +185,12 @@ func (em *ElementManager) StageAggregates(ID string) { em.stages[ID].aggregate = true } +// StageStateful marks the given stage as stateful, which means elements are +// processed by key. +func (em *ElementManager) StageStateful(ID string) { + em.stages[ID].stateful = true +} + // Impulse marks and initializes the given stage as an impulse which // is a root transform that starts processing. func (em *ElementManager) Impulse(stageID string) { @@ -257,10 +266,13 @@ func (em *ElementManager) Bundles(ctx context.Context, nextBundID func() string) ss := em.stages[stageID] watermark, ready := ss.bundleReady(em) if ready { - bundleID, ok := ss.startBundle(watermark, nextBundID) + bundleID, ok, reschedule := ss.startBundle(watermark, nextBundID) if !ok { continue } + if reschedule { + em.watermarkRefreshes.insert(stageID) + } rb := RunBundle{StageID: stageID, BundleID: bundleID, Watermark: watermark} em.inprogressBundles.insert(rb.BundleID) @@ -278,7 +290,11 @@ func (em *ElementManager) Bundles(ctx context.Context, nextBundID func() string) v := em.livePending.Load() slog.Debug("Bundles: nothing in progress and no refreshes", slog.Int64("pendingElementCount", v)) if v > 0 { - panic(fmt.Sprintf("nothing in progress and no refreshes with non zero pending elements: %v", v)) + var stageState []string + for id, ss := range em.stages { + stageState = append(stageState, fmt.Sprintln(id, ss.pending, ss.pendingByKeys, ss.inprogressKeys, ss.inprogressKeysByBundle)) + } + panic(fmt.Sprintf("nothing in progress and no refreshes with non zero pending elements: %v\n%v", v, strings.Join(stageState, ""))) } } else if len(em.inprogressBundles) == 0 { v := em.livePending.Load() @@ -304,6 +320,56 @@ func (em *ElementManager) InputForBundle(rb RunBundle, info PColInfo) [][]byte { return es.ToData(info) } +// StateForBundle retreives relevant state for the given bundle, WRT the data in the bundle. +// +// TODO(lostluck): Consider unifiying with InputForBundle, to reduce lock contention. +func (em *ElementManager) StateForBundle(rb RunBundle) TentativeData { + ss := em.stages[rb.StageID] + ss.mu.Lock() + defer ss.mu.Unlock() + var ret TentativeData + keys := ss.inprogressKeysByBundle[rb.BundleID] + // TODO(lostluck): Also track windows per bundle, to reduce copying. + if len(ss.state) > 0 { + ret.state = map[LinkID]map[typex.Window]map[string]StateData{} + } + for link, winMap := range ss.state { + for w, keyMap := range winMap { + for key := range keys { + data, ok := keyMap[key] + if !ok { + continue + } + linkMap, ok := ret.state[link] + if !ok { + linkMap = map[typex.Window]map[string]StateData{} + ret.state[link] = linkMap + } + wlinkMap, ok := linkMap[w] + if !ok { + wlinkMap = map[string]StateData{} + linkMap[w] = wlinkMap + } + var mm map[string][][]byte + if len(data.Multimap) > 0 { + mm = map[string][][]byte{} + for uk, v := range data.Multimap { + // Clone the "holding" slice, but refer to the existing data bytes. + mm[uk] = append([][]byte(nil), v...) + } + } + // Clone the "holding" slice, but refer to the existing data bytes. + wlinkMap[key] = StateData{ + Bag: append([][]byte(nil), data.Bag...), + Multimap: mm, + } + } + } + } + + return ret +} + // reElementResiduals extracts the windowed value header from residual bytes, and explodes them // back out to their windows. func reElementResiduals(residuals [][]byte, inputInfo PColInfo, rb RunBundle) []element { @@ -322,6 +388,15 @@ func reElementResiduals(residuals [][]byte, inputInfo PColInfo, rb RunBundle) [] slog.Error("reElementResiduals: sdk provided a windowed value header 0 windows", "bundle", rb) panic("error decoding residual header: sdk provided a windowed value header 0 windows") } + // POSSIBLY BAD PATTERN: The buffer is invalidated on the next call, which doesn't always happen. + // But the decoder won't be mutating the buffer bytes, just reading the data. So the elmBytes + // should remain pointing to the whole element, and we should have a copy of the key bytes. + // Ideally, we're simply refering to the key part of the existing buffer. + elmBytes := buf.Bytes() + var keyBytes []byte + if inputInfo.KeyDec != nil { + keyBytes = inputInfo.KeyDec(buf) + } for _, w := range ws { unprocessedElements = append(unprocessedElements, @@ -329,7 +404,8 @@ func reElementResiduals(residuals [][]byte, inputInfo PColInfo, rb RunBundle) [] window: w, timestamp: et, pane: pn, - elmBytes: buf.Bytes(), + elmBytes: elmBytes, + keyBytes: keyBytes, }) } } @@ -373,6 +449,11 @@ func (em *ElementManager) PersistBundle(rb RunBundle, col2Coders map[string]PCol } // TODO: Optimize unnecessary copies. This is doubleteeing. elmBytes := info.EDec(tee) + var keyBytes []byte + if info.KeyDec != nil { + kbuf := bytes.NewBuffer(elmBytes) + keyBytes = info.KeyDec(kbuf) // TODO: Optimize unnecessary copies. This is tripleteeing? + } for _, w := range ws { newPending = append(newPending, element{ @@ -380,6 +461,7 @@ func (em *ElementManager) PersistBundle(rb RunBundle, col2Coders map[string]PCol timestamp: et, pane: pn, elmBytes: elmBytes, + keyBytes: keyBytes, }) } } @@ -412,6 +494,10 @@ func (em *ElementManager) PersistBundle(rb RunBundle, col2Coders map[string]PCol completed := stage.inprogress[rb.BundleID] em.addPending(-len(completed.es)) delete(stage.inprogress, rb.BundleID) + for k := range stage.inprogressKeysByBundle[rb.BundleID] { + delete(stage.inprogressKeys, k) + } + delete(stage.inprogressKeysByBundle, rb.BundleID) // If there are estimated output watermarks, set the estimated // output watermark for the stage. if len(estimatedOWM) > 0 { @@ -421,6 +507,25 @@ func (em *ElementManager) PersistBundle(rb RunBundle, col2Coders map[string]PCol } stage.estimatedOutput = estimate } + + // Handle persisting. + for link, winMap := range d.state { + linkMap, ok := stage.state[link] + if !ok { + linkMap = map[typex.Window]map[string]StateData{} + stage.state[link] = linkMap + } + for w, keyMap := range winMap { + wlinkMap, ok := linkMap[w] + if !ok { + wlinkMap = map[string]StateData{} + linkMap[w] = wlinkMap + } + for key, data := range keyMap { + wlinkMap[key] = data + } + } + } stage.mu.Unlock() // TODO support state/timer watermark holds. @@ -499,6 +604,11 @@ func (em *ElementManager) refreshWatermarks() set[string] { type set[K comparable] map[K]struct{} +func (s set[K]) present(k K) bool { + _, ok := s[k] + return ok +} + func (s set[K]) remove(k K) { delete(s, k) } @@ -525,7 +635,8 @@ type stageState struct { sides []LinkID // PCollection IDs of side inputs that can block execution. // Special handling bits - aggregate bool // whether this state needs to block for aggregation. + stateful bool // whether this stage uses state or timers, and needs keyed processing. + aggregate bool // whether this stage needs to block for aggregation. strat winStrat // Windowing Strategy for aggregation fireings. mu sync.Mutex @@ -537,6 +648,12 @@ type stageState struct { pending elementHeap // pending input elements for this stage that are to be processesd inprogress map[string]elements // inprogress elements by active bundles, keyed by bundle sideInputs map[LinkID]map[typex.Window][][]byte // side input data for this stage, from {tid, inputID} -> window + + // Fields for stateful stages which need to be per key. + pendingByKeys map[string]elementHeap // pending input elements by Key, if stateful. + inprogressKeys set[string] // all keys that are assigned to bundles. + inprogressKeysByBundle map[string]set[string] // bundle to key assignments. + state map[LinkID]map[typex.Window]map[string]StateData // state data for this stage, from {tid, stateID} -> window -> userKey } // makeStageState produces an initialized stageState. @@ -546,6 +663,7 @@ func makeStageState(ID string, inputIDs, outputIDs []string, sides []LinkID) *st outputIDs: outputIDs, sides: sides, strat: defaultStrat{}, + state: map[LinkID]map[typex.Window]map[string]StateData{}, input: mtime.MinTimestamp, output: mtime.MinTimestamp, @@ -566,8 +684,22 @@ func makeStageState(ID string, inputIDs, outputIDs []string, sides []LinkID) *st func (ss *stageState) AddPending(newPending []element) { ss.mu.Lock() defer ss.mu.Unlock() - ss.pending = append(ss.pending, newPending...) - heap.Init(&ss.pending) + if ss.stateful { + if ss.pendingByKeys == nil { + ss.pendingByKeys = map[string]elementHeap{} + } + for _, e := range newPending { + if len(e.keyBytes) == 0 { + panic(fmt.Sprintf("zero length key: %v %v", ss.ID, ss.inputID)) + } + h := ss.pendingByKeys[string(e.keyBytes)] + h.Push(e) + ss.pendingByKeys[string(e.keyBytes)] = h // (Is this necessary, with the way the heap interface works over a slice?) + } + } else { + ss.pending = append(ss.pending, newPending...) + heap.Init(&ss.pending) + } } // AddPendingSide adds elements to be consumed as side inputs. @@ -647,10 +779,16 @@ func (ss *stageState) OutputWatermark() mtime.Time { return ss.output } +// TODO: Move to better place for configuration +var ( + OneKeyPerBundle bool // OneKeyPerBundle sets if a bundle is restricted to a single key. + OneElementPerKey bool // OneElementPerKey sets if a key in a bundle is restricted to one element. +) + // startBundle initializes a bundle with elements if possible. // A bundle only starts if there are elements at all, and if it's // an aggregation stage, if the windowing stratgy allows it. -func (ss *stageState) startBundle(watermark mtime.Time, genBundID func() string) (string, bool) { +func (ss *stageState) startBundle(watermark mtime.Time, genBundID func() string) (string, bool, bool) { defer func() { if e := recover(); e != nil { panic(fmt.Sprintf("generating bundle for stage %v at %v panicked\n%v", ss.ID, watermark, e)) @@ -669,21 +807,73 @@ func (ss *stageState) startBundle(watermark mtime.Time, genBundID func() string) } ss.pending = notYet heap.Init(&ss.pending) + if ss.inprogressKeys == nil { + ss.inprogressKeys = set[string]{} + } + minTs := mtime.MaxTimestamp + // TODO: Allow configurable limit of keys per bundle, and elements per key to improve parallelism. + // TODO: when we do, we need to ensure that the stage remains schedualable for bundle execution, for remaining pending elements and keys. + // With the greedy approach, we don't need to since "new data" triggers a refresh, and so should completing processing of a bundle. + newKeys := set[string]{} + stillSchedulable := true + +keysPerBundle: + for k, h := range ss.pendingByKeys { + if ss.inprogressKeys.present(k) { + continue + } + newKeys.insert(k) + // Track the min-timestamp for later watermark handling. + if h[0].timestamp < minTs { + minTs = h[0].timestamp + } + + if OneElementPerKey { + hp := &h + toProcess = append(toProcess, heap.Pop(hp).(element)) + if hp.Len() == 0 { + // Once we've taken all the elements for a key, + // we must delete them from pending as well. + delete(ss.pendingByKeys, k) + } else { + ss.pendingByKeys[k] = *hp + } + } else { + toProcess = append(toProcess, h...) + delete(ss.pendingByKeys, k) + } + if OneKeyPerBundle { + break keysPerBundle + } + } + if len(ss.pendingByKeys) == 0 { + stillSchedulable = false + } if len(toProcess) == 0 { - return "", false + return "", false, false + } + + if toProcess[0].timestamp < minTs { + // Catch the ordinary case. + minTs = toProcess[0].timestamp } - // Is THIS is where basic splits should happen/per element processing? + es := elements{ es: toProcess, - minTimestamp: toProcess[0].timestamp, + minTimestamp: minTs, } if ss.inprogress == nil { ss.inprogress = make(map[string]elements) } + if ss.inprogressKeysByBundle == nil { + ss.inprogressKeysByBundle = make(map[string]set[string]) + } bundID := genBundID() ss.inprogress[bundID] = es - return bundID, true + ss.inprogressKeysByBundle[bundID] = newKeys + ss.inprogressKeys.merge(newKeys) + return bundID, true, stillSchedulable } func (ss *stageState) splitBundle(rb RunBundle, firstResidual int) { @@ -713,6 +903,12 @@ func (ss *stageState) minPendingTimestamp() mtime.Time { if len(ss.pending) != 0 { minPending = ss.pending[0].timestamp } + if len(ss.pendingByKeys) != 0 { + // TODO(lostluck): Can we figure out how to avoid checking every key on every watermark refresh? + for _, h := range ss.pendingByKeys { + minPending = mtime.Min(minPending, h[0].timestamp) + } + } for _, es := range ss.inprogress { minPending = mtime.Min(minPending, es.minTimestamp) } @@ -785,6 +981,14 @@ func (ss *stageState) updateWatermarks(minPending, minStateHold mtime.Time, em * } } } + for _, wins := range ss.state { + for win := range wins { + // Clear out anything we've already used. + if win.MaxTimestamp() < newOut { + delete(wins, win) + } + } + } } return refreshes } diff --git a/sdks/go/pkg/beam/runners/prism/internal/engine/engine_test.go b/sdks/go/pkg/beam/runners/prism/internal/engine/engine_test.go new file mode 100644 index 0000000000000..af41e089a2e91 --- /dev/null +++ b/sdks/go/pkg/beam/runners/prism/internal/engine/engine_test.go @@ -0,0 +1,159 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package engine_test ensures coverage of the element manager via pipeline actuation. +package engine_test + +import ( + "context" + "fmt" + "math/rand" + "os" + "strings" + "testing" + "time" + + "github.com/apache/beam/sdks/v2/go/pkg/beam" + "github.com/apache/beam/sdks/v2/go/pkg/beam/core/util/reflectx" + "github.com/apache/beam/sdks/v2/go/pkg/beam/options/jobopts" + "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/prism/internal" + "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/prism/internal/engine" + "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/prism/internal/jobservices" + "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/universal" + "github.com/apache/beam/sdks/v2/go/pkg/beam/testing/ptest" + "github.com/apache/beam/sdks/v2/go/test/integration/primitives" +) + +func init() { + // Not actually being used, but explicitly registering + // will avoid accidentally using a different runner for + // the tests if I change things later. + beam.RegisterRunner("testlocal", execute) +} + +func TestMain(m *testing.M) { + ptest.MainWithDefault(m, "testlocal") +} + +func initRunner(t testing.TB) { + t.Helper() + if *jobopts.Endpoint == "" { + s := jobservices.NewServer(0, internal.RunPipeline) + *jobopts.Endpoint = s.Endpoint() + go s.Serve() + t.Cleanup(func() { + *jobopts.Endpoint = "" + s.Stop() + }) + } + if !jobopts.IsLoopback() { + *jobopts.EnvironmentType = "loopback" + } + // Since we force loopback, avoid cross-compilation. + f, err := os.CreateTemp("", "dummy") + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { os.Remove(f.Name()) }) + *jobopts.WorkerBinary = f.Name() +} + +func execute(ctx context.Context, p *beam.Pipeline) (beam.PipelineResult, error) { + return universal.Execute(ctx, p) +} + +func executeWithT(ctx context.Context, t testing.TB, p *beam.Pipeline) (beam.PipelineResult, error) { + t.Helper() + t.Log("startingTest - ", t.Name()) + s1 := rand.NewSource(time.Now().UnixNano()) + r1 := rand.New(s1) + *jobopts.JobName = fmt.Sprintf("%v-%v", strings.ToLower(t.Name()), r1.Intn(1000)) + return execute(ctx, p) +} + +func initTestName(fn any) string { + name := reflectx.FunctionName(fn) + n := strings.LastIndex(name, "/") + return name[n+1:] +} + +func TestStateAPI(t *testing.T) { + initRunner(t) + + tests := []struct { + pipeline func(s beam.Scope) + }{ + {pipeline: primitives.BagStateParDo}, + {pipeline: primitives.BagStateParDoClear}, + {pipeline: primitives.CombiningStateParDo}, + {pipeline: primitives.ValueStateParDo}, + {pipeline: primitives.ValueStateParDoClear}, + {pipeline: primitives.ValueStateParDoWindowed}, + {pipeline: primitives.MapStateParDo}, + {pipeline: primitives.MapStateParDoClear}, + {pipeline: primitives.SetStateParDo}, + {pipeline: primitives.SetStateParDoClear}, + } + + configs := []struct { + name string + OneElementPerKey, OneKeyPerBundle bool + }{ + {"Greedy", false, false}, + {"AllElementsPerKey", false, true}, + {"OneElementPerKey", true, false}, + {"OneElementPerBundle", true, true}, + } + for _, config := range configs { + for _, test := range tests { + t.Run(initTestName(test.pipeline)+"_"+config.name, func(t *testing.T) { + t.Cleanup(func() { + engine.OneElementPerKey = false + engine.OneKeyPerBundle = false + }) + engine.OneElementPerKey = config.OneElementPerKey + engine.OneKeyPerBundle = config.OneKeyPerBundle + p, s := beam.NewPipelineWithRoot() + test.pipeline(s) + _, err := executeWithT(context.Background(), t, p) + if err != nil { + t.Fatalf("pipeline failed, but feature should be implemented in Prism: %v", err) + } + }) + } + } +} + +func TestElementManagerCoverage(t *testing.T) { + initRunner(t) + + tests := []struct { + pipeline func(s beam.Scope) + }{ + {pipeline: primitives.Checkpoints}, // (Doesn't run long enough to split.) + {pipeline: primitives.WindowSums_Lifted}, + } + + for _, test := range tests { + t.Run(initTestName(test.pipeline), func(t *testing.T) { + p, s := beam.NewPipelineWithRoot() + test.pipeline(s) + _, err := executeWithT(context.Background(), t, p) + if err != nil { + t.Fatalf("pipeline failed, but feature should be implemented in Prism: %v", err) + } + }) + } +} diff --git a/sdks/go/pkg/beam/runners/prism/internal/execute.go b/sdks/go/pkg/beam/runners/prism/internal/execute.go index 895122383857e..b8bc68dcd1b77 100644 --- a/sdks/go/pkg/beam/runners/prism/internal/execute.go +++ b/sdks/go/pkg/beam/runners/prism/internal/execute.go @@ -179,11 +179,16 @@ func executePipeline(ctx context.Context, wks map[string]*worker.W, j *jobservic ed := collectionPullDecoder(col.GetCoderId(), coders, comps) wDec, wEnc := getWindowValueCoders(comps, col, coders) + var kd func(io.Reader) []byte + if kcid, ok := extractKVCoderID(col.GetCoderId(), coders); ok { + kd = collectionPullDecoder(kcid, coders, comps) + } stage.OutputsToCoders[onlyOut] = engine.PColInfo{ GlobalID: onlyOut, WDec: wDec, WEnc: wEnc, EDec: ed, + KeyDec: kd, } // There's either 0, 1 or many inputs, but they should be all the same @@ -208,11 +213,17 @@ func executePipeline(ctx context.Context, wks map[string]*worker.W, j *jobservic col := comps.GetPcollections()[global] ed := collectionPullDecoder(col.GetCoderId(), coders, comps) wDec, wEnc := getWindowValueCoders(comps, col, coders) + + var kd func(io.Reader) []byte + if kcid, ok := extractKVCoderID(col.GetCoderId(), coders); ok { + kd = collectionPullDecoder(kcid, coders, comps) + } stage.inputInfo = engine.PColInfo{ GlobalID: global, WDec: wDec, WEnc: wEnc, EDec: ed, + KeyDec: kd, } } em.StageAggregates(stage.ID) @@ -234,6 +245,9 @@ func executePipeline(ctx context.Context, wks map[string]*worker.W, j *jobservic outputs := maps.Keys(stage.OutputsToCoders) sort.Strings(outputs) em.AddStage(stage.ID, []string{stage.primaryInput}, outputs, stage.sideInputs) + if stage.stateful { + em.StageStateful(stage.ID) + } default: err := fmt.Errorf("unknown environment[%v]", t.GetEnvironmentId()) slog.Error("Execute", err) @@ -286,6 +300,14 @@ func collectionPullDecoder(coldCId string, coders map[string]*pipepb.Coder, comp return pullDecoder(coders[cID], coders) } +func extractKVCoderID(coldCId string, coders map[string]*pipepb.Coder) (string, bool) { + c := coders[coldCId] + if c.GetSpec().GetUrn() == urns.CoderKV { + return c.GetComponentCoderIds()[0], true + } + return "", false +} + func getWindowValueCoders(comps *pipepb.Components, col *pipepb.PCollection, coders map[string]*pipepb.Coder) (exec.WindowDecoder, exec.WindowEncoder) { ws := comps.GetWindowingStrategies()[col.GetWindowingStrategyId()] wcID, err := lpUnknownCoders(ws.GetWindowCoderId(), coders, comps.GetCoders()) diff --git a/sdks/go/pkg/beam/runners/prism/internal/execute_test.go b/sdks/go/pkg/beam/runners/prism/internal/execute_test.go index fe3da83c67e25..29fccaeb238ee 100644 --- a/sdks/go/pkg/beam/runners/prism/internal/execute_test.go +++ b/sdks/go/pkg/beam/runners/prism/internal/execute_test.go @@ -37,6 +37,10 @@ import ( "github.com/apache/beam/sdks/v2/go/test/integration/primitives" ) +func TestMain(m *testing.M) { + ptest.MainWithDefault(m, "testlocal") +} + func initRunner(t testing.TB) { t.Helper() if *jobopts.Endpoint == "" { @@ -585,10 +589,6 @@ func init() { // There's a doubling bug since we re-use the same pcollection IDs for the source & sink, and // don't do any re-writing. -func TestMain(m *testing.M) { - ptest.MainWithDefault(m, "testlocal") -} - func init() { // Basic Registration // beam.RegisterFunction(identity) diff --git a/sdks/go/pkg/beam/runners/prism/internal/handlepardo.go b/sdks/go/pkg/beam/runners/prism/internal/handlepardo.go index 45223c1b2bcb9..38e7e9454df55 100644 --- a/sdks/go/pkg/beam/runners/prism/internal/handlepardo.go +++ b/sdks/go/pkg/beam/runners/prism/internal/handlepardo.go @@ -82,19 +82,26 @@ func (h *pardo) PrepareTransform(tid string, t *pipepb.PTransform, comps *pipepb !pdo.RequestsFinalization && !pdo.RequiresStableInput && !pdo.RequiresTimeSortedInput && - len(pdo.StateSpecs) == 0 && len(pdo.TimerFamilySpecs) == 0 && pdo.RestrictionCoderId == "" { // Which inputs are Side inputs don't change the graph further, // so they're not included here. Any nearly any ParDo can have them. // At their simplest, we don't need to do anything special at pre-processing time, and simply pass through as normal. + + // StatefulDoFns need to be marked as being roots. + var forcedRoots []string + if len(pdo.StateSpecs)+len(pdo.TimerFamilySpecs) > 0 { + forcedRoots = append(forcedRoots, tid) + } + return prepareResult{ SubbedComps: &pipepb.Components{ Transforms: map[string]*pipepb.PTransform{ tid: t, }, }, + ForcedRoots: forcedRoots, } } diff --git a/sdks/go/pkg/beam/runners/prism/internal/jobservices/job.go b/sdks/go/pkg/beam/runners/prism/internal/jobservices/job.go index cd302a70fcc07..d6e906bee59f0 100644 --- a/sdks/go/pkg/beam/runners/prism/internal/jobservices/job.go +++ b/sdks/go/pkg/beam/runners/prism/internal/jobservices/job.go @@ -42,7 +42,8 @@ import ( ) var supportedRequirements = map[string]struct{}{ - urns.RequirementSplittableDoFn: {}, + urns.RequirementSplittableDoFn: {}, + urns.RequirementStatefulProcessing: {}, } // TODO, move back to main package, and key off of executor handlers? diff --git a/sdks/go/pkg/beam/runners/prism/internal/jobservices/management.go b/sdks/go/pkg/beam/runners/prism/internal/jobservices/management.go index 0fd7381e17f4b..d3727b6508601 100644 --- a/sdks/go/pkg/beam/runners/prism/internal/jobservices/management.go +++ b/sdks/go/pkg/beam/runners/prism/internal/jobservices/management.go @@ -26,6 +26,7 @@ import ( "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/prism/internal/urns" "golang.org/x/exp/maps" "golang.org/x/exp/slog" + "google.golang.org/protobuf/proto" "google.golang.org/protobuf/types/known/timestamppb" ) @@ -110,11 +111,10 @@ func (s *Server) Prepare(ctx context.Context, req *jobpb.PrepareJobRequest) (*jo // Inspect Transforms for unsupported features. bypassedWindowingStrategies := map[string]bool{} ts := job.Pipeline.GetComponents().GetTransforms() - for _, t := range ts { + for tid, t := range ts { urn := t.GetSpec().GetUrn() switch urn { case urns.TransformImpulse, - urns.TransformParDo, urns.TransformGBK, urns.TransformFlatten, urns.TransformCombinePerKey, @@ -140,6 +140,22 @@ func (s *Server) Prepare(ctx context.Context, req *jobpb.PrepareJobRequest) (*jo wsID := pcs[col].GetWindowingStrategyId() bypassedWindowingStrategies[wsID] = true } + + case urns.TransformParDo: + var pardo pipepb.ParDoPayload + if err := proto.Unmarshal(t.GetSpec().GetPayload(), &pardo); err != nil { + return nil, fmt.Errorf("unable to unmarshal ParDoPayload for %v - %q: %w", tid, t.GetUniqueName(), err) + } + + // Validate all the state features + for _, spec := range pardo.GetStateSpecs() { + check("StateSpec.Protocol.Urn", spec.GetProtocol().GetUrn(), urns.UserStateBag, urns.UserStateMultiMap) + } + // Validate all the timer features + for _, spec := range pardo.GetTimerFamilySpecs() { + check("TimerFamilySpecs.TimeDomain.Urn", spec.GetTimeDomain()) + } + case "": // Composites can often have no spec if len(t.GetSubtransforms()) > 0 { diff --git a/sdks/go/pkg/beam/runners/prism/internal/preprocess.go b/sdks/go/pkg/beam/runners/prism/internal/preprocess.go index 494baa5b4a930..ea4cf2c996953 100644 --- a/sdks/go/pkg/beam/runners/prism/internal/preprocess.go +++ b/sdks/go/pkg/beam/runners/prism/internal/preprocess.go @@ -26,6 +26,7 @@ import ( "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/prism/internal/urns" "golang.org/x/exp/maps" "golang.org/x/exp/slog" + "google.golang.org/protobuf/proto" ) // transformPreparer is an interface for handling different urns in the preprocessor @@ -440,7 +441,18 @@ func finalizeStage(stg *stage, comps *pipepb.Components, pipelineFacts *fusionFa inputs[pid] = true for _, link := range plinks { t := comps.GetTransforms()[link.Transform] - sis, _ := getSideInputs(t) + + var sis map[string]*pipepb.SideInput + if t.GetSpec().GetUrn() == urns.TransformParDo { + pardo := &pipepb.ParDoPayload{} + if err := (proto.UnmarshalOptions{}).Unmarshal(t.GetSpec().GetPayload(), pardo); err != nil { + return fmt.Errorf("unable to decode ParDoPayload for %v", link.Transform) + } + if len(pardo.GetTimerFamilySpecs())+len(pardo.GetStateSpecs()) > 0 { + stg.stateful = true + } + sis = pardo.GetSideInputs() + } if _, ok := sis[link.Local]; ok { sideInputs = append(sideInputs, engine.LinkID{Transform: link.Transform, Global: link.Global, Local: link.Local}) } else { diff --git a/sdks/go/pkg/beam/runners/prism/internal/stage.go b/sdks/go/pkg/beam/runners/prism/internal/stage.go index 1ce1024063810..b415f5c241de1 100644 --- a/sdks/go/pkg/beam/runners/prism/internal/stage.go +++ b/sdks/go/pkg/beam/runners/prism/internal/stage.go @@ -19,6 +19,7 @@ import ( "bytes" "context" "fmt" + "io" "time" "github.com/apache/beam/sdks/v2/go/pkg/beam/core/graph/mtime" @@ -61,6 +62,7 @@ type stage struct { sideInputs []engine.LinkID // Non-parallel input PCollections and their consumers internalCols []string // PCollections that escape. Used for precise coder sending. envID string + stateful bool exe transformExecuter inputTransformID string @@ -77,6 +79,7 @@ func (s *stage) Execute(ctx context.Context, j *jobservices.Job, wk *worker.W, c var b *worker.B inputData := em.InputForBundle(rb, s.inputInfo) + initialState := em.StateForBundle(rb) var dataReady <-chan struct{} switch s.envID { case "": // Runner Transforms @@ -102,8 +105,8 @@ func (s *stage) Execute(ctx context.Context, j *jobservices.Job, wk *worker.W, c InputTransformID: s.inputTransformID, - // TODO Here's where we can split data for processing in multiple bundles. - InputData: inputData, + InputData: inputData, + OutputData: initialState, SinkToPCollection: s.SinkToPCollection, OutputCount: len(s.outputs), @@ -300,6 +303,12 @@ func buildDescriptor(stg *stage, comps *pipepb.Components, wk *worker.W, em *eng } sinkID := o.Transform + "_" + o.Local ed := collectionPullDecoder(col.GetCoderId(), coders, comps) + + var kd func(io.Reader) []byte + if kcid, ok := extractKVCoderID(col.GetCoderId(), coders); ok { + kd = collectionPullDecoder(kcid, coders, comps) + } + wDec, wEnc := getWindowValueCoders(comps, col, coders) sink2Col[sinkID] = o.Global col2Coders[o.Global] = engine.PColInfo{ @@ -307,6 +316,7 @@ func buildDescriptor(stg *stage, comps *pipepb.Components, wk *worker.W, em *eng WDec: wDec, WEnc: wEnc, EDec: ed, + KeyDec: kd, } transforms[sinkID] = sinkTransform(sinkID, portFor(wOutCid, wk), o.Global) } @@ -350,14 +360,20 @@ func buildDescriptor(stg *stage, comps *pipepb.Components, wk *worker.W, em *eng if err != nil { return fmt.Errorf("buildDescriptor: failed to handle coder on stage %v for primary input, pcol %q %v:\n%w\n%v", stg.ID, stg.primaryInput, prototext.Format(col), err, stg.transforms) } - ed := collectionPullDecoder(col.GetCoderId(), coders, comps) wDec, wEnc := getWindowValueCoders(comps, col, coders) + + var kd func(io.Reader) []byte + if kcid, ok := extractKVCoderID(col.GetCoderId(), coders); ok { + kd = collectionPullDecoder(kcid, coders, comps) + } + inputInfo := engine.PColInfo{ GlobalID: stg.primaryInput, WDec: wDec, WEnc: wEnc, EDec: ed, + KeyDec: kd, } stg.inputTransformID = stg.ID + "_source" diff --git a/sdks/go/pkg/beam/runners/prism/internal/unimplemented_test.go b/sdks/go/pkg/beam/runners/prism/internal/unimplemented_test.go index b8a04a7306b2a..323773bd4cd62 100644 --- a/sdks/go/pkg/beam/runners/prism/internal/unimplemented_test.go +++ b/sdks/go/pkg/beam/runners/prism/internal/unimplemented_test.go @@ -27,7 +27,7 @@ import ( // This file covers pipelines with features that aren't yet supported by Prism. -func intTestName(fn any) string { +func initTestName(fn any) string { name := reflectx.FunctionName(fn) n := strings.LastIndex(name, "/") return name[n+1:] @@ -68,23 +68,11 @@ func TestUnimplemented(t *testing.T) { {pipeline: primitives.TriggerOrFinally}, {pipeline: primitives.TriggerRepeat}, - // State API - {pipeline: primitives.BagStateParDo}, - {pipeline: primitives.BagStateParDoClear}, - {pipeline: primitives.MapStateParDo}, - {pipeline: primitives.MapStateParDoClear}, - {pipeline: primitives.SetStateParDo}, - {pipeline: primitives.SetStateParDoClear}, - {pipeline: primitives.CombiningStateParDo}, - {pipeline: primitives.ValueStateParDo}, - {pipeline: primitives.ValueStateParDoClear}, - {pipeline: primitives.ValueStateParDoWindowed}, - // TODO: Timers integration tests. } for _, test := range tests { - t.Run(intTestName(test.pipeline), func(t *testing.T) { + t.Run(initTestName(test.pipeline), func(t *testing.T) { p, s := beam.NewPipelineWithRoot() test.pipeline(s) _, err := executeWithT(context.Background(), t, p) @@ -113,7 +101,37 @@ func TestImplemented(t *testing.T) { } for _, test := range tests { - t.Run(intTestName(test.pipeline), func(t *testing.T) { + t.Run(initTestName(test.pipeline), func(t *testing.T) { + p, s := beam.NewPipelineWithRoot() + test.pipeline(s) + _, err := executeWithT(context.Background(), t, p) + if err != nil { + t.Fatalf("pipeline failed, but feature should be implemented in Prism: %v", err) + } + }) + } +} + +func TestStateAPI(t *testing.T) { + initRunner(t) + + tests := []struct { + pipeline func(s beam.Scope) + }{ + {pipeline: primitives.BagStateParDo}, + {pipeline: primitives.BagStateParDoClear}, + {pipeline: primitives.CombiningStateParDo}, + {pipeline: primitives.ValueStateParDo}, + {pipeline: primitives.ValueStateParDoClear}, + {pipeline: primitives.ValueStateParDoWindowed}, + {pipeline: primitives.MapStateParDo}, + {pipeline: primitives.MapStateParDoClear}, + {pipeline: primitives.SetStateParDo}, + {pipeline: primitives.SetStateParDoClear}, + } + + for _, test := range tests { + t.Run(initTestName(test.pipeline), func(t *testing.T) { p, s := beam.NewPipelineWithRoot() test.pipeline(s) _, err := executeWithT(context.Background(), t, p) diff --git a/sdks/go/pkg/beam/runners/prism/internal/urns/urns.go b/sdks/go/pkg/beam/runners/prism/internal/urns/urns.go index bf1e36656661b..5312fd799c89c 100644 --- a/sdks/go/pkg/beam/runners/prism/internal/urns/urns.go +++ b/sdks/go/pkg/beam/runners/prism/internal/urns/urns.go @@ -51,6 +51,7 @@ var ( reqUrn = toUrn[pipepb.StandardRequirements_Enum]() runProcUrn = toUrn[pipepb.StandardRunnerProtocols_Enum]() envUrn = toUrn[pipepb.StandardEnvironments_Environments]() + usUrn = toUrn[pipepb.StandardUserStateTypes_Enum]() ) var ( @@ -93,6 +94,10 @@ var ( SideInputIterable = siUrn(pipepb.StandardSideInputTypes_ITERABLE) SideInputMultiMap = siUrn(pipepb.StandardSideInputTypes_MULTIMAP) + // UserState kinds + UserStateBag = usUrn(pipepb.StandardUserStateTypes_BAG) + UserStateMultiMap = usUrn(pipepb.StandardUserStateTypes_MULTIMAP) + // WindowsFns WindowFnGlobal = quickUrn(pipepb.GlobalWindowsPayload_PROPERTIES) WindowFnFixed = quickUrn(pipepb.FixedWindowsPayload_PROPERTIES) diff --git a/sdks/go/pkg/beam/runners/prism/internal/worker/bundle.go b/sdks/go/pkg/beam/runners/prism/internal/worker/bundle.go index 97250092940df..6ef3a81e6239a 100644 --- a/sdks/go/pkg/beam/runners/prism/internal/worker/bundle.go +++ b/sdks/go/pkg/beam/runners/prism/internal/worker/bundle.go @@ -42,11 +42,13 @@ type B struct { InputTransformID string InputData [][]byte // Data specifically for this bundle. - // IterableSideInputData is a map from transformID, to inputID, to window, to data. + // IterableSideInputData is a map from transformID + inputID, to window, to data. IterableSideInputData map[SideInputKey]map[typex.Window][][]byte - // MultiMapSideInputData is a map from transformID, to inputID, to window, to data key, to data values. + // MultiMapSideInputData is a map from transformID + inputID, to window, to data key, to data values. MultiMapSideInputData map[SideInputKey]map[typex.Window]map[string][][]byte + // State lives in OutputData + // OutputCount is the number of data or timer outputs this bundle has. // We need to see this many closed data channels before the bundle is complete. OutputCount int diff --git a/sdks/go/pkg/beam/runners/prism/internal/worker/worker.go b/sdks/go/pkg/beam/runners/prism/internal/worker/worker.go index beee5e896ffc1..2859dfe2356d4 100644 --- a/sdks/go/pkg/beam/runners/prism/internal/worker/worker.go +++ b/sdks/go/pkg/beam/runners/prism/internal/worker/worker.go @@ -36,6 +36,7 @@ import ( "github.com/apache/beam/sdks/v2/go/pkg/beam/core/typex" fnpb "github.com/apache/beam/sdks/v2/go/pkg/beam/model/fnexecution_v1" pipepb "github.com/apache/beam/sdks/v2/go/pkg/beam/model/pipeline_v1" + "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/prism/internal/engine" "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/prism/internal/urns" "golang.org/x/exp/slog" "google.golang.org/grpc" @@ -412,21 +413,21 @@ func (wk *W) State(state fnpb.BeamFnState_StateServer) error { panic(err) } } + + // State requests are always for an active ProcessBundle instruction + wk.mu.Lock() + b, ok := wk.activeInstructions[req.GetInstructionId()].(*B) + wk.mu.Unlock() + if !ok { + slog.Warn("state request after bundle inactive", "instruction", req.GetInstructionId(), "worker", wk) + continue + } switch req.GetRequest().(type) { case *fnpb.StateRequest_Get: // TODO: move data handling to be pcollection based. - // State requests are always for an active ProcessBundle instruction - wk.mu.Lock() - b, ok := wk.activeInstructions[req.GetInstructionId()].(*B) - wk.mu.Unlock() - if !ok { - slog.Warn("state request after bundle inactive", "instruction", req.GetInstructionId(), "worker", wk) - continue - } key := req.GetStateKey() slog.Debug("StateRequest_Get", prototext.Format(req), "bundle", b) - var data [][]byte switch key.GetType().(type) { case *fnpb.StateKey_IterableSideInput_: @@ -442,11 +443,13 @@ func (wk *W) State(state fnpb.BeamFnState_StateServer) error { } } winMap := b.IterableSideInputData[SideInputKey{TransformID: ikey.GetTransformId(), Local: ikey.GetSideInputId()}] + var wins []typex.Window for w := range winMap { wins = append(wins, w) } slog.Debug(fmt.Sprintf("side input[%v][%v] I Key: %v Windows: %v", req.GetId(), req.GetInstructionId(), w, wins)) + data = winMap[w] case *fnpb.StateKey_MultimapSideInput_: @@ -458,37 +461,81 @@ func (wk *W) State(state fnpb.BeamFnState_StateServer) error { } else { w, err = exec.MakeWindowDecoder(coder.NewIntervalWindow()).DecodeSingle(bytes.NewBuffer(wKey)) if err != nil { - panic(fmt.Sprintf("error decoding iterable side input window key %v: %v", wKey, err)) + panic(fmt.Sprintf("error decoding multimap side input window key %v: %v", wKey, err)) } } dKey := mmkey.GetKey() winMap := b.MultiMapSideInputData[SideInputKey{TransformID: mmkey.GetTransformId(), Local: mmkey.GetSideInputId()}] - var wins []typex.Window - for w := range winMap { - wins = append(wins, w) - } - slog.Debug(fmt.Sprintf("side input[%v][%v] MM Key: %v Windows: %v", req.GetId(), req.GetInstructionId(), w, wins)) + + slog.Debug(fmt.Sprintf("side input[%v][%v] MultiMap Window: %v", req.GetId(), req.GetInstructionId(), w)) data = winMap[w][string(dKey)] + case *fnpb.StateKey_BagUserState_: + bagkey := key.GetBagUserState() + data = b.OutputData.GetBagState(engine.LinkID{Transform: bagkey.GetTransformId(), Local: bagkey.GetUserStateId()}, bagkey.GetWindow(), bagkey.GetKey()) + case *fnpb.StateKey_MultimapUserState_: + mmkey := key.GetMultimapUserState() + data = b.OutputData.GetMultimapState(engine.LinkID{Transform: mmkey.GetTransformId(), Local: mmkey.GetUserStateId()}, mmkey.GetWindow(), mmkey.GetKey(), mmkey.GetMapKey()) + case *fnpb.StateKey_MultimapKeysUserState_: + mmkey := key.GetMultimapKeysUserState() + data = b.OutputData.GetMultimapKeysState(engine.LinkID{Transform: mmkey.GetTransformId(), Local: mmkey.GetUserStateId()}, mmkey.GetWindow(), mmkey.GetKey()) default: - panic(fmt.Sprintf("unsupported StateKey Access type: %T: %v", key.GetType(), prototext.Format(key))) + panic(fmt.Sprintf("unsupported StateKey Get type: %T: %v", key.GetType(), prototext.Format(key))) } // Encode the runner iterable (no length, just consecutive elements), and send it out. // This is also where we can handle things like State Backed Iterables. - var buf bytes.Buffer - for _, value := range data { - buf.Write(value) - } responses <- &fnpb.StateResponse{ Id: req.GetId(), Response: &fnpb.StateResponse_Get{ Get: &fnpb.StateGetResponse{ - Data: buf.Bytes(), + Data: bytes.Join(data, []byte{}), }, }, } + + case *fnpb.StateRequest_Append: + key := req.GetStateKey() + switch key.GetType().(type) { + case *fnpb.StateKey_BagUserState_: + bagkey := key.GetBagUserState() + b.OutputData.AppendBagState(engine.LinkID{Transform: bagkey.GetTransformId(), Local: bagkey.GetUserStateId()}, bagkey.GetWindow(), bagkey.GetKey(), req.GetAppend().GetData()) + case *fnpb.StateKey_MultimapUserState_: + mmkey := key.GetMultimapUserState() + b.OutputData.AppendMultimapState(engine.LinkID{Transform: mmkey.GetTransformId(), Local: mmkey.GetUserStateId()}, mmkey.GetWindow(), mmkey.GetKey(), mmkey.GetMapKey(), req.GetAppend().GetData()) + default: + panic(fmt.Sprintf("unsupported StateKey Append type: %T: %v", key.GetType(), prototext.Format(key))) + } + responses <- &fnpb.StateResponse{ + Id: req.GetId(), + Response: &fnpb.StateResponse_Append{ + Append: &fnpb.StateAppendResponse{}, + }, + } + + case *fnpb.StateRequest_Clear: + key := req.GetStateKey() + switch key.GetType().(type) { + case *fnpb.StateKey_BagUserState_: + bagkey := key.GetBagUserState() + b.OutputData.ClearBagState(engine.LinkID{Transform: bagkey.GetTransformId(), Local: bagkey.GetUserStateId()}, bagkey.GetWindow(), bagkey.GetKey()) + case *fnpb.StateKey_MultimapUserState_: + mmkey := key.GetMultimapUserState() + b.OutputData.ClearMultimapState(engine.LinkID{Transform: mmkey.GetTransformId(), Local: mmkey.GetUserStateId()}, mmkey.GetWindow(), mmkey.GetKey(), mmkey.GetMapKey()) + case *fnpb.StateKey_MultimapKeysUserState_: + mmkey := key.GetMultimapUserState() + b.OutputData.ClearMultimapKeysState(engine.LinkID{Transform: mmkey.GetTransformId(), Local: mmkey.GetUserStateId()}, mmkey.GetWindow(), mmkey.GetKey()) + default: + panic(fmt.Sprintf("unsupported StateKey Clear type: %T: %v", key.GetType(), prototext.Format(key))) + } + responses <- &fnpb.StateResponse{ + Id: req.GetId(), + Response: &fnpb.StateResponse_Clear{ + Clear: &fnpb.StateClearResponse{}, + }, + } + default: panic(fmt.Sprintf("unsupported StateRequest kind %T: %v", req.GetRequest(), prototext.Format(req))) } diff --git a/sdks/go/test/integration/primitives/state.go b/sdks/go/test/integration/primitives/state.go index 5f105597ba379..acf1bf8fa6651 100644 --- a/sdks/go/test/integration/primitives/state.go +++ b/sdks/go/test/integration/primitives/state.go @@ -39,6 +39,7 @@ func init() { register.DoFn3x1[state.Provider, string, int, string](&mapStateClearFn{}) register.DoFn3x1[state.Provider, string, int, string](&setStateFn{}) register.DoFn3x1[state.Provider, string, int, string](&setStateClearFn{}) + register.Function2x0(pairWithOne) register.Emitter2[string, int]() register.Combiner1[int](&combine1{}) register.Combiner2[string, int](&combine2{}) @@ -78,12 +79,14 @@ func (f *valueStateFn) ProcessElement(s state.Provider, w string, c int) string return fmt.Sprintf("%s: %v, %s", w, i, j) } +func pairWithOne(w string, emit func(string, int)) { + emit(w, 1) +} + // ValueStateParDo tests a DoFn that uses value state. func ValueStateParDo(s beam.Scope) { in := beam.Create(s, "apple", "pear", "peach", "apple", "apple", "pear") - keyed := beam.ParDo(s, func(w string, emit func(string, int)) { - emit(w, 1) - }, in) + keyed := beam.ParDo(s, pairWithOne, in) counts := beam.ParDo(s, &valueStateFn{}, keyed) passert.Equals(s, counts, "apple: 1, I", "pear: 1, I", "peach: 1, I", "apple: 2, II", "apple: 3, III", "pear: 2, II") } @@ -124,9 +127,7 @@ func (f *valueStateClearFn) ProcessElement(s state.Provider, w string, c int) st // ValueStateParDoClear tests that a DoFn that uses value state can be cleared. func ValueStateParDoClear(s beam.Scope) { in := beam.Create(s, "apple", "pear", "peach", "apple", "apple", "pear", "pear", "apple") - keyed := beam.ParDo(s, func(w string, emit func(string, int)) { - emit(w, 1) - }, in) + keyed := beam.ParDo(s, pairWithOne, in) counts := beam.ParDo(s, &valueStateClearFn{State1: state.MakeValueState[int]("key1")}, keyed) passert.Equals(s, counts, "apple: 0,false", "pear: 0,false", "peach: 0,false", "apple: 1,true", "apple: 0,false", "pear: 1,true", "pear: 0,false", "apple: 1,true") } @@ -170,9 +171,7 @@ func (f *bagStateFn) ProcessElement(s state.Provider, w string, c int) string { // BagStateParDo tests a DoFn that uses bag state. func BagStateParDo(s beam.Scope) { in := beam.Create(s, "apple", "pear", "peach", "apple", "apple", "pear") - keyed := beam.ParDo(s, func(w string, emit func(string, int)) { - emit(w, 1) - }, in) + keyed := beam.ParDo(s, pairWithOne, in) counts := beam.ParDo(s, &bagStateFn{}, keyed) passert.Equals(s, counts, "apple: 0, ", "pear: 0, ", "peach: 0, ", "apple: 1, I", "apple: 2, I,I", "pear: 1, I") } @@ -207,9 +206,7 @@ func (f *bagStateClearFn) ProcessElement(s state.Provider, w string, c int) stri // BagStateParDoClear tests a DoFn that uses bag state. func BagStateParDoClear(s beam.Scope) { in := beam.Create(s, "apple", "pear", "apple", "apple", "pear", "apple", "apple", "pear", "pear", "pear", "apple", "pear") - keyed := beam.ParDo(s, func(w string, emit func(string, int)) { - emit(w, 1) - }, in) + keyed := beam.ParDo(s, pairWithOne, in) counts := beam.ParDo(s, &bagStateClearFn{State1: state.MakeBagState[int]("key1")}, keyed) passert.Equals(s, counts, "apple: 0", "pear: 0", "apple: 1", "apple: 2", "pear: 1", "apple: 3", "apple: 0", "pear: 2", "pear: 3", "pear: 0", "apple: 1", "pear: 1") } @@ -312,16 +309,20 @@ func (f *combiningStateFn) ProcessElement(s state.Provider, w string, c int) str return fmt.Sprintf("%s: %v %v %v %v %v", w, i, i1, i2, i3, i4) } +func init() { + register.Function2x1(sumInt) +} + +func sumInt(a, b int) int { + return a + b +} + // CombiningStateParDo tests a DoFn that uses value state. func CombiningStateParDo(s beam.Scope) { in := beam.Create(s, "apple", "pear", "peach", "apple", "apple", "pear") - keyed := beam.ParDo(s, func(w string, emit func(string, int)) { - emit(w, 1) - }, in) + keyed := beam.ParDo(s, pairWithOne, in) counts := beam.ParDo(s, &combiningStateFn{ - State0: state.MakeCombiningState[int, int, int]("key0", func(a, b int) int { - return a + b - }), + State0: state.MakeCombiningState[int, int, int]("key0", sumInt), State1: state.Combining[int, int, int](state.MakeCombiningState[int, int, int]("key1", &combine1{})), State2: state.Combining[string, string, int](state.MakeCombiningState[string, string, int]("key2", &combine2{})), State3: state.Combining[string, string, int](state.MakeCombiningState[string, string, int]("key3", &combine3{})), @@ -369,9 +370,7 @@ func (f *mapStateFn) ProcessElement(s state.Provider, w string, c int) string { // MapStateParDo tests a DoFn that uses value state. func MapStateParDo(s beam.Scope) { in := beam.Create(s, "apple", "pear", "peach", "apple", "apple", "pear") - keyed := beam.ParDo(s, func(w string, emit func(string, int)) { - emit(w, 1) - }, in) + keyed := beam.ParDo(s, pairWithOne, in) counts := beam.ParDo(s, &mapStateFn{State1: state.MakeMapState[string, int]("key1")}, keyed) passert.Equals(s, counts, "apple: 1, keys: [apple apple1]", "pear: 1, keys: [pear pear1]", "peach: 1, keys: [peach peach1]", "apple: 2, keys: [apple apple1 apple2]", "apple: 3, keys: [apple apple1 apple2 apple3]", "pear: 2, keys: [pear pear1 pear2]") } @@ -425,9 +424,7 @@ func (f *mapStateClearFn) ProcessElement(s state.Provider, w string, c int) stri // MapStateParDoClear tests clearing and removing from a DoFn that uses map state. func MapStateParDoClear(s beam.Scope) { in := beam.Create(s, "apple", "pear", "peach", "apple", "apple", "pear") - keyed := beam.ParDo(s, func(w string, emit func(string, int)) { - emit(w, 1) - }, in) + keyed := beam.ParDo(s, pairWithOne, in) counts := beam.ParDo(s, &mapStateClearFn{State1: state.MakeMapState[string, int]("key1")}, keyed) passert.Equals(s, counts, "apple: [apple]", "pear: [pear]", "peach: [peach]", "apple: [apple1 apple2 apple3]", "apple: []", "pear: [pear1 pear2 pear3]") } @@ -465,9 +462,7 @@ func (f *setStateFn) ProcessElement(s state.Provider, w string, c int) string { // SetStateParDo tests a DoFn that uses set state. func SetStateParDo(s beam.Scope) { in := beam.Create(s, "apple", "pear", "peach", "apple", "apple", "pear") - keyed := beam.ParDo(s, func(w string, emit func(string, int)) { - emit(w, 1) - }, in) + keyed := beam.ParDo(s, pairWithOne, in) counts := beam.ParDo(s, &setStateFn{State1: state.MakeSetState[string]("key1")}, keyed) passert.Equals(s, counts, "apple: false, keys: [apple]", "pear: false, keys: [pear]", "peach: false, keys: [peach]", "apple: true, keys: [apple apple1]", "apple: true, keys: [apple apple1]", "pear: true, keys: [pear pear1]") } @@ -521,9 +516,7 @@ func (f *setStateClearFn) ProcessElement(s state.Provider, w string, c int) stri // SetStateParDoClear tests clearing and removing from a DoFn that uses set state. func SetStateParDoClear(s beam.Scope) { in := beam.Create(s, "apple", "pear", "peach", "apple", "apple", "pear") - keyed := beam.ParDo(s, func(w string, emit func(string, int)) { - emit(w, 1) - }, in) + keyed := beam.ParDo(s, pairWithOne, in) counts := beam.ParDo(s, &setStateClearFn{State1: state.MakeSetState[string]("key1")}, keyed) passert.Equals(s, counts, "apple: [apple]", "pear: [pear]", "peach: [peach]", "apple: [apple1 apple2 apple3]", "apple: []", "pear: [pear1 pear2 pear3]") }