diff --git a/serv/auth/auth.go b/serv/auth/auth.go index bfdf1f6c..02c63586 100644 --- a/serv/auth/auth.go +++ b/serv/auth/auth.go @@ -48,8 +48,6 @@ import ( "github.com/dosco/graphjin/v2/serv/auth/provider" ) -var ErrNoAuthDefined = errors.New("no auth defined") - type JWTConfig = provider.JWTConfig // Auth struct contains authentication related config values used by the GraphJin service @@ -159,8 +157,9 @@ func NewAuthHandlerFunc(ac Auth) (HandlerFunc, error) { // case "magiclink": // h, err = MagicLinkHandler(ac, next) + case "", "none": - return nil, ErrNoAuthDefined + h, err = NoAuth() default: return nil, fmt.Errorf("auth: unknown auth type: %s", ac.Type) @@ -173,13 +172,20 @@ func NewAuthHandlerFunc(ac Auth) (HandlerFunc, error) { return h, err } +func NoAuth() (HandlerFunc, error) { + return func(w http.ResponseWriter, r *http.Request) (context.Context, error) { + return r.Context(), nil + }, nil +} + // NewAuth returns a new auth handler. It will create a HandlerFunc based on the // provided config. // // Optionally an existing HandlerFunc can be provided. This is required to // support auth in WS subscriptions. func NewAuth(ac Auth, log *zap.Logger, opt Options, hFn ...HandlerFunc) ( - func(next http.Handler) http.Handler, error) { + func(next http.Handler) http.Handler, error, +) { var err error var h HandlerFunc var wsAuthSupported bool diff --git a/serv/routes.go b/serv/routes.go index 8a173ced..6a802117 100644 --- a/serv/routes.go +++ b/serv/routes.go @@ -50,7 +50,7 @@ func routesHandler(s1 *Service, mux Mux, ns *string) (http.Handler, error) { } ah, err := auth.NewAuthHandlerFunc(s.conf.Auth) - if err != nil && err != auth.ErrNoAuthDefined { + if err != nil { s.log.Fatalf("api: error initializing auth handler: %s", err) } diff --git a/serv/ws.go b/serv/ws.go index d736ba7a..ebcd0c3c 100644 --- a/serv/ws.go +++ b/serv/ws.go @@ -55,179 +55,146 @@ func init() { } } +type wsState struct { + c context.Context + conn *websocket.Conn + req wsReq + ah auth.HandlerFunc + exit bool + done chan bool + + w http.ResponseWriter + r *http.Request +} + func (s *service) apiV1Ws(w http.ResponseWriter, r *http.Request, ah auth.HandlerFunc) { var m *core.Member - var ready bool var err error - ct := r.Context() - c, err := upgrader.Upgrade(w, r, nil) + conn, err := upgrader.Upgrade(w, r, nil) if err != nil { renderErr(w, err) return } - defer c.Close() - c.SetReadLimit(2048) - - var v wsReq + defer conn.Close() + conn.SetReadLimit(2048) + + st := wsState{ + c: r.Context(), + done: make(chan bool), + conn: conn, + ah: ah, + w: w, + r: r, + } - done := make(chan bool) for { var b []byte - if _, b, err = c.ReadMessage(); err != nil { + if _, b, err = conn.ReadMessage(); err != nil { break } - if err = json.Unmarshal(b, &v); err != nil { + if err = json.Unmarshal(b, &st.req); err != nil { break } - if ready { - if v.Type != "connection_terminate" && - v.Type != "stop" && - v.Type != "complete" { - err = fmt.Errorf("unknown message type: %s", v.Type) - } + if err = s.subSwitch(&st); err != nil { break } - if ct, ready, err = s.subSwitch(ct, c, v, done, ah, w, r); err != nil { - if err1 := sendError(ct, c, err, v.ID); err1 != nil { - err = err1 - } + if st.exit { break } } if err != nil { s.zlog.Error("Subscription", []zapcore.Field{zap.Error(err)}...) + sendError(&st, err) //nolint:errcheck } m.Unsubscribe() - done <- true + st.done <- true } -func (s *service) subSwitch( - ct context.Context, - c *websocket.Conn, - v wsReq, - done chan bool, - ah auth.HandlerFunc, - w http.ResponseWriter, - r *http.Request) (context.Context, bool, error) { +type authHeaders struct { + UserIDProvider string `json:"X-User-ID-Provider"` + UserRole string `json:"X-User-Role"` + UserID interface{} `json:"X-User-ID"` +} - switch v.Type { +func (s *service) subSwitch(st *wsState) (err error) { + switch st.req.Type { case "connection_init": - if len(v.Payload) != 0 { - var p map[string]interface{} - if err := json.Unmarshal(v.Payload, &p); err != nil { - s.zlog.Error("Websockets", []zapcore.Field{zap.Error(err)}...) - break - } - for k, v := range p { - switch v1 := v.(type) { - case string: - r.Header.Set(k, v1) - case json.Number: - r.Header.Set(k, v1.String()) - } - } + if err = setHeaders(st); err != nil { + return } - - if ah != nil { - c, err := ah(w, r) - if err != nil { - s.zlog.Error("Auth", []zapcore.Field{zap.Error(err)}...) - } - if err == auth.Err401 { - http.Error(w, "401 unauthorized", http.StatusUnauthorized) - break - } - if s.conf.Serv.AuthFailBlock && !auth.IsAuth(c) { - http.Error(w, "401 unauthorized", http.StatusUnauthorized) - break - } - if c != nil { - if v := c.Value(core.UserIDProviderKey); v != nil { - ct = context.WithValue(ct, core.UserIDProviderKey, v) - } - if v := c.Value(core.UserRoleKey); v != nil { - ct = context.WithValue(ct, core.UserRoleKey, v) - } - if v := c.Value(core.UserIDKey); v != nil { - ct = context.WithValue(ct, core.UserIDKey, v) - } - } + if st.c, err = st.ah(st.w, st.r); err != nil { + return + } + if s.conf.Serv.AuthFailBlock && !auth.IsAuth(st.c) { + err = auth.Err401 + return } - if err := c.WritePreparedMessage(initMsg); err != nil { - return ct, false, err + if err = st.conn.WritePreparedMessage(initMsg); err != nil { + return } case "start", "subscribe": var p gqlReq - if err := json.Unmarshal(v.Payload, &p); err != nil { - return ct, false, err + if err = json.Unmarshal(st.req.Payload, &p); err != nil { + return } if s.conf.Serv.Auth.Development { - type authHeaders struct { - UserIDProvider string `json:"X-User-ID-Provider"` - UserRole string `json:"X-User-Role"` - UserID interface{} `json:"X-User-ID"` - } - var x authHeaders - if err := json.Unmarshal(p.Vars, &x); err == nil { - if x.UserIDProvider != "" { - ct = context.WithValue(ct, core.UserIDProviderKey, x.UserIDProvider) - } - if x.UserRole != "" { - ct = context.WithValue(ct, core.UserRoleKey, x.UserRole) - } - if x.UserID != nil { - ct = context.WithValue(ct, core.UserIDKey, x.UserID) - } - } else { - return ct, false, err + if err = json.Unmarshal(p.Vars, &x); err != nil { + return + } + if x.UserIDProvider != "" { + st.c = context.WithValue(st.c, core.UserIDProviderKey, x.UserIDProvider) + } + if x.UserRole != "" { + st.c = context.WithValue(st.c, core.UserRoleKey, x.UserRole) + } + if x.UserID != nil { + st.c = context.WithValue(st.c, core.UserIDKey, x.UserID) } } - m, err := s.gj.Subscribe(ct, p.Query, p.Vars, nil) - if err != nil { - return ct, false, err + var m *core.Member + if m, err = s.gj.Subscribe(st.c, p.Query, p.Vars, nil); err != nil { + return } + go s.waitForData(st, m) + return - go s.waitForData(ct, done, c, m, v) - return ct, true, nil + case "complete", "connection_terminate", "stop": + st.exit = true default: - return ct, false, fmt.Errorf("unknown message type: %s", v.Type) + err = fmt.Errorf("unknown message type: %s", st.req.Type) } - - return ct, false, nil + return } -func (s *service) waitForData( - ct context.Context, done chan bool, c *websocket.Conn, - m *core.Member, req wsReq) { +func (s *service) waitForData(st *wsState, m *core.Member) { var buf bytes.Buffer var ptype string var err error - if req.Type == "subscribe" { + if st.req.Type == "subscribe" { ptype = "next" } else { ptype = "data" } enc := json.NewEncoder(&buf) - for { select { case v := <-m.Result: - m := wsRes{ID: req.ID, Type: ptype} + m := wsRes{ID: st.req.ID, Type: ptype} m.Payload.Data = v.Data m.Payload.Errors = v.Errors @@ -236,34 +203,49 @@ func (s *service) waitForData( } msg := buf.Bytes() buf.Reset() + err = st.conn.WriteMessage(websocket.TextMessage, msg) - err = c.WriteMessage(websocket.TextMessage, msg) - case v := <-done: + case v := <-st.done: if v { return } } if err != nil { - if err1 := sendError(ct, c, err, req.ID); err != nil { - err = err1 - } - s.zlog.Error("Websockets", []zapcore.Field{zap.Error(err)}...) + s.zlog.Error("Subscription", []zapcore.Field{zap.Error(err)}...) + sendError(st, err) //nolint:errcheck break } } } -func sendError(ct context.Context, c *websocket.Conn, err error, id string) error { - m := wsRes{ID: id, Type: "error"} - m.Payload.Errors = []core.Error{{Message: err.Error()}} +func setHeaders(st *wsState) (err error) { + if len(st.req.Payload) == 0 { + return + } + var p map[string]interface{} + if err = json.Unmarshal(st.req.Payload, &p); err != nil { + return + } + for k, v := range p { + switch v1 := v.(type) { + case string: + st.r.Header.Set(k, v1) + case json.Number: + st.r.Header.Set(k, v1.String()) + } + } + return +} + +func sendError(st *wsState, cerr error) (err error) { + m := wsRes{ID: st.req.ID, Type: "error"} + m.Payload.Errors = []core.Error{{Message: cerr.Error()}} msg, err := json.Marshal(m) if err != nil { - return err - } - if err := c.WriteMessage(websocket.TextMessage, msg); err != nil { - return err + return } - return nil + err = st.conn.WriteMessage(websocket.TextMessage, msg) + return } diff --git a/wasm/graphjin.wasm b/wasm/graphjin.wasm index b8a93176..0682cc1d 100755 Binary files a/wasm/graphjin.wasm and b/wasm/graphjin.wasm differ