diff --git a/sdks/go/cmd/prism/prism.go b/sdks/go/cmd/prism/prism.go index 5e3f42a9e5a5..4b3c15f285a9 100644 --- a/sdks/go/cmd/prism/prism.go +++ b/sdks/go/cmd/prism/prism.go @@ -23,6 +23,7 @@ import ( "fmt" "log" "log/slog" + "net" "os" "strings" "time" @@ -37,6 +38,7 @@ import ( var ( jobPort = flag.Int("job_port", 8073, "specify the job management service port") webPort = flag.Int("web_port", 8074, "specify the web ui port") + workerPoolPort = flag.Int("worker_pool_port", 8075, "specify the worker pool port") jobManagerEndpoint = flag.String("jm_override", "", "set to only stand up a web ui that refers to a seperate JobManagement endpoint") serveHTTP = flag.Bool("serve_http", true, "enable or disable the web ui") idleShutdownTimeout = flag.Duration("idle_shutdown_timeout", -1, "duration that prism will wait for a new job before shutting itself down. Negative durations disable auto shutdown. Defaults to never shutting down.") @@ -100,6 +102,7 @@ func main() { Port: *jobPort, IdleShutdownTimeout: *idleShutdownTimeout, CancelFn: cancel, + WorkerPoolEndpoint: fmt.Sprintf("localhost:%d", *workerPoolPort), }, *jobManagerEndpoint) if err != nil { @@ -110,6 +113,18 @@ func main() { log.Fatalf("error creating web server: %v", err) } } + g := prism.CreateWorkerPoolServer(ctx) + addr := fmt.Sprintf(":%d", *workerPoolPort) + lis, err := net.Listen("tcp", addr) + if err != nil { + log.Fatalf("error creating worker pool server: %v", err) + } + slog.Info("Serving Worker Pool", "endpoint", fmt.Sprintf("localhost:%d", *workerPoolPort)) + go g.Serve(lis) + go func() { + <-ctx.Done() + g.GracefulStop() + }() // Block main thread forever to keep main from exiting. <-ctx.Done() } diff --git a/sdks/go/pkg/beam/runners/prism/internal/engine/engine_test.go b/sdks/go/pkg/beam/runners/prism/internal/engine/engine_test.go index b6e4412c3a83..ac8b9b6f3c95 100644 --- a/sdks/go/pkg/beam/runners/prism/internal/engine/engine_test.go +++ b/sdks/go/pkg/beam/runners/prism/internal/engine/engine_test.go @@ -20,6 +20,7 @@ import ( "context" "fmt" "math/rand" + "net" "os" "strings" "testing" @@ -31,6 +32,7 @@ import ( "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/prism/internal" "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/prism/internal/engine" "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/prism/internal/jobservices" + "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/prism/internal/worker" "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/universal" "github.com/apache/beam/sdks/v2/go/pkg/beam/testing/ptest" "github.com/apache/beam/sdks/v2/go/test/integration/primitives" @@ -80,6 +82,23 @@ func executeWithT(ctx context.Context, t testing.TB, p *beam.Pipeline) (beam.Pip s1 := rand.NewSource(time.Now().UnixNano()) r1 := rand.New(s1) *jobopts.JobName = fmt.Sprintf("%v-%v", strings.ToLower(t.Name()), r1.Intn(1000)) + lis, err := net.Listen("tcp", ":0") + if err != nil { + t.Fatal(err) + } + _, port, _ := net.SplitHostPort(lis.Addr().String()) + addr := "localhost:" + port + g := worker.NewMultiplexW() + t.Cleanup(g.Stop) + go g.Serve(lis) + s := jobservices.NewServer(0, internal.RunPipeline) + s.WorkerPoolEndpoint = addr + *jobopts.Endpoint = s.Endpoint() + go s.Serve() + t.Cleanup(func() { + *jobopts.Endpoint = "" + s.Stop() + }) return execute(ctx, p) } diff --git a/sdks/go/pkg/beam/runners/prism/internal/execute.go b/sdks/go/pkg/beam/runners/prism/internal/execute.go index 8b56c30eb61b..0d6797ebcd63 100644 --- a/sdks/go/pkg/beam/runners/prism/internal/execute.go +++ b/sdks/go/pkg/beam/runners/prism/internal/execute.go @@ -36,6 +36,9 @@ import ( "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/prism/internal/worker" "golang.org/x/exp/maps" "golang.org/x/sync/errgroup" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" + healthpb "google.golang.org/grpc/health/grpc_health_v1" "google.golang.org/protobuf/proto" ) @@ -88,28 +91,41 @@ func RunPipeline(j *jobservices.Job) { // makeWorker creates a worker for that environment. func makeWorker(env string, j *jobservices.Job) (*worker.W, error) { - wk := worker.New(j.String()+"_"+env, env) + wk := worker.Pool.NewWorker(j.String()+"_"+env, env) wk.EnvPb = j.Pipeline.GetComponents().GetEnvironments()[env] wk.PipelineOptions = j.PipelineOptions() wk.JobKey = j.JobKey() wk.ArtifactEndpoint = j.ArtifactEndpoint() - - go wk.Serve() + wk.WorkerPoolEndpoint = j.WorkerPoolEndpoint if err := runEnvironment(j.RootCtx, j, env, wk); err != nil { return nil, fmt.Errorf("failed to start environment %v for job %v: %w", env, j, err) } // Check for connection succeeding after we've created the environment successfully. timeout := 1 * time.Minute - time.AfterFunc(timeout, func() { - if wk.Connected() || wk.Stopped() { - return + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + go func() { + <-ctx.Done() + if errors.Is(ctx.Err(), context.DeadlineExceeded) { + err := fmt.Errorf("prism %v didn't get control connection to %v after %v", wk, j.WorkerPoolEndpoint, timeout) + j.Failed(err) + j.CancelFn(err) } - err := fmt.Errorf("prism %v didn't get control connection to %v after %v", wk, wk.Endpoint(), timeout) + }() + conn, err := grpc.NewClient(j.WorkerPoolEndpoint, grpc.WithTransportCredentials(insecure.NewCredentials())) + if err != nil { j.Failed(err) j.CancelFn(err) - }) + } + health := healthpb.NewHealthClient(conn) + _, err = health.Check(ctx, &healthpb.HealthCheckRequest{}) + if err != nil { + j.Failed(err) + j.CancelFn(err) + } + return wk, nil } diff --git a/sdks/go/pkg/beam/runners/prism/internal/execute_test.go b/sdks/go/pkg/beam/runners/prism/internal/execute_test.go index 29fccaeb238e..b8aa7ee81e79 100644 --- a/sdks/go/pkg/beam/runners/prism/internal/execute_test.go +++ b/sdks/go/pkg/beam/runners/prism/internal/execute_test.go @@ -19,6 +19,7 @@ import ( "context" "fmt" "math/rand" + "net" "os" "strings" "testing" @@ -30,6 +31,7 @@ import ( "github.com/apache/beam/sdks/v2/go/pkg/beam/register" "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/prism/internal" "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/prism/internal/jobservices" + "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/prism/internal/worker" "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/universal" "github.com/apache/beam/sdks/v2/go/pkg/beam/testing/passert" "github.com/apache/beam/sdks/v2/go/pkg/beam/testing/ptest" @@ -43,8 +45,18 @@ func TestMain(m *testing.M) { func initRunner(t testing.TB) { t.Helper() + lis, err := net.Listen("tcp", ":0") + if err != nil { + t.Fatal(err) + } + _, port, _ := net.SplitHostPort(lis.Addr().String()) + addr := "localhost:" + port + g := worker.NewMultiplexW() + t.Cleanup(g.Stop) + go g.Serve(lis) if *jobopts.Endpoint == "" { s := jobservices.NewServer(0, internal.RunPipeline) + s.WorkerPoolEndpoint = addr *jobopts.Endpoint = s.Endpoint() go s.Serve() t.Cleanup(func() { diff --git a/sdks/go/pkg/beam/runners/prism/internal/jobservices/job.go b/sdks/go/pkg/beam/runners/prism/internal/jobservices/job.go index 6158cd6d612c..7a19812b23af 100644 --- a/sdks/go/pkg/beam/runners/prism/internal/jobservices/job.go +++ b/sdks/go/pkg/beam/runners/prism/internal/jobservices/job.go @@ -92,7 +92,8 @@ type Job struct { // Logger for this job. Logger *slog.Logger - metrics metricsStore + metrics metricsStore + WorkerPoolEndpoint string } func (j *Job) ArtifactEndpoint() string { diff --git a/sdks/go/pkg/beam/runners/prism/internal/jobservices/management.go b/sdks/go/pkg/beam/runners/prism/internal/jobservices/management.go index af559a92ab46..7b99159263a8 100644 --- a/sdks/go/pkg/beam/runners/prism/internal/jobservices/management.go +++ b/sdks/go/pkg/beam/runners/prism/internal/jobservices/management.go @@ -92,8 +92,9 @@ func (s *Server) Prepare(ctx context.Context, req *jobpb.PrepareJobRequest) (_ * cancelFn(err) terminalOnceWrap() }, - Logger: s.logger, // TODO substitute with a configured logger. - artifactEndpoint: s.Endpoint(), + Logger: s.logger, // TODO substitute with a configured logger. + artifactEndpoint: s.Endpoint(), + WorkerPoolEndpoint: s.WorkerPoolEndpoint, } // Stop the idle timer when a new job appears. if idleTimer := s.idleTimer.Load(); idleTimer != nil { diff --git a/sdks/go/pkg/beam/runners/prism/internal/jobservices/server.go b/sdks/go/pkg/beam/runners/prism/internal/jobservices/server.go index bdfe2aff2dd4..e104f37c048d 100644 --- a/sdks/go/pkg/beam/runners/prism/internal/jobservices/server.go +++ b/sdks/go/pkg/beam/runners/prism/internal/jobservices/server.go @@ -59,7 +59,8 @@ type Server struct { execute func(*Job) // Artifact hack - artifacts map[string][]byte + artifacts map[string][]byte + WorkerPoolEndpoint string } // NewServer acquires the indicated port. diff --git a/sdks/go/pkg/beam/runners/prism/internal/worker/bundle_test.go b/sdks/go/pkg/beam/runners/prism/internal/worker/bundle_test.go index 161fb199ce96..971df7fa2b78 100644 --- a/sdks/go/pkg/beam/runners/prism/internal/worker/bundle_test.go +++ b/sdks/go/pkg/beam/runners/prism/internal/worker/bundle_test.go @@ -25,7 +25,7 @@ import ( ) func TestBundle_ProcessOn(t *testing.T) { - wk := New("test", "testEnv") + wk := Pool.NewWorker("test", "testEnv") b := &B{ InstID: "testInst", PBDID: "testPBDID", diff --git a/sdks/go/pkg/beam/runners/prism/internal/worker/pool.go b/sdks/go/pkg/beam/runners/prism/internal/worker/pool.go new file mode 100644 index 000000000000..1491ebc147d0 --- /dev/null +++ b/sdks/go/pkg/beam/runners/prism/internal/worker/pool.go @@ -0,0 +1,166 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package worker + +import ( + "context" + "fmt" + "log/slog" + "math" + "sync" + + fnpb "github.com/apache/beam/sdks/v2/go/pkg/beam/model/fnexecution_v1" + "github.com/apache/beam/sdks/v2/go/pkg/beam/util/grpcx" + "google.golang.org/grpc" + healthpb "google.golang.org/grpc/health/grpc_health_v1" +) + +var ( + mu sync.Mutex + + // Pool stores *W. + Pool = make(MapW) +) + +// MapW manages the creation and querying of *W. +type MapW map[string]*W + +func (m MapW) workerFromMetadataCtx(ctx context.Context) (*W, error) { + id, err := grpcx.ReadWorkerID(ctx) + if err != nil { + return nil, err + } + if id == "" { + return nil, fmt.Errorf("worker id in ctx metadata is an empty string") + } + mu.Lock() + defer mu.Unlock() + if w, ok := m[id]; ok { + return w, nil + } + return nil, fmt.Errorf("worker id: '%s' read from ctx but not registered in worker pool", id) +} + +// NewWorker instantiates and registers new *W instances. +func (m MapW) NewWorker(id, env string) *W { + wk := &W{ + ID: id, + Env: env, + + InstReqs: make(chan *fnpb.InstructionRequest, 10), + DataReqs: make(chan *fnpb.Elements, 10), + StoppedChan: make(chan struct{}), + + activeInstructions: make(map[string]controlResponder), + Descriptors: make(map[string]*fnpb.ProcessBundleDescriptor), + } + mu.Lock() + defer mu.Unlock() + m[wk.ID] = wk + return wk +} + +// NewMultiplexW instantiates a grpc.Server for multiplexing worker FnAPI requests. +func NewMultiplexW(opts ...grpc.ServerOption) *grpc.Server { + opts = append(opts, grpc.MaxSendMsgSize(math.MaxInt32)) + + g := grpc.NewServer(opts...) + wk := &MultiplexW{ + logger: slog.Default(), + } + + fnpb.RegisterBeamFnControlServer(g, wk) + fnpb.RegisterBeamFnDataServer(g, wk) + fnpb.RegisterBeamFnLoggingServer(g, wk) + fnpb.RegisterBeamFnStateServer(g, wk) + fnpb.RegisterProvisionServiceServer(g, wk) + healthpb.RegisterHealthServer(g, wk) + + return g +} + +// MultiplexW multiplexes FnAPI gRPC requests to *W stored in the Pool. +type MultiplexW struct { + fnpb.UnimplementedBeamFnControlServer + fnpb.UnimplementedBeamFnDataServer + fnpb.UnimplementedBeamFnStateServer + fnpb.UnimplementedBeamFnLoggingServer + fnpb.UnimplementedProvisionServiceServer + healthpb.UnimplementedHealthServer + + logger *slog.Logger +} + +func (mw *MultiplexW) Check(_ context.Context, _ *healthpb.HealthCheckRequest) (*healthpb.HealthCheckResponse, error) { + return &healthpb.HealthCheckResponse{Status: healthpb.HealthCheckResponse_SERVING}, nil +} + +func (mw *MultiplexW) GetProvisionInfo(ctx context.Context, req *fnpb.GetProvisionInfoRequest) (*fnpb.GetProvisionInfoResponse, error) { + w, err := Pool.workerFromMetadataCtx(ctx) + if err != nil { + return nil, err + } + return w.GetProvisionInfo(ctx, req) +} + +func (mw *MultiplexW) Logging(stream fnpb.BeamFnLogging_LoggingServer) error { + w, err := Pool.workerFromMetadataCtx(stream.Context()) + if err != nil { + return err + } + return w.Logging(stream) +} + +func (mw *MultiplexW) GetProcessBundleDescriptor(ctx context.Context, req *fnpb.GetProcessBundleDescriptorRequest) (*fnpb.ProcessBundleDescriptor, error) { + w, err := Pool.workerFromMetadataCtx(ctx) + if err != nil { + return nil, err + } + return w.GetProcessBundleDescriptor(ctx, req) +} + +func (mw *MultiplexW) Control(ctrl fnpb.BeamFnControl_ControlServer) error { + w, err := Pool.workerFromMetadataCtx(ctrl.Context()) + if err != nil { + return err + } + return w.Control(ctrl) +} + +func (mw *MultiplexW) Data(data fnpb.BeamFnData_DataServer) error { + w, err := Pool.workerFromMetadataCtx(data.Context()) + if err != nil { + return err + } + return w.Data(data) +} + +func (mw *MultiplexW) State(state fnpb.BeamFnState_StateServer) error { + w, err := Pool.workerFromMetadataCtx(state.Context()) + if err != nil { + return err + } + return w.State(state) +} + +func (mw *MultiplexW) MonitoringMetadata(ctx context.Context, unknownIDs []string) *fnpb.MonitoringInfosMetadataResponse { + w, err := Pool.workerFromMetadataCtx(ctx) + if err != nil { + mw.logger.Error(err.Error()) + return nil + } + return w.MonitoringMetadata(ctx, unknownIDs) +} diff --git a/sdks/go/pkg/beam/runners/prism/internal/worker/pool_test.go b/sdks/go/pkg/beam/runners/prism/internal/worker/pool_test.go new file mode 100644 index 000000000000..0b8531058169 --- /dev/null +++ b/sdks/go/pkg/beam/runners/prism/internal/worker/pool_test.go @@ -0,0 +1,511 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package worker + +import ( + "bytes" + "context" + "net" + "sort" + "sync" + "testing" + "time" + + "github.com/apache/beam/sdks/v2/go/pkg/beam/core/graph/coder" + "github.com/apache/beam/sdks/v2/go/pkg/beam/core/graph/window" + "github.com/apache/beam/sdks/v2/go/pkg/beam/core/runtime/exec" + "github.com/apache/beam/sdks/v2/go/pkg/beam/core/typex" + fnpb "github.com/apache/beam/sdks/v2/go/pkg/beam/model/fnexecution_v1" + "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/prism/internal/engine" + "github.com/apache/beam/sdks/v2/go/pkg/beam/util/grpcx" + "github.com/google/go-cmp/cmp" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" + "google.golang.org/grpc/metadata" +) + +func serveTestWorker(t *testing.T) (context.Context, *grpc.ClientConn) { + t.Helper() + ctx, cancelFn := context.WithCancel(context.Background()) + t.Cleanup(cancelFn) + w := Pool.NewWorker("test", "testEnv") + ctx = metadata.NewIncomingContext(ctx, metadata.Pairs("worker_id", w.ID)) + ctx = grpcx.WriteWorkerID(ctx, w.ID) + lis, err := net.Listen("tcp", ":0") + if err != nil { + t.Fatal(err) + } + _, port, _ := net.SplitHostPort(lis.Addr().String()) + addr := "localhost:" + port + g := NewMultiplexW() + t.Cleanup(func() { + w.Stop() + g.Stop() + }) + go g.Serve(lis) + + clientConn, err := grpc.NewClient(addr, grpc.WithTransportCredentials(insecure.NewCredentials())) + if err != nil { + t.Fatal(err) + } + return ctx, clientConn +} + +func serveTestWorkerStateStream(t *testing.T) (context.Context, fnpb.BeamFnState_StateClient, closeSend) { + ctx, clientConn := serveTestWorker(t) + + stateCli := fnpb.NewBeamFnStateClient(clientConn) + stateStream, err := stateCli.State(ctx) + if err != nil { + t.Fatal("couldn't create state client:", err) + } + return ctx, stateStream, func() { + if err := stateStream.CloseSend(); err != nil { + t.Errorf("stateStream.CloseSend() = %v", err) + } + } +} + +func TestWorker_Logging(t *testing.T) { + ctx, clientConn := serveTestWorker(t) + + logCli := fnpb.NewBeamFnLoggingClient(clientConn) + logStream, err := logCli.Logging(ctx) + if err != nil { + t.Fatal("couldn't create log client:", err) + } + + logStream.Send(&fnpb.LogEntry_List{ + LogEntries: []*fnpb.LogEntry{{ + Severity: fnpb.LogEntry_Severity_INFO, + Message: "squeamish ossiphrage", + LogLocation: "intentionally.go:124", + }}, + }) + + logStream.Send(&fnpb.LogEntry_List{ + LogEntries: []*fnpb.LogEntry{{ + Severity: fnpb.LogEntry_Severity_INFO, + Message: "squeamish ossiphrage the second", + LogLocation: "intentionally bad log location", + }}, + }) + + // TODO: Connect to the job management service. + // At this point job messages are just logged to wherever the prism runner executes + // But this should pivot to anyone connecting to the Job Management service for the + // job. + // In the meantime, sleep to validate execution via coverage. + time.Sleep(20 * time.Millisecond) +} + +func TestWorker_Control_HappyPath(t *testing.T) { + ctx, clientConn := serveTestWorker(t) + wk, err := Pool.workerFromMetadataCtx(ctx) + if err != nil { + t.Fatal(err) + } + + ctrlCli := fnpb.NewBeamFnControlClient(clientConn) + ctrlStream, err := ctrlCli.Control(ctx) + if err != nil { + t.Fatal("couldn't create control client:", err) + } + + instID := wk.NextInst() + + b := &B{} + b.Init() + wk.activeInstructions[instID] = b + b.ProcessOn(ctx, wk) + + ctrlStream.Send(&fnpb.InstructionResponse{ + InstructionId: instID, + Response: &fnpb.InstructionResponse_ProcessBundle{ + ProcessBundle: &fnpb.ProcessBundleResponse{ + RequiresFinalization: true, // Simple thing to check. + }, + }, + }) + + if err := ctrlStream.CloseSend(); err != nil { + t.Errorf("ctrlStream.CloseSend() = %v", err) + } + resp := <-b.Resp + + if resp == nil { + t.Fatal("resp is nil from bundle") + } + + if !resp.RequiresFinalization { + t.Errorf("got %v, want response that Requires Finalization", resp) + } +} + +func TestWorker_Data_HappyPath(t *testing.T) { + ctx, clientConn := serveTestWorker(t) + wk, err := Pool.workerFromMetadataCtx(ctx) + if err != nil { + t.Fatal(err) + } + + dataCli := fnpb.NewBeamFnDataClient(clientConn) + dataStream, err := dataCli.Data(ctx) + if err != nil { + t.Fatal("couldn't create data client:", err) + } + + instID := wk.NextInst() + + b := &B{ + InstID: instID, + PBDID: "teststageID", + Input: []*engine.Block{ + { + Kind: engine.BlockData, + Bytes: [][]byte{{1, 1, 1, 1, 1, 1}}, + }}, + OutputCount: 1, + } + b.Init() + wk.activeInstructions[instID] = b + + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + b.ProcessOn(ctx, wk) + }() + + wk.InstReqs <- &fnpb.InstructionRequest{ + InstructionId: instID, + } + + elements, err := dataStream.Recv() + if err != nil { + t.Fatal("couldn't receive data elements:", err) + } + + if got, want := elements.GetData()[0].GetInstructionId(), b.InstID; got != want { + t.Fatalf("couldn't receive data elements ID: got %v, want %v", got, want) + } + if got, want := elements.GetData()[0].GetData(), []byte{1, 1, 1, 1, 1, 1}; !bytes.Equal(got, want) { + t.Fatalf("client Data received %v, want %v", got, want) + } + if got, want := elements.GetData()[0].GetIsLast(), false; got != want { + t.Fatalf("client Data received was last: got %v, want %v", got, want) + } + + elements, err = dataStream.Recv() + if err != nil { + t.Fatal("expected 2nd data elements:", err) + } + if got, want := elements.GetData()[0].GetInstructionId(), b.InstID; got != want { + t.Fatalf("couldn't receive data elements ID: got %v, want %v", got, want) + } + if got, want := elements.GetData()[0].GetData(), []byte(nil); !bytes.Equal(got, want) { + t.Fatalf("client Data received %v, want %v", got, want) + } + if got, want := elements.GetData()[0].GetIsLast(), true; got != want { + t.Fatalf("client Data received wasn't last: got %v, want %v", got, want) + } + + dataStream.Send(elements) + + if err := dataStream.CloseSend(); err != nil { + t.Errorf("ctrlStream.CloseSend() = %v", err) + } + + wg.Wait() + t.Log("ProcessOn successfully exited") +} + +func TestWorker_State_Iterable(t *testing.T) { + ctx, clientConn := serveTestWorker(t) + wk, err := Pool.workerFromMetadataCtx(ctx) + if err != nil { + t.Fatal(err) + } + + stateCli := fnpb.NewBeamFnStateClient(clientConn) + stateStream, err := stateCli.State(ctx) + if err != nil { + t.Fatal("couldn't create state client:", err) + } + + instID := wk.NextInst() + wk.activeInstructions[instID] = &B{ + IterableSideInputData: map[SideInputKey]map[typex.Window][][]byte{ + {TransformID: "transformID", Local: "i1"}: { + window.GlobalWindow{}: [][]byte{ + {42}, + }, + }, + }, + } + + stateStream.Send(&fnpb.StateRequest{ + Id: "first", + InstructionId: instID, + Request: &fnpb.StateRequest_Get{ + Get: &fnpb.StateGetRequest{}, + }, + StateKey: &fnpb.StateKey{Type: &fnpb.StateKey_IterableSideInput_{ + IterableSideInput: &fnpb.StateKey_IterableSideInput{ + TransformId: "transformID", + SideInputId: "i1", + Window: []byte{}, // Global Windows + }, + }}, + }) + + resp, err := stateStream.Recv() + if err != nil { + t.Fatal("couldn't receive state response:", err) + } + + if got, want := resp.GetId(), "first"; got != want { + t.Fatalf("didn't receive expected state response: got %v, want %v", got, want) + } + + if got, want := resp.GetGet().GetData(), []byte{42}; !bytes.Equal(got, want) { + t.Fatalf("didn't receive expected state response data: got %v, want %v", got, want) + } + + if err := stateStream.CloseSend(); err != nil { + t.Errorf("stateStream.CloseSend() = %v", err) + } +} + +func TestWorker_State_MultimapKeysSideInput(t *testing.T) { + for _, tt := range []struct { + name string + w typex.Window + }{ + { + name: "global window", + w: window.GlobalWindow{}, + }, + { + name: "interval window", + w: window.IntervalWindow{ + Start: 1000, + End: 2000, + }, + }, + } { + t.Run(tt.name, func(t *testing.T) { + var encW []byte + if !tt.w.Equals(window.GlobalWindow{}) { + buf := bytes.Buffer{} + if err := exec.MakeWindowEncoder(coder.NewIntervalWindow()).EncodeSingle(tt.w, &buf); err != nil { + t.Fatalf("error encoding window: %v, err: %v", tt.w, err) + } + encW = buf.Bytes() + } + ctx, stateStream, done := serveTestWorkerStateStream(t) + defer done() + + wk, err := Pool.workerFromMetadataCtx(ctx) + if err != nil { + t.Fatal(err) + } + + instID := wk.NextInst() + wk.activeInstructions[instID] = &B{ + MultiMapSideInputData: map[SideInputKey]map[typex.Window]map[string][][]byte{ + SideInputKey{ + TransformID: "transformID", + Local: "i1", + }: { + tt.w: map[string][][]byte{"a": {{1}}, "b": {{2}}}, + }, + }, + } + + stateStream.Send(&fnpb.StateRequest{ + Id: "first", + InstructionId: instID, + Request: &fnpb.StateRequest_Get{ + Get: &fnpb.StateGetRequest{}, + }, + StateKey: &fnpb.StateKey{Type: &fnpb.StateKey_MultimapKeysSideInput_{ + MultimapKeysSideInput: &fnpb.StateKey_MultimapKeysSideInput{ + TransformId: "transformID", + SideInputId: "i1", + Window: encW, + }, + }}, + }) + + resp, err := stateStream.Recv() + if err != nil { + t.Fatal("couldn't receive state response:", err) + } + + want := []int{97, 98} + var got []int + for _, b := range resp.GetGet().GetData() { + got = append(got, int(b)) + } + sort.Ints(got) + + if !cmp.Equal(got, want) { + t.Errorf("didn't receive expected state response data: got %v, want %v", got, want) + } + }) + } +} + +func TestWorker_State_MultimapSideInput(t *testing.T) { + for _, tt := range []struct { + name string + w typex.Window + }{ + { + name: "global window", + w: window.GlobalWindow{}, + }, + { + name: "interval window", + w: window.IntervalWindow{ + Start: 1000, + End: 2000, + }, + }, + } { + t.Run(tt.name, func(t *testing.T) { + var encW []byte + if !tt.w.Equals(window.GlobalWindow{}) { + buf := bytes.Buffer{} + if err := exec.MakeWindowEncoder(coder.NewIntervalWindow()).EncodeSingle(tt.w, &buf); err != nil { + t.Fatalf("error encoding window: %v, err: %v", tt.w, err) + } + encW = buf.Bytes() + } + ctx, stateStream, done := serveTestWorkerStateStream(t) + defer done() + + wk, err := Pool.workerFromMetadataCtx(ctx) + if err != nil { + t.Fatal(err) + } + + instID := wk.NextInst() + wk.activeInstructions[instID] = &B{ + MultiMapSideInputData: map[SideInputKey]map[typex.Window]map[string][][]byte{ + SideInputKey{ + TransformID: "transformID", + Local: "i1", + }: { + tt.w: map[string][][]byte{"a": {{5}}, "b": {{12}}}, + }, + }, + } + var testKey = []string{"a", "b", "x"} + expectedResult := map[string][]int{ + "a": {5}, + "b": {12}, + } + for _, key := range testKey { + stateStream.Send(&fnpb.StateRequest{ + Id: "first", + InstructionId: instID, + Request: &fnpb.StateRequest_Get{ + Get: &fnpb.StateGetRequest{}, + }, + StateKey: &fnpb.StateKey{Type: &fnpb.StateKey_MultimapSideInput_{ + MultimapSideInput: &fnpb.StateKey_MultimapSideInput{ + TransformId: "transformID", + SideInputId: "i1", + Window: encW, + Key: []byte(key), + }, + }}, + }) + + resp, err := stateStream.Recv() + if err != nil { + t.Fatal("Couldn't receive state response:", err) + } + + var got []int + for _, b := range resp.GetGet().GetData() { + got = append(got, int(b)) + } + if !cmp.Equal(got, expectedResult[key]) { + t.Errorf("For test key: %v, didn't receive expected state response data: got %v, want %v", key, got, expectedResult[key]) + } + } + }) + } +} + +func TestMapW_workerFromMetadataCtx(t *testing.T) { + tests := []struct { + name string + ctx context.Context + m MapW + want *W + wantErr string + }{ + { + name: "missing metadata", + m: make(MapW), + wantErr: "failed to read metadata from context", + }, + { + name: "ctx metadata worker_id=''", + ctx: metadata.NewIncomingContext(context.Background(), metadata.Pairs("worker_id", "")), + m: make(MapW), + wantErr: "worker id in ctx metadata is an empty string", + }, + { + name: "mismatched ctx metadata", + ctx: metadata.NewIncomingContext(context.Background(), metadata.Pairs("worker_id", "wk1")), + m: map[string]*W{ + "wk2": {ID: "wk2"}, + }, + wantErr: "worker id: 'wk1' read from ctx but not registered in worker pool", + }, + { + name: "matching ctx metadata", + ctx: metadata.NewIncomingContext(context.Background(), metadata.Pairs("worker_id", "wk1")), + m: map[string]*W{ + "wk1": {ID: "wk1"}, + }, + want: &W{ID: "wk1"}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.ctx == nil { + tt.ctx = context.Background() + } + got, err := tt.m.workerFromMetadataCtx(tt.ctx) + if err != nil && err.Error() != tt.wantErr { + t.Errorf("workerFromMetadataCtx() error = %v, wantErr %v", err, tt.wantErr) + return + } + if tt.wantErr != "" { + return + } + if got.ID != tt.want.ID { + t.Errorf("workerFromMetadataCtx() got = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/sdks/go/pkg/beam/runners/prism/internal/worker/worker.go b/sdks/go/pkg/beam/runners/prism/internal/worker/worker.go index 9d9058975b26..61ac3d38054c 100644 --- a/sdks/go/pkg/beam/runners/prism/internal/worker/worker.go +++ b/sdks/go/pkg/beam/runners/prism/internal/worker/worker.go @@ -23,11 +23,8 @@ import ( "fmt" "io" "log/slog" - "math" - "net" "sync" "sync/atomic" - "time" "github.com/apache/beam/sdks/v2/go/pkg/beam/core" "github.com/apache/beam/sdks/v2/go/pkg/beam/core/graph/coder" @@ -38,7 +35,6 @@ import ( pipepb "github.com/apache/beam/sdks/v2/go/pkg/beam/model/pipeline_v1" "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/prism/internal/engine" "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/prism/internal/urns" - "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" "google.golang.org/protobuf/encoding/prototext" @@ -47,7 +43,7 @@ import ( // A W manages worker environments, sending them work // that they're able to execute, and manages the server -// side handlers for FnAPI RPCs. +// side handlers for FnAPI RPCs forwarded from a multiplex FnAPI service. type W struct { fnpb.UnimplementedBeamFnControlServer fnpb.UnimplementedBeamFnDataServer @@ -61,10 +57,6 @@ type W struct { EnvPb *pipepb.Environment PipelineOptions *structpb.Struct - // Server management - lis net.Listener - server *grpc.Server - // These are the ID sources inst uint64 connected, stopped atomic.Bool @@ -76,51 +68,17 @@ type W struct { mu sync.Mutex activeInstructions map[string]controlResponder // Active instructions keyed by InstructionID Descriptors map[string]*fnpb.ProcessBundleDescriptor // Stages keyed by PBDID + + WorkerPoolEndpoint string } type controlResponder interface { Respond(*fnpb.InstructionResponse) } -// New starts the worker server components of FnAPI Execution. -func New(id, env string) *W { - lis, err := net.Listen("tcp", ":0") - if err != nil { - panic(fmt.Sprintf("failed to listen: %v", err)) - } - opts := []grpc.ServerOption{ - grpc.MaxRecvMsgSize(math.MaxInt32), - } - wk := &W{ - ID: id, - Env: env, - lis: lis, - server: grpc.NewServer(opts...), - - InstReqs: make(chan *fnpb.InstructionRequest, 10), - DataReqs: make(chan *fnpb.Elements, 10), - StoppedChan: make(chan struct{}), - - activeInstructions: make(map[string]controlResponder), - Descriptors: make(map[string]*fnpb.ProcessBundleDescriptor), - } - slog.Debug("Serving Worker components", slog.String("endpoint", wk.Endpoint())) - fnpb.RegisterBeamFnControlServer(wk.server, wk) - fnpb.RegisterBeamFnDataServer(wk.server, wk) - fnpb.RegisterBeamFnLoggingServer(wk.server, wk) - fnpb.RegisterBeamFnStateServer(wk.server, wk) - fnpb.RegisterProvisionServiceServer(wk.server, wk) - return wk -} - +// Endpoint forwards the endpoint of the multiplex gRPC FnAPI service. func (wk *W) Endpoint() string { - _, port, _ := net.SplitHostPort(wk.lis.Addr().String()) - return fmt.Sprintf("localhost:%v", port) -} - -// Serve serves on the started listener. Blocks. -func (wk *W) Serve() { - wk.server.Serve(wk.lis) + return wk.WorkerPoolEndpoint } func (wk *W) String() string { @@ -135,10 +93,6 @@ func (wk *W) LogValue() slog.Value { } // shutdown safely closes channels, and can be called in the event of SDK crashes. -// -// Splitting this logic from the GRPC server Stop is necessary, since a worker -// crash would be handled in a streaming RPC context, which will block GRPC -// stop calls. func (wk *W) shutdown() { // If this is the first call to "stop" this worker, also close the channels. if wk.stopped.CompareAndSwap(false, true) { @@ -151,20 +105,10 @@ func (wk *W) shutdown() { } } -// Stop the GRPC server. +// Stop the worker and delete it from the Pool. func (wk *W) Stop() { wk.shutdown() - - // Give the SDK side 5 seconds to gracefully stop, before - // hard stopping all RPCs. - tim := time.AfterFunc(5*time.Second, func() { - wk.server.Stop() - }) - wk.server.GracefulStop() - tim.Stop() - - wk.lis.Close() - slog.Debug("stopped", "worker", wk) + delete(Pool, wk.ID) } func (wk *W) NextInst() string { diff --git a/sdks/go/pkg/beam/runners/prism/internal/worker/worker_test.go b/sdks/go/pkg/beam/runners/prism/internal/worker/worker_test.go index 469e0e2f3d83..c2201ecd7dd1 100644 --- a/sdks/go/pkg/beam/runners/prism/internal/worker/worker_test.go +++ b/sdks/go/pkg/beam/runners/prism/internal/worker/worker_test.go @@ -16,36 +16,21 @@ package worker import ( - "bytes" "context" - "net" - "sort" - "sync" "testing" - "time" - "github.com/google/go-cmp/cmp" - - "github.com/apache/beam/sdks/v2/go/pkg/beam/core/graph/coder" - "github.com/apache/beam/sdks/v2/go/pkg/beam/core/graph/window" - "github.com/apache/beam/sdks/v2/go/pkg/beam/core/runtime/exec" - "github.com/apache/beam/sdks/v2/go/pkg/beam/core/typex" fnpb "github.com/apache/beam/sdks/v2/go/pkg/beam/model/fnexecution_v1" - "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/prism/internal/engine" - "google.golang.org/grpc" - "google.golang.org/grpc/credentials/insecure" - "google.golang.org/grpc/test/bufconn" ) func TestWorker_New(t *testing.T) { - w := New("test", "testEnv") + w := Pool.NewWorker("test", "testEnv") if got, want := w.ID, "test"; got != want { t.Errorf("New(%q) = %v, want %v", want, got, want) } } func TestWorker_NextInst(t *testing.T) { - w := New("test", "testEnv") + w := Pool.NewWorker("test", "testEnv") instIDs := map[string]struct{}{} for i := 0; i < 100; i++ { @@ -57,7 +42,7 @@ func TestWorker_NextInst(t *testing.T) { } func TestWorker_GetProcessBundleDescriptor(t *testing.T) { - w := New("test", "testEnv") + w := Pool.NewWorker("test", "testEnv") id := "available" w.Descriptors[id] = &fnpb.ProcessBundleDescriptor{ @@ -82,386 +67,4 @@ func TestWorker_GetProcessBundleDescriptor(t *testing.T) { } } -func serveTestWorker(t *testing.T) (context.Context, *W, *grpc.ClientConn) { - t.Helper() - ctx, cancelFn := context.WithCancel(context.Background()) - t.Cleanup(cancelFn) - - w := New("test", "testEnv") - lis := bufconn.Listen(2048) - w.lis = lis - t.Cleanup(func() { w.Stop() }) - go w.Serve() - - clientConn, err := grpc.DialContext(ctx, "", grpc.WithContextDialer(func(ctx context.Context, _ string) (net.Conn, error) { - return lis.DialContext(ctx) - }), grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithBlock()) - if err != nil { - t.Fatal("couldn't create bufconn grpc connection:", err) - } - return ctx, w, clientConn -} - type closeSend func() - -func serveTestWorkerStateStream(t *testing.T) (*W, fnpb.BeamFnState_StateClient, closeSend) { - ctx, wk, clientConn := serveTestWorker(t) - - stateCli := fnpb.NewBeamFnStateClient(clientConn) - stateStream, err := stateCli.State(ctx) - if err != nil { - t.Fatal("couldn't create state client:", err) - } - return wk, stateStream, func() { - if err := stateStream.CloseSend(); err != nil { - t.Errorf("stateStream.CloseSend() = %v", err) - } - } -} - -func TestWorker_Logging(t *testing.T) { - ctx, _, clientConn := serveTestWorker(t) - - logCli := fnpb.NewBeamFnLoggingClient(clientConn) - logStream, err := logCli.Logging(ctx) - if err != nil { - t.Fatal("couldn't create log client:", err) - } - - logStream.Send(&fnpb.LogEntry_List{ - LogEntries: []*fnpb.LogEntry{{ - Severity: fnpb.LogEntry_Severity_INFO, - Message: "squeamish ossiphrage", - LogLocation: "intentionally.go:124", - }}, - }) - - logStream.Send(&fnpb.LogEntry_List{ - LogEntries: []*fnpb.LogEntry{{ - Severity: fnpb.LogEntry_Severity_INFO, - Message: "squeamish ossiphrage the second", - LogLocation: "intentionally bad log location", - }}, - }) - - // TODO: Connect to the job management service. - // At this point job messages are just logged to wherever the prism runner executes - // But this should pivot to anyone connecting to the Job Management service for the - // job. - // In the meantime, sleep to validate execution via coverage. - time.Sleep(20 * time.Millisecond) -} - -func TestWorker_Control_HappyPath(t *testing.T) { - ctx, wk, clientConn := serveTestWorker(t) - - ctrlCli := fnpb.NewBeamFnControlClient(clientConn) - ctrlStream, err := ctrlCli.Control(ctx) - if err != nil { - t.Fatal("couldn't create control client:", err) - } - - instID := wk.NextInst() - - b := &B{} - b.Init() - wk.activeInstructions[instID] = b - b.ProcessOn(ctx, wk) - - ctrlStream.Send(&fnpb.InstructionResponse{ - InstructionId: instID, - Response: &fnpb.InstructionResponse_ProcessBundle{ - ProcessBundle: &fnpb.ProcessBundleResponse{ - RequiresFinalization: true, // Simple thing to check. - }, - }, - }) - - if err := ctrlStream.CloseSend(); err != nil { - t.Errorf("ctrlStream.CloseSend() = %v", err) - } - resp := <-b.Resp - - if !resp.RequiresFinalization { - t.Errorf("got %v, want response that Requires Finalization", resp) - } -} - -func TestWorker_Data_HappyPath(t *testing.T) { - ctx, wk, clientConn := serveTestWorker(t) - - dataCli := fnpb.NewBeamFnDataClient(clientConn) - dataStream, err := dataCli.Data(ctx) - if err != nil { - t.Fatal("couldn't create data client:", err) - } - - instID := wk.NextInst() - - b := &B{ - InstID: instID, - PBDID: "teststageID", - Input: []*engine.Block{ - { - Kind: engine.BlockData, - Bytes: [][]byte{{1, 1, 1, 1, 1, 1}}, - }}, - OutputCount: 1, - } - b.Init() - wk.activeInstructions[instID] = b - - var wg sync.WaitGroup - wg.Add(1) - go func() { - defer wg.Done() - b.ProcessOn(ctx, wk) - }() - - wk.InstReqs <- &fnpb.InstructionRequest{ - InstructionId: instID, - } - - elements, err := dataStream.Recv() - if err != nil { - t.Fatal("couldn't receive data elements:", err) - } - - if got, want := elements.GetData()[0].GetInstructionId(), b.InstID; got != want { - t.Fatalf("couldn't receive data elements ID: got %v, want %v", got, want) - } - if got, want := elements.GetData()[0].GetData(), []byte{1, 1, 1, 1, 1, 1}; !bytes.Equal(got, want) { - t.Fatalf("client Data received %v, want %v", got, want) - } - if got, want := elements.GetData()[0].GetIsLast(), false; got != want { - t.Fatalf("client Data received was last: got %v, want %v", got, want) - } - - elements, err = dataStream.Recv() - if err != nil { - t.Fatal("expected 2nd data elements:", err) - } - if got, want := elements.GetData()[0].GetInstructionId(), b.InstID; got != want { - t.Fatalf("couldn't receive data elements ID: got %v, want %v", got, want) - } - if got, want := elements.GetData()[0].GetData(), []byte(nil); !bytes.Equal(got, want) { - t.Fatalf("client Data received %v, want %v", got, want) - } - if got, want := elements.GetData()[0].GetIsLast(), true; got != want { - t.Fatalf("client Data received wasn't last: got %v, want %v", got, want) - } - - dataStream.Send(elements) - - if err := dataStream.CloseSend(); err != nil { - t.Errorf("ctrlStream.CloseSend() = %v", err) - } - - wg.Wait() - t.Log("ProcessOn successfully exited") -} - -func TestWorker_State_Iterable(t *testing.T) { - ctx, wk, clientConn := serveTestWorker(t) - - stateCli := fnpb.NewBeamFnStateClient(clientConn) - stateStream, err := stateCli.State(ctx) - if err != nil { - t.Fatal("couldn't create state client:", err) - } - - instID := wk.NextInst() - wk.activeInstructions[instID] = &B{ - IterableSideInputData: map[SideInputKey]map[typex.Window][][]byte{ - {TransformID: "transformID", Local: "i1"}: { - window.GlobalWindow{}: [][]byte{ - {42}, - }, - }, - }, - } - - stateStream.Send(&fnpb.StateRequest{ - Id: "first", - InstructionId: instID, - Request: &fnpb.StateRequest_Get{ - Get: &fnpb.StateGetRequest{}, - }, - StateKey: &fnpb.StateKey{Type: &fnpb.StateKey_IterableSideInput_{ - IterableSideInput: &fnpb.StateKey_IterableSideInput{ - TransformId: "transformID", - SideInputId: "i1", - Window: []byte{}, // Global Windows - }, - }}, - }) - - resp, err := stateStream.Recv() - if err != nil { - t.Fatal("couldn't receive state response:", err) - } - - if got, want := resp.GetId(), "first"; got != want { - t.Fatalf("didn't receive expected state response: got %v, want %v", got, want) - } - - if got, want := resp.GetGet().GetData(), []byte{42}; !bytes.Equal(got, want) { - t.Fatalf("didn't receive expected state response data: got %v, want %v", got, want) - } - - if err := stateStream.CloseSend(); err != nil { - t.Errorf("stateStream.CloseSend() = %v", err) - } -} - -func TestWorker_State_MultimapKeysSideInput(t *testing.T) { - for _, tt := range []struct { - name string - w typex.Window - }{ - { - name: "global window", - w: window.GlobalWindow{}, - }, - { - name: "interval window", - w: window.IntervalWindow{ - Start: 1000, - End: 2000, - }, - }, - } { - t.Run(tt.name, func(t *testing.T) { - var encW []byte - if !tt.w.Equals(window.GlobalWindow{}) { - buf := bytes.Buffer{} - if err := exec.MakeWindowEncoder(coder.NewIntervalWindow()).EncodeSingle(tt.w, &buf); err != nil { - t.Fatalf("error encoding window: %v, err: %v", tt.w, err) - } - encW = buf.Bytes() - } - wk, stateStream, done := serveTestWorkerStateStream(t) - defer done() - instID := wk.NextInst() - wk.activeInstructions[instID] = &B{ - MultiMapSideInputData: map[SideInputKey]map[typex.Window]map[string][][]byte{ - SideInputKey{ - TransformID: "transformID", - Local: "i1", - }: { - tt.w: map[string][][]byte{"a": {{1}}, "b": {{2}}}, - }, - }, - } - - stateStream.Send(&fnpb.StateRequest{ - Id: "first", - InstructionId: instID, - Request: &fnpb.StateRequest_Get{ - Get: &fnpb.StateGetRequest{}, - }, - StateKey: &fnpb.StateKey{Type: &fnpb.StateKey_MultimapKeysSideInput_{ - MultimapKeysSideInput: &fnpb.StateKey_MultimapKeysSideInput{ - TransformId: "transformID", - SideInputId: "i1", - Window: encW, - }, - }}, - }) - - resp, err := stateStream.Recv() - if err != nil { - t.Fatal("couldn't receive state response:", err) - } - - want := []int{97, 98} - var got []int - for _, b := range resp.GetGet().GetData() { - got = append(got, int(b)) - } - sort.Ints(got) - - if !cmp.Equal(got, want) { - t.Errorf("didn't receive expected state response data: got %v, want %v", got, want) - } - }) - } -} - -func TestWorker_State_MultimapSideInput(t *testing.T) { - for _, tt := range []struct { - name string - w typex.Window - }{ - { - name: "global window", - w: window.GlobalWindow{}, - }, - { - name: "interval window", - w: window.IntervalWindow{ - Start: 1000, - End: 2000, - }, - }, - } { - t.Run(tt.name, func(t *testing.T) { - var encW []byte - if !tt.w.Equals(window.GlobalWindow{}) { - buf := bytes.Buffer{} - if err := exec.MakeWindowEncoder(coder.NewIntervalWindow()).EncodeSingle(tt.w, &buf); err != nil { - t.Fatalf("error encoding window: %v, err: %v", tt.w, err) - } - encW = buf.Bytes() - } - wk, stateStream, done := serveTestWorkerStateStream(t) - defer done() - instID := wk.NextInst() - wk.activeInstructions[instID] = &B{ - MultiMapSideInputData: map[SideInputKey]map[typex.Window]map[string][][]byte{ - SideInputKey{ - TransformID: "transformID", - Local: "i1", - }: { - tt.w: map[string][][]byte{"a": {{5}}, "b": {{12}}}, - }, - }, - } - var testKey = []string{"a", "b", "x"} - expectedResult := map[string][]int{ - "a": {5}, - "b": {12}, - } - for _, key := range testKey { - stateStream.Send(&fnpb.StateRequest{ - Id: "first", - InstructionId: instID, - Request: &fnpb.StateRequest_Get{ - Get: &fnpb.StateGetRequest{}, - }, - StateKey: &fnpb.StateKey{Type: &fnpb.StateKey_MultimapSideInput_{ - MultimapSideInput: &fnpb.StateKey_MultimapSideInput{ - TransformId: "transformID", - SideInputId: "i1", - Window: encW, - Key: []byte(key), - }, - }}, - }) - - resp, err := stateStream.Recv() - if err != nil { - t.Fatal("Couldn't receive state response:", err) - } - - var got []int - for _, b := range resp.GetGet().GetData() { - got = append(got, int(b)) - } - if !cmp.Equal(got, expectedResult[key]) { - t.Errorf("For test key: %v, didn't receive expected state response data: got %v, want %v", key, got, expectedResult[key]) - } - } - }) - } -} diff --git a/sdks/go/pkg/beam/runners/prism/prism.go b/sdks/go/pkg/beam/runners/prism/prism.go index e260a7bb7ecd..f694159cf380 100644 --- a/sdks/go/pkg/beam/runners/prism/prism.go +++ b/sdks/go/pkg/beam/runners/prism/prism.go @@ -19,6 +19,8 @@ package prism import ( "context" + "fmt" + "net" "time" "github.com/apache/beam/sdks/v2/go/pkg/beam" @@ -27,6 +29,7 @@ import ( "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/prism/internal" "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/prism/internal/jobservices" "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/prism/internal/web" + "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/prism/internal/worker" "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/universal" "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" @@ -47,8 +50,17 @@ func Execute(ctx context.Context, p *beam.Pipeline) (beam.PipelineResult, error) // One hasn't been selected, so lets start one up and set the address. // Conveniently, this means that if multiple pipelines are executed against // the local runner, they will all use the same server. + lis, err := net.Listen("tcp", ":0") + if err != nil { + return nil, err + } + _, port, _ := net.SplitHostPort(lis.Addr().String()) + addr := fmt.Sprintf("localhost:%v", port) + g := worker.NewMultiplexW() + go g.Serve(lis) s := jobservices.NewServer(0, internal.RunPipeline) *jobopts.Endpoint = s.Endpoint() + s.WorkerPoolEndpoint = addr go s.Serve() if !jobopts.IsLoopback() { *jobopts.EnvironmentType = "loopback" @@ -62,6 +74,9 @@ type Options struct { // Port the Job Management Server should start on. Port int + // WorkerPoolEndpoint is the endpoint to connect with the worker pool service. + WorkerPoolEndpoint string + // The time prism will wait for new jobs before shuting itself down. IdleShutdownTimeout time.Duration // CancelFn allows Prism to terminate the program due to it's internal state, such as via the idle shutdown timeout. @@ -73,6 +88,7 @@ type Options struct { // This call is non-blocking. func CreateJobServer(ctx context.Context, opts Options) (jobpb.JobServiceClient, error) { s := jobservices.NewServer(opts.Port, internal.RunPipeline) + s.WorkerPoolEndpoint = opts.WorkerPoolEndpoint if opts.IdleShutdownTimeout > 0 { s.IdleShutdown(opts.IdleShutdownTimeout, opts.CancelFn) @@ -90,3 +106,8 @@ func CreateJobServer(ctx context.Context, opts Options) (jobpb.JobServiceClient, func CreateWebServer(ctx context.Context, cli jobpb.JobServiceClient, opts Options) error { return web.Initialize(ctx, opts.Port, cli) } + +// CreateWorkerPoolServer initializes the worker pool server that multiplexes worker.W gRPC requests. +func CreateWorkerPoolServer(ctx context.Context) *grpc.Server { + return worker.NewMultiplexW() +} diff --git a/sdks/python/test-suites/portable/common.gradle b/sdks/python/test-suites/portable/common.gradle index be87be749862..dfa243bb6957 100644 --- a/sdks/python/test-suites/portable/common.gradle +++ b/sdks/python/test-suites/portable/common.gradle @@ -206,7 +206,10 @@ def createPrismRunnerTestTask(String workerType) { def taskName = "prismCompatibilityMatrix${workerType}" def prismBin = "${rootDir}/runners/prism/build/tmp/prism" - def options = "--prism_bin=${prismBin} --environment_type=${workerType}" + def options = "--prism_bin=${prismBin}" + if (workerType != 'LOOPBACK') { + options += " --environment_type=${workerType}" + } if (workerType == 'PROCESS') { options += " --environment_options=process_command=${buildDir.absolutePath}/sdk_worker.sh" }