diff --git a/sdks/go/pkg/beam/core/timers/timers.go b/sdks/go/pkg/beam/core/timers/timers.go index 9e188dcbc69de..f55b03c802781 100644 --- a/sdks/go/pkg/beam/core/timers/timers.go +++ b/sdks/go/pkg/beam/core/timers/timers.go @@ -51,8 +51,10 @@ type TimerMap struct { FireTimestamp, HoldTimestamp mtime.Time } +// timerConfig is used transiently to hold configuration from the functional options. type timerConfig struct { Tag string + HoldSet bool // Whether the HoldTimestamp was set. HoldTimestamp mtime.Time } @@ -68,6 +70,7 @@ func WithTag(tag string) timerOptions { // WithOutputTimestamp sets the output timestamp for the timer. func WithOutputTimestamp(outputTimestamp time.Time) timerOptions { return func(tm *timerConfig) { + tm.HoldSet = true tm.HoldTimestamp = mtime.FromTime(outputTimestamp) } } @@ -108,7 +111,7 @@ func (et EventTime) Set(p Provider, FiringTimestamp time.Time, opts ...timerOpti opt(&tc) } tm := TimerMap{Family: et.Family, Tag: tc.Tag, FireTimestamp: mtime.FromTime(FiringTimestamp), HoldTimestamp: mtime.FromTime(FiringTimestamp)} - if !tc.HoldTimestamp.ToTime().IsZero() { + if tc.HoldSet { tm.HoldTimestamp = tc.HoldTimestamp } p.Set(tm) @@ -142,7 +145,7 @@ func (pt ProcessingTime) Set(p Provider, FiringTimestamp time.Time, opts ...time opt(&tc) } tm := TimerMap{Family: pt.Family, Tag: tc.Tag, FireTimestamp: mtime.FromTime(FiringTimestamp), HoldTimestamp: mtime.FromTime(FiringTimestamp)} - if !tc.HoldTimestamp.ToTime().IsZero() { + if tc.HoldSet { tm.HoldTimestamp = tc.HoldTimestamp } diff --git a/sdks/go/pkg/beam/runners/prism/internal/coders.go b/sdks/go/pkg/beam/runners/prism/internal/coders.go index 64005177b94b7..6deaab65d3661 100644 --- a/sdks/go/pkg/beam/runners/prism/internal/coders.go +++ b/sdks/go/pkg/beam/runners/prism/internal/coders.go @@ -200,7 +200,7 @@ func pullDecoder(c *pipepb.Coder, coders map[string]*pipepb.Coder) func(io.Reade } } -// pullDecoderNoAlloc returns a function that decodes a single eleemnt of the given coder. +// pullDecoderNoAlloc returns a function that decodes a single element of the given coder. // Intended to only be used as an internal function for pullDecoder, which will use a io.TeeReader // to extract the bytes. func pullDecoderNoAlloc(c *pipepb.Coder, coders map[string]*pipepb.Coder) func(io.Reader) { diff --git a/sdks/go/pkg/beam/runners/prism/internal/engine/data.go b/sdks/go/pkg/beam/runners/prism/internal/engine/data.go index 6679f484aa2d3..eaaf7f831712d 100644 --- a/sdks/go/pkg/beam/runners/prism/internal/engine/data.go +++ b/sdks/go/pkg/beam/runners/prism/internal/engine/data.go @@ -32,6 +32,11 @@ type StateData struct { Multimap map[string][][]byte } +// TimerKey is for use as a key for timers. +type TimerKey struct { + Transform, Family string +} + // TentativeData is where data for in progress bundles is put // until the bundle executes successfully. type TentativeData struct { @@ -39,6 +44,8 @@ type TentativeData struct { // state is a map from transformID + UserStateID, to window, to userKey, to datavalues. state map[LinkID]map[typex.Window]map[string]StateData + // timers is a map from the Timer transform+family to the encoded timer. + timers map[TimerKey][][]byte } // WriteData adds data to a given global collectionID. @@ -49,6 +56,15 @@ func (d *TentativeData) WriteData(colID string, data []byte) { d.Raw[colID] = append(d.Raw[colID], data) } +// WriteTimers adds timers to the associated transform handler. +func (d *TentativeData) WriteTimers(transformID, familyID string, timers []byte) { + if d.timers == nil { + d.timers = map[TimerKey][][]byte{} + } + link := TimerKey{Transform: transformID, Family: familyID} + d.timers[link] = append(d.timers[link], timers) +} + func (d *TentativeData) toWindow(wKey []byte) typex.Window { if len(wKey) == 0 { return window.GlobalWindow{} diff --git a/sdks/go/pkg/beam/runners/prism/internal/engine/elementmanager.go b/sdks/go/pkg/beam/runners/prism/internal/engine/elementmanager.go index 6cb5523541863..077d6386315ad 100644 --- a/sdks/go/pkg/beam/runners/prism/internal/engine/elementmanager.go +++ b/sdks/go/pkg/beam/runners/prism/internal/engine/elementmanager.go @@ -23,26 +23,46 @@ import ( "context" "fmt" "io" + "sort" "strings" "sync" "sync/atomic" + "github.com/apache/beam/sdks/v2/go/pkg/beam/core/graph/coder" "github.com/apache/beam/sdks/v2/go/pkg/beam/core/graph/mtime" "github.com/apache/beam/sdks/v2/go/pkg/beam/core/graph/window" "github.com/apache/beam/sdks/v2/go/pkg/beam/core/runtime/exec" "github.com/apache/beam/sdks/v2/go/pkg/beam/core/typex" + "golang.org/x/exp/maps" "golang.org/x/exp/slog" ) type element struct { - window typex.Window - timestamp mtime.Time - pane typex.PaneInfo + window typex.Window + timestamp mtime.Time + holdTimestamp mtime.Time // only used for Timers + pane typex.PaneInfo + transform, family, tag string // only used for Timers. - elmBytes []byte + elmBytes []byte // When nil, indicates this is a timer. keyBytes []byte } +func (e *element) IsTimer() bool { + return e.elmBytes == nil +} + +func (e *element) IsData() bool { + return !e.IsTimer() +} + +func (e element) String() string { + if e.IsTimer() { + return fmt.Sprintf("{Timer - Window %v, EventTime %v, Hold %v, %q %q %q %q}", e.window, e.timestamp, e.holdTimestamp, e.transform, e.family, e.tag, e.keyBytes) + } + return fmt.Sprintf("{Data - Window %v, EventTime %v, Element %v}", e.window, e.timestamp, e.elmBytes) +} + type elements struct { es []element minTimestamp mtime.Time @@ -72,9 +92,21 @@ func (es elements) ToData(info PColInfo) [][]byte { // so we can always find the minimum timestamp of pending elements. type elementHeap []element -func (h elementHeap) Len() int { return len(h) } -func (h elementHeap) Less(i, j int) bool { return h[i].timestamp < h[j].timestamp } -func (h elementHeap) Swap(i, j int) { h[i], h[j] = h[j], h[i] } +func (h elementHeap) Len() int { return len(h) } +func (h elementHeap) Less(i, j int) bool { + // If the timestamps are the same, data comes before timers. + if h[i].timestamp == h[j].timestamp { + if h[i].IsTimer() && h[j].IsData() { + return false // j before i + } else if h[i].IsData() && h[j].IsTimer() { + return true // i before j. + } + // They're the same kind, fall through to timestamp less for consistency. + } + // Otherwise compare by timestamp. + return h[i].timestamp < h[j].timestamp +} +func (h elementHeap) Swap(i, j int) { h[i], h[j] = h[j], h[i] } func (h *elementHeap) Push(x any) { // Push and Pop use pointer receivers because they modify the slice's length, @@ -205,12 +237,12 @@ func (em *ElementManager) Impulse(stageID string) { consumers := em.consumers[stage.outputIDs[0]] slog.Debug("Impulse", slog.String("stageID", stageID), slog.Any("outputs", stage.outputIDs), slog.Any("consumers", consumers)) - em.addPending(len(consumers)) for _, sID := range consumers { consumer := em.stages[sID] - consumer.AddPending(newPending) + count := consumer.AddPending(newPending) + em.addPending(count) } - refreshes := stage.updateWatermarks(mtime.MaxTimestamp, mtime.MaxTimestamp, em) + refreshes := stage.updateWatermarks(em) em.addRefreshes(refreshes) } @@ -267,12 +299,13 @@ func (em *ElementManager) Bundles(ctx context.Context, nextBundID func() string) watermark, ready := ss.bundleReady(em) if ready { bundleID, ok, reschedule := ss.startBundle(watermark, nextBundID) - if !ok { - continue - } + // Handle the reschedule even when there's no bundle. if reschedule { em.watermarkRefreshes.insert(stageID) } + if !ok { + continue + } rb := RunBundle{StageID: stageID, BundleID: bundleID, Watermark: watermark} em.inprogressBundles.insert(rb.BundleID) @@ -291,8 +324,15 @@ func (em *ElementManager) Bundles(ctx context.Context, nextBundID func() string) slog.Debug("Bundles: nothing in progress and no refreshes", slog.Int64("pendingElementCount", v)) if v > 0 { var stageState []string - for id, ss := range em.stages { - stageState = append(stageState, fmt.Sprintln(id, ss.pending, ss.pendingByKeys, ss.inprogressKeys, ss.inprogressKeysByBundle)) + ids := maps.Keys(em.stages) + sort.Strings(ids) + for _, id := range ids { + ss := em.stages[id] + inW := ss.InputWatermark() + outW := ss.OutputWatermark() + upPCol, upW := ss.UpstreamWatermark() + upS := em.pcolParents[upPCol] + stageState = append(stageState, fmt.Sprintln(id, "watermark in", inW, "out", outW, "upstream", upW, "from", upS, "pending", ss.pending, "byKey", ss.pendingByKeys, "inprogressKeys", ss.inprogressKeys, "byBundle", ss.inprogressKeysByBundle, "holds", ss.watermarkHoldHeap, "holdCounts", ss.watermarkHoldsCounts)) } panic(fmt.Sprintf("nothing in progress and no refreshes with non zero pending elements: %v\n%v", v, strings.Join(stageState, ""))) } @@ -320,6 +360,81 @@ func (em *ElementManager) InputForBundle(rb RunBundle, info PColInfo) [][]byte { return es.ToData(info) } +// DataAndTimerInputForBundle returns pre-allocated data for the given bundle and the estimated number of elements. +// Elements are encoded with the PCollection's coders. +func (em *ElementManager) DataAndTimerInputForBundle(rb RunBundle, info PColInfo) ([]*Block, int) { + ss := em.stages[rb.StageID] + ss.mu.Lock() + defer ss.mu.Unlock() + es := ss.inprogress[rb.BundleID] + + var total int + + var ret []*Block + cur := &Block{} + for _, e := range es.es { + switch { + case e.IsTimer() && (cur.Kind != BlockTimer || e.family != cur.Family || cur.Transform != e.transform): + total += len(cur.Bytes) + cur = &Block{ + Kind: BlockTimer, + Transform: e.transform, + Family: e.family, + } + ret = append(ret, cur) + fallthrough + case e.IsTimer() && cur.Kind == BlockTimer: + var buf bytes.Buffer + // Key + buf.Write(e.keyBytes) // Includes the length prefix if any. + // Tag + coder.EncodeVarInt(int64(len(e.tag)), &buf) + buf.WriteString(e.tag) + // Windows + info.WEnc.Encode([]typex.Window{e.window}, &buf) + // Clear + buf.Write([]byte{0}) + // Firing timestamp + coder.EncodeEventTime(e.timestamp, &buf) + // Hold timestamp + coder.EncodeEventTime(e.holdTimestamp, &buf) + // Pane + coder.EncodePane(e.pane, &buf) + + cur.Bytes = append(cur.Bytes, buf.Bytes()) + case cur.Kind != BlockData: + total += len(cur.Bytes) + cur = &Block{ + Kind: BlockData, + } + ret = append(ret, cur) + fallthrough + default: + var buf bytes.Buffer + exec.EncodeWindowedValueHeader(info.WEnc, []typex.Window{e.window}, e.timestamp, e.pane, &buf) + buf.Write(e.elmBytes) + cur.Bytes = append(cur.Bytes, buf.Bytes()) + } + } + total += len(cur.Bytes) + return ret, total +} + +// BlockKind indicates how the block is to be handled. +type BlockKind int32 + +const ( + blockUnset BlockKind = iota // blockUnset + BlockData // BlockData represents data for the bundle. + BlockTimer // BlockTimer represents timers for the bundle. +) + +type Block struct { + Kind BlockKind + Bytes [][]byte + Transform, Family string +} + // StateForBundle retreives relevant state for the given bundle, WRT the data in the bundle. // // TODO(lostluck): Consider unifiying with InputForBundle, to reduce lock contention. @@ -422,7 +537,6 @@ func reElementResiduals(residuals [][]byte, inputInfo PColInfo, rb RunBundle) [] // PersistBundle takes in the stage ID, ID of the bundle associated with the pending // input elements, and the committed output elements. func (em *ElementManager) PersistBundle(rb RunBundle, col2Coders map[string]PColInfo, d TentativeData, inputInfo PColInfo, residuals [][]byte, estimatedOWM map[string]mtime.Time) { - stage := em.stages[rb.StageID] for output, data := range d.Raw { info := col2Coders[output] var newPending []element @@ -469,9 +583,9 @@ func (em *ElementManager) PersistBundle(rb RunBundle, col2Coders map[string]PCol consumers := em.consumers[output] slog.Debug("PersistBundle: bundle has downstream consumers.", "bundle", rb, slog.Int("newPending", len(newPending)), "consumers", consumers) for _, sID := range consumers { - em.addPending(len(newPending)) consumer := em.stages[sID] - consumer.AddPending(newPending) + count := consumer.AddPending(newPending) + em.addPending(count) } sideConsumers := em.sideConsumers[output] for _, link := range sideConsumers { @@ -480,12 +594,48 @@ func (em *ElementManager) PersistBundle(rb RunBundle, col2Coders map[string]PCol } } + // Process each timer family in the order we received them, so we can filter to the last one. + // Since we're process each timer family individually, use a unique key for each userkey, tag, window. + // The last timer set for each combination is the next one we're keeping. + type timerKey struct { + key string + tag string + win typex.Window + } + + var pendingTimers []element + for tentativeKey, timers := range d.timers { + keyToTimers := map[timerKey]element{} + for _, t := range timers { + key, tag, elms := decodeTimer(inputInfo.KeyDec, true, t) + for _, e := range elms { + keyToTimers[timerKey{key: string(key), tag: tag, win: e.window}] = e + } + if len(elms) == 0 { + // TODO(lostluck): Determine best way to mark clear a timer cleared. + continue + } + } + + for _, elm := range keyToTimers { + elm.transform = tentativeKey.Transform + elm.family = tentativeKey.Family + pendingTimers = append(pendingTimers, elm) + } + } + + stage := em.stages[rb.StageID] + if len(pendingTimers) > 0 { + count := stage.AddPending(pendingTimers) + em.addPending(count) + } + // Return unprocessed to this stage's pending unprocessedElements := reElementResiduals(residuals, inputInfo, rb) // Add unprocessed back to the pending stack. if len(unprocessedElements) > 0 { - em.addPending(len(unprocessedElements)) - stage.AddPending(unprocessedElements) + count := stage.AddPending(unprocessedElements) + em.addPending(count) } // Clear out the inprogress elements associated with the completed bundle. // Must be done after adding the new pending elements to avoid an incorrect @@ -498,6 +648,23 @@ func (em *ElementManager) PersistBundle(rb RunBundle, col2Coders map[string]PCol delete(stage.inprogressKeys, k) } delete(stage.inprogressKeysByBundle, rb.BundleID) + + for hold, v := range stage.inprogressHoldsByBundle[rb.BundleID] { + n := stage.watermarkHoldsCounts[hold] - v + if n == 0 { + delete(stage.watermarkHoldsCounts, hold) + for i, h := range stage.watermarkHoldHeap { + if hold == h { + heap.Remove(&stage.watermarkHoldHeap, i) + break + } + } + } else { + stage.watermarkHoldsCounts[hold] = n + } + } + delete(stage.inprogressHoldsByBundle, rb.BundleID) + // If there are estimated output watermarks, set the estimated // output watermark for the stage. if len(estimatedOWM) > 0 { @@ -528,7 +695,6 @@ func (em *ElementManager) PersistBundle(rb RunBundle, col2Coders map[string]PCol } stage.mu.Unlock() - // TODO support state/timer watermark holds. em.addRefreshAndClearBundle(stage.ID, rb.BundleID) } @@ -552,8 +718,8 @@ func (em *ElementManager) ReturnResiduals(rb RunBundle, firstRsIndex int, inputI unprocessedElements := reElementResiduals(residuals, inputInfo, rb) if len(unprocessedElements) > 0 { slog.Debug("ReturnResiduals: unprocessed elements", "bundle", rb, "count", len(unprocessedElements)) - em.addPending(len(unprocessedElements)) - stage.AddPending(unprocessedElements) + count := stage.AddPending(unprocessedElements) + em.addPending(count) } em.addRefreshes(singleSet(rb.StageID)) } @@ -587,9 +753,7 @@ func (em *ElementManager) refreshWatermarks() set[string] { ss := em.stages[stageID] refreshed.insert(stageID) - dummyStateHold := mtime.MaxTimestamp - - refreshes := ss.updateWatermarks(ss.minPendingTimestamp(), dummyStateHold, em) + refreshes := ss.updateWatermarks(em) nextUpdates.merge(refreshes) // cap refreshes incrementally. if i < 10 { @@ -650,20 +814,68 @@ type stageState struct { sideInputs map[LinkID]map[typex.Window][][]byte // side input data for this stage, from {tid, inputID} -> window // Fields for stateful stages which need to be per key. - pendingByKeys map[string]elementHeap // pending input elements by Key, if stateful. + pendingByKeys map[string]*dataAndTimers // pending input elements by Key, if stateful. inprogressKeys set[string] // all keys that are assigned to bundles. inprogressKeysByBundle map[string]set[string] // bundle to key assignments. state map[LinkID]map[typex.Window]map[string]StateData // state data for this stage, from {tid, stateID} -> window -> userKey + + // Accounting for handling watermark holds for timers. + // We track the count of timers with the same hold, and clear it from + // the map and heap when the count goes to zero. + // This avoids scanning the heap to remove or access a hold for each element. + watermarkHoldsCounts map[mtime.Time]int + watermarkHoldHeap holdHeap + inprogressHoldsByBundle map[string]map[mtime.Time]int // bundle to associated holds. +} + +// timerKey uniquely identifies a given timer within the space of a user key. +type timerKey struct { + family, tag string + window typex.Window +} + +type timerTimes struct { + firing, hold mtime.Time +} + +// dataAndTimers represents all elements for a single user key and the latest +// eventTime for a given family and tag. +type dataAndTimers struct { + elements elementHeap + timers map[timerKey]timerTimes +} + +// holdHeap orders holds based on their timestamps +// so we can always find the minimum timestamp of pending holds. +type holdHeap []mtime.Time + +func (h holdHeap) Len() int { return len(h) } +func (h holdHeap) Less(i, j int) bool { return h[i] < h[j] } +func (h holdHeap) Swap(i, j int) { h[i], h[j] = h[j], h[i] } + +func (h *holdHeap) Push(x any) { + // Push and Pop use pointer receivers because they modify the slice's length, + // not just its contents. + *h = append(*h, x.(mtime.Time)) +} + +func (h *holdHeap) Pop() any { + old := *h + n := len(old) + x := old[n-1] + *h = old[0 : n-1] + return x } // makeStageState produces an initialized stageState. func makeStageState(ID string, inputIDs, outputIDs []string, sides []LinkID) *stageState { ss := &stageState{ - ID: ID, - outputIDs: outputIDs, - sides: sides, - strat: defaultStrat{}, - state: map[LinkID]map[typex.Window]map[string]StateData{}, + ID: ID, + outputIDs: outputIDs, + sides: sides, + strat: defaultStrat{}, + state: map[LinkID]map[typex.Window]map[string]StateData{}, + watermarkHoldsCounts: map[mtime.Time]int{}, input: mtime.MinTimestamp, output: mtime.MinTimestamp, @@ -681,25 +893,65 @@ func makeStageState(ID string, inputIDs, outputIDs []string, sides []LinkID) *st } // AddPending adds elements to the pending heap. -func (ss *stageState) AddPending(newPending []element) { +func (ss *stageState) AddPending(newPending []element) int { ss.mu.Lock() defer ss.mu.Unlock() if ss.stateful { if ss.pendingByKeys == nil { - ss.pendingByKeys = map[string]elementHeap{} + ss.pendingByKeys = map[string]*dataAndTimers{} } + count := 0 for _, e := range newPending { + count++ if len(e.keyBytes) == 0 { panic(fmt.Sprintf("zero length key: %v %v", ss.ID, ss.inputID)) } - h := ss.pendingByKeys[string(e.keyBytes)] - h.Push(e) - ss.pendingByKeys[string(e.keyBytes)] = h // (Is this necessary, with the way the heap interface works over a slice?) + dnt, ok := ss.pendingByKeys[string(e.keyBytes)] + if !ok { + dnt = &dataAndTimers{ + timers: map[timerKey]timerTimes{}, + } + ss.pendingByKeys[string(e.keyBytes)] = dnt + } + dnt.elements.Push(e) + + if e.IsTimer() { + if lastSet, ok := dnt.timers[timerKey{family: e.family, tag: e.tag, window: e.window}]; ok { + // existing timer! + // don't increase the count this time, as "this" timer is already pending. + count-- + // clear out the existing hold for accounting purposes. + v := ss.watermarkHoldsCounts[lastSet.hold] - 1 + if v == 0 { + delete(ss.watermarkHoldsCounts, lastSet.hold) + for i, hold := range ss.watermarkHoldHeap { + if hold == lastSet.hold { + heap.Remove(&ss.watermarkHoldHeap, i) + break + } + } + } else { + ss.watermarkHoldsCounts[lastSet.hold] = v + } + } + // Update the last set time on the timer. + dnt.timers[timerKey{family: e.family, tag: e.tag, window: e.window}] = timerTimes{firing: e.timestamp, hold: e.holdTimestamp} + + // Mark the hold in the heap. + ss.watermarkHoldsCounts[e.holdTimestamp] = ss.watermarkHoldsCounts[e.holdTimestamp] + 1 + + if len(ss.watermarkHoldsCounts) != len(ss.watermarkHoldHeap) { + // The hold should not be in the heap, so we add it. + heap.Push(&ss.watermarkHoldHeap, e.holdTimestamp) + } + } } - } else { - ss.pending = append(ss.pending, newPending...) - heap.Init(&ss.pending) + return count } + // Default path. + ss.pending = append(ss.pending, newPending...) + heap.Init(&ss.pending) + return len(newPending) } // AddPendingSide adds elements to be consumed as side inputs. @@ -817,41 +1069,63 @@ func (ss *stageState) startBundle(watermark mtime.Time, genBundID func() string) newKeys := set[string]{} stillSchedulable := true + holdsInBundle := map[mtime.Time]int{} + + // If timers are cleared, and we end up with nothing to process + // we need to reschedule a watermark refresh, since those vestigial + // timers might have held back the minimum pending watermark. + timerCleared := false + keysPerBundle: - for k, h := range ss.pendingByKeys { + for k, dnt := range ss.pendingByKeys { if ss.inprogressKeys.present(k) { continue } newKeys.insert(k) // Track the min-timestamp for later watermark handling. - if h[0].timestamp < minTs { - minTs = h[0].timestamp + if dnt.elements[0].timestamp < minTs { + minTs = dnt.elements[0].timestamp } - if OneElementPerKey { - hp := &h - toProcess = append(toProcess, heap.Pop(hp).(element)) - if hp.Len() == 0 { - // Once we've taken all the elements for a key, - // we must delete them from pending as well. - delete(ss.pendingByKeys, k) - } else { - ss.pendingByKeys[k] = *hp + // Can we pre-compute this bit when adding to pendingByKeys? + // startBundle is in run in a single scheduling goroutine, so moving per-element code + // to be computed by the bundle parallel goroutines will speed things up a touch. + for dnt.elements.Len() > 0 { + e := heap.Pop(&dnt.elements).(element) + if e.IsTimer() { + lastSet, ok := dnt.timers[timerKey{family: e.family, tag: e.tag, window: e.window}] + if !ok { + timerCleared = true + continue // Timer has "fired" already, so this can be ignored. + } + if e.timestamp != lastSet.firing { + timerCleared = true + continue + } + holdsInBundle[e.holdTimestamp] = holdsInBundle[e.holdTimestamp] + 1 + // Clear the "fired" timer so subsequent matches can be ignored. + delete(dnt.timers, timerKey{family: e.family, tag: e.tag, window: e.window}) } - } else { - toProcess = append(toProcess, h...) + toProcess = append(toProcess, e) + if OneElementPerKey { + break + } + } + if dnt.elements.Len() == 0 { delete(ss.pendingByKeys, k) } if OneKeyPerBundle { break keysPerBundle } } - if len(ss.pendingByKeys) == 0 { + if len(ss.pendingByKeys) == 0 && !timerCleared { + // If we're out of data, and timers were not cleared then the watermark is are accurate. stillSchedulable = false } if len(toProcess) == 0 { - return "", false, false + // If we have nothing + return "", false, stillSchedulable } if toProcess[0].timestamp < minTs { @@ -869,10 +1143,14 @@ keysPerBundle: if ss.inprogressKeysByBundle == nil { ss.inprogressKeysByBundle = make(map[string]set[string]) } + if ss.inprogressHoldsByBundle == nil { + ss.inprogressHoldsByBundle = make(map[string]map[mtime.Time]int) + } bundID := genBundID() ss.inprogress[bundID] = es ss.inprogressKeysByBundle[bundID] = newKeys ss.inprogressKeys.merge(newKeys) + ss.inprogressHoldsByBundle[bundID] = holdsInBundle return bundID, true, stillSchedulable } @@ -899,14 +1177,19 @@ func (ss *stageState) splitBundle(rb RunBundle, firstResidual int) { func (ss *stageState) minPendingTimestamp() mtime.Time { ss.mu.Lock() defer ss.mu.Unlock() + return ss.minPendingTimestampLocked() +} + +// minPendingTimestampLocked must be called under the ss.mu Lock. +func (ss *stageState) minPendingTimestampLocked() mtime.Time { minPending := mtime.MaxTimestamp if len(ss.pending) != 0 { minPending = ss.pending[0].timestamp } if len(ss.pendingByKeys) != 0 { // TODO(lostluck): Can we figure out how to avoid checking every key on every watermark refresh? - for _, h := range ss.pendingByKeys { - minPending = mtime.Min(minPending, h[0].timestamp) + for _, dnt := range ss.pendingByKeys { + minPending = mtime.Min(minPending, dnt.elements[0].timestamp) } } for _, es := range ss.inprogress { @@ -923,12 +1206,18 @@ func (ss *stageState) String() string { // updateWatermarks performs the following operations: // // Watermark_In' = MAX(Watermark_In, MIN(U(TS_Pending), U(Watermark_InputPCollection))) -// Watermark_Out' = MAX(Watermark_Out, MIN(Watermark_In', U(StateHold))) +// Watermark_Out' = MAX(Watermark_Out, MIN(Watermark_In', U(minWatermarkHold))) // Watermark_PCollection = Watermark_Out_ProducingPTransform -func (ss *stageState) updateWatermarks(minPending, minStateHold mtime.Time, em *ElementManager) set[string] { +func (ss *stageState) updateWatermarks(em *ElementManager) set[string] { ss.mu.Lock() defer ss.mu.Unlock() + minPending := ss.minPendingTimestampLocked() + minWatermarkHold := mtime.MaxTimestamp + if ss.watermarkHoldHeap.Len() > 0 { + minWatermarkHold = ss.watermarkHoldHeap[0] + } + // PCollection watermarks are based on their parents's output watermark. _, newIn := ss.UpstreamWatermark() @@ -951,8 +1240,9 @@ func (ss *stageState) updateWatermarks(minPending, minStateHold mtime.Time, em * } // We adjust based on the minimum state hold. - if minStateHold < newOut { - newOut = minStateHold + // If we hold it, mark this stage as refreshable? + if minWatermarkHold < newOut { + newOut = minWatermarkHold } refreshes := set[string]{} // If bigger, advance the output watermark diff --git a/sdks/go/pkg/beam/runners/prism/internal/engine/elementmanager_test.go b/sdks/go/pkg/beam/runners/prism/internal/engine/elementmanager_test.go index 0005ca8ed881b..2dcdff0925a11 100644 --- a/sdks/go/pkg/beam/runners/prism/internal/engine/elementmanager_test.go +++ b/sdks/go/pkg/beam/runners/prism/internal/engine/elementmanager_test.go @@ -294,7 +294,9 @@ func TestStageState_updateWatermarks(t *testing.T) { ss.input = test.initInput ss.output = test.initOutput ss.updateUpstreamWatermark(inputCol, test.upstream) - ss.updateWatermarks(test.minPending, test.minStateHold, em) + ss.pending = append(ss.pending, element{timestamp: test.minPending}) + ss.watermarkHoldHeap = append(ss.watermarkHoldHeap, test.minStateHold) + ss.updateWatermarks(em) if got, want := ss.input, test.wantInput; got != want { pcol, up := ss.UpstreamWatermark() t.Errorf("ss.updateWatermarks(%v,%v); ss.input = %v, want %v (upstream %v %v)", test.minPending, test.minStateHold, got, want, pcol, up) diff --git a/sdks/go/pkg/beam/runners/prism/internal/engine/engine_test.go b/sdks/go/pkg/beam/runners/prism/internal/engine/engine_test.go index af41e089a2e91..6a39b9d207022 100644 --- a/sdks/go/pkg/beam/runners/prism/internal/engine/engine_test.go +++ b/sdks/go/pkg/beam/runners/prism/internal/engine/engine_test.go @@ -89,7 +89,17 @@ func initTestName(fn any) string { return name[n+1:] } -func TestStateAPI(t *testing.T) { +// TestStatefulStages validates that stateful transform execution is correct in +// four different modes for producing bundles: +// +// - Greedily batching all ready keys and elements. +// - All elements for a single key. +// - Only one element for each available key. +// - Only one element. +// +// Executing these pipeline here ensures their coverage is reflected in the +// engine package. +func TestStatefulStages(t *testing.T) { initRunner(t) tests := []struct { @@ -105,6 +115,8 @@ func TestStateAPI(t *testing.T) { {pipeline: primitives.MapStateParDoClear}, {pipeline: primitives.SetStateParDo}, {pipeline: primitives.SetStateParDoClear}, + {pipeline: primitives.TimersEventTimeBounded}, + {pipeline: primitives.TimersEventTimeUnbounded}, } configs := []struct { diff --git a/sdks/go/pkg/beam/runners/prism/internal/engine/timers.go b/sdks/go/pkg/beam/runners/prism/internal/engine/timers.go new file mode 100644 index 0000000000000..245b82dd10ddc --- /dev/null +++ b/sdks/go/pkg/beam/runners/prism/internal/engine/timers.go @@ -0,0 +1,174 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package engine + +import ( + "bytes" + "encoding/binary" + "fmt" + "io" + "math" + + "github.com/apache/beam/sdks/v2/go/pkg/beam/core/graph/coder" + "github.com/apache/beam/sdks/v2/go/pkg/beam/core/graph/mtime" + "github.com/apache/beam/sdks/v2/go/pkg/beam/core/graph/window" + "github.com/apache/beam/sdks/v2/go/pkg/beam/core/typex" + "google.golang.org/protobuf/encoding/protowire" +) + +// DecodeTimer extracts timers to elements for insertion into their keyed queues. +// Returns the key bytes, tag, window exploded elements, and the hold timestamp. +// If the timer has been cleared, no elements will be returned. Any existing timers +// for the tag *must* be cleared from the pending queue. +func decodeTimer(keyDec func(io.Reader) []byte, usesGlobalWindow bool, raw []byte) ([]byte, string, []element) { + keyBytes := keyDec(bytes.NewBuffer(raw)) + + d := decoder{raw: raw, cursor: len(keyBytes)} + tag := string(d.Bytes()) + + var ws []typex.Window + numWin := d.Fixed32() + if usesGlobalWindow { + for i := 0; i < int(numWin); i++ { + ws = append(ws, window.GlobalWindow{}) + } + } else { + // Assume interval windows here, since we don't understand custom windows yet. + for i := 0; i < int(numWin); i++ { + ws = append(ws, d.IntervalWindow()) + } + } + + clear := d.Bool() + hold := mtime.MaxTimestamp + if clear { + return keyBytes, tag, nil + } + + firing := d.Timestamp() + hold = d.Timestamp() + pane := d.Pane() + + var ret []element + for _, w := range ws { + ret = append(ret, element{ + tag: tag, + elmBytes: nil, // indicates this is a timer. + keyBytes: keyBytes, + window: w, + timestamp: firing, + holdTimestamp: hold, + pane: pane, + }) + } + return keyBytes, tag, ret +} + +type decoder struct { + raw []byte + cursor int +} + +// Varint consumes a varint from the bytes, returning the decoded length. +func (d *decoder) Varint() (l int64) { + v, n := protowire.ConsumeVarint(d.raw[d.cursor:]) + if n < 0 { + panic("invalid varint") + } + d.cursor += n + return int64(v) +} + +// Uint64 decodes a value of type uint64. +func (d *decoder) Uint64() uint64 { + defer func() { + d.cursor += 8 + }() + return binary.BigEndian.Uint64(d.raw[d.cursor : d.cursor+8]) +} + +func (d *decoder) Timestamp() mtime.Time { + msec := d.Uint64() + return mtime.Time((int64)(msec) + math.MinInt64) +} + +// Fixed32 decodes a fixed length encoding of uint32, for window decoding. +func (d *decoder) Fixed32() uint32 { + defer func() { + d.cursor += 4 + }() + return binary.BigEndian.Uint32(d.raw[d.cursor : d.cursor+4]) +} + +func (d *decoder) IntervalWindow() window.IntervalWindow { + end := d.Timestamp() + dur := d.Varint() + return window.IntervalWindow{ + End: end, + Start: mtime.FromMilliseconds(end.Milliseconds() - dur), + } +} + +func (d *decoder) Byte() byte { + defer func() { + d.cursor += 1 + }() + return d.raw[d.cursor] +} + +func (d *decoder) Bytes() []byte { + l := d.Varint() + end := d.cursor + int(l) + b := d.raw[d.cursor:end] + d.cursor = end + return b +} + +func (d *decoder) Bool() bool { + if b := d.Byte(); b == 0 { + return false + } else if b == 1 { + return true + } else { + panic(fmt.Sprintf("unable to decode bool; expected {0, 1} got %v", b)) + } +} + +func (d *decoder) Pane() typex.PaneInfo { + first := d.Byte() + pn := coder.NewPane(first & 0x0f) + + switch first >> 4 { + case 0: + // Result encoded in only one pane. + return pn + case 1: + // Result encoded in one pane plus a VarInt encoded integer. + index := d.Varint() + pn.Index = index + if pn.Timing == typex.PaneEarly { + pn.NonSpeculativeIndex = -1 + } else { + pn.NonSpeculativeIndex = pn.Index + } + case 2: + // Result encoded in one pane plus two VarInt encoded integer. + index := d.Varint() + pn.Index = index + pn.NonSpeculativeIndex = d.Varint() + } + return pn +} diff --git a/sdks/go/pkg/beam/runners/prism/internal/handlepardo.go b/sdks/go/pkg/beam/runners/prism/internal/handlepardo.go index 38e7e9454df55..c60a8bf2a3f56 100644 --- a/sdks/go/pkg/beam/runners/prism/internal/handlepardo.go +++ b/sdks/go/pkg/beam/runners/prism/internal/handlepardo.go @@ -82,7 +82,6 @@ func (h *pardo) PrepareTransform(tid string, t *pipepb.PTransform, comps *pipepb !pdo.RequestsFinalization && !pdo.RequiresStableInput && !pdo.RequiresTimeSortedInput && - len(pdo.TimerFamilySpecs) == 0 && pdo.RestrictionCoderId == "" { // Which inputs are Side inputs don't change the graph further, // so they're not included here. Any nearly any ParDo can have them. diff --git a/sdks/go/pkg/beam/runners/prism/internal/jobservices/management.go b/sdks/go/pkg/beam/runners/prism/internal/jobservices/management.go index d3727b6508601..323d8c46efb13 100644 --- a/sdks/go/pkg/beam/runners/prism/internal/jobservices/management.go +++ b/sdks/go/pkg/beam/runners/prism/internal/jobservices/management.go @@ -153,9 +153,11 @@ func (s *Server) Prepare(ctx context.Context, req *jobpb.PrepareJobRequest) (*jo } // Validate all the timer features for _, spec := range pardo.GetTimerFamilySpecs() { - check("TimerFamilySpecs.TimeDomain.Urn", spec.GetTimeDomain()) + check("TimerFamilySpecs.TimeDomain.Urn", spec.GetTimeDomain(), pipepb.TimeDomain_EVENT_TIME) } + check("OnWindowExpirationTimerFamily", pardo.GetOnWindowExpirationTimerFamilySpec(), "") // Unsupported for now. + case "": // Composites can often have no spec if len(t.GetSubtransforms()) > 0 { diff --git a/sdks/go/pkg/beam/runners/prism/internal/preprocess.go b/sdks/go/pkg/beam/runners/prism/internal/preprocess.go index ea4cf2c996953..e357714166a5e 100644 --- a/sdks/go/pkg/beam/runners/prism/internal/preprocess.go +++ b/sdks/go/pkg/beam/runners/prism/internal/preprocess.go @@ -448,7 +448,7 @@ func finalizeStage(stg *stage, comps *pipepb.Components, pipelineFacts *fusionFa if err := (proto.UnmarshalOptions{}).Unmarshal(t.GetSpec().GetPayload(), pardo); err != nil { return fmt.Errorf("unable to decode ParDoPayload for %v", link.Transform) } - if len(pardo.GetTimerFamilySpecs())+len(pardo.GetStateSpecs()) > 0 { + if len(pardo.GetTimerFamilySpecs())+len(pardo.GetStateSpecs())+len(pardo.GetOnWindowExpirationTimerFamilySpec()) > 0 { stg.stateful = true } sis = pardo.GetSideInputs() diff --git a/sdks/go/pkg/beam/runners/prism/internal/stage.go b/sdks/go/pkg/beam/runners/prism/internal/stage.go index 17c2a67b95c60..d677a0cd4cfed 100644 --- a/sdks/go/pkg/beam/runners/prism/internal/stage.go +++ b/sdks/go/pkg/beam/runners/prism/internal/stage.go @@ -63,6 +63,7 @@ type stage struct { internalCols []string // PCollections that escape. Used for precise coder sending. envID string stateful bool + hasTimers []string exe transformExecuter inputTransformID string @@ -78,7 +79,6 @@ func (s *stage) Execute(ctx context.Context, j *jobservices.Job, wk *worker.W, c slog.Debug("Execute: starting bundle", "bundle", rb) var b *worker.B - inputData := em.InputForBundle(rb, s.inputInfo) initialState := em.StateForBundle(rb) var dataReady <-chan struct{} switch s.envID { @@ -88,7 +88,7 @@ func (s *stage) Execute(ctx context.Context, j *jobservices.Job, wk *worker.W, c } tid := s.transforms[0] // Runner transforms are processed immeadiately. - b = s.exe.ExecuteTransform(s.ID, tid, comps.GetTransforms()[tid], comps, rb.Watermark, inputData) + b = s.exe.ExecuteTransform(s.ID, tid, comps.GetTransforms()[tid], comps, rb.Watermark, em.InputForBundle(rb, s.inputInfo)) b.InstID = rb.BundleID slog.Debug("Execute: runner transform", "bundle", rb, slog.String("tid", tid)) @@ -99,14 +99,18 @@ func (s *stage) Execute(ctx context.Context, j *jobservices.Job, wk *worker.W, c close(closed) dataReady = closed case wk.Env: + input, estimatedElements := em.DataAndTimerInputForBundle(rb, s.inputInfo) b = &worker.B{ PBDID: s.ID, InstID: rb.BundleID, InputTransformID: s.inputTransformID, - InputData: inputData, + Input: input, + EstimatedInputElements: estimatedElements, + OutputData: initialState, + HasTimers: s.hasTimers, SinkToPCollection: s.SinkToPCollection, OutputCount: len(s.outputs), @@ -196,8 +200,9 @@ progress: cs := sr.GetChannelSplits()[0] fr := cs.GetFirstResidualElement() // The first residual can be after the end of data, so filter out those cases. - if len(b.InputData) >= int(fr) { - b.InputData = b.InputData[:int(fr)] + if b.EstimatedInputElements >= int(fr) { + b.EstimatedInputElements = int(fr) // Update the estimate for the next split. + // Residuals are returned right away for rescheduling. em.ReturnResiduals(rb, int(fr), s.inputInfo, residualData) } } else { @@ -286,7 +291,71 @@ func buildDescriptor(stg *stage, comps *pipepb.Components, wk *worker.W, em *eng transforms := map[string]*pipepb.PTransform{} for _, tid := range stg.transforms { - transforms[tid] = comps.GetTransforms()[tid] + t := comps.GetTransforms()[tid] + + transforms[tid] = t + + if t.GetSpec().GetUrn() != urns.TransformParDo { + continue + } + + pardo := &pipepb.ParDoPayload{} + if err := (proto.UnmarshalOptions{}).Unmarshal(t.GetSpec().GetPayload(), pardo); err != nil { + return fmt.Errorf("unable to decode ParDoPayload for %v in stage %v", tid, stg.ID) + } + + // We need to ensure the coders can be handled by prism, and are available in the bundle descriptor. + // So we rewrite the transform's Payload with updated coder ids here. + var rewrite bool + var rewriteErr error + for stateID, s := range pardo.GetStateSpecs() { + rewrite = true + rewriteCoder := func(cid *string) { + newCid, err := lpUnknownCoders(*cid, coders, comps.GetCoders()) + if err != nil { + rewriteErr = fmt.Errorf("unable to rewrite coder %v for state %v for transform %v in stage %v:%w", *cid, stateID, tid, stg.ID, err) + return + } + *cid = newCid + } + switch s := s.GetSpec().(type) { + case *pipepb.StateSpec_BagSpec: + rewriteCoder(&s.BagSpec.ElementCoderId) + case *pipepb.StateSpec_SetSpec: + rewriteCoder(&s.SetSpec.ElementCoderId) + case *pipepb.StateSpec_OrderedListSpec: + rewriteCoder(&s.OrderedListSpec.ElementCoderId) + case *pipepb.StateSpec_CombiningSpec: + rewriteCoder(&s.CombiningSpec.AccumulatorCoderId) + case *pipepb.StateSpec_MapSpec: + rewriteCoder(&s.MapSpec.KeyCoderId) + rewriteCoder(&s.MapSpec.ValueCoderId) + case *pipepb.StateSpec_MultimapSpec: + rewriteCoder(&s.MultimapSpec.KeyCoderId) + rewriteCoder(&s.MultimapSpec.ValueCoderId) + case *pipepb.StateSpec_ReadModifyWriteSpec: + rewriteCoder(&s.ReadModifyWriteSpec.CoderId) + } + if rewriteErr != nil { + return rewriteErr + } + } + for timerID, v := range pardo.GetTimerFamilySpecs() { + stg.hasTimers = append(stg.hasTimers, tid) + rewrite = true + newCid, err := lpUnknownCoders(v.GetTimerFamilyCoderId(), coders, comps.GetCoders()) + if err != nil { + return fmt.Errorf("unable to rewrite coder %v for timer %v for transform %v in stage %v: %w", v.GetTimerFamilyCoderId(), timerID, tid, stg.ID, err) + } + v.TimerFamilyCoderId = newCid + } + if rewrite { + pyld, err := proto.MarshalOptions{}.Marshal(pardo) + if err != nil { + return fmt.Errorf("unable to encode ParDoPayload for %v in stage %v after rewrite", tid, stg.ID) + } + t.Spec.Payload = pyld + } } if len(transforms) == 0 { return fmt.Errorf("buildDescriptor: invalid stage - no transforms at all %v", stg.ID) @@ -386,6 +455,13 @@ func buildDescriptor(stg *stage, comps *pipepb.Components, wk *worker.W, em *eng reconcileCoders(coders, comps.GetCoders()) + var timerServiceDescriptor *pipepb.ApiServiceDescriptor + if len(stg.hasTimers) > 0 { + timerServiceDescriptor = &pipepb.ApiServiceDescriptor{ + Url: wk.Endpoint(), + } + } + desc := &fnpb.ProcessBundleDescriptor{ Id: stg.ID, Transforms: transforms, @@ -395,6 +471,7 @@ func buildDescriptor(stg *stage, comps *pipepb.Components, wk *worker.W, em *eng StateApiServiceDescriptor: &pipepb.ApiServiceDescriptor{ Url: wk.Endpoint(), }, + TimerApiServiceDescriptor: timerServiceDescriptor, } stg.desc = desc diff --git a/sdks/go/pkg/beam/runners/prism/internal/unimplemented_test.go b/sdks/go/pkg/beam/runners/prism/internal/unimplemented_test.go index 323773bd4cd62..a50a7fe21b0c8 100644 --- a/sdks/go/pkg/beam/runners/prism/internal/unimplemented_test.go +++ b/sdks/go/pkg/beam/runners/prism/internal/unimplemented_test.go @@ -141,3 +141,25 @@ func TestStateAPI(t *testing.T) { }) } } + +func TestTimers(t *testing.T) { + initRunner(t) + + tests := []struct { + pipeline func(s beam.Scope) + }{ + {pipeline: primitives.TimersEventTimeBounded}, + {pipeline: primitives.TimersEventTimeUnbounded}, + } + + for _, test := range tests { + t.Run(initTestName(test.pipeline), func(t *testing.T) { + p, s := beam.NewPipelineWithRoot() + test.pipeline(s) + _, err := executeWithT(context.Background(), t, p) + if err != nil { + t.Fatalf("pipeline failed, but feature should be implemented in Prism: %v", err) + } + }) + } +} diff --git a/sdks/go/pkg/beam/runners/prism/internal/worker/bundle.go b/sdks/go/pkg/beam/runners/prism/internal/worker/bundle.go index 6ef3a81e6239a..35d23e7024a5c 100644 --- a/sdks/go/pkg/beam/runners/prism/internal/worker/bundle.go +++ b/sdks/go/pkg/beam/runners/prism/internal/worker/bundle.go @@ -38,9 +38,11 @@ type B struct { InstID string // ID for the instruction processing this bundle. PBDID string // ID for the ProcessBundleDescriptor - // InputTransformID is data being sent to the SDK. - InputTransformID string - InputData [][]byte // Data specifically for this bundle. + // InputTransformID is where data is being sent to in the SDK. + InputTransformID string + Input []*engine.Block // Data and Timers for this bundle. + EstimatedInputElements int + HasTimers []string // IterableSideInputData is a map from transformID + inputID, to window, to data. IterableSideInputData map[SideInputKey]map[typex.Window][][]byte @@ -69,9 +71,9 @@ type B struct { // Init initializes the bundle's internal state for waiting on all // data and for relaying a response back. func (b *B) Init() { - // We need to see final data signals that match the number of + // We need to see final data and timer signals that match the number of // outputs the stage this bundle executes posesses - b.dataSema.Store(int32(b.OutputCount)) + b.dataSema.Store(int32(b.OutputCount + len(b.HasTimers))) b.DataWait = make(chan struct{}) if b.OutputCount == 0 { close(b.DataWait) // Can happen if there are no outputs for the bundle. @@ -79,8 +81,8 @@ func (b *B) Init() { b.Resp = make(chan *fnpb.ProcessBundleResponse, 1) } -// DataDone indicates a final element has been received from a Data or Timer output. -func (b *B) DataDone() { +// DataOrTimerDone indicates a final element has been received from a Data or Timer output. +func (b *B) DataOrTimerDone() { sema := b.dataSema.Add(-1) if sema == 0 { close(b.DataWait) @@ -132,23 +134,66 @@ func (b *B) ProcessOn(ctx context.Context, wk *W) <-chan struct{} { }, } - // TODO: make batching decisions. - dataBuf := bytes.Join(b.InputData, []byte{}) + // TODO: make batching decisions on the maxium to send per elements block, to reduce processing time overhead. + for _, block := range b.Input { + elms := &fnpb.Elements{} + + dataBuf := bytes.Join(block.Bytes, []byte{}) + switch block.Kind { + case engine.BlockData: + elms.Data = []*fnpb.Elements_Data{ + { + InstructionId: b.InstID, + TransformId: b.InputTransformID, + Data: dataBuf, + }, + } + case engine.BlockTimer: + elms.Timers = []*fnpb.Elements_Timers{ + { + InstructionId: b.InstID, + TransformId: block.Transform, + TimerFamilyId: block.Family, + Timers: dataBuf, + }, + } + default: + panic("unknown engine.Block kind") + } + + select { + case wk.DataReqs <- elms: + case <-ctx.Done(): + b.DataOrTimerDone() + return b.DataWait + } + } + + // Send last of everything for now. + timers := make([]*fnpb.Elements_Timers, 0, len(b.HasTimers)) + for _, tid := range b.HasTimers { + timers = append(timers, &fnpb.Elements_Timers{ + InstructionId: b.InstID, + TransformId: tid, + IsLast: true, + }) + } select { case wk.DataReqs <- &fnpb.Elements{ + Timers: timers, Data: []*fnpb.Elements_Data{ { InstructionId: b.InstID, TransformId: b.InputTransformID, - Data: dataBuf, IsLast: true, }, }, }: case <-ctx.Done(): - b.DataDone() + b.DataOrTimerDone() return b.DataWait } + return b.DataWait } @@ -184,7 +229,7 @@ func (b *B) Split(ctx context.Context, wk *W, fraction float64, allowedSplits [] b.InputTransformID: { FractionOfRemainder: fraction, AllowedSplitPoints: allowedSplits, - EstimatedInputElements: int64(len(b.InputData)), + EstimatedInputElements: int64(b.EstimatedInputElements), }, }, }, diff --git a/sdks/go/pkg/beam/runners/prism/internal/worker/bundle_test.go b/sdks/go/pkg/beam/runners/prism/internal/worker/bundle_test.go index ba5b10f5fd39a..161fb199ce965 100644 --- a/sdks/go/pkg/beam/runners/prism/internal/worker/bundle_test.go +++ b/sdks/go/pkg/beam/runners/prism/internal/worker/bundle_test.go @@ -20,6 +20,8 @@ import ( "context" "sync" "testing" + + "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/prism/internal/engine" ) func TestBundle_ProcessOn(t *testing.T) { @@ -28,7 +30,11 @@ func TestBundle_ProcessOn(t *testing.T) { InstID: "testInst", PBDID: "testPBDID", OutputCount: 1, - InputData: [][]byte{{1, 2, 3}}, + Input: []*engine.Block{ + { + Kind: engine.BlockData, + Bytes: [][]byte{{1, 2, 3}}, + }}, } b.Init() var completed sync.WaitGroup @@ -37,7 +43,7 @@ func TestBundle_ProcessOn(t *testing.T) { b.ProcessOn(context.Background(), wk) completed.Done() }() - b.DataDone() + b.DataOrTimerDone() gotData := <-wk.DataReqs if got, want := gotData.GetData()[0].GetData(), []byte{1, 2, 3}; !bytes.EqualFold(got, want) { t.Errorf("ProcessOn(): data not sent; got %v, want %v", got, want) diff --git a/sdks/go/pkg/beam/runners/prism/internal/worker/worker.go b/sdks/go/pkg/beam/runners/prism/internal/worker/worker.go index 2859dfe2356d4..155f59a1487b6 100644 --- a/sdks/go/pkg/beam/runners/prism/internal/worker/worker.go +++ b/sdks/go/pkg/beam/runners/prism/internal/worker/worker.go @@ -367,7 +367,23 @@ func (wk *W) Data(data fnpb.BeamFnData_DataServer) error { b.OutputData.WriteData(colID, d.GetData()) } if d.GetIsLast() { - b.DataDone() + b.DataOrTimerDone() + } + } + for _, t := range resp.GetTimers() { + cr, ok := wk.activeInstructions[t.GetInstructionId()] + if !ok { + slog.Info("data.Recv for unknown bundle", "response", resp) + continue + } + // Received data is always for an active ProcessBundle instruction + b := cr.(*B) + + if len(t.GetTimers()) > 0 { + b.OutputData.WriteTimers(t.GetTransformId(), t.GetTimerFamilyId(), t.GetTimers()) + } + if t.GetIsLast() { + b.DataOrTimerDone() } } wk.mu.Unlock() @@ -507,6 +523,7 @@ func (wk *W) State(state fnpb.BeamFnState_StateServer) error { default: panic(fmt.Sprintf("unsupported StateKey Append type: %T: %v", key.GetType(), prototext.Format(key))) } + responses <- &fnpb.StateResponse{ Id: req.GetId(), Response: &fnpb.StateResponse_Append{ diff --git a/sdks/go/pkg/beam/runners/prism/internal/worker/worker_test.go b/sdks/go/pkg/beam/runners/prism/internal/worker/worker_test.go index c45d330168328..b87667eef387c 100644 --- a/sdks/go/pkg/beam/runners/prism/internal/worker/worker_test.go +++ b/sdks/go/pkg/beam/runners/prism/internal/worker/worker_test.go @@ -26,6 +26,7 @@ import ( "github.com/apache/beam/sdks/v2/go/pkg/beam/core/graph/window" "github.com/apache/beam/sdks/v2/go/pkg/beam/core/typex" fnpb "github.com/apache/beam/sdks/v2/go/pkg/beam/model/fnexecution_v1" + "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/prism/internal/engine" "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/test/bufconn" @@ -178,9 +179,11 @@ func TestWorker_Data_HappyPath(t *testing.T) { b := &B{ InstID: instID, PBDID: "teststageID", - InputData: [][]byte{ - {1, 1, 1, 1, 1, 1}, - }, + Input: []*engine.Block{ + { + Kind: engine.BlockData, + Bytes: [][]byte{{1, 1, 1, 1, 1, 1}}, + }}, OutputCount: 1, } b.Init() @@ -208,6 +211,20 @@ func TestWorker_Data_HappyPath(t *testing.T) { if got, want := elements.GetData()[0].GetData(), []byte{1, 1, 1, 1, 1, 1}; !bytes.Equal(got, want) { t.Fatalf("client Data received %v, want %v", got, want) } + if got, want := elements.GetData()[0].GetIsLast(), false; got != want { + t.Fatalf("client Data received was last: got %v, want %v", got, want) + } + + elements, err = dataStream.Recv() + if err != nil { + t.Fatal("expected 2nd data elements:", err) + } + if got, want := elements.GetData()[0].GetInstructionId(), b.InstID; got != want { + t.Fatalf("couldn't receive data elements ID: got %v, want %v", got, want) + } + if got, want := elements.GetData()[0].GetData(), []byte(nil); !bytes.Equal(got, want) { + t.Fatalf("client Data received %v, want %v", got, want) + } if got, want := elements.GetData()[0].GetIsLast(), true; got != want { t.Fatalf("client Data received wasn't last: got %v, want %v", got, want) } diff --git a/sdks/go/test/integration/integration.go b/sdks/go/test/integration/integration.go index 160af1ab741ff..622689c40d0a3 100644 --- a/sdks/go/test/integration/integration.go +++ b/sdks/go/test/integration/integration.go @@ -157,8 +157,8 @@ var prismFilters = []string{ // OOMs currently only lead to heap dumps on Dataflow runner "TestOomParDo", - // The prism runner does not support timers https://github.com/apache/beam/issues/29772. - "TestTimers.*", + // The prism runner does not support processing time timers https://github.com/apache/beam/issues/29772. + "TestTimers_ProcessingTime.*", } var flinkFilters = []string{