Skip to content

Commit

Permalink
[Go SDK]: Allow SDF methods to have context param and error return va…
Browse files Browse the repository at this point in the history
…lue (#25437)

* Allow context param and error return value in SDF validation

* Use context param and error return value in SDF method invocation

* Run go fmt

* Clean up error messages from "fn reflect.methodValueCall"

* Validate return value count in a more correct way
  • Loading branch information
johannaojeling authored Feb 16, 2023
1 parent 9fb9d2c commit 29ea6e0
Show file tree
Hide file tree
Showing 12 changed files with 1,227 additions and 314 deletions.
108 changes: 73 additions & 35 deletions sdks/go/pkg/beam/core/graph/fn.go
Original file line number Diff line number Diff line change
Expand Up @@ -866,14 +866,15 @@ func validateSdfSignatures(fn *Fn, numMainIn mainInputs) error {
// CreateInitialRestriction.
if numMainIn == MainUnknown {
initialRestFn := fn.methods[createInitialRestrictionName]
paramNum := len(initialRestFn.Param)
paramNum := len(initialRestFn.Params(funcx.FnValue))

switch paramNum {
case int(MainSingle), int(MainKv):
num = paramNum
default: // Can't infer because method has invalid # of main inputs.
err := errors.Errorf("invalid number of params in method %v. got: %v, want: %v or %v",
err := errors.Errorf("invalid number of main input params in method %v. got: %v, want: %v or %v",
createInitialRestrictionName, paramNum, int(MainSingle), int(MainKv))
return errors.SetTopLevelMsgf(err, "Invalid number of parameters in method %v. "+
return errors.SetTopLevelMsgf(err, "Invalid number of main input parameters in method %v. "+
"Got: %v, Want: %v or %v. Check that the signature conforms to the expected signature for %v, "+
"and that elements in SDF method parameters match elements in %v.",
createInitialRestrictionName, paramNum, int(MainSingle), int(MainKv), createInitialRestrictionName, processElementName)
Expand All @@ -894,7 +895,7 @@ func validateSdfSignatures(fn *Fn, numMainIn mainInputs) error {
// in each SDF method in the given Fn, and returns an error if a method has an
// invalid/unexpected number.
func validateSdfSigNumbers(fn *Fn, num int) error {
paramNums := map[string]int{
reqParamNums := map[string]int{
createInitialRestrictionName: num,
splitRestrictionName: num + 1,
restrictionSizeName: num + 1,
Expand All @@ -904,32 +905,52 @@ func validateSdfSigNumbers(fn *Fn, num int) error {
optionalSdfs := map[string]bool{
truncateRestrictionName: true,
}
returnNum := 1 // TODO(BEAM-3301): Enable optional error params in SDF methods.
reqReturnNum := 1

for _, name := range sdfNames {
method, ok := fn.methods[name]
if !ok && optionalSdfs[name] {
continue
}
if len(method.Param) != paramNums[name] {
err := errors.Errorf("unexpected number of params in method %v. got: %v, want: %v",
name, len(method.Param), paramNums[name])

reqParamNum := reqParamNums[name]
if !sdfHasValidParamNum(method.Param, reqParamNum) {
err := errors.Errorf("unexpected number of params in method %v. got: %v, want: %v or optionally %v "+
"if first param is of type context.Context", name, len(method.Param), reqParamNum, reqParamNum+1)
return errors.SetTopLevelMsgf(err, "Unexpected number of parameters in method %v. "+
"Got: %v, Want: %v. Check that the signature conforms to the expected signature for %v, "+
"and that elements in SDF method parameters match elements in %v.",
name, len(method.Param), paramNums[name], name, processElementName)
"Got: %v, Want: %v or optionally %v if first param is of type context.Context. "+
"Check that the signature conforms to the expected signature for %v, and that elements in SDF method "+
"parameters match elements in %v.", name, len(method.Param), reqParamNum, reqParamNum+1,
name, processElementName)
}
if len(method.Ret) != returnNum {
err := errors.Errorf("unexpected number of returns in method %v. got: %v, want: %v",
name, len(method.Ret), returnNum)
if !sdfHasValidReturnNum(method.Ret, reqReturnNum) {
err := errors.Errorf("unexpected number of returns in method %v. got: %v, want: %v or optionally %v "+
"if last value is of type error", name, len(method.Ret), reqReturnNum, reqReturnNum+1)
return errors.SetTopLevelMsgf(err, "Unexpected number of return values in method %v. "+
"Got: %v, Want: %v. Check that the signature conforms to the expected signature for %v.",
name, len(method.Ret), returnNum, name)
"Got: %v, Want: %v or optionally %v if last value is of type error. "+
"Check that the signature conforms to the expected signature for %v.",
name, len(method.Ret), reqReturnNum, reqReturnNum+1, name)
}
}
return nil
}

func sdfHasValidParamNum(params []funcx.FnParam, requiredNum int) bool {
if len(params) == requiredNum {
return true
}

return len(params) == requiredNum+1 && params[0].Kind == funcx.FnContext
}

func sdfHasValidReturnNum(returns []funcx.ReturnParam, requiredNum int) bool {
if len(returns) == requiredNum {
return true
}

return len(returns) == requiredNum+1 && returns[len(returns)-1].Kind == funcx.RetError
}

// validateSdfSigTypes validates the types of the parameters and return values
// in each SDF method in the given Fn, and returns an error if a method has an
// invalid/mismatched type. Assumes that the number of parameters and return
Expand All @@ -940,22 +961,25 @@ func validateSdfSigTypes(fn *Fn, num int) error {

for _, name := range requiredSdfNames {
method := fn.methods[name]
startIdx := sdfRequiredParamStartIndex(method)

switch name {
case createInitialRestrictionName:
if err := validateSdfElementT(fn, createInitialRestrictionName, method, num, 0); err != nil {
if err := validateSdfElementT(fn, createInitialRestrictionName, method, num, startIdx); err != nil {
return err
}
case splitRestrictionName:
if err := validateSdfElementT(fn, splitRestrictionName, method, num, 0); err != nil {
if err := validateSdfElementT(fn, splitRestrictionName, method, num, startIdx); err != nil {
return err
}
if method.Param[num].T != restrictionT {
idx := num + startIdx
if method.Param[idx].T != restrictionT {
err := errors.Errorf("mismatched restriction type in method %v, param %v. got: %v, want: %v",
splitRestrictionName, num, method.Param[num].T, restrictionT)
splitRestrictionName, idx, method.Param[idx].T, restrictionT)
return errors.SetTopLevelMsgf(err, "Mismatched restriction type in method %v, "+
"parameter at index %v. Got: %v, Want: %v (from method %v). "+
"Ensure that all restrictions in an SDF are the same type.",
splitRestrictionName, num, method.Param[num].T, restrictionT, createInitialRestrictionName)
splitRestrictionName, idx, method.Param[idx].T, restrictionT, createInitialRestrictionName)
}
if method.Ret[0].T.Kind() != reflect.Slice ||
method.Ret[0].T.Elem() != restrictionT {
Expand All @@ -967,16 +991,17 @@ func validateSdfSigTypes(fn *Fn, num int) error {
splitRestrictionName, 0, method.Ret[0].T, reflect.SliceOf(restrictionT), createInitialRestrictionName, splitRestrictionName)
}
case restrictionSizeName:
if err := validateSdfElementT(fn, restrictionSizeName, method, num, 0); err != nil {
if err := validateSdfElementT(fn, restrictionSizeName, method, num, startIdx); err != nil {
return err
}
if method.Param[num].T != restrictionT {
idx := num + startIdx
if method.Param[idx].T != restrictionT {
err := errors.Errorf("mismatched restriction type in method %v, param %v. got: %v, want: %v",
restrictionSizeName, num, method.Param[num].T, restrictionT)
restrictionSizeName, idx, method.Param[idx].T, restrictionT)
return errors.SetTopLevelMsgf(err, "Mismatched restriction type in method %v, "+
"parameter at index %v. Got: %v, Want: %v (from method %v). "+
"Ensure that all restrictions in an SDF are the same type.",
restrictionSizeName, num, method.Param[num].T, restrictionT, createInitialRestrictionName)
restrictionSizeName, idx, method.Param[idx].T, restrictionT, createInitialRestrictionName)
}
if method.Ret[0].T != reflectx.Float64 {
err := errors.Errorf("invalid output type in method %v, return %v. got: %v, want: %v",
Expand All @@ -986,13 +1011,13 @@ func validateSdfSigTypes(fn *Fn, num int) error {
restrictionSizeName, 0, method.Ret[0].T, reflectx.Float64)
}
case createTrackerName:
if method.Param[0].T != restrictionT {
if method.Param[startIdx].T != restrictionT {
err := errors.Errorf("mismatched restriction type in method %v, param %v. got: %v, want: %v",
createTrackerName, 0, method.Param[0].T, restrictionT)
createTrackerName, startIdx, method.Param[startIdx].T, restrictionT)
return errors.SetTopLevelMsgf(err, "Mismatched restriction type in method %v, "+
"parameter at index %v. Got: %v, Want: %v (from method %v). "+
"Ensure that all restrictions in an SDF are the same type.",
createTrackerName, 0, method.Param[0].T, restrictionT, createInitialRestrictionName)
createTrackerName, startIdx, method.Param[startIdx].T, restrictionT, createInitialRestrictionName)
}
if !method.Ret[0].T.Implements(rTrackerT) {
err := errors.Errorf("invalid output type in method %v, return %v: %v does not implement sdf.RTracker",
Expand Down Expand Up @@ -1020,15 +1045,18 @@ func validateSdfSigTypes(fn *Fn, num int) error {
if !ok {
continue
}

startIdx := sdfRequiredParamStartIndex(method)

switch name {
case truncateRestrictionName:
if method.Param[0].T != rTrackerImplT {
if method.Param[startIdx].T != rTrackerImplT {
err := errors.Errorf("mismatched restriction tracker type in method %v, param %v. got: %v, want: %v",
truncateRestrictionName, 0, method.Param[0].T, rTrackerImplT)
truncateRestrictionName, startIdx, method.Param[startIdx].T, rTrackerImplT)
return errors.SetTopLevelMsgf(err, "Mismatched restriction tracker type in method %v, "+
"parameter at index %v. Got: %v, Want: %v (from method %v). "+
"Ensure that restriction tracker is the first parameter.",
truncateRestrictionName, 0, method.Param[0].T, rTrackerImplT, createTrackerName)
truncateRestrictionName, startIdx, method.Param[startIdx].T, rTrackerImplT, createTrackerName)
}
if method.Ret[0].T != restrictionT {
err := errors.Errorf("invalid output type in method %v, return %v. got: %v, want: %v",
Expand All @@ -1052,6 +1080,14 @@ func validateSdfSigTypes(fn *Fn, num int) error {
return nil
}

func sdfRequiredParamStartIndex(method *funcx.Fn) int {
if ctxIndex, ok := method.Context(); ok {
return ctxIndex + 1
}

return 0
}

// validateSdfElementT validates that element types in an SDF method are
// consistent with the ProcessElement method. This method assumes that the
// first 'num' parameters starting with startIndex are the elements.
Expand All @@ -1062,13 +1098,14 @@ func validateSdfElementT(fn *Fn, name string, method *funcx.Fn, num int, startIn
pos, _, _ := processFn.Inputs()

for i := 0; i < num; i++ {
if method.Param[i+startIndex].T != processFn.Param[pos+i].T {
idx := i + startIndex
if method.Param[idx].T != processFn.Param[pos+i].T {
err := errors.Errorf("mismatched element type in method %v, param %v. got: %v, want: %v",
name, i, method.Param[i].T, processFn.Param[pos+i].T)
name, idx, method.Param[idx].T, processFn.Param[pos+i].T)
return errors.SetTopLevelMsgf(err, "Mismatched element type in method %v, "+
"parameter at index %v. Got: %v, Want: %v (from method %v). "+
"Ensure that element parameters in SDF methods have consistent types with element parameters in %v.",
name, i, method.Param[i].T, processFn.Param[pos+i].T, processElementName, processElementName)
name, idx, method.Param[idx].T, processFn.Param[pos+i].T, processElementName, processElementName)
}
}
return nil
Expand Down Expand Up @@ -1178,7 +1215,8 @@ func validateStatefulWatermarkSig(fn *Fn, numMainIn int) error {
// CreateInitialRestriction.
if numMainIn == int(MainUnknown) {
initialRestFn := fn.methods[createInitialRestrictionName]
paramNum := len(initialRestFn.Param)
paramNum := len(initialRestFn.Params(funcx.FnValue))

switch paramNum {
case int(MainSingle), int(MainKv):
numMainIn = paramNum
Expand Down
87 changes: 87 additions & 0 deletions sdks/go/pkg/beam/core/graph/fn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,9 @@ func TestNewDoFnSdf(t *testing.T) {
}{
{dfn: &GoodSdf{}, main: MainSingle},
{dfn: &GoodSdfKv{}, main: MainKv},
{dfn: &GoodSdfWContext{}, main: MainSingle},
{dfn: &GoodSdfKvWContext{}, main: MainKv},
{dfn: &GoodSdfWErr{}, main: MainSingle},
{dfn: &GoodIgnoreOtherExportedMethods{}, main: MainSingle},
}

Expand Down Expand Up @@ -987,6 +990,90 @@ func (fn *GoodSdfKv) TruncateRestriction(*RTrackerT, int, int) RestT {
return RestT{}
}

type GoodSdfWContext struct {
*GoodDoFn
}

func (fn *GoodSdfWContext) CreateInitialRestriction(context.Context, int) RestT {
return RestT{}
}

func (fn *GoodSdfWContext) SplitRestriction(context.Context, int, RestT) []RestT {
return []RestT{}
}

func (fn *GoodSdfWContext) RestrictionSize(context.Context, int, RestT) float64 {
return 0
}

func (fn *GoodSdfWContext) CreateTracker(context.Context, RestT) *RTrackerT {
return &RTrackerT{}
}

func (fn *GoodSdfWContext) ProcessElement(context.Context, *RTrackerT, int) (int, sdf.ProcessContinuation) {
return 0, sdf.StopProcessing()
}

func (fn *GoodSdfWContext) TruncateRestriction(context.Context, *RTrackerT, int) RestT {
return RestT{}
}

type GoodSdfKvWContext struct {
*GoodDoFnKv
}

func (fn *GoodSdfKvWContext) CreateInitialRestriction(context.Context, int, int) RestT {
return RestT{}
}

func (fn *GoodSdfKvWContext) SplitRestriction(context.Context, int, int, RestT) []RestT {
return []RestT{}
}

func (fn *GoodSdfKvWContext) RestrictionSize(context.Context, int, int, RestT) float64 {
return 0
}

func (fn *GoodSdfKvWContext) CreateTracker(context.Context, RestT) *RTrackerT {
return &RTrackerT{}
}

func (fn *GoodSdfKvWContext) ProcessElement(context.Context, *RTrackerT, int, int) (int, sdf.ProcessContinuation) {
return 0, sdf.StopProcessing()
}

func (fn *GoodSdfKvWContext) TruncateRestriction(context.Context, *RTrackerT, int, int) RestT {
return RestT{}
}

type GoodSdfWErr struct {
*GoodDoFn
}

func (fn *GoodSdfWErr) CreateInitialRestriction(int) (RestT, error) {
return RestT{}, nil
}

func (fn *GoodSdfWErr) SplitRestriction(int, RestT) ([]RestT, error) {
return []RestT{}, nil
}

func (fn *GoodSdfWErr) RestrictionSize(int, RestT) (float64, error) {
return 0, nil
}

func (fn *GoodSdfWErr) CreateTracker(RestT) (*RTrackerT, error) {
return &RTrackerT{}, nil
}

func (fn *GoodSdfWErr) ProcessElement(*RTrackerT, int) (int, sdf.ProcessContinuation, error) {
return 0, sdf.StopProcessing(), nil
}

func (fn *GoodSdfWErr) TruncateRestriction(*RTrackerT, int) (RestT, error) {
return RestT{}, nil
}

type GoodIgnoreOtherExportedMethods struct {
*GoodSdf
}
Expand Down
8 changes: 4 additions & 4 deletions sdks/go/pkg/beam/core/runtime/exec/datasource.go
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ func (n *DataSource) Process(ctx context.Context) ([]*Checkpoint, error) {
// Check if there's a continuation and return residuals
// Needs to be done immeadiately after processing to not lose the element.
if c := n.getProcessContinuation(); c != nil {
cp, err := n.checkpointThis(c)
cp, err := n.checkpointThis(ctx, c)
if err != nil {
// Errors during checkpointing should fail a bundle.
return nil, err
Expand Down Expand Up @@ -422,7 +422,7 @@ type Checkpoint struct {
// splittable or has not returned a resuming continuation, the function returns an empty
// SplitResult, a negative resumption time, and a false boolean to indicate that no split
// occurred.
func (n *DataSource) checkpointThis(pc sdf.ProcessContinuation) (*Checkpoint, error) {
func (n *DataSource) checkpointThis(ctx context.Context, pc sdf.ProcessContinuation) (*Checkpoint, error) {
n.mu.Lock()
defer n.mu.Unlock()

Expand All @@ -435,7 +435,7 @@ func (n *DataSource) checkpointThis(pc sdf.ProcessContinuation) (*Checkpoint, er
ow := su.GetOutputWatermark()

// Checkpointing is functionally a split at fraction 0.0
rs, err := su.Checkpoint()
rs, err := su.Checkpoint(ctx)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -530,7 +530,7 @@ func (n *DataSource) Split(ctx context.Context, splits []int64, frac float64, bu
// Get the output watermark before splitting to avoid accidentally overestimating
ow := su.GetOutputWatermark()
// Otherwise, perform a sub-element split.
ps, rs, err := su.Split(fr)
ps, rs, err := su.Split(ctx, fr)
if err != nil {
return SplitResult{}, err
}
Expand Down
Loading

0 comments on commit 29ea6e0

Please sign in to comment.