Skip to content

Commit

Permalink
chore: port to new graphql-go-tools Context API (#912)
Browse files Browse the repository at this point in the history
Fixes several race conditions too
  • Loading branch information
fiam authored May 5, 2023
1 parent 71302da commit 18c373a
Show file tree
Hide file tree
Showing 6 changed files with 37 additions and 33 deletions.
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ require (
github.com/tidwall/gjson v1.11.0
github.com/tidwall/sjson v1.1.5
github.com/valyala/fasthttp v1.44.0
github.com/wundergraph/graphql-go-tools v1.62.3
github.com/wundergraph/graphql-go-tools v1.63.0
go.uber.org/zap v1.24.0
golang.org/x/net v0.7.0
golang.org/x/oauth2 v0.0.0-20211104180415-d3ed0bb246c8
Expand Down
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -512,8 +512,8 @@ github.com/vmihailenco/msgpack/v5 v5.1.0 h1:+od5YbEXxW95SPlW6beocmt8nOtlh83zqat5
github.com/vmihailenco/msgpack/v5 v5.1.0/go.mod h1:C5gboKD0TJPqWDTVTtrQNfRbiBwHZGo8UTqP/9/XvLI=
github.com/vmihailenco/tagparser v0.1.2 h1:gnjoVuB/kljJ5wICEEOpx98oXMWPLj22G67Vbd1qPqc=
github.com/vmihailenco/tagparser v0.1.2/go.mod h1:OeAg3pn3UbLjkWt+rN9oFYB6u/cQgqMEUPoW2WPyhdI=
github.com/wundergraph/graphql-go-tools v1.62.3 h1:WKpkhKqWoTq/AE0AbrZ+66ezUe3l/wAqZZurZeL5qqQ=
github.com/wundergraph/graphql-go-tools v1.62.3/go.mod h1:Lsg/b4nVfNQLyJE1mjPV73O/JuhhCxH5qmaWQjitVHM=
github.com/wundergraph/graphql-go-tools v1.63.0 h1:iOFIlKitbzXAdncPhQJ3xQUO1WBbI6w5r+Wktwh2/90=
github.com/wundergraph/graphql-go-tools v1.63.0/go.mod h1:Lsg/b4nVfNQLyJE1mjPV73O/JuhhCxH5qmaWQjitVHM=
github.com/xeipuuv/gojsonpointer v0.0.0-20180127040702-4e3ac2762d5f h1:J9EGpcZtP0E/raorCMxlFGSTBrsSlaDGf3jU/qvAE2c=
github.com/xeipuuv/gojsonpointer v0.0.0-20180127040702-4e3ac2762d5f/go.mod h1:N2zxlSyiKSe5eX1tZViRH5QA0qijqEDrYZiPEAiq3wU=
github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415 h1:EzJWgHovont7NscjpAxXsDA8S8BMYve8Y5+7cuRE7R0=
Expand Down
37 changes: 21 additions & 16 deletions pkg/apihandler/apihandler.go
Original file line number Diff line number Diff line change
Expand Up @@ -878,7 +878,7 @@ func (h *GraphQLHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
flushWriter *httpFlushWriter
ok bool
)
shared.Ctx.Context, flushWriter, ok = getFlushWriter(shared.Ctx.Context, shared.Ctx.Variables, r, w)
shared.Ctx, flushWriter, ok = getFlushWriter(shared.Ctx, shared.Ctx.Variables, r, w)
if !ok {
requestLogger.Error("connection not flushable")
http.Error(w, "Connection not flushable", http.StatusBadRequest)
Expand Down Expand Up @@ -1250,7 +1250,7 @@ func (h *QueryHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
defer func() {
if cacheIsStale {
buf.Reset()
ctx.Context = context.WithValue(context.Background(), "user", authentication.UserFromContext(r.Context()))
ctx = ctx.WithContext(context.WithValue(context.Background(), "user", authentication.UserFromContext(r.Context())))
err := h.resolver.ResolveGraphQLResponse(ctx, h.preparedPlan.Response, nil, buf)
if err == nil {
bufferedData := buf.Bytes()
Expand All @@ -1267,7 +1267,7 @@ func (h *QueryHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
ctx.Variables = parseQueryVariables(r, h.queryParamsAllowList)
ctx.Variables = h.stringInterpolator.Interpolate(ctx.Variables)

if !validateInputVariables(ctx, requestLogger, ctx.Variables, h.variablesValidator, w) {
if !validateInputVariables(ctx.Context(), requestLogger, ctx.Variables, h.variablesValidator, w) {
return
}

Expand Down Expand Up @@ -1315,7 +1315,7 @@ func (h *QueryHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {

if h.cacheConfig.enable {
cacheKey = string(h.configHash) + r.RequestURI
item, hit := h.cache.Get(ctx.Context, cacheKey)
item, hit := h.cache.Get(ctx.Context(), cacheKey)
if hit {

w.Header().Set(WgCacheHeader, "HIT")
Expand Down Expand Up @@ -1421,7 +1421,7 @@ func (h *QueryHandler) handleLiveQueryEvent(ctx *resolve.Context, w http.Respons
func (h *QueryHandler) handleLiveQuery(r *http.Request, w http.ResponseWriter, ctx *resolve.Context, requestBuf *bytes.Buffer, flusher http.Flusher, requestLogger *zap.Logger) {
wgParams := NewWgRequestParams(r)

done := ctx.Context.Done()
done := ctx.Context().Done()

hookBuf := pool.GetBytesBuffer()
defer pool.PutBytesBuffer(hookBuf)
Expand Down Expand Up @@ -1612,7 +1612,7 @@ func (h *MutationHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
}
ctx.Variables = h.stringInterpolator.Interpolate(ctx.Variables)

if !validateInputVariables(ctx, requestLogger, ctx.Variables, h.variablesValidator, w) {
if !validateInputVariables(ctx.Context(), requestLogger, ctx.Variables, h.variablesValidator, w) {
return
}

Expand Down Expand Up @@ -1701,7 +1701,7 @@ func (h *SubscriptionHandler) ServeHTTP(w http.ResponseWriter, r *http.Request)
ctx.Variables = parseQueryVariables(r, h.queryParamsAllowList)
ctx.Variables = h.stringInterpolator.Interpolate(ctx.Variables)

if !validateInputVariables(ctx, requestLogger, ctx.Variables, h.variablesValidator, w) {
if !validateInputVariables(ctx.Context(), requestLogger, ctx.Variables, h.variablesValidator, w) {
return
}

Expand Down Expand Up @@ -2243,7 +2243,7 @@ func (h *FunctionsHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
}
ctx.Variables = variablesBuf.Bytes()

if !validateInputVariables(ctx, requestLogger, ctx.Variables, h.variablesValidator, w) {
if !validateInputVariables(ctx.Context(), requestLogger, ctx.Variables, h.variablesValidator, w) {
return
}

Expand All @@ -2267,7 +2267,7 @@ func (h *FunctionsHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
}
}

func (h *FunctionsHandler) handleLiveQuery(ctx context.Context, w http.ResponseWriter, r *http.Request, input []byte, requestLogger *zap.Logger) {
func (h *FunctionsHandler) handleLiveQuery(resolveCtx *resolve.Context, w http.ResponseWriter, r *http.Request, input []byte, requestLogger *zap.Logger) {

var (
err error
Expand All @@ -2276,7 +2276,7 @@ func (h *FunctionsHandler) handleLiveQuery(ctx context.Context, w http.ResponseW
out *hooks.MiddlewareHookResponse
)

ctx, fw, ok = getFlushWriter(ctx, input, r, w)
resolveCtx, fw, ok = getFlushWriter(resolveCtx, input, r, w)
if !ok {
requestLogger.Error("request doesn't support flushing")
w.WriteHeader(http.StatusBadRequest)
Expand All @@ -2292,6 +2292,7 @@ func (h *FunctionsHandler) handleLiveQuery(ctx context.Context, w http.ResponseW

defer fw.Close()

ctx := resolveCtx.Context()
for {
select {
case <-ctx.Done():
Expand Down Expand Up @@ -2321,11 +2322,12 @@ func (h *FunctionsHandler) handleLiveQuery(ctx context.Context, w http.ResponseW
}
}

func (h *FunctionsHandler) handleRequest(ctx context.Context, w http.ResponseWriter, input []byte, requestLogger *zap.Logger) {
func (h *FunctionsHandler) handleRequest(resolveCtx *resolve.Context, w http.ResponseWriter, input []byte, requestLogger *zap.Logger) {

buf := pool.GetBytesBuffer()
defer pool.PutBytesBuffer(buf)

ctx := resolveCtx.Context()
out, err := h.hooksClient.DoFunctionRequest(ctx, h.operation.Path, input, buf)
if err != nil {
if ctx.Err() != nil {
Expand All @@ -2344,12 +2346,13 @@ func (h *FunctionsHandler) handleRequest(ctx context.Context, w http.ResponseWri
}
}

func (h *FunctionsHandler) handleSubscriptionRequest(ctx context.Context, w http.ResponseWriter, r *http.Request, input []byte, requestLogger *zap.Logger) {
func (h *FunctionsHandler) handleSubscriptionRequest(resolveCtx *resolve.Context, w http.ResponseWriter, r *http.Request, input []byte, requestLogger *zap.Logger) {
wgParams := NewWgRequestParams(r)

setSubscriptionHeaders(w)
buf := pool.GetBytesBuffer()
defer pool.PutBytesBuffer(buf)
ctx := resolveCtx.Context()
err := h.hooksClient.DoFunctionSubscriptionRequest(ctx, h.operation.Path, input, wgParams.SubsribeOnce, wgParams.UseSse, wgParams.UseJsonPatch, w, buf)
if err != nil {
if ctx.Err() != nil {
Expand Down Expand Up @@ -2443,7 +2446,7 @@ func setSubscriptionHeaders(w http.ResponseWriter) {
func getHooksFlushWriter(ctx *resolve.Context, r *http.Request, w http.ResponseWriter, pipeline *hooks.SubscriptionOperationPipeline, logger *zap.Logger) (*httpFlushWriter, bool) {
var flushWriter *httpFlushWriter
var ok bool
ctx.Context, flushWriter, ok = getFlushWriter(ctx.Context, ctx.Variables, r, w)
ctx, flushWriter, ok = getFlushWriter(ctx, ctx.Variables, r, w)
if !ok {
return nil, false
}
Expand All @@ -2455,7 +2458,7 @@ func getHooksFlushWriter(ctx *resolve.Context, r *http.Request, w http.ResponseW
return flushWriter, true
}

func getFlushWriter(ctx context.Context, variables []byte, r *http.Request, w http.ResponseWriter) (context.Context, *httpFlushWriter, bool) {
func getFlushWriter(ctx *resolve.Context, variables []byte, r *http.Request, w http.ResponseWriter) (*resolve.Context, *httpFlushWriter, bool) {
wgParams := NewWgRequestParams(r)

flusher, ok := w.(http.Flusher)
Expand All @@ -2476,13 +2479,15 @@ func getFlushWriter(ctx context.Context, variables []byte, r *http.Request, w ht
useJsonPatch: wgParams.UseJsonPatch,
buf: &bytes.Buffer{},
lastMessage: &bytes.Buffer{},
ctx: ctx,
ctx: ctx.Context(),
variables: variables,
}

if wgParams.SubsribeOnce {
flushWriter.subscribeOnce = true
ctx, flushWriter.close = context.WithCancel(ctx)
var cancellableCtx context.Context
cancellableCtx, flushWriter.close = context.WithCancel(ctx.Context())
ctx = ctx.WithContext(cancellableCtx)
}

return ctx, flushWriter, true
Expand Down
6 changes: 3 additions & 3 deletions pkg/datasources/database/datasource.go
Original file line number Diff line number Diff line change
Expand Up @@ -1093,21 +1093,21 @@ func (s *Source) Load(ctx context.Context, input []byte, w io.Writer) (err error
request = s.unNullRequest(request)
buf := pool.GetBytesBuffer()
defer pool.PutBytesBuffer(buf)
ctx, cancel := context.WithTimeout(ctx, time.Second*5)
cancellableCtx, cancel := context.WithTimeout(ctx, time.Second*5)
defer cancel()
for {
s.log.Debug("database.Source.Execute",
zap.ByteString("request", request),
)

err = s.engine.Execute(ctx, request, buf)
err = s.engine.Execute(cancellableCtx, request, buf)
if err != nil {
s.log.Debug("database.Source.Execute.Error",
zap.ByteString("request", request),
zap.Error(err),
)

if ctx.Err() != nil {
if cancellableCtx.Err() != nil {
s.log.Debug("database.Source.Execute.Deadline Exceeded")

return err
Expand Down
14 changes: 7 additions & 7 deletions pkg/hooks/pipeline.go
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ func (p *pipeline) updateContextHeaders(ctx *resolve.Context, headers map[string
httpHeader.Set(name, headers[name])
}
ctx.Request.Header = httpHeader
clientRequest := ctx.Value(pool.ClientRequestKey)
clientRequest := ctx.Context().Value(pool.ClientRequestKey)
if clientRequest == nil {
return
}
Expand Down Expand Up @@ -173,7 +173,7 @@ func (p *pipeline) PreResolve(ctx *resolve.Context, w http.ResponseWriter, r *ht
if err != nil {
return nil, err
}
resp, err = p.client.DoOperationRequest(ctx.Context, p.operation.Name, PreResolve, hookData, payloadBuf)
resp, err = p.client.DoOperationRequest(ctx.Context(), p.operation.Name, PreResolve, hookData, payloadBuf)
if err != nil {
return nil, err
}
Expand All @@ -189,7 +189,7 @@ func (p *pipeline) PreResolve(ctx *resolve.Context, w http.ResponseWriter, r *ht
if err != nil {
return nil, err
}
resp, err = p.client.DoOperationRequest(ctx.Context, p.operation.Name, MutatingPreResolve, hookData, payloadBuf)
resp, err = p.client.DoOperationRequest(ctx.Context(), p.operation.Name, MutatingPreResolve, hookData, payloadBuf)
if err != nil {
return nil, err
}
Expand All @@ -208,7 +208,7 @@ func (p *pipeline) PreResolve(ctx *resolve.Context, w http.ResponseWriter, r *ht
if err != nil {
return nil, err
}
resp, err := p.client.DoOperationRequest(ctx.Context, p.operation.Name, MockResolve, hookData, payloadBuf)
resp, err := p.client.DoOperationRequest(ctx.Context(), p.operation.Name, MockResolve, hookData, payloadBuf)
if err != nil {
return nil, err
}
Expand All @@ -229,7 +229,7 @@ func (p *pipeline) PreResolve(ctx *resolve.Context, w http.ResponseWriter, r *ht
if err != nil {
return nil, err
}
resp, err = p.client.DoOperationRequest(ctx.Context, p.operation.Name, CustomResolve, hookData, payloadBuf)
resp, err = p.client.DoOperationRequest(ctx.Context(), p.operation.Name, CustomResolve, hookData, payloadBuf)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -267,7 +267,7 @@ func (p *pipeline) PostResolve(ctx *resolve.Context, w http.ResponseWriter, r *h
if err != nil {
return nil, err
}
resp, err := p.client.DoOperationRequest(ctx.Context, p.operation.Name, PostResolve, postResolveData, payloadBuf)
resp, err := p.client.DoOperationRequest(ctx.Context(), p.operation.Name, PostResolve, postResolveData, payloadBuf)
if err != nil {
return nil, err
}
Expand All @@ -283,7 +283,7 @@ func (p *pipeline) PostResolve(ctx *resolve.Context, w http.ResponseWriter, r *h
if err != nil {
return nil, err
}
resp, err := p.client.DoOperationRequest(ctx.Context, p.operation.Name, MutatingPostResolve, mutatingPostData, payloadBuf)
resp, err := p.client.DoOperationRequest(ctx.Context(), p.operation.Name, MutatingPostResolve, mutatingPostData, payloadBuf)
if err != nil {
return nil, err
}
Expand Down
7 changes: 3 additions & 4 deletions pkg/pool/pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,7 @@ func GetCtx(r, clientRequest *http.Request, cfg Config) *resolve.Context {
resolveCtx.RenameTypeNames = cfg.RenameTypeNames
return resolveCtx
}
resolveContext := next.(*resolve.Context)
resolveContext.Context = ctx
resolveContext := next.(*resolve.Context).WithContext(ctx)
resolveContext.Request.Header = r.Header
resolveContext.RenameTypeNames = cfg.RenameTypeNames
return resolveContext
Expand Down Expand Up @@ -97,7 +96,7 @@ func (p *Pool) GetShared(ctx context.Context, planConfig plan.Configuration, cfg
if shared != nil {
s := shared.(*Shared)
s.Planner.SetConfig(planConfig)
s.Ctx.Context = ctx
s.Ctx = s.Ctx.WithContext(ctx)
s.Ctx.RenameTypeNames = cfg.RenameTypeNames
return s
}
Expand All @@ -123,7 +122,7 @@ func (p *Pool) GetSharedFromRequest(ctx context.Context, r *http.Request, planCo
if shared != nil {
s := shared.(*Shared)
s.Planner.SetConfig(planConfig)
s.Ctx.Context = c
s.Ctx = s.Ctx.WithContext(ctx)
s.Ctx.Request.Header = r.Header
s.Ctx.RenameTypeNames = cfg.RenameTypeNames
return s
Expand Down

0 comments on commit 18c373a

Please sign in to comment.