diff --git a/sdks/go/pkg/beam/runners/prism/internal/engine/data.go b/sdks/go/pkg/beam/runners/prism/internal/engine/data.go new file mode 100644 index 000000000000..6fc192ac83be --- /dev/null +++ b/sdks/go/pkg/beam/runners/prism/internal/engine/data.go @@ -0,0 +1,30 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package engine + +// TentativeData is where data for in progress bundles is put +// until the bundle executes successfully. +type TentativeData struct { + Raw map[string][][]byte +} + +// WriteData adds data to a given global collectionID. +func (d *TentativeData) WriteData(colID string, data []byte) { + if d.Raw == nil { + d.Raw = map[string][][]byte{} + } + d.Raw[colID] = append(d.Raw[colID], data) +} diff --git a/sdks/go/pkg/beam/runners/prism/internal/worker/bundle.go b/sdks/go/pkg/beam/runners/prism/internal/worker/bundle.go new file mode 100644 index 000000000000..f6fbf1293f47 --- /dev/null +++ b/sdks/go/pkg/beam/runners/prism/internal/worker/bundle.go @@ -0,0 +1,114 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package worker + +import ( + "sync" + + "github.com/apache/beam/sdks/v2/go/pkg/beam/core/typex" + fnpb "github.com/apache/beam/sdks/v2/go/pkg/beam/model/fnexecution_v1" + "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/prism/internal/engine" + "golang.org/x/exp/slog" +) + +// B represents an extant ProcessBundle instruction sent to an SDK worker. +// Generally manipulated by another package to interact with a worker. +type B struct { + InstID string // ID for the instruction processing this bundle. + PBDID string // ID for the ProcessBundleDescriptor + + // InputTransformID is data being sent to the SDK. + InputTransformID string + InputData [][]byte // Data specifically for this bundle. + + // TODO change to a single map[tid] -> map[input] -> map[window] -> struct { Iter data, MultiMap data } instead of all maps. + // IterableSideInputData is a map from transformID, to inputID, to window, to data. + IterableSideInputData map[string]map[string]map[typex.Window][][]byte + // MultiMapSideInputData is a map from transformID, to inputID, to window, to data key, to data values. + MultiMapSideInputData map[string]map[string]map[typex.Window]map[string][][]byte + + // OutputCount is the number of data outputs this bundle has. + // We need to see this many closed data channels before the bundle is complete. + OutputCount int + // dataWait is how we determine if a bundle is finished, by waiting for each of + // a Bundle's DataSinks to produce their last output. + // After this point we can "commit" the bundle's output for downstream use. + dataWait sync.WaitGroup + OutputData engine.TentativeData + Resp chan *fnpb.ProcessBundleResponse + + SinkToPCollection map[string]string + + // TODO: Metrics for this bundle, can be handled after the fact. +} + +// Init initializes the bundle's internal state for waiting on all +// data and for relaying a response back. +func (b *B) Init() { + // We need to see final data signals that match the number of + // outputs the stage this bundle executes posesses + b.dataWait.Add(b.OutputCount) + b.Resp = make(chan *fnpb.ProcessBundleResponse, 1) +} + +func (b *B) LogValue() slog.Value { + return slog.GroupValue( + slog.String("ID", b.InstID), + slog.String("stage", b.PBDID)) +} + +// ProcessOn executes the given bundle on the given W, blocking +// until all data is complete. +// +// Assumes the bundle is initialized (all maps are non-nil, and data waitgroup is set, response channel initialized) +// Assumes the bundle descriptor is already registered with the W. +// +// While this method mostly manipulates a W, putting it on a B avoids mixing the workers +// public GRPC APIs up with local calls. +func (b *B) ProcessOn(wk *W) { + wk.mu.Lock() + wk.bundles[b.InstID] = b + wk.mu.Unlock() + + slog.Debug("processing", "bundle", b, "worker", wk) + + // Tell the SDK to start processing the bundle. + wk.InstReqs <- &fnpb.InstructionRequest{ + InstructionId: b.InstID, + Request: &fnpb.InstructionRequest_ProcessBundle{ + ProcessBundle: &fnpb.ProcessBundleRequest{ + ProcessBundleDescriptorId: b.PBDID, + }, + }, + } + + // TODO: make batching decisions. + for i, d := range b.InputData { + wk.DataReqs <- &fnpb.Elements{ + Data: []*fnpb.Elements_Data{ + { + InstructionId: b.InstID, + TransformId: b.InputTransformID, + Data: d, + IsLast: i+1 == len(b.InputData), + }, + }, + } + } + + slog.Debug("waiting on data", "bundle", b) + b.dataWait.Wait() // Wait until data is ready. +} diff --git a/sdks/go/pkg/beam/runners/prism/internal/worker/bundle_test.go b/sdks/go/pkg/beam/runners/prism/internal/worker/bundle_test.go new file mode 100644 index 000000000000..154306c3f6ba --- /dev/null +++ b/sdks/go/pkg/beam/runners/prism/internal/worker/bundle_test.go @@ -0,0 +1,52 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package worker + +import ( + "bytes" + "sync" + "testing" +) + +func TestBundle_ProcessOn(t *testing.T) { + wk := New("test") + b := &B{ + InstID: "testInst", + PBDID: "testPBDID", + OutputCount: 1, + InputData: [][]byte{{1, 2, 3}}, + } + b.Init() + var completed sync.WaitGroup + completed.Add(1) + go func() { + b.ProcessOn(wk) + completed.Done() + }() + b.dataWait.Done() + gotData := <-wk.DataReqs + if got, want := gotData.GetData()[0].GetData(), []byte{1, 2, 3}; !bytes.EqualFold(got, want) { + t.Errorf("ProcessOn(): data not sent; got %v, want %v", got, want) + } + + gotInst := <-wk.InstReqs + if got, want := gotInst.GetInstructionId(), b.InstID; got != want { + t.Errorf("ProcessOn(): bad instruction ID; got %v, want %v", got, want) + } + if got, want := gotInst.GetProcessBundle().GetProcessBundleDescriptorId(), b.PBDID; got != want { + t.Errorf("ProcessOn(): bad process bundle descriptor ID; got %v, want %v", got, want) + } +} diff --git a/sdks/go/pkg/beam/runners/prism/internal/worker/worker.go b/sdks/go/pkg/beam/runners/prism/internal/worker/worker.go new file mode 100644 index 000000000000..8458ce39e116 --- /dev/null +++ b/sdks/go/pkg/beam/runners/prism/internal/worker/worker.go @@ -0,0 +1,421 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package worker handles interactions with SDK side workers, representing +// the worker services, communicating with those services, and SDK environments. +package worker + +import ( + "bytes" + "context" + "fmt" + "io" + "net" + "sync" + "sync/atomic" + + "github.com/apache/beam/sdks/v2/go/pkg/beam/core/graph/coder" + "github.com/apache/beam/sdks/v2/go/pkg/beam/core/graph/window" + "github.com/apache/beam/sdks/v2/go/pkg/beam/core/runtime/exec" + "github.com/apache/beam/sdks/v2/go/pkg/beam/core/typex" + fnpb "github.com/apache/beam/sdks/v2/go/pkg/beam/model/fnexecution_v1" + "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/prism/internal/engine" + "golang.org/x/exp/slog" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + "google.golang.org/protobuf/encoding/prototext" +) + +// A W manages worker environments, sending them work +// that they're able to execute, and manages the server +// side handlers for FnAPI RPCs. +type W struct { + fnpb.UnimplementedBeamFnControlServer + fnpb.UnimplementedBeamFnDataServer + fnpb.UnimplementedBeamFnStateServer + fnpb.UnimplementedBeamFnLoggingServer + + ID string + + // Server management + lis net.Listener + server *grpc.Server + + // These are the ID sources + inst, bund uint64 + + InstReqs chan *fnpb.InstructionRequest + DataReqs chan *fnpb.Elements + + mu sync.Mutex + bundles map[string]*B // Bundles keyed by InstructionID + Descriptors map[string]*fnpb.ProcessBundleDescriptor // Stages keyed by PBDID + + D *DataService +} + +// New starts the worker server components of FnAPI Execution. +func New(id string) *W { + lis, err := net.Listen("tcp", ":0") + if err != nil { + panic(fmt.Sprintf("failed to listen: %v", err)) + } + var opts []grpc.ServerOption + wk := &W{ + ID: id, + lis: lis, + server: grpc.NewServer(opts...), + + InstReqs: make(chan *fnpb.InstructionRequest, 10), + DataReqs: make(chan *fnpb.Elements, 10), + + bundles: make(map[string]*B), + Descriptors: make(map[string]*fnpb.ProcessBundleDescriptor), + + D: &DataService{}, + } + slog.Info("Serving Worker components", slog.String("endpoint", wk.Endpoint())) + fnpb.RegisterBeamFnControlServer(wk.server, wk) + fnpb.RegisterBeamFnDataServer(wk.server, wk) + fnpb.RegisterBeamFnLoggingServer(wk.server, wk) + fnpb.RegisterBeamFnStateServer(wk.server, wk) + return wk +} + +func (wk *W) Endpoint() string { + return wk.lis.Addr().String() +} + +// Serve serves on the started listener. Blocks. +func (wk *W) Serve() { + wk.server.Serve(wk.lis) +} + +func (wk *W) String() string { + return "worker[" + wk.ID + "]" +} + +func (wk *W) LogValue() slog.Value { + return slog.GroupValue( + slog.String("ID", wk.ID), + slog.String("endpoint", wk.Endpoint()), + ) +} + +// Stop the GRPC server. +func (wk *W) Stop() { + slog.Debug("stopping", "worker", wk) + close(wk.InstReqs) + close(wk.DataReqs) + wk.server.Stop() + wk.lis.Close() + slog.Debug("stopped", "worker", wk) +} + +func (wk *W) NextInst() string { + return fmt.Sprintf("inst%03d", atomic.AddUint64(&wk.inst, 1)) +} + +func (wk *W) NextStage() string { + return fmt.Sprintf("stage%03d", atomic.AddUint64(&wk.bund, 1)) +} + +// TODO set logging level. +var minsev = fnpb.LogEntry_Severity_DEBUG + +// Logging relates SDK worker messages back to the job that spawned them. +// Messages are received from the SDK, +func (wk *W) Logging(stream fnpb.BeamFnLogging_LoggingServer) error { + for { + in, err := stream.Recv() + if err == io.EOF { + return nil + } + if err != nil { + slog.Error("logging.Recv", err, "worker", wk) + return err + } + for _, l := range in.GetLogEntries() { + if l.Severity >= minsev { + // TODO: Connect to the associated Job for this worker instead of + // logging locally for SDK side logging. + slog.Log(toSlogSev(l.GetSeverity()), l.GetMessage(), + slog.String(slog.SourceKey, l.GetLogLocation()), + slog.Time(slog.TimeKey, l.GetTimestamp().AsTime()), + "worker", wk, + ) + } + } + } +} + +func toSlogSev(sev fnpb.LogEntry_Severity_Enum) slog.Level { + switch sev { + case fnpb.LogEntry_Severity_TRACE: + return slog.Level(-8) + case fnpb.LogEntry_Severity_DEBUG: + return slog.LevelDebug // -4 + case fnpb.LogEntry_Severity_INFO: + return slog.LevelInfo // 0 + case fnpb.LogEntry_Severity_NOTICE: + return slog.Level(2) + case fnpb.LogEntry_Severity_WARN: + return slog.LevelWarn // 4 + case fnpb.LogEntry_Severity_ERROR: + return slog.LevelError // 8 + case fnpb.LogEntry_Severity_CRITICAL: + return slog.Level(10) + } + return slog.LevelInfo +} + +func (wk *W) GetProcessBundleDescriptor(ctx context.Context, req *fnpb.GetProcessBundleDescriptorRequest) (*fnpb.ProcessBundleDescriptor, error) { + desc, ok := wk.Descriptors[req.GetProcessBundleDescriptorId()] + if !ok { + return nil, fmt.Errorf("descriptor %v not found", req.GetProcessBundleDescriptorId()) + } + return desc, nil +} + +// Control relays instructions to SDKs and back again, coordinated via unique instructionIDs. +// +// Requests come from the runner, and are sent to the client in the SDK. +func (wk *W) Control(ctrl fnpb.BeamFnControl_ControlServer) error { + done := make(chan bool) + go func() { + for { + resp, err := ctrl.Recv() + if err == io.EOF { + slog.Debug("ctrl.Recv finished; marking done", "worker", wk) + done <- true // means stream is finished + return + } + if err != nil { + switch status.Code(err) { + case codes.Canceled: // Might ignore this all the time instead. + slog.Error("ctrl.Recv Canceled", err, "worker", wk) + done <- true // means stream is finished + return + default: + slog.Error("ctrl.Recv failed", err, "worker", wk) + panic(err) + } + } + + // TODO: Do more than assume these are ProcessBundleResponses. + wk.mu.Lock() + if b, ok := wk.bundles[resp.GetInstructionId()]; ok { + // TODO. Better pipeline error handling. + if resp.Error != "" { + slog.Log(slog.LevelError, "ctrl.Recv pipeline error", slog.ErrorKey, resp.GetError()) + panic(resp.GetError()) + } + b.Resp <- resp.GetProcessBundle() + } else { + slog.Debug("ctrl.Recv: %v", resp) + } + wk.mu.Unlock() + } + }() + + for req := range wk.InstReqs { + ctrl.Send(req) + } + slog.Debug("ctrl.Send finished waiting on done") + <-done + slog.Debug("Control done") + return nil +} + +// Data relays elements and timer bytes to SDKs and back again, coordinated via +// ProcessBundle instructionIDs, and receiving input transforms. +// +// Data is multiplexed on a single stream for all active bundles on a worker. +func (wk *W) Data(data fnpb.BeamFnData_DataServer) error { + go func() { + for { + resp, err := data.Recv() + if err == io.EOF { + return + } + if err != nil { + switch status.Code(err) { + case codes.Canceled: + slog.Error("data.Recv Canceled", err, "worker", wk) + return + default: + slog.Error("data.Recv failed", err, "worker", wk) + panic(err) + } + } + wk.mu.Lock() + for _, d := range resp.GetData() { + b, ok := wk.bundles[d.GetInstructionId()] + if !ok { + slog.Info("data.Recv for unknown bundle", "response", resp) + continue + } + colID := b.SinkToPCollection[d.GetTransformId()] + + // There might not be data, eg. for side inputs, so we need to reconcile this elsewhere for + // downstream side inputs. + if len(d.GetData()) > 0 { + b.OutputData.WriteData(colID, d.GetData()) + } + if d.GetIsLast() { + b.dataWait.Done() + } + } + wk.mu.Unlock() + } + }() + + for req := range wk.DataReqs { + if err := data.Send(req); err != nil { + slog.Log(slog.LevelDebug, "data.Send error", slog.ErrorKey, err) + } + } + return nil +} + +// State relays elements and timer bytes to SDKs and back again, coordinated via +// ProcessBundle instructionIDs, and receiving input transforms. +// +// State requests come from SDKs, and the runner responds. +func (wk *W) State(state fnpb.BeamFnState_StateServer) error { + responses := make(chan *fnpb.StateResponse) + go func() { + // This go routine creates all responses to state requests from the worker + // so we want to close the State handler when it's all done. + defer close(responses) + for { + req, err := state.Recv() + if err == io.EOF { + return + } + if err != nil { + switch status.Code(err) { + case codes.Canceled: + slog.Error("state.Recv Canceled", err, "worker", wk) + return + default: + slog.Error("state.Recv failed", err, "worker", wk) + panic(err) + } + } + switch req.GetRequest().(type) { + case *fnpb.StateRequest_Get: + // TODO: move data handling to be pcollection based. + b := wk.bundles[req.GetInstructionId()] + key := req.GetStateKey() + slog.Debug("StateRequest_Get", prototext.Format(req), "bundle", b) + + var data [][]byte + switch key.GetType().(type) { + case *fnpb.StateKey_IterableSideInput_: + ikey := key.GetIterableSideInput() + wKey := ikey.GetWindow() + var w typex.Window + if len(wKey) == 0 { + w = window.GlobalWindow{} + } else { + w, err = exec.MakeWindowDecoder(coder.NewIntervalWindow()).DecodeSingle(bytes.NewBuffer(wKey)) + if err != nil { + panic(fmt.Sprintf("error decoding iterable side input window key %v: %v", wKey, err)) + } + } + winMap := b.IterableSideInputData[ikey.GetTransformId()][ikey.GetSideInputId()] + var wins []typex.Window + for w := range winMap { + wins = append(wins, w) + } + slog.Debug(fmt.Sprintf("side input[%v][%v] I Key: %v Windows: %v", req.GetId(), req.GetInstructionId(), w, wins)) + data = winMap[w] + + case *fnpb.StateKey_MultimapSideInput_: + mmkey := key.GetMultimapSideInput() + wKey := mmkey.GetWindow() + var w typex.Window + if len(wKey) == 0 { + w = window.GlobalWindow{} + } else { + w, err = exec.MakeWindowDecoder(coder.NewIntervalWindow()).DecodeSingle(bytes.NewBuffer(wKey)) + if err != nil { + panic(fmt.Sprintf("error decoding iterable side input window key %v: %v", wKey, err)) + } + } + dKey := mmkey.GetKey() + winMap := b.MultiMapSideInputData[mmkey.GetTransformId()][mmkey.GetSideInputId()] + var wins []typex.Window + for w := range winMap { + wins = append(wins, w) + } + slog.Debug(fmt.Sprintf("side input[%v][%v] MM Key: %v Windows: %v", req.GetId(), req.GetInstructionId(), w, wins)) + + data = winMap[w][string(dKey)] + + default: + panic(fmt.Sprintf("unsupported StateKey Access type: %T: %v", key.GetType(), prototext.Format(key))) + } + + // Encode the runner iterable (no length, just consecutive elements), and send it out. + // This is also where we can handle things like State Backed Iterables. + var buf bytes.Buffer + for _, value := range data { + buf.Write(value) + } + responses <- &fnpb.StateResponse{ + Id: req.GetId(), + Response: &fnpb.StateResponse_Get{ + Get: &fnpb.StateGetResponse{ + Data: buf.Bytes(), + }, + }, + } + default: + panic(fmt.Sprintf("unsupported StateRequest kind %T: %v", req.GetRequest(), prototext.Format(req))) + } + } + }() + for resp := range responses { + if err := state.Send(resp); err != nil { + slog.Error("state.Send error", err) + } + } + return nil +} + +// DataService is slated to be deleted in favour of stage based state +// management for side inputs. +type DataService struct { + // TODO actually quick process the data to windows here as well. + raw map[string][][]byte +} + +// Commit tentative data to the datastore. +func (d *DataService) Commit(tent engine.TentativeData) { + if d.raw == nil { + d.raw = map[string][][]byte{} + } + for colID, data := range tent.Raw { + d.raw[colID] = append(d.raw[colID], data...) + } +} + +// GetAllData is a hack for Side Inputs until watermarks are sorted out. +func (d *DataService) GetAllData(colID string) [][]byte { + return d.raw[colID] +} diff --git a/sdks/go/pkg/beam/runners/prism/internal/worker/worker_test.go b/sdks/go/pkg/beam/runners/prism/internal/worker/worker_test.go new file mode 100644 index 000000000000..29b3fab92d64 --- /dev/null +++ b/sdks/go/pkg/beam/runners/prism/internal/worker/worker_test.go @@ -0,0 +1,281 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package worker + +import ( + "bytes" + "context" + "net" + "sync" + "testing" + "time" + + "github.com/apache/beam/sdks/v2/go/pkg/beam/core/graph/window" + "github.com/apache/beam/sdks/v2/go/pkg/beam/core/typex" + fnpb "github.com/apache/beam/sdks/v2/go/pkg/beam/model/fnexecution_v1" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" + "google.golang.org/grpc/test/bufconn" +) + +func TestWorker_New(t *testing.T) { + w := New("test") + if got, want := w.ID, "test"; got != want { + t.Errorf("New(%q) = %v, want %v", want, got, want) + } +} + +func TestWorker_NextInst(t *testing.T) { + w := New("test") + + instIDs := map[string]struct{}{} + for i := 0; i < 100; i++ { + instIDs[w.NextInst()] = struct{}{} + } + if got, want := len(instIDs), 100; got != want { + t.Errorf("calling w.NextInst() got %v unique ids, want %v", got, want) + } +} + +func TestWorker_NextStage(t *testing.T) { + w := New("test") + + stageIDs := map[string]struct{}{} + for i := 0; i < 100; i++ { + stageIDs[w.NextStage()] = struct{}{} + } + if got, want := len(stageIDs), 100; got != want { + t.Errorf("calling w.NextStage() got %v unique ids, want %v", got, want) + } +} + +func TestWorker_GetProcessBundleDescriptor(t *testing.T) { + w := New("test") + + id := "available" + w.Descriptors[id] = &fnpb.ProcessBundleDescriptor{ + Id: id, + } + + pbd, err := w.GetProcessBundleDescriptor(context.Background(), &fnpb.GetProcessBundleDescriptorRequest{ + ProcessBundleDescriptorId: id, + }) + if err != nil { + t.Errorf("got GetProcessBundleDescriptor(%q) error: %v, want nil", id, err) + } + if got, want := pbd.GetId(), id; got != want { + t.Errorf("got GetProcessBundleDescriptor(%q) = %v, want id %v", id, got, want) + } + + pbd, err = w.GetProcessBundleDescriptor(context.Background(), &fnpb.GetProcessBundleDescriptorRequest{ + ProcessBundleDescriptorId: "unknown", + }) + if err == nil { + t.Errorf("got GetProcessBundleDescriptor(%q) = %v, want error", "unknown", pbd) + } +} + +func serveTestWorker(t *testing.T) (context.Context, *W, *grpc.ClientConn) { + t.Helper() + ctx, cancelFn := context.WithCancel(context.Background()) + t.Cleanup(cancelFn) + + w := New("test") + lis := bufconn.Listen(2048) + w.lis = lis + t.Cleanup(func() { w.Stop() }) + go w.Serve() + + clientConn, err := grpc.DialContext(ctx, "", grpc.WithContextDialer(func(ctx context.Context, _ string) (net.Conn, error) { + return lis.DialContext(ctx) + }), grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithBlock()) + if err != nil { + t.Fatal("couldn't create bufconn grpc connection:", err) + } + return ctx, w, clientConn +} + +func TestWorker_Logging(t *testing.T) { + ctx, _, clientConn := serveTestWorker(t) + + logCli := fnpb.NewBeamFnLoggingClient(clientConn) + logStream, err := logCli.Logging(ctx) + if err != nil { + t.Fatal("couldn't create log client:", err) + } + + logStream.Send(&fnpb.LogEntry_List{ + LogEntries: []*fnpb.LogEntry{{ + Severity: fnpb.LogEntry_Severity_INFO, + Message: "squeamish ossiphrage", + }}, + }) + + // TODO: Connect to the job management service. + // At this point job messages are just logged to wherever the prism runner executes + // But this should pivot to anyone connecting to the Job Management service for the + // job. + // In the meantime, sleep to validate execution via coverage. + time.Sleep(20 * time.Millisecond) +} + +func TestWorker_Control_HappyPath(t *testing.T) { + ctx, wk, clientConn := serveTestWorker(t) + + ctrlCli := fnpb.NewBeamFnControlClient(clientConn) + ctrlStream, err := ctrlCli.Control(ctx) + if err != nil { + t.Fatal("couldn't create control client:", err) + } + + instID := wk.NextInst() + + b := &B{} + b.Init() + wk.bundles[instID] = b + b.ProcessOn(wk) + + ctrlStream.Send(&fnpb.InstructionResponse{ + InstructionId: instID, + Response: &fnpb.InstructionResponse_ProcessBundle{ + ProcessBundle: &fnpb.ProcessBundleResponse{ + RequiresFinalization: true, // Simple thing to check. + }, + }, + }) + + if err := ctrlStream.CloseSend(); err != nil { + t.Errorf("ctrlStream.CloseSend() = %v", err) + } + resp := <-b.Resp + + if !resp.RequiresFinalization { + t.Errorf("got %v, want response that Requires Finalization", resp) + } +} + +func TestWorker_Data_HappyPath(t *testing.T) { + ctx, wk, clientConn := serveTestWorker(t) + + dataCli := fnpb.NewBeamFnDataClient(clientConn) + dataStream, err := dataCli.Data(ctx) + if err != nil { + t.Fatal("couldn't create data client:", err) + } + + instID := wk.NextInst() + + b := &B{ + InstID: instID, + PBDID: wk.NextStage(), + InputData: [][]byte{ + {1, 1, 1, 1, 1, 1}, + }, + OutputCount: 1, + } + b.Init() + wk.bundles[instID] = b + + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + b.ProcessOn(wk) + }() + + wk.InstReqs <- &fnpb.InstructionRequest{ + InstructionId: instID, + } + + elements, err := dataStream.Recv() + if err != nil { + t.Fatal("couldn't receive data elements:", err) + } + + if got, want := elements.GetData()[0].GetInstructionId(), b.InstID; got != want { + t.Fatalf("couldn't receive data elements ID: got %v, want %v", got, want) + } + if got, want := elements.GetData()[0].GetData(), []byte{1, 1, 1, 1, 1, 1}; !bytes.Equal(got, want) { + t.Fatalf("client Data received %v, want %v", got, want) + } + if got, want := elements.GetData()[0].GetIsLast(), true; got != want { + t.Fatalf("client Data received wasn't last: got %v, want %v", got, want) + } + + dataStream.Send(elements) + + if err := dataStream.CloseSend(); err != nil { + t.Errorf("ctrlStream.CloseSend() = %v", err) + } + + wg.Wait() + t.Log("ProcessOn successfully exited") +} + +func TestWorker_State_Iterable(t *testing.T) { + ctx, wk, clientConn := serveTestWorker(t) + + stateCli := fnpb.NewBeamFnStateClient(clientConn) + stateStream, err := stateCli.State(ctx) + if err != nil { + t.Fatal("couldn't create state client:", err) + } + + instID := wk.NextInst() + wk.bundles[instID] = &B{ + IterableSideInputData: map[string]map[string]map[typex.Window][][]byte{ + "transformID": { + "i1": { + window.GlobalWindow{}: [][]byte{ + {42}, + }, + }, + }, + }, + } + + stateStream.Send(&fnpb.StateRequest{ + Id: "first", + InstructionId: instID, + Request: &fnpb.StateRequest_Get{ + Get: &fnpb.StateGetRequest{}, + }, + StateKey: &fnpb.StateKey{Type: &fnpb.StateKey_IterableSideInput_{ + IterableSideInput: &fnpb.StateKey_IterableSideInput{ + TransformId: "transformID", + SideInputId: "i1", + Window: []byte{}, // Global Windows + }, + }}, + }) + + resp, err := stateStream.Recv() + if err != nil { + t.Fatal("couldn't receive state response:", err) + } + + if got, want := resp.GetId(), "first"; got != want { + t.Fatalf("didn't receive expected state response: got %v, want %v", got, want) + } + + if got, want := resp.GetGet().GetData(), []byte{42}; !bytes.Equal(got, want) { + t.Fatalf("didn't receive expected state response data: got %v, want %v", got, want) + } + + if err := stateStream.CloseSend(); err != nil { + t.Errorf("stateStream.CloseSend() = %v", err) + } +}