Skip to content

Commit

Permalink
feat: Add websockets to update ui on change
Browse files Browse the repository at this point in the history
  • Loading branch information
NoUseFreak committed Jan 3, 2023
1 parent fd21373 commit 91603a9
Show file tree
Hide file tree
Showing 12 changed files with 274 additions and 83 deletions.
1 change: 1 addition & 0 deletions .env
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
OIDC_ISSUER_URL=https://auth.dev.stenic.io/auth/realms/dev
OIDC_CLIENT_ID=ledger
OIDC_AUDIENCE=account
3 changes: 3 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ require (
github.com/go-openapi/jsonpointer v0.19.5 // indirect
github.com/go-openapi/jsonreference v0.20.0 // indirect
github.com/go-openapi/swag v0.19.14 // indirect
github.com/gobwas/httphead v0.1.0 // indirect
github.com/gobwas/pool v0.2.1 // indirect
github.com/gobwas/ws v1.1.0 // indirect
github.com/gogo/protobuf v1.3.2 // indirect
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect
github.com/golang/protobuf v1.5.2 // indirect
Expand Down
7 changes: 7 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -535,6 +535,12 @@ github.com/gobuffalo/packd v0.1.0/go.mod h1:M2Juc+hhDXf/PnmBANFCqx4DM3wRbgDvnVWe
github.com/gobuffalo/packr/v2 v2.0.9/go.mod h1:emmyGweYTm6Kdper+iywB6YK5YzuKchGtJQZ0Odn4pQ=
github.com/gobuffalo/packr/v2 v2.2.0/go.mod h1:CaAwI0GPIAv+5wKLtv8Afwl+Cm78K/I/VCm/3ptBN+0=
github.com/gobuffalo/syncx v0.0.0-20190224160051-33c29581e754/go.mod h1:HhnNqWY95UYwwW3uSASeV7vtgYkT2t16hJgV3AEPUpw=
github.com/gobwas/httphead v0.1.0 h1:exrUm0f4YX0L7EBwZHuCF4GDp8aJfVeBrlLQrs6NqWU=
github.com/gobwas/httphead v0.1.0/go.mod h1:O/RXo79gxV8G+RqlR/otEwx4Q36zl9rqC5u12GKvMCM=
github.com/gobwas/pool v0.2.1 h1:xfeeEhW7pwmX8nuLVlqbzVc7udMDrwetjEv+TZIz1og=
github.com/gobwas/pool v0.2.1/go.mod h1:q8bcK0KcYlCgd9e7WYLm9LpyS+YeLd8JVDW6WezmKEw=
github.com/gobwas/ws v1.1.0 h1:7RFti/xnNkMJnrK7D1yQ/iCIB5OrrY/54/H930kIbHA=
github.com/gobwas/ws v1.1.0/go.mod h1:nzvNcVha5eUziGrbxFCo6qFIojQHjJV5cLYIbezhfL0=
github.com/goccy/go-json v0.9.7/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
github.com/goccy/go-json v0.9.11 h1:/pAaQDLHEoCq/5FFmSKBswWmK6H0e8g4159Kc/X/nqk=
github.com/goccy/go-json v0.9.11/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
Expand Down Expand Up @@ -1548,6 +1554,7 @@ golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7w
golang.org/x/sys v0.0.0-20201126233918-771906719818/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201201145000-ef89a241ccb3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201202213521-69691e467435/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201207223542-d4d67f95c62d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210104204734-6f8348627aad/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210112080510-489259a85091/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210119212857-b64e53b001e4/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
Expand Down
57 changes: 0 additions & 57 deletions internal/auth/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,9 @@ package auth
import (
"context"
"net/http"
"os"
"strings"

"github.com/gin-gonic/gin"
"github.com/sirupsen/logrus"

jwtmiddleware "github.com/auth0/go-jwt-middleware/v2"
"github.com/auth0/go-jwt-middleware/v2/validator"
)

type localJwtClaims struct{}
Expand All @@ -30,32 +25,6 @@ func (c *CustomClaims) Validate(ctx context.Context) error {
return nil
}

func JwtHandler(opts ApiSecurityOptions) gin.HandlerFunc {

var jwtOidcMiddleware *jwtmiddleware.JWTMiddleware
if opts.IssuerURL != "" {
jwtOidcMiddleware = getOidcValidator(opts.IssuerURL, opts.Audience)
}

return func(c *gin.Context) {
var handler http.HandlerFunc = func(w http.ResponseWriter, r *http.Request) {
c.Request = r
c.Next()
}

if jwtOidcMiddleware != nil {
logrus.Debug("Checking oidc token")
jwtOidcMiddleware.CheckJWT(handler).ServeHTTP(c.Writer, c.Request)
}

// Continue with local auth if OIDC is not detected
if c.Request.Context().Value(jwtmiddleware.ContextKey{}) == nil {
logrus.Debug("Checking local token")
checkLocalJWT(handler).ServeHTTP(c.Writer, c.Request)
}
}
}

func checkLocalJWT(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
auth := r.Header.Get("Authorization")
Expand All @@ -75,29 +44,3 @@ func checkLocalJWT(next http.Handler) http.Handler {
type UserInfo struct {
Username string
}

func TokenFromContext(ctx context.Context) *UserInfo {
if os.Getenv("AUTH_INSECURE") == "yes" {
logrus.Error("Skipping auth, hope you are in dev mode.")
return &UserInfo{
Username: "Insecure user",
}
}
if raw, valid := ctx.Value(localJwtClaims{}).(*Claims); valid {
return &UserInfo{
Username: raw.Username,
}
}

if raw, valid := ctx.Value(jwtmiddleware.ContextKey{}).(*validator.ValidatedClaims); valid {
cc, ok := raw.CustomClaims.(*CustomClaims)
if !ok {
return nil
}
return &UserInfo{
Username: cc.PreferredUsername,
}
}

return nil
}
43 changes: 19 additions & 24 deletions internal/auth/oidc.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,24 @@ import (
)

func getOidcValidator(issuerURLString string, audience []string) *jwtmiddleware.JWTMiddleware {
jwtValidator := getOidcValidatorFunc(issuerURLString, audience)

errorHandler := func(w http.ResponseWriter, r *http.Request, err error) {
if errors.Is(err, jwtmiddleware.ErrJWTInvalid) {
log.Debug(err)
} else {
jwtmiddleware.DefaultErrorHandler(w, r, err)
}
}

return jwtmiddleware.New(
jwtValidator.ValidateToken,
jwtmiddleware.WithErrorHandler(errorHandler),
jwtmiddleware.WithCredentialsOptional(true),
)
}

func getOidcValidatorFunc(issuerURLString string, audience []string) *validator.Validator {
issuerURL, err := url.Parse(issuerURLString)
if err != nil {
log.Fatal(err)
Expand All @@ -21,15 +39,6 @@ func getOidcValidator(issuerURLString string, audience []string) *jwtmiddleware.
provider := jwks.NewCachingProvider(
issuerURL,
time.Duration(5*time.Minute),
// func(p *jwks.Provider) {
// if false {
// return
// }
// tr := &http.Transport{
// TLSClientConfig: &tls.Config{InsecureSkipVerify: false},
// }
// p.Client.Transport = tr
// },
)

customClaims := func() validator.CustomClaims {
Expand All @@ -41,22 +50,8 @@ func getOidcValidator(issuerURLString string, audience []string) *jwtmiddleware.
validator.RS256,
issuerURL.String(),
audience,
// validator.WithAllowedClockSkew(30*time.Second),

validator.WithCustomClaims(customClaims),
)

errorHandler := func(w http.ResponseWriter, r *http.Request, err error) {
if errors.Is(err, jwtmiddleware.ErrJWTInvalid) {
log.Debug(err)
} else {
jwtmiddleware.DefaultErrorHandler(w, r, err)
}
}

return jwtmiddleware.New(
jwtValidator.ValidateToken,
jwtmiddleware.WithErrorHandler(errorHandler),
jwtmiddleware.WithCredentialsOptional(true),
)
return jwtValidator
}
107 changes: 107 additions & 0 deletions internal/auth/validate.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
package auth

import (
"context"
"fmt"
"net/http"
"os"
"strings"

jwtmiddleware "github.com/auth0/go-jwt-middleware/v2"
"github.com/auth0/go-jwt-middleware/v2/validator"
"github.com/gin-gonic/gin"
"github.com/sirupsen/logrus"
)

type LedgerValidator interface {
ValidateToken(authString string) error
GetJWTMiddleware() gin.HandlerFunc
}

type ledgerValidator struct {
IssuerURL string
Audience []string
}

func New(cfg ApiSecurityOptions) LedgerValidator {
l := ledgerValidator{
IssuerURL: cfg.IssuerURL,
Audience: cfg.Audience,
}

return l
}

func (l ledgerValidator) ValidateToken(authString string) error {
if len(authString) < 8 && strings.ToLower(authString[:6]) != "bearer" {
return fmt.Errorf("not a bearer auth string")
}

// Validate local token
logrus.Trace("Checking Local")
_, err := ValidateToken(authString[7:])
if err == nil {
return nil
}

// OIDC validation
jwtOIDCValidator := getOidcValidatorFunc(l.IssuerURL, l.Audience)
if jwtOIDCValidator == nil {
return fmt.Errorf("token is invalid")
}

logrus.Trace("Checking OIDC")
_, err = jwtOIDCValidator.ValidateToken(context.Background(), authString[7:])
return err
}

func (l ledgerValidator) GetJWTMiddleware() gin.HandlerFunc {
var jwtOidcMiddleware *jwtmiddleware.JWTMiddleware
if l.IssuerURL != "" {
jwtOidcMiddleware = getOidcValidator(l.IssuerURL, l.Audience)
}

return func(c *gin.Context) {
var handler http.HandlerFunc = func(w http.ResponseWriter, r *http.Request) {
c.Request = r
c.Next()
}

if jwtOidcMiddleware != nil {
logrus.Debug("Checking oidc token")
jwtOidcMiddleware.CheckJWT(handler).ServeHTTP(c.Writer, c.Request)
}

// Continue with local auth if OIDC is not detected
if c.Request.Context().Value(jwtmiddleware.ContextKey{}) == nil {
logrus.Debug("Checking local token")
checkLocalJWT(handler).ServeHTTP(c.Writer, c.Request)
}
}
}

func TokenFromContext(ctx context.Context) *UserInfo {
if os.Getenv("AUTH_INSECURE") == "yes" {
logrus.Error("Skipping auth, hope you are in dev mode.")
return &UserInfo{
Username: "Insecure user",
}
}
if raw, valid := ctx.Value(localJwtClaims{}).(*Claims); valid {
return &UserInfo{
Username: raw.Username,
}
}

if raw, valid := ctx.Value(jwtmiddleware.ContextKey{}).(*validator.ValidatedClaims); valid {
cc, ok := raw.CustomClaims.(*CustomClaims)
if !ok {
return nil
}
return &UserInfo{
Username: cc.PreferredUsername,
}
}

return nil
}
7 changes: 6 additions & 1 deletion internal/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,12 @@ func (s *Server) Listen(addr string) error {
if os.Getenv("DEBUG") == "true" {
s.server.Any("/playground", gin.WrapF(playground.Handler("GraphQL playground", "/query")))
}

ledgerAuth := auth.New(oidcOptions)

s.server.Any(
"/query",
auth.JwtHandler(oidcOptions),
ledgerAuth.GetJWTMiddleware(),
gin.WrapH(handler.NewDefaultServer(generated.NewExecutableSchema(generated.Config{Resolvers: &graph.Resolver{}}))),
)

Expand All @@ -82,6 +85,8 @@ func (s *Server) Listen(addr string) error {
c.File(bin)
})

s.server.GET("/socket", wsHandler(ledgerAuth))

logrus.Info("Starting webserver")
return s.server.Run(addr)
}
70 changes: 70 additions & 0 deletions internal/server/ws.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
package server

import (
"time"

"github.com/gin-gonic/gin"
"github.com/gobwas/ws"
"github.com/gobwas/ws/wsutil"
"github.com/sirupsen/logrus"
"github.com/stenic/ledger/internal/auth"
"github.com/stenic/ledger/internal/pkg/versions"
)

func wsHandler(authValidator auth.LedgerValidator) gin.HandlerFunc {
logger := logrus.WithFields(logrus.Fields{
"scope": "websockets",
})

var count = *versions.CountTotal()

ticker := time.NewTicker(5 * time.Second)
go func() {
for {
select {
case <-ticker.C:
count = *versions.CountTotal()
logger.WithField("count", count).Trace("Refreshed version count")
}
}
}()

return func(c *gin.Context) {
conn, _, _, err := ws.UpgradeHTTP(c.Request, c.Writer)
if err != nil {
logger.Warn(err)
}
logger.Debug("Client connected")

go func() {
defer conn.Close()
var lastSend = count

msg, _, err := wsutil.ReadClientData(conn)
if err != nil {
logger.Error(err)
}
if err := authValidator.ValidateToken(string(msg)); err != nil {
logger.Error(err)
return
}

for {
if lastSend != count {
logger.Debug("Sending refreshVersions")
err = wsutil.WriteServerMessage(conn, ws.OpText, []byte("refreshVersions"))
if err != nil {
if _, ok := err.(wsutil.ClosedError); ok {
logger.Debug("Client disconnected")
} else {
logger.Error(err)
}
return
}
lastSend = count
}
time.Sleep(1 * time.Second)
}
}()
}
}
Loading

0 comments on commit 91603a9

Please sign in to comment.