Skip to content

Commit

Permalink
[#32929] Add OrderedListState support to Prism. (#33350)
Browse files Browse the repository at this point in the history
  • Loading branch information
lostluck authored Dec 17, 2024
1 parent 9232cd8 commit 8e1e124
Show file tree
Hide file tree
Showing 10 changed files with 385 additions and 11 deletions.
1 change: 1 addition & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@
* X feature added (Java/Python) ([#X](https://github.com/apache/beam/issues/X)).
* Support OnWindowExpiration in Prism ([#32211](https://github.com/apache/beam/issues/32211)).
* This enables initial Java GroupIntoBatches support.
* Support OrderedListState in Prism ([#32929](https://github.com/apache/beam/issues/32929)).

## Breaking Changes

Expand Down
4 changes: 0 additions & 4 deletions runners/prism/java/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -233,10 +233,6 @@ def createPrismValidatesRunnerTask = { name, environmentType ->
excludeCategories 'org.apache.beam.sdk.testing.UsesExternalService'
excludeCategories 'org.apache.beam.sdk.testing.UsesSdkHarnessEnvironment'

// Not yet implemented in Prism
// https://github.com/apache/beam/issues/32929
excludeCategories 'org.apache.beam.sdk.testing.UsesOrderedListState'

// Not supported in Portable Java SDK yet.
// https://github.com/apache/beam/issues?q=is%3Aissue+is%3Aopen+MultimapState
excludeCategories 'org.apache.beam.sdk.testing.UsesMultimapState'
Expand Down
97 changes: 97 additions & 0 deletions sdks/go/pkg/beam/runners/prism/internal/engine/data.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,17 @@ package engine

import (
"bytes"
"cmp"
"fmt"
"log/slog"
"slices"
"sort"

"github.com/apache/beam/sdks/v2/go/pkg/beam/core/graph/coder"
"github.com/apache/beam/sdks/v2/go/pkg/beam/core/graph/window"
"github.com/apache/beam/sdks/v2/go/pkg/beam/core/runtime/exec"
"github.com/apache/beam/sdks/v2/go/pkg/beam/core/typex"
"google.golang.org/protobuf/encoding/protowire"
)

// StateData is a "union" between Bag state and MultiMap state to increase common code.
Expand All @@ -42,6 +46,10 @@ type TimerKey struct {
type TentativeData struct {
Raw map[string][][]byte

// stateTypeLen is a map from LinkID to valueLen function for parsing data.
// Only used by OrderedListState, since Prism must manipulate these datavalues,
// which isn't expected, or a requirement of other state values.
stateTypeLen map[LinkID]func([]byte) int
// state is a map from transformID + UserStateID, to window, to userKey, to datavalues.
state map[LinkID]map[typex.Window]map[string]StateData
// timers is a map from the Timer transform+family to the encoded timer.
Expand Down Expand Up @@ -220,3 +228,92 @@ func (d *TentativeData) ClearMultimapKeysState(stateID LinkID, wKey, uKey []byte
kmap[string(uKey)] = StateData{}
slog.Debug("State() MultimapKeys.Clear", slog.Any("StateID", stateID), slog.Any("UserKey", uKey), slog.Any("WindowKey", wKey))
}

// AppendOrderedListState appends the incoming timestamped data to the existing tentative data bundle.
// Assumes the data is TimestampedValue encoded, which has a BigEndian int64 suffixed to the data.
// This means we may always use the last 8 bytes to determine the value sorting.
//
// The stateID has the Transform and Local fields populated, for the Transform and UserStateID respectively.
func (d *TentativeData) AppendOrderedListState(stateID LinkID, wKey, uKey []byte, data []byte) {
kmap := d.appendState(stateID, wKey)
typeLen := d.stateTypeLen[stateID]
var datums [][]byte

// We need to parse out all values individually for later sorting.
//
// OrderedListState is encoded as KVs with varint encoded millis followed by the value.
// This is not the standard TimestampValueCoder encoding, which
// uses a big-endian long as a suffix to the value. This is important since
// values may be concatenated, and we'll need to split them out out.
//
// The TentativeData.stateTypeLen is populated with a function to extract
// the length of a the next value, so we can skip through elements individually.
for i := 0; i < len(data); {
// Get the length of the VarInt for the timestamp.
_, tn := protowire.ConsumeVarint(data[i:])

// Get the length of the encoded value.
vn := typeLen(data[i+tn:])
prev := i
i += tn + vn
datums = append(datums, data[prev:i])
}

s := StateData{Bag: append(kmap[string(uKey)].Bag, datums...)}
sort.SliceStable(s.Bag, func(i, j int) bool {
vi := s.Bag[i]
vj := s.Bag[j]
return compareTimestampSuffixes(vi, vj)
})
kmap[string(uKey)] = s
slog.Debug("State() OrderedList.Append", slog.Any("StateID", stateID), slog.Any("UserKey", uKey), slog.Any("Window", wKey), slog.Any("NewData", s))
}

func compareTimestampSuffixes(vi, vj []byte) bool {
ims, _ := protowire.ConsumeVarint(vi)
jms, _ := protowire.ConsumeVarint(vj)
return (int64(ims)) < (int64(jms))
}

// GetOrderedListState available state from the tentative bundle data.
// The stateID has the Transform and Local fields populated, for the Transform and UserStateID respectively.
func (d *TentativeData) GetOrderedListState(stateID LinkID, wKey, uKey []byte, start, end int64) [][]byte {
winMap := d.state[stateID]
w := d.toWindow(wKey)
data := winMap[w][string(uKey)]

lo, hi := findRange(data.Bag, start, end)
slog.Debug("State() OrderedList.Get", slog.Any("StateID", stateID), slog.Any("UserKey", uKey), slog.Any("Window", wKey), slog.Group("range", slog.Int64("start", start), slog.Int64("end", end)), slog.Group("outrange", slog.Int("lo", lo), slog.Int("hi", hi)), slog.Any("Data", data.Bag[lo:hi]))
return data.Bag[lo:hi]
}

func cmpSuffix(vs [][]byte, target int64) func(i int) int {
return func(i int) int {
v := vs[i]
ims, _ := protowire.ConsumeVarint(v)
tvsbi := cmp.Compare(target, int64(ims))
slog.Debug("cmpSuffix", "target", target, "bi", ims, "tvsbi", tvsbi)
return tvsbi
}
}

func findRange(bag [][]byte, start, end int64) (int, int) {
lo, _ := sort.Find(len(bag), cmpSuffix(bag, start))
hi, _ := sort.Find(len(bag), cmpSuffix(bag, end))
return lo, hi
}

func (d *TentativeData) ClearOrderedListState(stateID LinkID, wKey, uKey []byte, start, end int64) {
winMap := d.state[stateID]
w := d.toWindow(wKey)
kMap := winMap[w]
data := kMap[string(uKey)]

lo, hi := findRange(data.Bag, start, end)
slog.Debug("State() OrderedList.Clear", slog.Any("StateID", stateID), slog.Any("UserKey", uKey), slog.Any("Window", wKey), slog.Group("range", slog.Int64("start", start), slog.Int64("end", end)), "lo", lo, "hi", hi, slog.Any("PreClearData", data.Bag))

cleared := slices.Delete(data.Bag, lo, hi)
// Zero the current entry to clear.
// Delete makes it difficult to delete the persisted stage state for the key.
kMap[string(uKey)] = StateData{Bag: cleared}
}
222 changes: 222 additions & 0 deletions sdks/go/pkg/beam/runners/prism/internal/engine/data_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,222 @@
// Licensed to the Apache Software Foundation (ASF) under one or more
// contributor license agreements. See the NOTICE file distributed with
// this work for additional information regarding copyright ownership.
// The ASF licenses this file to You under the Apache License, Version 2.0
// (the "License"); you may not use this file except in compliance with
// the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package engine

import (
"bytes"
"encoding/binary"
"math"
"testing"

"github.com/google/go-cmp/cmp"
"google.golang.org/protobuf/encoding/protowire"
)

func TestCompareTimestampSuffixes(t *testing.T) {
t.Run("simple", func(t *testing.T) {
loI := int64(math.MinInt64)
hiI := int64(math.MaxInt64)

loB := binary.BigEndian.AppendUint64(nil, uint64(loI))
hiB := binary.BigEndian.AppendUint64(nil, uint64(hiI))

if compareTimestampSuffixes(loB, hiB) != (loI < hiI) {
t.Errorf("lo vs Hi%v < %v: bytes %v vs %v, %v %v", loI, hiI, loB, hiB, loI < hiI, compareTimestampSuffixes(loB, hiB))
}
})
}

func TestOrderedListState(t *testing.T) {
time1 := protowire.AppendVarint(nil, 11)
time2 := protowire.AppendVarint(nil, 22)
time3 := protowire.AppendVarint(nil, 33)
time4 := protowire.AppendVarint(nil, 44)
time5 := protowire.AppendVarint(nil, 55)

wKey := []byte{} // global window.
uKey := []byte("\u0007userkey")
linkID := LinkID{
Transform: "dofn",
Local: "localStateName",
}
cc := func(a []byte, b ...byte) []byte {
return bytes.Join([][]byte{a, b}, []byte{})
}

t.Run("bool", func(t *testing.T) {
d := TentativeData{
stateTypeLen: map[LinkID]func([]byte) int{
linkID: func(_ []byte) int {
return 1
},
},
}

d.AppendOrderedListState(linkID, wKey, uKey, cc(time3, 1))
d.AppendOrderedListState(linkID, wKey, uKey, cc(time2, 0))
d.AppendOrderedListState(linkID, wKey, uKey, cc(time5, 1))
d.AppendOrderedListState(linkID, wKey, uKey, cc(time1, 1))
d.AppendOrderedListState(linkID, wKey, uKey, cc(time4, 0))

got := d.GetOrderedListState(linkID, wKey, uKey, 0, 60)
want := [][]byte{
cc(time1, 1),
cc(time2, 0),
cc(time3, 1),
cc(time4, 0),
cc(time5, 1),
}
if d := cmp.Diff(want, got); d != "" {
t.Errorf("OrderedList booleans \n%v", d)
}

d.ClearOrderedListState(linkID, wKey, uKey, 12, 54)
got = d.GetOrderedListState(linkID, wKey, uKey, 0, 60)
want = [][]byte{
cc(time1, 1),
cc(time5, 1),
}
if d := cmp.Diff(want, got); d != "" {
t.Errorf("OrderedList booleans, after clear\n%v", d)
}
})
t.Run("float64", func(t *testing.T) {
d := TentativeData{
stateTypeLen: map[LinkID]func([]byte) int{
linkID: func(_ []byte) int {
return 8
},
},
}

d.AppendOrderedListState(linkID, wKey, uKey, cc(time5, 0, 0, 0, 0, 0, 0, 0, 1))
d.AppendOrderedListState(linkID, wKey, uKey, cc(time1, 0, 0, 0, 0, 0, 0, 1, 0))
d.AppendOrderedListState(linkID, wKey, uKey, cc(time3, 0, 0, 0, 0, 0, 1, 0, 0))
d.AppendOrderedListState(linkID, wKey, uKey, cc(time2, 0, 0, 0, 0, 1, 0, 0, 0))
d.AppendOrderedListState(linkID, wKey, uKey, cc(time4, 0, 0, 0, 1, 0, 0, 0, 0))

got := d.GetOrderedListState(linkID, wKey, uKey, 0, 60)
want := [][]byte{
cc(time1, 0, 0, 0, 0, 0, 0, 1, 0),
cc(time2, 0, 0, 0, 0, 1, 0, 0, 0),
cc(time3, 0, 0, 0, 0, 0, 1, 0, 0),
cc(time4, 0, 0, 0, 1, 0, 0, 0, 0),
cc(time5, 0, 0, 0, 0, 0, 0, 0, 1),
}
if d := cmp.Diff(want, got); d != "" {
t.Errorf("OrderedList float64s \n%v", d)
}

d.ClearOrderedListState(linkID, wKey, uKey, 11, 12)
d.ClearOrderedListState(linkID, wKey, uKey, 33, 34)
d.ClearOrderedListState(linkID, wKey, uKey, 55, 56)

got = d.GetOrderedListState(linkID, wKey, uKey, 0, 60)
want = [][]byte{
cc(time2, 0, 0, 0, 0, 1, 0, 0, 0),
cc(time4, 0, 0, 0, 1, 0, 0, 0, 0),
}
if d := cmp.Diff(want, got); d != "" {
t.Errorf("OrderedList float64s, after clear \n%v", d)
}
})

t.Run("varint", func(t *testing.T) {
d := TentativeData{
stateTypeLen: map[LinkID]func([]byte) int{
linkID: func(b []byte) int {
_, n := protowire.ConsumeVarint(b)
return int(n)
},
},
}

d.AppendOrderedListState(linkID, wKey, uKey, cc(time2, protowire.AppendVarint(nil, 56)...))
d.AppendOrderedListState(linkID, wKey, uKey, cc(time4, protowire.AppendVarint(nil, 20067)...))
d.AppendOrderedListState(linkID, wKey, uKey, cc(time3, protowire.AppendVarint(nil, 7777777)...))
d.AppendOrderedListState(linkID, wKey, uKey, cc(time1, protowire.AppendVarint(nil, 424242)...))
d.AppendOrderedListState(linkID, wKey, uKey, cc(time5, protowire.AppendVarint(nil, 0)...))

got := d.GetOrderedListState(linkID, wKey, uKey, 0, 60)
want := [][]byte{
cc(time1, protowire.AppendVarint(nil, 424242)...),
cc(time2, protowire.AppendVarint(nil, 56)...),
cc(time3, protowire.AppendVarint(nil, 7777777)...),
cc(time4, protowire.AppendVarint(nil, 20067)...),
cc(time5, protowire.AppendVarint(nil, 0)...),
}
if d := cmp.Diff(want, got); d != "" {
t.Errorf("OrderedList int32 \n%v", d)
}
})
t.Run("lp", func(t *testing.T) {
d := TentativeData{
stateTypeLen: map[LinkID]func([]byte) int{
linkID: func(b []byte) int {
l, n := protowire.ConsumeVarint(b)
return int(l) + n
},
},
}

d.AppendOrderedListState(linkID, wKey, uKey, cc(time1, []byte("\u0003one")...))
d.AppendOrderedListState(linkID, wKey, uKey, cc(time2, []byte("\u0003two")...))
d.AppendOrderedListState(linkID, wKey, uKey, cc(time3, []byte("\u0005three")...))
d.AppendOrderedListState(linkID, wKey, uKey, cc(time4, []byte("\u0004four")...))
d.AppendOrderedListState(linkID, wKey, uKey, cc(time5, []byte("\u0019FourHundredAndEleventyTwo")...))

got := d.GetOrderedListState(linkID, wKey, uKey, 0, 60)
want := [][]byte{
cc(time1, []byte("\u0003one")...),
cc(time2, []byte("\u0003two")...),
cc(time3, []byte("\u0005three")...),
cc(time4, []byte("\u0004four")...),
cc(time5, []byte("\u0019FourHundredAndEleventyTwo")...),
}
if d := cmp.Diff(want, got); d != "" {
t.Errorf("OrderedList int32 \n%v", d)
}
})
t.Run("lp_onecall", func(t *testing.T) {
d := TentativeData{
stateTypeLen: map[LinkID]func([]byte) int{
linkID: func(b []byte) int {
l, n := protowire.ConsumeVarint(b)
return int(l) + n
},
},
}
d.AppendOrderedListState(linkID, wKey, uKey, bytes.Join([][]byte{
time5, []byte("\u0019FourHundredAndEleventyTwo"),
time3, []byte("\u0005three"),
time2, []byte("\u0003two"),
time1, []byte("\u0003one"),
time4, []byte("\u0004four"),
}, nil))

got := d.GetOrderedListState(linkID, wKey, uKey, 0, 60)
want := [][]byte{
cc(time1, []byte("\u0003one")...),
cc(time2, []byte("\u0003two")...),
cc(time3, []byte("\u0005three")...),
cc(time4, []byte("\u0004four")...),
cc(time5, []byte("\u0019FourHundredAndEleventyTwo")...),
}
if d := cmp.Diff(want, got); d != "" {
t.Errorf("OrderedList int32 \n%v", d)
}
})
}
11 changes: 8 additions & 3 deletions sdks/go/pkg/beam/runners/prism/internal/engine/elementmanager.go
Original file line number Diff line number Diff line change
Expand Up @@ -269,8 +269,10 @@ func (em *ElementManager) StageAggregates(ID string) {

// StageStateful marks the given stage as stateful, which means elements are
// processed by key.
func (em *ElementManager) StageStateful(ID string) {
em.stages[ID].stateful = true
func (em *ElementManager) StageStateful(ID string, stateTypeLen map[LinkID]func([]byte) int) {
ss := em.stages[ID]
ss.stateful = true
ss.stateTypeLen = stateTypeLen
}

// StageOnWindowExpiration marks the given stage as stateful, which means elements are
Expand Down Expand Up @@ -669,7 +671,9 @@ func (em *ElementManager) StateForBundle(rb RunBundle) TentativeData {
ss := em.stages[rb.StageID]
ss.mu.Lock()
defer ss.mu.Unlock()
var ret TentativeData
ret := TentativeData{
stateTypeLen: ss.stateTypeLen,
}
keys := ss.inprogressKeysByBundle[rb.BundleID]
// TODO(lostluck): Also track windows per bundle, to reduce copying.
if len(ss.state) > 0 {
Expand Down Expand Up @@ -1136,6 +1140,7 @@ type stageState struct {
inprogressKeys set[string] // all keys that are assigned to bundles.
inprogressKeysByBundle map[string]set[string] // bundle to key assignments.
state map[LinkID]map[typex.Window]map[string]StateData // state data for this stage, from {tid, stateID} -> window -> userKey
stateTypeLen map[LinkID]func([]byte) int // map from state to a function that will produce the total length of a single value in bytes.

// Accounting for handling watermark holds for timers.
// We track the count of timers with the same hold, and clear it from
Expand Down
Loading

0 comments on commit 8e1e124

Please sign in to comment.