Skip to content

Commit

Permalink
[#24931][Go SDK] Make element checkpoints independant (#24932)
Browse files Browse the repository at this point in the history
  • Loading branch information
lostluck authored Jan 9, 2023
1 parent a803789 commit e0f463c
Show file tree
Hide file tree
Showing 13 changed files with 419 additions and 72 deletions.
49 changes: 34 additions & 15 deletions sdks/go/pkg/beam/core/runtime/exec/datasource.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,10 +127,10 @@ func (r *byteCountReader) reset() int {
}

// Process opens the data source, reads and decodes data, kicking off element processing.
func (n *DataSource) Process(ctx context.Context) error {
func (n *DataSource) Process(ctx context.Context) ([]*Checkpoint, error) {
r, err := n.source.OpenRead(ctx, n.SID)
if err != nil {
return err
return nil, err
}
defer r.Close()
n.PCol.resetSize() // initialize the size distribution for this bundle.
Expand All @@ -154,23 +154,24 @@ func (n *DataSource) Process(ctx context.Context) error {
cp = MakeElementDecoder(c)
}

var checkpoints []*Checkpoint
for {
if n.incrementIndexAndCheckSplit() {
return nil
break
}
// TODO(lostluck) 2020/02/22: Should we include window headers or just count the element sizes?
ws, t, pn, err := DecodeWindowedValueHeader(wc, r)
if err != nil {
if err == io.EOF {
return nil
break
}
return errors.Wrap(err, "source failed")
return nil, errors.Wrap(err, "source failed")
}

// Decode key or parallel element.
pe, err := cp.Decode(&bcr)
if err != nil {
return errors.Wrap(err, "source decode failed")
return nil, errors.Wrap(err, "source decode failed")
}
pe.Timestamp = t
pe.Windows = ws
Expand All @@ -180,18 +181,32 @@ func (n *DataSource) Process(ctx context.Context) error {
for _, cv := range cvs {
values, err := n.makeReStream(ctx, cv, &bcr, len(cvs) == 1 && n.singleIterate)
if err != nil {
return err
return nil, err
}
valReStreams = append(valReStreams, values)
}

if err := n.Out.ProcessElement(ctx, pe, valReStreams...); err != nil {
return err
return nil, err
}
// Collect the actual size of the element, and reset the bytecounter reader.
n.PCol.addSize(int64(bcr.reset()))
bcr.reader = r

// Check if there's a continuation and return residuals
// Needs to be done immeadiately after processing to not lose the element.
if c := n.getProcessContinuation(); c != nil {
cp, err := n.checkpointThis(c)
if err != nil {
// Errors during checkpointing should fail a bundle.
return nil, err
}
if cp != nil {
checkpoints = append(checkpoints, cp)
}
}
}
return checkpoints, nil
}

func (n *DataSource) makeReStream(ctx context.Context, cv ElementDecoder, bcr *byteCountReader, onlyStream bool) (ReStream, error) {
Expand Down Expand Up @@ -397,18 +412,22 @@ func (n *DataSource) makeEncodeElms() func([]*FullValue) ([][]byte, error) {
return encodeElms
}

type Checkpoint struct {
SR SplitResult
Reapply time.Duration
}

// Checkpoint attempts to split an SDF that has self-checkpointed (e.g. returned a
// ProcessContinuation) and needs to be resumed later. If the underlying DoFn is not
// splittable or has not returned a resuming continuation, the function returns an empty
// SplitResult, a negative resumption time, and a false boolean to indicate that no split
// occurred.
func (n *DataSource) Checkpoint() (SplitResult, time.Duration, bool, error) {
func (n *DataSource) checkpointThis(pc sdf.ProcessContinuation) (*Checkpoint, error) {
n.mu.Lock()
defer n.mu.Unlock()

pc := n.getProcessContinuation()
if pc == nil || !pc.ShouldResume() {
return SplitResult{}, -1 * time.Minute, false, nil
return nil, nil
}

su := SplittableUnit(n.Out.(*ProcessSizedElementsAndRestrictions))
Expand All @@ -418,17 +437,17 @@ func (n *DataSource) Checkpoint() (SplitResult, time.Duration, bool, error) {
// Checkpointing is functionally a split at fraction 0.0
rs, err := su.Checkpoint()
if err != nil {
return SplitResult{}, -1 * time.Minute, false, err
return nil, err
}
if len(rs) == 0 {
return SplitResult{}, -1 * time.Minute, false, nil
return nil, nil
}

encodeElms := n.makeEncodeElms()

rsEnc, err := encodeElms(rs)
if err != nil {
return SplitResult{}, -1 * time.Minute, false, err
return nil, err
}

res := SplitResult{
Expand All @@ -437,7 +456,7 @@ func (n *DataSource) Checkpoint() (SplitResult, time.Duration, bool, error) {
InId: su.GetInputId(),
OW: ow,
}
return res, pc.ResumeDelay(), true, nil
return &Checkpoint{SR: res, Reapply: pc.ResumeDelay()}, nil
}

// Split takes a sorted set of potential split indices and a fraction of the
Expand Down
155 changes: 152 additions & 3 deletions sdks/go/pkg/beam/core/runtime/exec/datasource_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,25 @@
package exec

import (
"bytes"
"context"
"fmt"
"io"
"math"
"reflect"
"testing"
"time"

"github.com/apache/beam/sdks/v2/go/pkg/beam/core/graph"
"github.com/apache/beam/sdks/v2/go/pkg/beam/core/graph/coder"
"github.com/apache/beam/sdks/v2/go/pkg/beam/core/graph/mtime"
"github.com/apache/beam/sdks/v2/go/pkg/beam/core/graph/window"
"github.com/apache/beam/sdks/v2/go/pkg/beam/core/runtime/coderx"
"github.com/apache/beam/sdks/v2/go/pkg/beam/core/sdf"
"github.com/apache/beam/sdks/v2/go/pkg/beam/core/typex"
"github.com/apache/beam/sdks/v2/go/pkg/beam/core/util/reflectx"
"github.com/apache/beam/sdks/v2/go/pkg/beam/internal/errors"
"github.com/apache/beam/sdks/v2/go/pkg/beam/io/rtrackers/offsetrange"
"google.golang.org/protobuf/types/known/timestamppb"
)

Expand Down Expand Up @@ -314,7 +321,10 @@ func TestDataSource_Split(t *testing.T) {
t.Fatalf("error in Split: got primary index = %v, want %v ", got, want)
}

runOnRoots(ctx, t, p, "Process", Root.Process)
runOnRoots(ctx, t, p, "Process", func(root Root, ctx context.Context) error {
_, err := root.Process(ctx)
return err
})
runOnRoots(ctx, t, p, "FinishBundle", Root.FinishBundle)

validateSource(t, out, source, makeValues(test.expected...))
Expand Down Expand Up @@ -449,7 +459,10 @@ func TestDataSource_Split(t *testing.T) {
if got, want := splitRes.PI, test.splitIdx-1; got != want {
t.Fatalf("error in Split: got primary index = %v, want %v ", got, want)
}
runOnRoots(ctx, t, p, "Process", Root.Process)
runOnRoots(ctx, t, p, "Process", func(root Root, ctx context.Context) error {
_, err := root.Process(ctx)
return err
})
runOnRoots(ctx, t, p, "FinishBundle", Root.FinishBundle)

validateSource(t, out, source, makeValues(test.expected...))
Expand Down Expand Up @@ -582,7 +595,10 @@ func TestDataSource_Split(t *testing.T) {
if sr, err := p.Split(ctx, SplitPoints{Splits: []int64{0}, Frac: -1}); err != nil || !sr.Unsuccessful {
t.Fatalf("p.Split(active) = %v,%v want unsuccessful split & nil err", sr, err)
}
runOnRoots(ctx, t, p, "Process", Root.Process)
runOnRoots(ctx, t, p, "Process", func(root Root, ctx context.Context) error {
_, err := root.Process(ctx)
return err
})
if sr, err := p.Split(ctx, SplitPoints{Splits: []int64{0}, Frac: -1}); err != nil || !sr.Unsuccessful {
t.Fatalf("p.Split(active, unable to get desired split) = %v,%v want unsuccessful split & nil err", sr, err)
}
Expand Down Expand Up @@ -858,6 +874,139 @@ func TestSplitHelper(t *testing.T) {
})
}

func TestCheckpointing(t *testing.T) {
t.Run("nil", func(t *testing.T) {
cps, err := (&DataSource{}).checkpointThis(nil)
if err != nil {
t.Fatalf("checkpointThis() = %v, %v", cps, err)
}
})
t.Run("Stop", func(t *testing.T) {
cps, err := (&DataSource{}).checkpointThis(sdf.StopProcessing())
if err != nil {
t.Fatalf("checkpointThis() = %v, %v", cps, err)
}
})
t.Run("Delay_no_residuals", func(t *testing.T) {
wesInv, _ := newWatermarkEstimatorStateInvoker(nil)
root := &DataSource{
Out: &ProcessSizedElementsAndRestrictions{
PDo: &ParDo{},
wesInv: wesInv,
rt: offsetrange.NewTracker(offsetrange.Restriction{}),
elm: &FullValue{
Windows: window.SingleGlobalWindow,
},
},
}
cp, err := root.checkpointThis(sdf.ResumeProcessingIn(time.Second * 13))
if err != nil {
t.Fatalf("checkpointThis() = %v, %v, want nil", cp, err)
}
if cp != nil {
t.Fatalf("checkpointThis() = %v, want nil", cp)
}
})
dfn, err := graph.NewDoFn(&CheckpointingSdf{delay: time.Minute}, graph.NumMainInputs(graph.MainSingle))
if err != nil {
t.Fatalf("invalid function: %v", err)
}

intCoder, _ := coderx.NewVarIntZ(reflectx.Int)
ERSCoder := coder.NewKV([]*coder.Coder{
coder.NewKV([]*coder.Coder{
coder.CoderFrom(intCoder), // Element
coder.NewKV([]*coder.Coder{
coder.NewR(typex.New(reflect.TypeOf((*offsetrange.Restriction)(nil)).Elem())), // Restriction
coder.NewBool(), // Watermark State
}),
}),
coder.NewDouble(), // Size
})
wvERSCoder := coder.NewW(
ERSCoder,
coder.NewGlobalWindow(),
)

rest := offsetrange.Restriction{Start: 1, End: 10}
value := &FullValue{
Elm: &FullValue{
Elm: 42,
Elm2: &FullValue{
Elm: rest, // Restriction
Elm2: false, // Watermark State falsie
},
},
Elm2: rest.Size(),
Windows: window.SingleGlobalWindow,
Timestamp: mtime.MaxTimestamp,
Pane: typex.NoFiringPane(),
}
t.Run("Delay_residuals_Process", func(t *testing.T) {
ctx := context.Background()
wesInv, _ := newWatermarkEstimatorStateInvoker(nil)
rest := offsetrange.Restriction{Start: 1, End: 10}
root := &DataSource{
Coder: wvERSCoder,
Out: &ProcessSizedElementsAndRestrictions{
PDo: &ParDo{
Fn: dfn,
Out: []Node{&Discard{}},
},
TfId: "testTransformID",
wesInv: wesInv,
rt: offsetrange.NewTracker(rest),
},
}
if err := root.Up(ctx); err != nil {
t.Fatalf("invalid function: %v", err)
}
if err := root.Out.Up(ctx); err != nil {
t.Fatalf("invalid function: %v", err)
}

enc := MakeElementEncoder(wvERSCoder)
var buf bytes.Buffer

// We encode the element several times to ensure we don't
// drop any residuals, the root of issue #24931.
wantCount := 3
for i := 0; i < wantCount; i++ {
if err := enc.Encode(value, &buf); err != nil {
t.Fatalf("couldn't encode value: %v", err)
}
}

if err := root.StartBundle(ctx, "testBund", DataContext{
Data: &TestDataManager{
R: io.NopCloser(&buf),
},
},
); err != nil {
t.Fatalf("invalid function: %v", err)
}
cps, err := root.Process(ctx)
if err != nil {
t.Fatalf("Process() = %v, %v, want nil", cps, err)
}
if got, want := len(cps), wantCount; got != want {
t.Fatalf("Process() = len %v checkpoints, want %v", got, want)
}
// Check each checkpoint has the expected values.
for _, cp := range cps {
if got, want := cp.Reapply, time.Minute; got != want {
t.Errorf("Process(delay(%v)) delay = %v, want %v", want, got, want)
}
if got, want := cp.SR.TId, root.Out.(*ProcessSizedElementsAndRestrictions).TfId; got != want {
t.Errorf("Process() transformID = %v, want %v", got, want)
}
if got, want := cp.SR.InId, "i0"; got != want {
t.Errorf("Process() transformID = %v, want %v", got, want)
}
}
})
}

func runOnRoots(ctx context.Context, t *testing.T, p *Plan, name string, mthd func(Root, context.Context) error) {
t.Helper()
for i, root := range p.roots {
Expand Down
25 changes: 14 additions & 11 deletions sdks/go/pkg/beam/core/runtime/exec/plan.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,12 @@ import (
// from a part of a pipeline. A plan can be used to process multiple bundles
// serially.
type Plan struct {
id string // id of the bundle descriptor for this plan
roots []Root
units []Unit
pcols []*PCollection
bf *bundleFinalizer
id string // id of the bundle descriptor for this plan
roots []Root
units []Unit
pcols []*PCollection
bf *bundleFinalizer
checkpoints []*Checkpoint

status Status

Expand Down Expand Up @@ -126,7 +127,11 @@ func (p *Plan) Execute(ctx context.Context, id string, manager DataContext) erro
}
}
for _, root := range p.roots {
if err := callNoPanic(ctx, root.Process); err != nil {
if err := callNoPanic(ctx, func(ctx context.Context) error {
cps, err := root.Process(ctx)
p.checkpoints = cps
return err
}); err != nil {
p.status = Broken
return errors.Wrapf(err, "while executing Process for %v", p)
}
Expand Down Expand Up @@ -281,9 +286,7 @@ func (p *Plan) Split(ctx context.Context, s SplitPoints) (SplitResult, error) {
}

// Checkpoint attempts to split an SDF if the DoFn self-checkpointed.
func (p *Plan) Checkpoint() (SplitResult, time.Duration, bool, error) {
if p.source != nil {
return p.source.Checkpoint()
}
return SplitResult{}, -1 * time.Minute, false, nil
func (p *Plan) Checkpoint() []*Checkpoint {
defer func() { p.checkpoints = nil }()
return p.checkpoints
}
Loading

0 comments on commit e0f463c

Please sign in to comment.