From 2f413c454850afdb19a685fdee95c5d7701891ab Mon Sep 17 00:00:00 2001 From: Sean Barag Date: Tue, 13 Dec 2022 11:31:23 -0800 Subject: [PATCH] transport/http2: use HTTP 400 for bad requests instead of 500 (#5804) --- internal/transport/handler_server.go | 30 ++++++++++++++++------- internal/transport/handler_server_test.go | 4 +-- server.go | 3 ++- 3 files changed, 25 insertions(+), 12 deletions(-) diff --git a/internal/transport/handler_server.go b/internal/transport/handler_server.go index 0c6ada99274d..ebe8bfe330a5 100644 --- a/internal/transport/handler_server.go +++ b/internal/transport/handler_server.go @@ -46,24 +46,32 @@ import ( "google.golang.org/grpc/status" ) -// NewServerHandlerTransport returns a ServerTransport handling gRPC -// from inside an http.Handler. It requires that the http Server -// supports HTTP/2. +// NewServerHandlerTransport returns a ServerTransport handling gRPC from +// inside an http.Handler, or writes an HTTP error to w and returns an error. +// It requires that the http Server supports HTTP/2. func NewServerHandlerTransport(w http.ResponseWriter, r *http.Request, stats []stats.Handler) (ServerTransport, error) { if r.ProtoMajor != 2 { - return nil, errors.New("gRPC requires HTTP/2") + msg := "gRPC requires HTTP/2" + http.Error(w, msg, http.StatusBadRequest) + return nil, errors.New(msg) } if r.Method != "POST" { - return nil, errors.New("invalid gRPC request method") + msg := fmt.Sprintf("invalid gRPC request method %q", r.Method) + http.Error(w, msg, http.StatusBadRequest) + return nil, errors.New(msg) } contentType := r.Header.Get("Content-Type") // TODO: do we assume contentType is lowercase? we did before contentSubtype, validContentType := grpcutil.ContentSubtype(contentType) if !validContentType { - return nil, errors.New("invalid gRPC request content-type") + msg := fmt.Sprintf("invalid gRPC request content-type %q", contentType) + http.Error(w, msg, http.StatusBadRequest) + return nil, errors.New(msg) } if _, ok := w.(http.Flusher); !ok { - return nil, errors.New("gRPC requires a ResponseWriter supporting http.Flusher") + msg := "gRPC requires a ResponseWriter supporting http.Flusher" + http.Error(w, msg, http.StatusInternalServerError) + return nil, errors.New(msg) } st := &serverHandlerTransport{ @@ -79,7 +87,9 @@ func NewServerHandlerTransport(w http.ResponseWriter, r *http.Request, stats []s if v := r.Header.Get("grpc-timeout"); v != "" { to, err := decodeTimeout(v) if err != nil { - return nil, status.Errorf(codes.Internal, "malformed time-out: %v", err) + msg := fmt.Sprintf("malformed time-out: %v", err) + http.Error(w, msg, http.StatusBadRequest) + return nil, status.Error(codes.Internal, msg) } st.timeoutSet = true st.timeout = to @@ -97,7 +107,9 @@ func NewServerHandlerTransport(w http.ResponseWriter, r *http.Request, stats []s for _, v := range vv { v, err := decodeMetadataHeader(k, v) if err != nil { - return nil, status.Errorf(codes.Internal, "malformed binary metadata: %v", err) + msg := fmt.Sprintf("malformed binary metadata %q in header %q: %v", v, k, err) + http.Error(w, msg, http.StatusBadRequest) + return nil, status.Error(codes.Internal, msg) } metakv = append(metakv, k, v) } diff --git a/internal/transport/handler_server_test.go b/internal/transport/handler_server_test.go index b08dcaaf3c4b..82b4baca58b6 100644 --- a/internal/transport/handler_server_test.go +++ b/internal/transport/handler_server_test.go @@ -63,7 +63,7 @@ func (s) TestHandlerTransport_NewServerHandlerTransport(t *testing.T) { Method: "GET", Header: http.Header{}, }, - wantErr: "invalid gRPC request method", + wantErr: `invalid gRPC request method "GET"`, }, { name: "bad content type", @@ -74,7 +74,7 @@ func (s) TestHandlerTransport_NewServerHandlerTransport(t *testing.T) { "Content-Type": {"application/foo"}, }, }, - wantErr: "invalid gRPC request content-type", + wantErr: `invalid gRPC request content-type "application/foo"`, }, { name: "not flusher", diff --git a/server.go b/server.go index 7456d6d32bc3..2808b7c83e80 100644 --- a/server.go +++ b/server.go @@ -1008,7 +1008,8 @@ var _ http.Handler = (*Server)(nil) func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { st, err := transport.NewServerHandlerTransport(w, r, s.opts.statsHandlers) if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) + // Errors returned from transport.NewServerHandlerTransport have + // already been written to w. return } if !s.addConn(listenerAddressForServeHTTP, st) {