From 074c7daa2872b29a8997b19812ce36e6f4ac8829 Mon Sep 17 00:00:00 2001 From: Zach Reyes Date: Wed, 28 Sep 2022 20:09:53 -0400 Subject: [PATCH 1/2] Add binary logger option for client and server --- default_dial_option_server_option_test.go | 8 +- dialoptions.go | 14 ++- gcp/observability/opencensus.go | 8 +- internal/internal.go | 16 +-- server.go | 127 +++++++++++++++------- stream.go | 123 ++++++++++++++------- test/end2end_test.go | 97 +++++++++++++++++ 7 files changed, 301 insertions(+), 92 deletions(-) diff --git a/default_dial_option_server_option_test.go b/default_dial_option_server_option_test.go index 3dc446f58b5a..eecd6b846f28 100644 --- a/default_dial_option_server_option_test.go +++ b/default_dial_option_server_option_test.go @@ -38,7 +38,7 @@ func (s) TestAddExtraDialOptions(t *testing.T) { // Set and check the DialOptions opts := []DialOption{WithTransportCredentials(insecure.NewCredentials()), WithTransportCredentials(insecure.NewCredentials()), WithTransportCredentials(insecure.NewCredentials())} - internal.AddExtraDialOptions.(func(opt ...DialOption))(opts...) + internal.AddGlobalDialOptions.(func(opt ...DialOption))(opts...) for i, opt := range opts { if extraDialOptions[i] != opt { t.Fatalf("Unexpected extra dial option at index %d: %v != %v", i, extraDialOptions[i], opt) @@ -52,7 +52,7 @@ func (s) TestAddExtraDialOptions(t *testing.T) { cc.Close() } - internal.ClearExtraDialOptions() + internal.ClearGlobalDialOptions() if len(extraDialOptions) != 0 { t.Fatalf("Unexpected len of extraDialOptions: %d != 0", len(extraDialOptions)) } @@ -62,7 +62,7 @@ func (s) TestAddExtraServerOptions(t *testing.T) { const maxRecvSize = 998765 // Set and check the ServerOptions opts := []ServerOption{Creds(insecure.NewCredentials()), MaxRecvMsgSize(maxRecvSize)} - internal.AddExtraServerOptions.(func(opt ...ServerOption))(opts...) + internal.AddGlobalServerOptions.(func(opt ...ServerOption))(opts...) for i, opt := range opts { if extraServerOptions[i] != opt { t.Fatalf("Unexpected extra server option at index %d: %v != %v", i, extraServerOptions[i], opt) @@ -75,7 +75,7 @@ func (s) TestAddExtraServerOptions(t *testing.T) { t.Fatalf("Unexpected s.opts.maxReceiveMessageSize: %d != %d", s.opts.maxReceiveMessageSize, maxRecvSize) } - internal.ClearExtraServerOptions() + internal.ClearGlobalServerOptions() if len(extraServerOptions) != 0 { t.Fatalf("Unexpected len of extraServerOptions: %d != 0", len(extraServerOptions)) } diff --git a/dialoptions.go b/dialoptions.go index 60403bc160ec..0dab38d4d9b8 100644 --- a/dialoptions.go +++ b/dialoptions.go @@ -29,6 +29,7 @@ import ( "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/internal" internalbackoff "google.golang.org/grpc/internal/backoff" + "google.golang.org/grpc/internal/binarylog" "google.golang.org/grpc/internal/transport" "google.golang.org/grpc/keepalive" "google.golang.org/grpc/resolver" @@ -36,10 +37,10 @@ import ( ) func init() { - internal.AddExtraDialOptions = func(opt ...DialOption) { + internal.AddGlobalDialOptions = func(opt ...DialOption) { extraDialOptions = append(extraDialOptions, opt...) } - internal.ClearExtraDialOptions = func() { + internal.ClearGlobalDialOptions = func() { extraDialOptions = nil } } @@ -61,6 +62,7 @@ type dialOptions struct { timeout time.Duration scChan <-chan ServiceConfig authority string + binaryLogger binarylog.Logger copts transport.ConnectOptions callOptions []CallOption channelzParentID *channelz.Identifier @@ -401,6 +403,14 @@ func WithStatsHandler(h stats.Handler) DialOption { }) } +// WithBinaryLogger returns a DialOption that specifies the binary logger for +// this ClientConn. +func WithBinaryLogger(bl binarylog.Logger) DialOption { + return newFuncDialOption(func(o *dialOptions) { + o.binaryLogger = bl + }) +} + // FailOnNonTempDialError returns a DialOption that specifies if gRPC fails on // non-temporary dial errors. If f is true, and dialer returns a non-temporary // error, gRPC will fail the connection to the network address and won't try to diff --git a/gcp/observability/opencensus.go b/gcp/observability/opencensus.go index ccaa6a98a42c..1bca322cc52c 100644 --- a/gcp/observability/opencensus.go +++ b/gcp/observability/opencensus.go @@ -117,8 +117,8 @@ func startOpenCensus(config *config) error { } // Only register default StatsHandlers if other things are setup correctly. - internal.AddExtraServerOptions.(func(opt ...grpc.ServerOption))(grpc.StatsHandler(&ocgrpc.ServerHandler{StartOptions: so})) - internal.AddExtraDialOptions.(func(opt ...grpc.DialOption))(grpc.WithStatsHandler(&ocgrpc.ClientHandler{StartOptions: so})) + internal.AddGlobalServerOptions.(func(opt ...grpc.ServerOption))(grpc.StatsHandler(&ocgrpc.ServerHandler{StartOptions: so})) + internal.AddGlobalDialOptions.(func(opt ...grpc.DialOption))(grpc.WithStatsHandler(&ocgrpc.ClientHandler{StartOptions: so})) logger.Infof("Enabled OpenCensus StatsHandlers for clients and servers") return nil @@ -128,8 +128,8 @@ func startOpenCensus(config *config) error { // packages if exporter was created. func stopOpenCensus() { if exporter != nil { - internal.ClearExtraDialOptions() - internal.ClearExtraServerOptions() + internal.ClearGlobalDialOptions() + internal.ClearGlobalServerOptions() // Call these unconditionally, doesn't matter if not registered, will be // a noop if not registered. diff --git a/internal/internal.go b/internal/internal.go index 9ce1f18ae9d6..2a24b00f7fc9 100644 --- a/internal/internal.go +++ b/internal/internal.go @@ -63,20 +63,20 @@ var ( // xDS-enabled server invokes this method on a grpc.Server when a particular // listener moves to "not-serving" mode. DrainServerTransports interface{} // func(*grpc.Server, string) - // AddExtraServerOptions adds an array of ServerOption that will be + // AddGlobalServerOptions adds an array of ServerOption that will be // effective globally for newly created servers. The priority will be: 1. // user-provided; 2. this method; 3. default values. - AddExtraServerOptions interface{} // func(opt ...ServerOption) - // ClearExtraServerOptions clears the array of extra ServerOption. This + AddGlobalServerOptions interface{} // func(opt ...ServerOption) + // ClearGlobalServerOptions clears the array of extra ServerOption. This // method is useful in testing and benchmarking. - ClearExtraServerOptions func() - // AddExtraDialOptions adds an array of DialOption that will be effective + ClearGlobalServerOptions func() + // AddGlobalDialOptions adds an array of DialOption that will be effective // globally for newly created client channels. The priority will be: 1. // user-provided; 2. this method; 3. default values. - AddExtraDialOptions interface{} // func(opt ...DialOption) - // ClearExtraDialOptions clears the array of extra DialOption. This + AddGlobalDialOptions interface{} // func(opt ...DialOption) + // ClearGlobalDialOptions clears the array of extra DialOption. This // method is useful in testing and benchmarking. - ClearExtraDialOptions func() + ClearGlobalDialOptions func() // JoinServerOptions combines the server options passed as arguments into a // single server option. JoinServerOptions interface{} // func(...grpc.ServerOption) grpc.ServerOption diff --git a/server.go b/server.go index 6ef3df67d9e5..4b9cf89c0439 100644 --- a/server.go +++ b/server.go @@ -73,10 +73,10 @@ func init() { internal.DrainServerTransports = func(srv *Server, addr string) { srv.drainServerTransports(addr) } - internal.AddExtraServerOptions = func(opt ...ServerOption) { + internal.AddGlobalServerOptions = func(opt ...ServerOption) { extraServerOptions = opt } - internal.ClearExtraServerOptions = func() { + internal.ClearGlobalServerOptions = func() { extraServerOptions = nil } internal.JoinServerOptions = newJoinServerOption @@ -156,6 +156,7 @@ type serverOptions struct { streamInt StreamServerInterceptor chainUnaryInts []UnaryServerInterceptor chainStreamInts []StreamServerInterceptor + binaryLogger binarylog.Logger inTapHandle tap.ServerInHandle statsHandlers []stats.Handler maxConcurrentStreams uint32 @@ -469,6 +470,14 @@ func StatsHandler(h stats.Handler) ServerOption { }) } +// BinaryLogger returns a ServerOption that can set the binary logger for the +// server. +func BinaryLogger(bl binarylog.Logger) ServerOption { + return newFuncServerOption(func(o *serverOptions) { + o.binaryLogger = bl + }) +} + // UnknownServiceHandler returns a ServerOption that allows for adding a custom // unknown service handler. The provided method is a bidi-streaming RPC service // handler that will be invoked instead of returning the "unimplemented" gRPC @@ -1216,9 +1225,16 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport. } }() } - - binlog := binarylog.GetMethodLogger(stream.Method()) - if binlog != nil { + var binlogs []binarylog.MethodLogger + if ml := binarylog.GetMethodLogger(stream.Method()); ml != nil { + binlogs = append(binlogs, ml) + } + if s.opts.binaryLogger != nil { + if ml := s.opts.binaryLogger.GetMethodLogger(stream.Method()); ml != nil { + binlogs = append(binlogs, ml) + } + } + if len(binlogs) != 0 { ctx := stream.Context() md, _ := metadata.FromIncomingContext(ctx) logEntry := &binarylog.ClientHeader{ @@ -1238,7 +1254,9 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport. if peer, ok := peer.FromContext(ctx); ok { logEntry.PeerAddr = peer.Addr } - binlog.Log(logEntry) + for _, binlog := range binlogs { + binlog.Log(logEntry) + } } // comp and cp are used for compression. decomp and dc are used for @@ -1278,7 +1296,7 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport. } var payInfo *payloadInfo - if len(shs) != 0 || binlog != nil { + if len(shs) != 0 || len(binlogs) != 0 { payInfo = &payloadInfo{} } d, err := recvAndDecompress(&parser{r: stream}, stream, dc, s.opts.maxReceiveMessageSize, payInfo, decomp) @@ -1304,10 +1322,13 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport. Length: len(d), }) } - if binlog != nil { - binlog.Log(&binarylog.ClientMessage{ + if len(binlogs) != 0 { + cm := &binarylog.ClientMessage{ Message: d, - }) + } + for _, binlog := range binlogs { + binlog.Log(cm) + } } if trInfo != nil { trInfo.tr.LazyLog(&payload{sent: false, msg: v}, true) @@ -1331,18 +1352,24 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport. if e := t.WriteStatus(stream, appStatus); e != nil { channelz.Warningf(logger, s.channelzID, "grpc: Server.processUnaryRPC failed to write status: %v", e) } - if binlog != nil { + if len(binlogs) != 0 { if h, _ := stream.Header(); h.Len() > 0 { // Only log serverHeader if there was header. Otherwise it can // be trailer only. - binlog.Log(&binarylog.ServerHeader{ + sh := &binarylog.ServerHeader{ Header: h, - }) + } + for _, binlog := range binlogs { + binlog.Log(sh) + } } - binlog.Log(&binarylog.ServerTrailer{ + st := &binarylog.ServerTrailer{ Trailer: stream.Trailer(), Err: appErr, - }) + } + for _, binlog := range binlogs { + binlog.Log(st) + } } return appErr } @@ -1368,26 +1395,34 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport. panic(fmt.Sprintf("grpc: Unexpected error (%T) from sendResponse: %v", st, st)) } } - if binlog != nil { + if len(binlogs) != 0 { h, _ := stream.Header() - binlog.Log(&binarylog.ServerHeader{ + sh := &binarylog.ServerHeader{ Header: h, - }) - binlog.Log(&binarylog.ServerTrailer{ + } + st := &binarylog.ServerTrailer{ Trailer: stream.Trailer(), Err: appErr, - }) + } + for _, binlog := range binlogs { + binlog.Log(sh) + binlog.Log(st) + } } return err } - if binlog != nil { + if len(binlogs) != 0 { h, _ := stream.Header() - binlog.Log(&binarylog.ServerHeader{ + sh := &binarylog.ServerHeader{ Header: h, - }) - binlog.Log(&binarylog.ServerMessage{ + } + sm := &binarylog.ServerMessage{ Message: reply, - }) + } + for _, binlog := range binlogs { + binlog.Log(sh) + binlog.Log(sm) + } } if channelz.IsOn() { t.IncrMsgSent() @@ -1399,11 +1434,14 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport. // Should the logging be in WriteStatus? Should we ignore the WriteStatus // error or allow the stats handler to see it? err = t.WriteStatus(stream, statusOK) - if binlog != nil { - binlog.Log(&binarylog.ServerTrailer{ + if len(binlogs) != 0 { + st := &binarylog.ServerTrailer{ Trailer: stream.Trailer(), Err: appErr, - }) + } + for _, binlog := range binlogs { + binlog.Log(st) + } } return err } @@ -1516,8 +1554,15 @@ func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transp }() } - ss.binlog = binarylog.GetMethodLogger(stream.Method()) - if ss.binlog != nil { + if ml := binarylog.GetMethodLogger(stream.Method()); ml != nil { + ss.binlogs = append(ss.binlogs, ml) + } + if s.opts.binaryLogger != nil { + if ml := s.opts.binaryLogger.GetMethodLogger(stream.Method()); ml != nil { + ss.binlogs = append(ss.binlogs, ml) + } + } + if len(ss.binlogs) != 0 { md, _ := metadata.FromIncomingContext(ctx) logEntry := &binarylog.ClientHeader{ Header: md, @@ -1536,7 +1581,9 @@ func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transp if peer, ok := peer.FromContext(ss.Context()); ok { logEntry.PeerAddr = peer.Addr } - ss.binlog.Log(logEntry) + for _, binlog := range ss.binlogs { + binlog.Log(logEntry) + } } // If dc is set and matches the stream's compression, use it. Otherwise, try @@ -1602,11 +1649,14 @@ func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transp ss.mu.Unlock() } t.WriteStatus(ss.s, appStatus) - if ss.binlog != nil { - ss.binlog.Log(&binarylog.ServerTrailer{ + if len(ss.binlogs) != 0 { + st := &binarylog.ServerTrailer{ Trailer: ss.s.Trailer(), Err: appErr, - }) + } + for _, binlog := range ss.binlogs { + binlog.Log(st) + } } // TODO: Should we log an error from WriteStatus here and below? return appErr @@ -1617,11 +1667,14 @@ func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transp ss.mu.Unlock() } err = t.WriteStatus(ss.s, statusOK) - if ss.binlog != nil { - ss.binlog.Log(&binarylog.ServerTrailer{ + if len(ss.binlogs) != 0 { + st := &binarylog.ServerTrailer{ Trailer: ss.s.Trailer(), Err: appErr, - }) + } + for _, binlog := range ss.binlogs { + binlog.Log(st) + } } return err } diff --git a/stream.go b/stream.go index 446a91e323ee..f18f264d09d5 100644 --- a/stream.go +++ b/stream.go @@ -301,7 +301,14 @@ func newClientStreamWithParams(ctx context.Context, desc *StreamDesc, cc *Client if !cc.dopts.disableRetry { cs.retryThrottler = cc.retryThrottler.Load().(*retryThrottler) } - cs.binlog = binarylog.GetMethodLogger(method) + if ml := binarylog.GetMethodLogger(method); ml != nil { + cs.binlogs = append(cs.binlogs, ml) + } + if cc.dopts.binaryLogger != nil { + if ml := cc.dopts.binaryLogger.GetMethodLogger(method); ml != nil { + cs.binlogs = append(cs.binlogs, ml) + } + } // Pick the transport to use and create a new stream on the transport. // Assign cs.attempt upon success. @@ -322,7 +329,7 @@ func newClientStreamWithParams(ctx context.Context, desc *StreamDesc, cc *Client return nil, err } - if cs.binlog != nil { + if len(cs.binlogs) != 0 { md, _ := metadata.FromOutgoingContext(ctx) logEntry := &binarylog.ClientHeader{ OnClientSide: true, @@ -336,7 +343,9 @@ func newClientStreamWithParams(ctx context.Context, desc *StreamDesc, cc *Client logEntry.Timeout = 0 } } - cs.binlog.Log(logEntry) + for _, binlog := range cs.binlogs { + binlog.Log(logEntry) + } } if desc != unaryStreamDesc { @@ -480,7 +489,7 @@ type clientStream struct { retryThrottler *retryThrottler // The throttler active when the RPC began. - binlog binarylog.MethodLogger // Binary logger, can be nil. + binlogs []binarylog.MethodLogger // serverHeaderBinlogged is a boolean for whether server header has been // logged. Server header will be logged when the first time one of those // happens: stream.Header(), stream.Recv(). @@ -744,7 +753,7 @@ func (cs *clientStream) Header() (metadata.MD, error) { cs.finish(err) return nil, err } - if cs.binlog != nil && !cs.serverHeaderBinlogged { + if len(cs.binlogs) != 0 && !cs.serverHeaderBinlogged { // Only log if binary log is on and header has not been logged. logEntry := &binarylog.ServerHeader{ OnClientSide: true, @@ -754,8 +763,10 @@ func (cs *clientStream) Header() (metadata.MD, error) { if peer, ok := peer.FromContext(cs.Context()); ok { logEntry.PeerAddr = peer.Addr } - cs.binlog.Log(logEntry) cs.serverHeaderBinlogged = true + for _, binlog := range cs.binlogs { + binlog.Log(logEntry) + } } return m, nil } @@ -829,38 +840,44 @@ func (cs *clientStream) SendMsg(m interface{}) (err error) { return a.sendMsg(m, hdr, payload, data) } err = cs.withRetry(op, func() { cs.bufferForRetryLocked(len(hdr)+len(payload), op) }) - if cs.binlog != nil && err == nil { - cs.binlog.Log(&binarylog.ClientMessage{ + if len(cs.binlogs) != 0 && err == nil { + cm := &binarylog.ClientMessage{ OnClientSide: true, Message: data, - }) + } + for _, binlog := range cs.binlogs { + binlog.Log(cm) + } } return err } func (cs *clientStream) RecvMsg(m interface{}) error { - if cs.binlog != nil && !cs.serverHeaderBinlogged { + if len(cs.binlogs) != 0 && !cs.serverHeaderBinlogged { // Call Header() to binary log header if it's not already logged. cs.Header() } var recvInfo *payloadInfo - if cs.binlog != nil { + if len(cs.binlogs) != 0 { recvInfo = &payloadInfo{} } err := cs.withRetry(func(a *csAttempt) error { return a.recvMsg(m, recvInfo) }, cs.commitAttemptLocked) - if cs.binlog != nil && err == nil { - cs.binlog.Log(&binarylog.ServerMessage{ + if len(cs.binlogs) != 0 && err == nil { + sm := &binarylog.ServerMessage{ OnClientSide: true, Message: recvInfo.uncompressedBytes, - }) + } + for _, binlog := range cs.binlogs { + binlog.Log(sm) + } } if err != nil || !cs.desc.ServerStreams { // err != nil or non-server-streaming indicates end of stream. cs.finish(err) - if cs.binlog != nil { + if len(cs.binlogs) != 0 { // finish will not log Trailer. Log Trailer here. logEntry := &binarylog.ServerTrailer{ OnClientSide: true, @@ -873,7 +890,9 @@ func (cs *clientStream) RecvMsg(m interface{}) error { if peer, ok := peer.FromContext(cs.Context()); ok { logEntry.PeerAddr = peer.Addr } - cs.binlog.Log(logEntry) + for _, binlog := range cs.binlogs { + binlog.Log(logEntry) + } } } return err @@ -894,10 +913,13 @@ func (cs *clientStream) CloseSend() error { return nil } cs.withRetry(op, func() { cs.bufferForRetryLocked(0, op) }) - if cs.binlog != nil { - cs.binlog.Log(&binarylog.ClientHalfClose{ + if len(cs.binlogs) != 0 { + chc := &binarylog.ClientHalfClose{ OnClientSide: true, - }) + } + for _, binlog := range cs.binlogs { + binlog.Log(chc) + } } // We never returned an error here for reasons. return nil @@ -930,10 +952,13 @@ func (cs *clientStream) finish(err error) { // // Only one of cancel or trailer needs to be logged. In the cases where // users don't call RecvMsg, users must have already canceled the RPC. - if cs.binlog != nil && status.Code(err) == codes.Canceled { - cs.binlog.Log(&binarylog.Cancel{ + if len(cs.binlogs) != 0 && status.Code(err) == codes.Canceled { + c := &binarylog.Cancel{ OnClientSide: true, - }) + } + for _, binlog := range cs.binlogs { + binlog.Log(c) + } } if err == nil { cs.retryThrottler.successfulRPC() @@ -1005,6 +1030,7 @@ func (a *csAttempt) recvMsg(m interface{}, payInfo *payloadInfo) (err error) { } return io.EOF // indicates successful end of stream. } + return toRPCErr(err) } if a.trInfo != nil { @@ -1453,7 +1479,7 @@ type serverStream struct { statsHandler []stats.Handler - binlog binarylog.MethodLogger + binlogs []binarylog.MethodLogger // serverHeaderBinlogged indicates whether server header has been logged. It // will happen when one of the following two happens: stream.SendHeader(), // stream.Send(). @@ -1487,12 +1513,25 @@ func (ss *serverStream) SendHeader(md metadata.MD) error { } err = ss.t.WriteHeader(ss.s, md) - if ss.binlog != nil && !ss.serverHeaderBinlogged { + if len(ss.binlogs) != 0 && !ss.serverHeaderBinlogged { h, _ := ss.s.Header() - ss.binlog.Log(&binarylog.ServerHeader{ + sh := &binarylog.ServerHeader{ Header: h, - }) + } ss.serverHeaderBinlogged = true + for _, binlog := range ss.binlogs { + binlog.Log(sh) + } + } + if len(ss.binlogs) != 0 && !ss.serverHeaderBinlogged { + h, _ := ss.s.Header() + sh := &binarylog.ServerHeader{ + Header: h, + } + ss.serverHeaderBinlogged = true + for _, binlog := range ss.binlogs { + binlog.Log(sh) + } } return err } @@ -1549,17 +1588,23 @@ func (ss *serverStream) SendMsg(m interface{}) (err error) { if err := ss.t.Write(ss.s, hdr, payload, &transport.Options{Last: false}); err != nil { return toRPCErr(err) } - if ss.binlog != nil { + if len(ss.binlogs) != 0 { if !ss.serverHeaderBinlogged { h, _ := ss.s.Header() - ss.binlog.Log(&binarylog.ServerHeader{ + sh := &binarylog.ServerHeader{ Header: h, - }) + } ss.serverHeaderBinlogged = true + for _, binlog := range ss.binlogs { + binlog.Log(sh) + } } - ss.binlog.Log(&binarylog.ServerMessage{ + sm := &binarylog.ServerMessage{ Message: data, - }) + } + for _, binlog := range ss.binlogs { + binlog.Log(sm) + } } if len(ss.statsHandler) != 0 { for _, sh := range ss.statsHandler { @@ -1598,13 +1643,14 @@ func (ss *serverStream) RecvMsg(m interface{}) (err error) { } }() var payInfo *payloadInfo - if len(ss.statsHandler) != 0 || ss.binlog != nil { + if len(ss.statsHandler) != 0 || len(ss.binlogs) != 0 { payInfo = &payloadInfo{} } if err := recv(ss.p, ss.codec, ss.s, ss.dc, m, ss.maxReceiveMessageSize, payInfo, ss.decomp); err != nil { if err == io.EOF { - if ss.binlog != nil { - ss.binlog.Log(&binarylog.ClientHalfClose{}) + chc := &binarylog.ClientHalfClose{} + for _, binlog := range ss.binlogs { + binlog.Log(chc) } return err } @@ -1625,10 +1671,13 @@ func (ss *serverStream) RecvMsg(m interface{}) (err error) { }) } } - if ss.binlog != nil { - ss.binlog.Log(&binarylog.ClientMessage{ + if len(ss.binlogs) != 0 { + cm := &binarylog.ClientMessage{ Message: payInfo.uncompressedBytes, - }) + } + for _, binlog := range ss.binlogs { + binlog.Log(cm) + } } return nil } diff --git a/test/end2end_test.go b/test/end2end_test.go index 725bcdb641eb..9c35b74fff56 100644 --- a/test/end2end_test.go +++ b/test/end2end_test.go @@ -57,6 +57,7 @@ import ( healthgrpc "google.golang.org/grpc/health/grpc_health_v1" healthpb "google.golang.org/grpc/health/grpc_health_v1" "google.golang.org/grpc/internal" + "google.golang.org/grpc/internal/binarylog" "google.golang.org/grpc/internal/channelz" "google.golang.org/grpc/internal/grpcsync" "google.golang.org/grpc/internal/grpctest" @@ -8184,3 +8185,99 @@ func (s) TestGoAwayStreamIDSmallerThanCreatedStreams(t *testing.T) { ct.writeGoAway(1, http2.ErrCodeNo, []byte{}) goAwayWritten.Fire() } + +type mockBinaryLogger struct { + mml *mockMethodLogger +} + +func newMockBinaryLogger() *mockBinaryLogger { + return &mockBinaryLogger{ + mml: &mockMethodLogger{}, + } +} + +func (mbl *mockBinaryLogger) GetMethodLogger(string) binarylog.MethodLogger { + return mbl.mml +} + +type mockMethodLogger struct { + events uint64 +} + +func (mml *mockMethodLogger) Log(binarylog.LogEntryConfig) { + atomic.AddUint64(&mml.events, 1) +} + +// TestGlobalBinaryLoggingOptions tests the binary logging options for client +// and server side. The test configures a binary logger to be plumbed into every +// created ClientConn and server. It then makes a unary RPC call, and a +// streaming RPC call. A certain amount of logging calls should happen as a +// result of the stream operations on each of these calls. +func (s) TestGlobalBinaryLoggingOptions(t *testing.T) { + csbl := newMockBinaryLogger() + ssbl := newMockBinaryLogger() + + internal.AddGlobalDialOptions.(func(opt ...grpc.DialOption))(grpc.WithBinaryLogger(csbl)) + internal.AddGlobalServerOptions.(func(opt ...grpc.ServerOption))(grpc.BinaryLogger(ssbl)) + defer func() { + internal.ClearGlobalDialOptions() + internal.ClearGlobalServerOptions() + }() + ss := &stubserver.StubServer{ + UnaryCallF: func(ctx context.Context, in *testpb.SimpleRequest) (*testpb.SimpleResponse, error) { + return &testpb.SimpleResponse{}, nil + }, + FullDuplexCallF: func(stream testpb.TestService_FullDuplexCallServer) error { + for { + _, err := stream.Recv() + if err == io.EOF { + return nil + } + } + }, + } + + // No client or server options specified, because should pick up configured + // global options. + if err := ss.Start(nil); err != nil { + t.Fatalf("Error starting endpoint server: %v", err) + } + defer ss.Stop() + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + // Make a Unary RPC. This should cause Log calls on the MethodLogger. + if _, err := ss.Client.UnaryCall(ctx, &testpb.SimpleRequest{}); err != nil { + t.Fatalf("Unexpected error from UnaryCall: %v", err) + } + if csbl.mml.events != 5 { + t.Fatalf("want 5 client side binary logging events, got %v", csbl.mml.events) + } + if ssbl.mml.events != 5 { + t.Fatalf("want 5 server side binary logging events, got %v", ssbl.mml.events) + } + + // Make a streaming RPC. This should cause Log calls on the MethodLogger. + stream, err := ss.Client.FullDuplexCall(ctx) + if err != nil { + t.Fatalf("ss.Client.FullDuplexCall failed: %f", err) + } + rpcDone := make(chan struct{}) + go func() { + for { + _, err := stream.Recv() + if err == io.EOF { + close(rpcDone) + return + } + } + }() + stream.CloseSend() + <-rpcDone + if csbl.mml.events != 9 { + t.Fatalf("want 9 client side binary logging events, got %v", csbl.mml.events) + } + if ssbl.mml.events != 8 { + t.Fatalf("want 8 server side binary logging events, got %v", ssbl.mml.events) + } +} From eb4db59e24951bffe8e85ef0acf45e8f37042b8e Mon Sep 17 00:00:00 2001 From: Zach Reyes Date: Wed, 5 Oct 2022 00:27:35 -0400 Subject: [PATCH 2/2] Responded to Doug's comments --- dialoptions.go | 5 +++-- internal/internal.go | 7 +++++++ server.go | 5 +++-- stream.go | 18 +++++------------- test/end2end_test.go | 20 +++++++------------- 5 files changed, 25 insertions(+), 30 deletions(-) diff --git a/dialoptions.go b/dialoptions.go index 0dab38d4d9b8..9372dc322e80 100644 --- a/dialoptions.go +++ b/dialoptions.go @@ -43,6 +43,7 @@ func init() { internal.ClearGlobalDialOptions = func() { extraDialOptions = nil } + internal.WithBinaryLogger = withBinaryLogger } // dialOptions configure a Dial call. dialOptions are set by the DialOption @@ -403,9 +404,9 @@ func WithStatsHandler(h stats.Handler) DialOption { }) } -// WithBinaryLogger returns a DialOption that specifies the binary logger for +// withBinaryLogger returns a DialOption that specifies the binary logger for // this ClientConn. -func WithBinaryLogger(bl binarylog.Logger) DialOption { +func withBinaryLogger(bl binarylog.Logger) DialOption { return newFuncDialOption(func(o *dialOptions) { o.binaryLogger = bl }) diff --git a/internal/internal.go b/internal/internal.go index 2a24b00f7fc9..891e3444046a 100644 --- a/internal/internal.go +++ b/internal/internal.go @@ -81,6 +81,13 @@ var ( // single server option. JoinServerOptions interface{} // func(...grpc.ServerOption) grpc.ServerOption + // WithBinaryLogger returns a DialOption that specifies the binary logger + // for a ClientConn. + WithBinaryLogger interface{} // func(binarylog.Logger) grpc.DialOption + // BinaryLogger returns a ServerOption that can set the binary logger for a + // server. + BinaryLogger interface{} // func(binarylog.Logger) grpc.ServerOption + // NewXDSResolverWithConfigForTesting creates a new xds resolver builder using // the provided xds bootstrap config instead of the global configuration from // the supported environment variables. The resolver.Builder is meant to be diff --git a/server.go b/server.go index 4b9cf89c0439..e8143492c25e 100644 --- a/server.go +++ b/server.go @@ -79,6 +79,7 @@ func init() { internal.ClearGlobalServerOptions = func() { extraServerOptions = nil } + internal.BinaryLogger = binaryLogger internal.JoinServerOptions = newJoinServerOption } @@ -470,9 +471,9 @@ func StatsHandler(h stats.Handler) ServerOption { }) } -// BinaryLogger returns a ServerOption that can set the binary logger for the +// binaryLogger returns a ServerOption that can set the binary logger for the // server. -func BinaryLogger(bl binarylog.Logger) ServerOption { +func binaryLogger(bl binarylog.Logger) ServerOption { return newFuncServerOption(func(o *serverOptions) { o.binaryLogger = bl }) diff --git a/stream.go b/stream.go index f18f264d09d5..0c16cfb2ea80 100644 --- a/stream.go +++ b/stream.go @@ -1523,16 +1523,6 @@ func (ss *serverStream) SendHeader(md metadata.MD) error { binlog.Log(sh) } } - if len(ss.binlogs) != 0 && !ss.serverHeaderBinlogged { - h, _ := ss.s.Header() - sh := &binarylog.ServerHeader{ - Header: h, - } - ss.serverHeaderBinlogged = true - for _, binlog := range ss.binlogs { - binlog.Log(sh) - } - } return err } @@ -1648,9 +1638,11 @@ func (ss *serverStream) RecvMsg(m interface{}) (err error) { } if err := recv(ss.p, ss.codec, ss.s, ss.dc, m, ss.maxReceiveMessageSize, payInfo, ss.decomp); err != nil { if err == io.EOF { - chc := &binarylog.ClientHalfClose{} - for _, binlog := range ss.binlogs { - binlog.Log(chc) + if len(ss.binlogs) != 0 { + chc := &binarylog.ClientHalfClose{} + for _, binlog := range ss.binlogs { + binlog.Log(chc) + } } return err } diff --git a/test/end2end_test.go b/test/end2end_test.go index 9c35b74fff56..ecf5b5e303ed 100644 --- a/test/end2end_test.go +++ b/test/end2end_test.go @@ -8217,8 +8217,8 @@ func (s) TestGlobalBinaryLoggingOptions(t *testing.T) { csbl := newMockBinaryLogger() ssbl := newMockBinaryLogger() - internal.AddGlobalDialOptions.(func(opt ...grpc.DialOption))(grpc.WithBinaryLogger(csbl)) - internal.AddGlobalServerOptions.(func(opt ...grpc.ServerOption))(grpc.BinaryLogger(ssbl)) + internal.AddGlobalDialOptions.(func(opt ...grpc.DialOption))(internal.WithBinaryLogger.(func(bl binarylog.Logger) grpc.DialOption)(csbl)) + internal.AddGlobalServerOptions.(func(opt ...grpc.ServerOption))(internal.BinaryLogger.(func(bl binarylog.Logger) grpc.ServerOption)(ssbl)) defer func() { internal.ClearGlobalDialOptions() internal.ClearGlobalServerOptions() @@ -8262,18 +8262,12 @@ func (s) TestGlobalBinaryLoggingOptions(t *testing.T) { if err != nil { t.Fatalf("ss.Client.FullDuplexCall failed: %f", err) } - rpcDone := make(chan struct{}) - go func() { - for { - _, err := stream.Recv() - if err == io.EOF { - close(rpcDone) - return - } - } - }() + stream.CloseSend() - <-rpcDone + if _, err = stream.Recv(); err != io.EOF { + t.Fatalf("unexpected error: %v, expected an EOF error", err) + } + if csbl.mml.events != 9 { t.Fatalf("want 9 client side binary logging events, got %v", csbl.mml.events) }