diff --git a/serv/auth/auth_test.go b/serv/auth/auth_test.go new file mode 100644 index 00000000..8e677123 --- /dev/null +++ b/serv/auth/auth_test.go @@ -0,0 +1,39 @@ +package auth_test + +import ( + "net/http" + "testing" + + "github.com/dosco/graphjin/serv/auth" + "github.com/stretchr/testify/assert" +) + +func TestJWTTokenInAuthorizationHeader(t *testing.T) { + ah, err := auth.JwtHandler(auth.Auth{ + Cookie: "Boo", + JWT: auth.JWTConfig{ + Secret: "casper", + }, + }) + assert.NoError(t, err) + + // { + // "sub": "1234567890", + // "name": "John Doe", + // "iat": 1516239022 + // } + tok := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ.VZ01qXI7Whbuj8X3FZw0mLyZT7iMKMCDl_rtzdNpjAg" + + req, err := http.NewRequest( + http.MethodGet, + "https://test.com", + nil) + assert.NoError(t, err) + + req.Header.Set("Authorization", "Bearer "+tok) + + c, err := ah(nil, req) + assert.NoError(t, err) + + assert.Equal(t, 1234567890, auth.UserIDInt(c)) +} diff --git a/serv/auth/jwt.go b/serv/auth/jwt.go index d00fc727..07005330 100644 --- a/serv/auth/jwt.go +++ b/serv/auth/jwt.go @@ -26,28 +26,22 @@ func JwtHandler(ac Auth) (HandlerFunc, error) { var tok string if cookie != "" { - ck, err := r.Cookie(cookie) - if err == http.ErrNoCookie { - return nil, nil + if ck, err := r.Cookie(cookie); err == nil && len(ck.Value) != 0 { + tok = ck.Value } - if err != nil { - return nil, err - } - tok = ck.Value - } else { - ah := r.Header.Get(authHeader) - if len(ah) < 10 { - return nil, fmt.Errorf("invalid or missing header: %s", authHeader) + } + + if tok == "" { + if ah := r.Header.Get(authHeader); len(ah) > 10 { + tok = ah[7:] } - tok = ah[7:] } if tok == "" { - return nil, fmt.Errorf("jwt not found") + return nil, fmt.Errorf("no jwt token found in cookie or authorization header") } keyFunc := jwtProvider.KeyFunc() - token, err := jwt.ParseWithClaims(tok, jwt.MapClaims{}, keyFunc) //jwt.MapClaims is already passed by reference if err != nil { diff --git a/serv/ws.go b/serv/ws.go index bdec2540..0cb857bb 100644 --- a/serv/ws.go +++ b/serv/ws.go @@ -119,11 +119,7 @@ func (s *service) subSwitch( switch v.Type { case "connection_init": - if err := c.WritePreparedMessage(initMsg); err != nil { - return ct, false, err - } - - if len(v.Payload) > 0 { + if len(v.Payload) != 0 { var p map[string]interface{} if err := json.Unmarshal(v.Payload, &p); err != nil { s.zlog.Error("Websockets", []zapcore.Field{zap.Error(err)}...) @@ -164,6 +160,9 @@ func (s *service) subSwitch( } } } + if err := c.WritePreparedMessage(initMsg); err != nil { + return ct, false, err + } case "start", "subscribe": var p gqlReq