Skip to content

Commit

Permalink
Switch async workflow request encoding from json to thrift (#5907)
Browse files Browse the repository at this point in the history
  • Loading branch information
taylanisikdemir authored Apr 12, 2024
1 parent b258e62 commit ba39678
Show file tree
Hide file tree
Showing 8 changed files with 125 additions and 107 deletions.
28 changes: 16 additions & 12 deletions common/asyncworkflow/queue/consumer/default_consumer.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ package consumer

import (
"context"
"encoding/json"
"fmt"
"sort"
"sync"
Expand All @@ -43,6 +42,7 @@ import (
"github.com/uber/cadence/common/messaging"
"github.com/uber/cadence/common/metrics"
"github.com/uber/cadence/common/types"
"github.com/uber/cadence/common/types/mapper/thrift"
)

const (
Expand Down Expand Up @@ -185,7 +185,7 @@ func (c *DefaultConsumer) processRequest(logger log.Logger, request *sqlblobs.As
scope := c.scope.Tagged(metrics.AsyncWFRequestTypeTag(request.GetType().String()))
switch request.GetType() {
case sqlblobs.AsyncRequestTypeStartWorkflowExecutionAsyncRequest:
startWFReq, err := decodeStartWorkflowRequest(request.GetPayload(), request.GetEncoding())
startWFReq, err := c.decodeStartWorkflowRequest(request.GetPayload(), request.GetEncoding())
if err != nil {
scope.IncCounter(metrics.AsyncWorkflowFailureCorruptMsgCount)
return err
Expand All @@ -210,7 +210,7 @@ func (c *DefaultConsumer) processRequest(logger log.Logger, request *sqlblobs.As
scope.IncCounter(metrics.AsyncWorkflowSuccessCount)
logger.Info("StartWorkflowExecution succeeded", tag.WorkflowID(startWFReq.GetWorkflowID()), tag.WorkflowRunID(resp.GetRunID()))
case sqlblobs.AsyncRequestTypeSignalWithStartWorkflowExecutionAsyncRequest:
startWFReq, err := decodeSignalWithStartWorkflowRequest(request.GetPayload(), request.GetEncoding())
startWFReq, err := c.decodeSignalWithStartWorkflowRequest(request.GetPayload(), request.GetEncoding())
if err != nil {
c.scope.IncCounter(metrics.AsyncWorkflowFailureCorruptMsgCount)
return err
Expand Down Expand Up @@ -270,26 +270,30 @@ func getYARPCOptions(header *shared.Header) []yarpc.CallOption {
return opts
}

func decodeStartWorkflowRequest(payload []byte, encoding string) (*types.StartWorkflowExecutionRequest, error) {
if encoding != string(common.EncodingTypeJSON) {
func (c *DefaultConsumer) decodeStartWorkflowRequest(payload []byte, encoding string) (*types.StartWorkflowExecutionRequest, error) {
if encoding != string(common.EncodingTypeThriftRW) {
return nil, &UnsupportedEncoding{EncodingType: encoding}
}

var startRequest types.StartWorkflowExecutionAsyncRequest
if err := json.Unmarshal(payload, &startRequest); err != nil {
var thriftObj shared.StartWorkflowExecutionAsyncRequest
if err := c.msgDecoder.Decode(payload, &thriftObj); err != nil {
return nil, err
}

startRequest := thrift.ToStartWorkflowExecutionAsyncRequest(&thriftObj)
return startRequest.StartWorkflowExecutionRequest, nil
}

func decodeSignalWithStartWorkflowRequest(payload []byte, encoding string) (*types.SignalWithStartWorkflowExecutionRequest, error) {
if encoding != string(common.EncodingTypeJSON) {
func (c *DefaultConsumer) decodeSignalWithStartWorkflowRequest(payload []byte, encoding string) (*types.SignalWithStartWorkflowExecutionRequest, error) {
if encoding != string(common.EncodingTypeThriftRW) {
return nil, &UnsupportedEncoding{EncodingType: encoding}
}

var startRequest types.SignalWithStartWorkflowExecutionAsyncRequest
if err := json.Unmarshal(payload, &startRequest); err != nil {
var thriftObj shared.SignalWithStartWorkflowExecutionAsyncRequest
if err := c.msgDecoder.Decode(payload, &thriftObj); err != nil {
return nil, err
}
return startRequest.SignalWithStartWorkflowExecutionRequest, nil

signalWithStartRequest := thrift.ToSignalWithStartWorkflowExecutionAsyncRequest(&thriftObj)
return signalWithStartRequest.SignalWithStartWorkflowExecutionRequest, nil
}
98 changes: 53 additions & 45 deletions common/asyncworkflow/queue/consumer/default_consumer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,12 @@
package consumer

import (
"encoding/json"
"errors"
"testing"

"github.com/golang/mock/gomock"
"github.com/google/go-cmp/cmp"
"go.uber.org/yarpc"

"github.com/uber/cadence/.gen/go/shared"
"github.com/uber/cadence/.gen/go/sqlblobs"
Expand All @@ -38,6 +39,28 @@ import (
"github.com/uber/cadence/common/messaging"
"github.com/uber/cadence/common/metrics"
"github.com/uber/cadence/common/types"
"github.com/uber/cadence/common/types/mapper/thrift"
)

var (
testSignalWithStartAsyncReq = &types.SignalWithStartWorkflowExecutionAsyncRequest{
SignalWithStartWorkflowExecutionRequest: &types.SignalWithStartWorkflowExecutionRequest{
Domain: "test-domain",
WorkflowID: "test-workflow-id",
WorkflowType: &types.WorkflowType{Name: "test-workflow-type"},
Input: []byte("test-input"),
SignalName: "test-signal-name",
},
}

testStartReq = &types.StartWorkflowExecutionAsyncRequest{
StartWorkflowExecutionRequest: &types.StartWorkflowExecutionRequest{
Domain: "test-domain",
WorkflowID: "test-workflow-id",
WorkflowType: &types.WorkflowType{Name: "test-workflow-type"},
Input: []byte("test-input"),
},
}
)

type fakeMessageConsumer struct {
Expand Down Expand Up @@ -127,60 +150,60 @@ func TestDefaultConsumer(t *testing.T) {
name: "startworkflow request with invalid payload content",
frontendFails: true,
msgs: []*fakeMessage{
{val: mustGenerateStartWorkflowExecutionRequestMsg(t, common.EncodingTypeJSON, false), wantAck: false},
{val: mustGenerateStartWorkflowExecutionRequestMsg(t, common.EncodingTypeThriftRW, false), wantAck: false},
},
},
{
name: "startworkflowfrontend fails to respond",
frontendFails: true,
msgs: []*fakeMessage{
{val: mustGenerateStartWorkflowExecutionRequestMsg(t, common.EncodingTypeJSON, true), wantAck: false},
{val: mustGenerateStartWorkflowExecutionRequestMsg(t, common.EncodingTypeThriftRW, true), wantAck: false},
},
},
{
name: "startworkflow unsupported encoding type",
name: "startworkflow unsupported encoding type. json encoding of requests are lossy due to PII masking so it shouldn't be used for async requests",
msgs: []*fakeMessage{
{val: mustGenerateStartWorkflowExecutionRequestMsg(t, common.EncodingTypeProto, true), wantAck: false},
{val: mustGenerateStartWorkflowExecutionRequestMsg(t, common.EncodingTypeJSON, true), wantAck: false},
},
},
{
name: "startworkflow ok",
msgs: []*fakeMessage{
{val: mustGenerateStartWorkflowExecutionRequestMsg(t, common.EncodingTypeJSON, true), wantAck: true},
{val: mustGenerateStartWorkflowExecutionRequestMsg(t, common.EncodingTypeThriftRW, true), wantAck: true},
},
},
{
name: "startworkflow ok with chan closed before stopping",
closeChanBeforeStop: true,
msgs: []*fakeMessage{
{val: mustGenerateStartWorkflowExecutionRequestMsg(t, common.EncodingTypeJSON, true), wantAck: true},
{val: mustGenerateStartWorkflowExecutionRequestMsg(t, common.EncodingTypeThriftRW, true), wantAck: true},
},
},
// signal with start test cases
{
name: "signalwithstartworkflow request with invalid payload content",
frontendFails: true,
msgs: []*fakeMessage{
{val: mustGenerateSignalWithStartWorkflowExecutionRequestMsg(t, common.EncodingTypeJSON, false), wantAck: false},
{val: mustGenerateSignalWithStartWorkflowExecutionRequestMsg(t, common.EncodingTypeThriftRW, false), wantAck: false},
},
},
{
name: "signalwithstartworkflow frontend fails to respond",
frontendFails: true,
msgs: []*fakeMessage{
{val: mustGenerateSignalWithStartWorkflowExecutionRequestMsg(t, common.EncodingTypeJSON, true), wantAck: false},
{val: mustGenerateSignalWithStartWorkflowExecutionRequestMsg(t, common.EncodingTypeThriftRW, true), wantAck: false},
},
},
{
name: "signalwithstartworkflow unsupported encoding type",
name: "signalwithstartworkflow unsupported encoding type. json encoding of requests are lossy due to PII masking so it shouldn't be used for async requests",
msgs: []*fakeMessage{
{val: mustGenerateSignalWithStartWorkflowExecutionRequestMsg(t, common.EncodingTypeProto, true), wantAck: false},
{val: mustGenerateSignalWithStartWorkflowExecutionRequestMsg(t, common.EncodingTypeJSON, true), wantAck: false},
},
},
{
name: "signalwithstartworkflow ok",
msgs: []*fakeMessage{
{val: mustGenerateSignalWithStartWorkflowExecutionRequestMsg(t, common.EncodingTypeJSON, true), wantAck: true},
{val: mustGenerateSignalWithStartWorkflowExecutionRequestMsg(t, common.EncodingTypeThriftRW, true), wantAck: true},
},
},
}
Expand All @@ -206,10 +229,20 @@ func TestDefaultConsumer(t *testing.T) {
resp := &types.StartWorkflowExecutionResponse{RunID: "test-run-id"}
mockFrontend.EXPECT().
StartWorkflowExecution(gomock.Any(), gomock.Any(), opts[0], opts[1]).
Return(resp, nil).AnyTimes()
DoAndReturn(func(ctx interface{}, req *types.StartWorkflowExecutionRequest, opts ...yarpc.CallOption) (*types.StartWorkflowExecutionResponse, error) {
if diff := cmp.Diff(testStartReq.StartWorkflowExecutionRequest, req); diff != "" {
t.Fatalf("Request mismatch (-want +got):\n%s", diff)
}
return resp, nil
}).AnyTimes()
mockFrontend.EXPECT().
SignalWithStartWorkflowExecution(gomock.Any(), gomock.Any(), opts[0], opts[1]).
Return(resp, nil).AnyTimes()
DoAndReturn(func(ctx interface{}, req *types.SignalWithStartWorkflowExecutionRequest, opts ...yarpc.CallOption) (*types.StartWorkflowExecutionResponse, error) {
if diff := cmp.Diff(testSignalWithStartAsyncReq.SignalWithStartWorkflowExecutionRequest, req); diff != "" {
t.Fatalf("Request mismatch (-want +got):\n%s", diff)
}
return resp, nil
}).AnyTimes()
}

c := New("queueid1", fakeConsumer, testlogger.New(t), metrics.NewNoopMetricsClient(), mockFrontend, WithConcurrency(2))
Expand Down Expand Up @@ -247,16 +280,8 @@ func TestDefaultConsumer(t *testing.T) {
}

func mustGenerateStartWorkflowExecutionRequestMsg(t *testing.T, encodingType common.EncodingType, validPayload bool) []byte {
startRequest := &types.StartWorkflowExecutionAsyncRequest{
StartWorkflowExecutionRequest: &types.StartWorkflowExecutionRequest{
Domain: "test-domain",
WorkflowID: "test-workflow-id",
WorkflowType: &types.WorkflowType{Name: "test-workflow-type"},
Input: []byte("test-input"),
},
}

payload, err := json.Marshal(startRequest)
encoder := codec.NewThriftRWEncoder()
payload, err := encoder.Encode(thrift.FromStartWorkflowExecutionAsyncRequest(testStartReq))
if err != nil {
t.Fatal(err)
}
Expand All @@ -281,17 +306,8 @@ func mustGenerateStartWorkflowExecutionRequestMsg(t *testing.T, encodingType com
}

func mustGenerateSignalWithStartWorkflowExecutionRequestMsg(t *testing.T, encodingType common.EncodingType, validPayload bool) []byte {
signalWithStartRequest := &types.SignalWithStartWorkflowExecutionAsyncRequest{
SignalWithStartWorkflowExecutionRequest: &types.SignalWithStartWorkflowExecutionRequest{
Domain: "test-domain",
WorkflowID: "test-workflow-id",
WorkflowType: &types.WorkflowType{Name: "test-workflow-type"},
Input: []byte("test-input"),
SignalName: "test-signal-name",
},
}

payload, err := json.Marshal(signalWithStartRequest)
encoder := codec.NewThriftRWEncoder()
payload, err := encoder.Encode(thrift.FromSignalWithStartWorkflowExecutionAsyncRequest(testSignalWithStartAsyncReq))
if err != nil {
t.Fatal(err)
}
Expand All @@ -316,16 +332,8 @@ func mustGenerateSignalWithStartWorkflowExecutionRequestMsg(t *testing.T, encodi
}

func mustGenerateUnsupportedRequestMsg(t *testing.T) []byte {
startRequest := &types.StartWorkflowExecutionAsyncRequest{
StartWorkflowExecutionRequest: &types.StartWorkflowExecutionRequest{
Domain: "test-domain",
WorkflowID: "test-workflow-id",
WorkflowType: &types.WorkflowType{Name: "test-workflow-type"},
Input: []byte("test-input"),
},
}

payload, err := json.Marshal(startRequest)
encoder := codec.NewThriftRWEncoder()
payload, err := encoder.Encode(thrift.FromStartWorkflowExecutionAsyncRequest(testStartReq))
if err != nil {
t.Fatal(err)
}
Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,15 @@
package codec

import (
"fmt"
"log"
"os"
"sync"
"testing"

"github.com/google/go-cmp/cmp"
"github.com/stretchr/testify/suite"
"go.uber.org/multierr"

workflow "github.com/uber/cadence/.gen/go/shared"
)
Expand Down Expand Up @@ -84,6 +88,30 @@ func (s *thriftRWEncoderSuite) TestEncode() {
s.Equal(thriftEncodedBinary, binary)
}

func (s *thriftRWEncoderSuite) TestEncodeConcurrent() {
var wg sync.WaitGroup
count := 200
errs := make([]error, count)
wg.Add(count)
for i := 0; i < count; i++ {
go func(idx int) {
defer wg.Done()
binary, err := s.encoder.Encode(thriftObject)
if err != nil {
errs[idx] = err
return
}

if diff := cmp.Diff(thriftEncodedBinary, binary); diff != "" {
errs[idx] = fmt.Errorf("Mismatch (-want +got):\n%s", diff)
return
}
}(i)
}
wg.Wait()
s.NoError(multierr.Combine(errs...))
}

func (s *thriftRWEncoderSuite) TestDecode() {
var val workflow.HistoryEvent
err := s.encoder.Decode(thriftEncodedBinary, &val)
Expand Down
30 changes: 0 additions & 30 deletions common/rpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,6 @@
package common

import (
"context"

"go.uber.org/yarpc"
)

Expand Down Expand Up @@ -69,31 +67,3 @@ type (
GetMaxMessageSize() int
}
)

// GetClientHeaders returns the headers that should be sent from client to server
func GetClientHeaders(ctx context.Context) map[string]string {
call := yarpc.CallFromContext(ctx)
headerNames := call.HeaderNames()
headerExists := map[string]struct{}{}
for _, h := range headerNames {
headerExists[h] = struct{}{}
}

headers := make(map[string]string)
if _, ok := headerExists[LibraryVersionHeaderName]; !ok {
headers[LibraryVersionHeaderName] = call.Header(LibraryVersionHeaderName)
}
if _, ok := headerExists[FeatureVersionHeaderName]; !ok {
headers[FeatureVersionHeaderName] = call.Header(FeatureVersionHeaderName)
}
if _, ok := headerExists[ClientImplHeaderName]; !ok {
headers[ClientImplHeaderName] = call.Header(ClientImplHeaderName)
}
if _, ok := headerExists[ClientFeatureFlagsHeaderName]; !ok {
headers[ClientFeatureFlagsHeaderName] = call.Header(ClientFeatureFlagsHeaderName)
}
if _, ok := headerExists[ClientIsolationGroupHeaderName]; !ok {
headers[ClientIsolationGroupHeaderName] = call.Header(ClientIsolationGroupHeaderName)
}
return headers
}
Loading

0 comments on commit ba39678

Please sign in to comment.