Skip to content

Commit

Permalink
fix: code cleanup websocket handler
Browse files Browse the repository at this point in the history
  • Loading branch information
dosco committed Jan 20, 2023
1 parent e04acb1 commit e4b158f
Show file tree
Hide file tree
Showing 4 changed files with 112 additions and 124 deletions.
14 changes: 10 additions & 4 deletions serv/auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,6 @@ import (
"github.com/dosco/graphjin/v2/serv/auth/provider"
)

var ErrNoAuthDefined = errors.New("no auth defined")

type JWTConfig = provider.JWTConfig

// Auth struct contains authentication related config values used by the GraphJin service
Expand Down Expand Up @@ -159,8 +157,9 @@ func NewAuthHandlerFunc(ac Auth) (HandlerFunc, error) {

// case "magiclink":
// h, err = MagicLinkHandler(ac, next)

case "", "none":
return nil, ErrNoAuthDefined
h, err = NoAuth()

default:
return nil, fmt.Errorf("auth: unknown auth type: %s", ac.Type)
Expand All @@ -173,13 +172,20 @@ func NewAuthHandlerFunc(ac Auth) (HandlerFunc, error) {
return h, err
}

func NoAuth() (HandlerFunc, error) {
return func(w http.ResponseWriter, r *http.Request) (context.Context, error) {
return r.Context(), nil
}, nil
}

// NewAuth returns a new auth handler. It will create a HandlerFunc based on the
// provided config.
//
// Optionally an existing HandlerFunc can be provided. This is required to
// support auth in WS subscriptions.
func NewAuth(ac Auth, log *zap.Logger, opt Options, hFn ...HandlerFunc) (
func(next http.Handler) http.Handler, error) {
func(next http.Handler) http.Handler, error,
) {
var err error
var h HandlerFunc
var wsAuthSupported bool
Expand Down
2 changes: 1 addition & 1 deletion serv/routes.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ func routesHandler(s1 *Service, mux Mux, ns *string) (http.Handler, error) {
}

ah, err := auth.NewAuthHandlerFunc(s.conf.Auth)
if err != nil && err != auth.ErrNoAuthDefined {
if err != nil {
s.log.Fatalf("api: error initializing auth handler: %s", err)
}

Expand Down
220 changes: 101 additions & 119 deletions serv/ws.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,179 +55,146 @@ func init() {
}
}

type wsState struct {
c context.Context
conn *websocket.Conn
req wsReq
ah auth.HandlerFunc
exit bool
done chan bool

w http.ResponseWriter
r *http.Request
}

func (s *service) apiV1Ws(w http.ResponseWriter, r *http.Request, ah auth.HandlerFunc) {
var m *core.Member
var ready bool
var err error

ct := r.Context()
c, err := upgrader.Upgrade(w, r, nil)
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
renderErr(w, err)
return
}
defer c.Close()
c.SetReadLimit(2048)

var v wsReq
defer conn.Close()
conn.SetReadLimit(2048)

st := wsState{
c: r.Context(),
done: make(chan bool),
conn: conn,
ah: ah,
w: w,
r: r,
}

done := make(chan bool)
for {
var b []byte

if _, b, err = c.ReadMessage(); err != nil {
if _, b, err = conn.ReadMessage(); err != nil {
break
}

if err = json.Unmarshal(b, &v); err != nil {
if err = json.Unmarshal(b, &st.req); err != nil {
break
}

if ready {
if v.Type != "connection_terminate" &&
v.Type != "stop" &&
v.Type != "complete" {
err = fmt.Errorf("unknown message type: %s", v.Type)
}
if err = s.subSwitch(&st); err != nil {
break
}

if ct, ready, err = s.subSwitch(ct, c, v, done, ah, w, r); err != nil {
if err1 := sendError(ct, c, err, v.ID); err1 != nil {
err = err1
}
if st.exit {
break
}
}

if err != nil {
s.zlog.Error("Subscription", []zapcore.Field{zap.Error(err)}...)
sendError(&st, err) //nolint:errcheck
}

m.Unsubscribe()
done <- true
st.done <- true
}

func (s *service) subSwitch(
ct context.Context,
c *websocket.Conn,
v wsReq,
done chan bool,
ah auth.HandlerFunc,
w http.ResponseWriter,
r *http.Request) (context.Context, bool, error) {
type authHeaders struct {
UserIDProvider string `json:"X-User-ID-Provider"`
UserRole string `json:"X-User-Role"`
UserID interface{} `json:"X-User-ID"`
}

switch v.Type {
func (s *service) subSwitch(st *wsState) (err error) {
switch st.req.Type {
case "connection_init":
if len(v.Payload) != 0 {
var p map[string]interface{}
if err := json.Unmarshal(v.Payload, &p); err != nil {
s.zlog.Error("Websockets", []zapcore.Field{zap.Error(err)}...)
break
}
for k, v := range p {
switch v1 := v.(type) {
case string:
r.Header.Set(k, v1)
case json.Number:
r.Header.Set(k, v1.String())
}
}
if err = setHeaders(st); err != nil {
return
}

if ah != nil {
c, err := ah(w, r)
if err != nil {
s.zlog.Error("Auth", []zapcore.Field{zap.Error(err)}...)
}
if err == auth.Err401 {
http.Error(w, "401 unauthorized", http.StatusUnauthorized)
break
}
if s.conf.Serv.AuthFailBlock && !auth.IsAuth(c) {
http.Error(w, "401 unauthorized", http.StatusUnauthorized)
break
}
if c != nil {
if v := c.Value(core.UserIDProviderKey); v != nil {
ct = context.WithValue(ct, core.UserIDProviderKey, v)
}
if v := c.Value(core.UserRoleKey); v != nil {
ct = context.WithValue(ct, core.UserRoleKey, v)
}
if v := c.Value(core.UserIDKey); v != nil {
ct = context.WithValue(ct, core.UserIDKey, v)
}
}
if st.c, err = st.ah(st.w, st.r); err != nil {
return
}
if s.conf.Serv.AuthFailBlock && !auth.IsAuth(st.c) {
err = auth.Err401
return
}
if err := c.WritePreparedMessage(initMsg); err != nil {
return ct, false, err
if err = st.conn.WritePreparedMessage(initMsg); err != nil {
return
}

case "start", "subscribe":
var p gqlReq
if err := json.Unmarshal(v.Payload, &p); err != nil {
return ct, false, err
if err = json.Unmarshal(st.req.Payload, &p); err != nil {
return
}

if s.conf.Serv.Auth.Development {
type authHeaders struct {
UserIDProvider string `json:"X-User-ID-Provider"`
UserRole string `json:"X-User-Role"`
UserID interface{} `json:"X-User-ID"`
}

var x authHeaders
if err := json.Unmarshal(p.Vars, &x); err == nil {
if x.UserIDProvider != "" {
ct = context.WithValue(ct, core.UserIDProviderKey, x.UserIDProvider)
}
if x.UserRole != "" {
ct = context.WithValue(ct, core.UserRoleKey, x.UserRole)
}
if x.UserID != nil {
ct = context.WithValue(ct, core.UserIDKey, x.UserID)
}
} else {
return ct, false, err
if err = json.Unmarshal(p.Vars, &x); err != nil {
return
}
if x.UserIDProvider != "" {
st.c = context.WithValue(st.c, core.UserIDProviderKey, x.UserIDProvider)
}
if x.UserRole != "" {
st.c = context.WithValue(st.c, core.UserRoleKey, x.UserRole)
}
if x.UserID != nil {
st.c = context.WithValue(st.c, core.UserIDKey, x.UserID)
}
}

m, err := s.gj.Subscribe(ct, p.Query, p.Vars, nil)
if err != nil {
return ct, false, err
var m *core.Member
if m, err = s.gj.Subscribe(st.c, p.Query, p.Vars, nil); err != nil {
return
}
go s.waitForData(st, m)
return

go s.waitForData(ct, done, c, m, v)
return ct, true, nil
case "complete", "connection_terminate", "stop":
st.exit = true

default:
return ct, false, fmt.Errorf("unknown message type: %s", v.Type)
err = fmt.Errorf("unknown message type: %s", st.req.Type)
}

return ct, false, nil
return
}

func (s *service) waitForData(
ct context.Context, done chan bool, c *websocket.Conn,
m *core.Member, req wsReq) {
func (s *service) waitForData(st *wsState, m *core.Member) {
var buf bytes.Buffer

var ptype string
var err error

if req.Type == "subscribe" {
if st.req.Type == "subscribe" {
ptype = "next"
} else {
ptype = "data"
}

enc := json.NewEncoder(&buf)

for {
select {
case v := <-m.Result:
m := wsRes{ID: req.ID, Type: ptype}
m := wsRes{ID: st.req.ID, Type: ptype}
m.Payload.Data = v.Data
m.Payload.Errors = v.Errors

Expand All @@ -236,34 +203,49 @@ func (s *service) waitForData(
}
msg := buf.Bytes()
buf.Reset()
err = st.conn.WriteMessage(websocket.TextMessage, msg)

err = c.WriteMessage(websocket.TextMessage, msg)
case v := <-done:
case v := <-st.done:
if v {
return
}
}

if err != nil {
if err1 := sendError(ct, c, err, req.ID); err != nil {
err = err1
}
s.zlog.Error("Websockets", []zapcore.Field{zap.Error(err)}...)
s.zlog.Error("Subscription", []zapcore.Field{zap.Error(err)}...)
sendError(st, err) //nolint:errcheck
break
}
}
}

func sendError(ct context.Context, c *websocket.Conn, err error, id string) error {
m := wsRes{ID: id, Type: "error"}
m.Payload.Errors = []core.Error{{Message: err.Error()}}
func setHeaders(st *wsState) (err error) {
if len(st.req.Payload) == 0 {
return
}
var p map[string]interface{}
if err = json.Unmarshal(st.req.Payload, &p); err != nil {
return
}
for k, v := range p {
switch v1 := v.(type) {
case string:
st.r.Header.Set(k, v1)
case json.Number:
st.r.Header.Set(k, v1.String())
}
}
return
}

func sendError(st *wsState, cerr error) (err error) {
m := wsRes{ID: st.req.ID, Type: "error"}
m.Payload.Errors = []core.Error{{Message: cerr.Error()}}

msg, err := json.Marshal(m)
if err != nil {
return err
}
if err := c.WriteMessage(websocket.TextMessage, msg); err != nil {
return err
return
}
return nil
err = st.conn.WriteMessage(websocket.TextMessage, msg)
return
}
Binary file modified wasm/graphjin.wasm
Binary file not shown.

1 comment on commit e4b158f

@vercel
Copy link

@vercel vercel bot commented on e4b158f Jan 20, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please sign in to comment.