Skip to content

Commit

Permalink
fix: replace RunContext and ProjectContext contextValues with c…
Browse files Browse the repository at this point in the history
…oncurrency safe `ContextValues` (infracost#2544)

This change introduces the ContextValues struct to both RunContext and ProjectContext, replacing the previously used map[string]interface{}. This change is introduced to enhance thread safety in our codebase by preventing concurrent read/write operations to these shared resources, which could lead to unpredictable behavior or data races.

Prior to this change the `ContextValues()` could hit a conccurent write panic because it was returning the actual map and not a copy of it, so when we read from the map it could potentially be written to in a separate goroutine.

To fix this `Values` now returns a copy of the underlying map, and I've introduced a `GetValue` method to safely read values.
  • Loading branch information
hugorut authored Jul 6, 2023
1 parent 5583a07 commit 48d8ba5
Show file tree
Hide file tree
Showing 17 changed files with 156 additions and 131 deletions.
2 changes: 1 addition & 1 deletion cmd/infracost/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ func authLoginCmd(ctx *config.RunContext) *cobra.Command {
cmd.Println("We're redirecting you to our log in page, please complete that,\nand return here to continue using Infracost.")

auth := apiclient.AuthClient{Host: ctx.Config.DashboardEndpoint}
apiKey, info, err := auth.Login(ctx.ContextValues())
apiKey, info, err := auth.Login(ctx.ContextValues.Values())
if err != nil {
return err
}
Expand Down
2 changes: 1 addition & 1 deletion cmd/infracost/breakdown.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ func breakdownCmd(ctx *config.RunContext) *cobra.Command {
return err
}

ctx.SetContextValue("outputFormat", ctx.Config.Format)
ctx.ContextValues.SetValue("outputFormat", ctx.Config.Format)

err = checkRunConfig(cmd.ErrOrStderr(), ctx.Config)
if err != nil {
Expand Down
10 changes: 6 additions & 4 deletions cmd/infracost/comment.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@ import (
"context"
"errors"
"fmt"
"github.com/infracost/infracost/internal/apiclient"
log "github.com/sirupsen/logrus"
"os"
"strconv"

log "github.com/sirupsen/logrus"

"github.com/infracost/infracost/internal/apiclient"

"github.com/open-policy-agent/opa/ast"
"github.com/open-policy-agent/opa/rego"
"github.com/spf13/cobra"
Expand Down Expand Up @@ -117,8 +119,8 @@ func buildCommentBody(cmd *cobra.Command, ctx *config.RunContext, paths []string
return nil, hasDiff, err
}

ctx.SetContextValue("passedPolicyCount", len(policyChecks.Passed))
ctx.SetContextValue("failedPolicyCount", len(policyChecks.Failures))
ctx.ContextValues.SetValue("passedPolicyCount", len(policyChecks.Passed))
ctx.ContextValues.SetValue("failedPolicyCount", len(policyChecks.Failures))
}

opts := output.Options{
Expand Down
6 changes: 3 additions & 3 deletions cmd/infracost/comment_azure_repos.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ func commentAzureReposCmd(ctx *config.RunContext) *cobra.Command {
infracost comment azure-repos --repo-url https://dev.azure.com/my-org/my-project/_git/my-repo --pull-request 3 --path infracost.json --azure-access-token $AZURE_ACCESS_TOKEN`,
ValidArgs: []string{"--", "-"},
RunE: func(cmd *cobra.Command, args []string) error {
ctx.SetContextValue("platform", "azure-repos")
ctx.ContextValues.SetValue("platform", "azure-repos")

var err error

Expand All @@ -43,7 +43,7 @@ func commentAzureReposCmd(ctx *config.RunContext) *cobra.Command {

var commentHandler *comment.CommentHandler
if prNumber != 0 {
ctx.SetContextValue("targetType", "pull-request")
ctx.ContextValues.SetValue("targetType", "pull-request")

commentHandler, err = comment.NewAzureReposPRHandler(ctx.Context(), repoURL, strconv.Itoa(prNumber), extra)
if err != nil {
Expand All @@ -59,7 +59,7 @@ func commentAzureReposCmd(ctx *config.RunContext) *cobra.Command {
ui.PrintUsage(cmd)
return fmt.Errorf("--behavior only supports %s", strings.Join(validCommentAzureReposBehaviors, ", "))
}
ctx.SetContextValue("behavior", behavior)
ctx.ContextValues.SetValue("behavior", behavior)

paths, _ := cmd.Flags().GetStringArray("path")

Expand Down
8 changes: 4 additions & 4 deletions cmd/infracost/comment_bitbucket.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ func commentBitbucketCmd(ctx *config.RunContext) *cobra.Command {
infracost comment bitbucket --repo my-org/my-repo --commit 2ca7182 --path infracost.json --behavior delete-and-new --bitbucket-token $BITBUCKET_TOKEN`,
ValidArgs: []string{"--", "-"},
RunE: func(cmd *cobra.Command, args []string) error {
ctx.SetContextValue("platform", "bitbucket")
ctx.ContextValues.SetValue("platform", "bitbucket")

var err error

Expand All @@ -52,14 +52,14 @@ func commentBitbucketCmd(ctx *config.RunContext) *cobra.Command {

var commentHandler *comment.CommentHandler
if prNumber != 0 {
ctx.SetContextValue("targetType", "pull-request")
ctx.ContextValues.SetValue("targetType", "pull-request")

commentHandler, err = comment.NewBitbucketPRHandler(ctx.Context(), repo, strconv.Itoa(prNumber), extra)
if err != nil {
return err
}
} else if commit != "" {
ctx.SetContextValue("targetType", "commit")
ctx.ContextValues.SetValue("targetType", "commit")

commentHandler, err = comment.NewBitbucketCommitHandler(ctx.Context(), repo, commit, extra)
if err != nil {
Expand All @@ -75,7 +75,7 @@ func commentBitbucketCmd(ctx *config.RunContext) *cobra.Command {
ui.PrintUsage(cmd)
return fmt.Errorf("--behavior only supports %s", strings.Join(validCommentBitbucketBehaviors, ", "))
}
ctx.SetContextValue("behavior", behavior)
ctx.ContextValues.SetValue("behavior", behavior)

paths, _ := cmd.Flags().GetStringArray("path")

Expand Down
8 changes: 4 additions & 4 deletions cmd/infracost/comment_github.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ func commentGitHubCmd(ctx *config.RunContext) *cobra.Command {
infracost comment github --repo my-org/my-repo --commit 2ca7182 --path infracost.json --behavior hide-and-new --github-token $GITHUB_TOKEN`,
ValidArgs: []string{"--", "-"},
RunE: func(cmd *cobra.Command, args []string) error {
ctx.SetContextValue("platform", "github")
ctx.ContextValues.SetValue("platform", "github")

var err error

Expand Down Expand Up @@ -77,14 +77,14 @@ func commentGitHubCmd(ctx *config.RunContext) *cobra.Command {

var commentHandler *comment.CommentHandler
if prNumber != 0 {
ctx.SetContextValue("targetType", "pull-request")
ctx.ContextValues.SetValue("targetType", "pull-request")

commentHandler, err = comment.NewGitHubPRHandler(ctx.Context(), repo, strconv.Itoa(prNumber), extra)
if err != nil {
return err
}
} else if commit != "" {
ctx.SetContextValue("targetType", "commit")
ctx.ContextValues.SetValue("targetType", "commit")

commentHandler, err = comment.NewGitHubCommitHandler(ctx.Context(), repo, commit, extra)
if err != nil {
Expand All @@ -100,7 +100,7 @@ func commentGitHubCmd(ctx *config.RunContext) *cobra.Command {
ui.PrintUsage(cmd)
return fmt.Errorf("--behavior only supports %s", strings.Join(validCommentGitHubBehaviors, ", "))
}
ctx.SetContextValue("behavior", behavior)
ctx.ContextValues.SetValue("behavior", behavior)

paths, _ := cmd.Flags().GetStringArray("path")

Expand Down
8 changes: 4 additions & 4 deletions cmd/infracost/comment_gitlab.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ func commentGitLabCmd(ctx *config.RunContext) *cobra.Command {
infracost comment gitlab --repo my-org/my-repo --commit 2ca7182 --path infracost.json --behavior delete-and-new --gitlab-token $GITLAB_TOKEN`,
ValidArgs: []string{"--", "-"},
RunE: func(cmd *cobra.Command, args []string) error {
ctx.SetContextValue("platform", "gitlab")
ctx.ContextValues.SetValue("platform", "gitlab")

var err error

Expand All @@ -50,14 +50,14 @@ func commentGitLabCmd(ctx *config.RunContext) *cobra.Command {

var commentHandler *comment.CommentHandler
if mrNumber != 0 {
ctx.SetContextValue("targetType", "merge-request")
ctx.ContextValues.SetValue("targetType", "merge-request")

commentHandler, err = comment.NewGitLabPRHandler(ctx.Context(), repo, strconv.Itoa(mrNumber), extra)
if err != nil {
return err
}
} else if commit != "" {
ctx.SetContextValue("targetType", "commit")
ctx.ContextValues.SetValue("targetType", "commit")

commentHandler, err = comment.NewGitLabCommitHandler(ctx.Context(), repo, commit, extra)
if err != nil {
Expand All @@ -73,7 +73,7 @@ func commentGitLabCmd(ctx *config.RunContext) *cobra.Command {
ui.PrintUsage(cmd)
return fmt.Errorf("--behavior only supports %s", strings.Join(validCommentGitLabBehaviors, ", "))
}
ctx.SetContextValue("behavior", behavior)
ctx.ContextValues.SetValue("behavior", behavior)

paths, _ := cmd.Flags().GetStringArray("path")

Expand Down
10 changes: 5 additions & 5 deletions cmd/infracost/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ func newRootCmd(ctx *config.RunContext) *cobra.Command {
SilenceErrors: true,
PersistentPreRunE: func(cmd *cobra.Command, args []string) error {
cmd.SilenceUsage = true
ctx.SetContextValue("command", cmd.Name())
ctx.ContextValues.SetValue("command", cmd.Name())
ctx.CMD = cmd.Name()
if cmd.Name() == "comment" || (cmd.Parent() != nil && cmd.Parent().Name() == "comment") {
ctx.SetIsInfracostComment()
Expand Down Expand Up @@ -318,17 +318,17 @@ func loadGlobalFlags(ctx *config.RunContext, cmd *cobra.Command) error {
return err
}

ctx.SetContextValue("dashboardEnabled", ctx.Config.EnableDashboard)
ctx.SetContextValue("cloudEnabled", ctx.IsCloudEnabled())
ctx.SetContextValue("isDefaultPricingAPIEndpoint", ctx.Config.PricingAPIEndpoint == ctx.Config.DefaultPricingAPIEndpoint)
ctx.ContextValues.SetValue("dashboardEnabled", ctx.Config.EnableDashboard)
ctx.ContextValues.SetValue("cloudEnabled", ctx.IsCloudEnabled())
ctx.ContextValues.SetValue("isDefaultPricingAPIEndpoint", ctx.Config.PricingAPIEndpoint == ctx.Config.DefaultPricingAPIEndpoint)

flagNames := make([]string, 0)

cmd.Flags().Visit(func(f *pflag.Flag) {
flagNames = append(flagNames, f.Name)
})

ctx.SetContextValue("flags", flagNames)
ctx.ContextValues.SetValue("flags", flagNames)

return nil
}
Expand Down
4 changes: 2 additions & 2 deletions cmd/infracost/output.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ func outputCmd(ctx *config.RunContext) *cobra.Command {

format, _ := cmd.Flags().GetString("format")
format = strings.ToLower(format)
ctx.SetContextValue("outputFormat", format)
ctx.ContextValues.SetValue("outputFormat", format)

if format != "" && !contains(validOutputFormats, format) {
ui.PrintUsage(cmd)
Expand Down Expand Up @@ -200,7 +200,7 @@ func shareCombinedRun(ctx *config.RunContext, combined output.Root, inputs []out
combinedRunIds = append(combinedRunIds, id)
}
}
ctx.SetContextValue("runIds", combinedRunIds)
ctx.ContextValues.SetValue("runIds", combinedRunIds)

dashboardClient := apiclient.NewDashboardAPIClient(ctx)
result, err := dashboardClient.AddRun(ctx, combined)
Expand Down
18 changes: 9 additions & 9 deletions cmd/infracost/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ func runMain(cmd *cobra.Command, runCtx *config.RunContext) error {

if format == "diff" || format == "table" {
lines := bytes.Count(b, []byte("\n")) + 1
runCtx.SetContextValue("lineCount", lines)
runCtx.ContextValues.SetValue("lineCount", lines)
}

env := buildRunEnv(runCtx, projectContexts, r)
Expand Down Expand Up @@ -223,7 +223,7 @@ func newParallelRunner(cmd *cobra.Command, runCtx *config.RunContext) (*parallel
if err != nil {
return nil, err
}
runCtx.SetContextValue("parallelism", parallelism)
runCtx.ContextValues.SetValue("parallelism", parallelism)

numJobs := len(runCtx.Config.Projects)

Expand Down Expand Up @@ -337,14 +337,14 @@ func (r *parallelRunner) runProjectConfig(ctx *config.ProjectContext) (*projectO
return nil, clierror.NewCLIError(errors.New(m), "Could not detect path type")
}

ctx.SetContextValue("projectType", provider.Type())
ctx.ContextValues.SetValue("projectType", provider.Type())

projectTypes := []interface{}{}
if t, ok := ctx.RunContext.ContextValues()["projectTypes"]; ok {
if t, ok := ctx.RunContext.ContextValues.GetValue("projectTypes"); ok {
projectTypes = t.([]interface{})
}
projectTypes = append(projectTypes, provider.Type())
ctx.RunContext.SetContextValue("projectTypes", projectTypes)
ctx.RunContext.ContextValues.SetValue("projectTypes", projectTypes)

if r.cmd.Name() == "diff" && provider.Type() == "terraform_state_json" {
m := "Cannot use Terraform state JSON with the infracost diff command.\n\n"
Expand Down Expand Up @@ -396,7 +396,7 @@ func (r *parallelRunner) runProjectConfig(ctx *config.ProjectContext) (*projectO
)
}

ctx.SetContextValue("hasUsageFile", true)
ctx.ContextValues.SetValue("hasUsageFile", true)
} else {
usageFile = usage.NewBlankUsageFile()
}
Expand Down Expand Up @@ -481,7 +481,7 @@ func (r *parallelRunner) runProjectConfig(ctx *config.ProjectContext) (*projectO

t2 := time.Now()
taken := t2.Sub(t1).Milliseconds()
ctx.SetContextValue("tfProjectRunTimeMs", taken)
ctx.ContextValues.SetValue("tfProjectRunTimeMs", taken)

spinner.Success()

Expand All @@ -503,7 +503,7 @@ func (r *parallelRunner) uploadCloudResourceIDs(projects []*schema.Project) erro
return nil
}

r.runCtx.SetContextValue("uploadedResourceIds", true)
r.runCtx.ContextValues.SetValue("uploadedResourceIds", true)

spinnerOpts := ui.SpinnerOptions{
EnableLogging: r.runCtx.Config.IsLogging(),
Expand Down Expand Up @@ -580,7 +580,7 @@ func (r *parallelRunner) fetchProjectUsage(projects []*schema.Project) map[*sche
logging.Logger.WithError(err).Debugf("failed to retrieve usage data for project %s", project.Name)
return nil
}
r.runCtx.SetContextValue("fetchedUsageData", true)
r.runCtx.ContextValues.SetValue("fetchedUsageData", true)
projectPtrToUsageMap[project] = usageMap
}

Expand Down
4 changes: 2 additions & 2 deletions internal/apiclient/dashboard.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ func newRunInput(ctx *config.RunContext, out output.Root) (*runInput, error) {
}
}

ctxValues := ctx.ContextValues()
ctxValues := ctx.ContextValues.Values()

var metadata map[string]interface{}
b, err := json.Marshal(out.Metadata)
Expand All @@ -94,7 +94,7 @@ func newRunInput(ctx *config.RunContext, out output.Root) (*runInput, error) {
// Clone the map to cleanup up the "command" key to show "comment". It is
// currently set to the sub comment (e.g. "github")
ctxValues = make(map[string]interface{}, len(ctxValues))
for k, v := range ctx.ContextValues() {
for k, v := range ctx.ContextValues.Values() {
ctxValues[k] = v
}
ctxValues["command"] = "comment"
Expand Down
32 changes: 10 additions & 22 deletions internal/config/project_context.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ type ProjectContext struct {
RunContext *RunContext
ProjectConfig *Project
logger *logrus.Entry
contextVals map[string]interface{}
ContextValues *ContextValues
mu *sync.RWMutex

UsingCache bool
Expand All @@ -35,42 +35,30 @@ func NewProjectContext(runCtx *RunContext, projectCfg *Project, fields logrus.Fi
RunContext: runCtx,
ProjectConfig: projectCfg,
logger: contextLogger,
contextVals: map[string]interface{}{},
ContextValues: NewContextValues(map[string]interface{}{}),
mu: &sync.RWMutex{},
}
}

func (p *ProjectContext) Logger() *logrus.Entry {
if p.logger == nil {
return logging.Logger.WithFields(p.logFields())
func (c *ProjectContext) Logger() *logrus.Entry {
if c.logger == nil {
return logging.Logger.WithFields(c.logFields())
}

return p.logger.WithFields(p.logFields())
return c.logger.WithFields(c.logFields())
}

func (p *ProjectContext) logFields() logrus.Fields {
func (c *ProjectContext) logFields() logrus.Fields {
return logrus.Fields{
"project_name": p.ProjectConfig.Name,
"project_path": p.ProjectConfig.Path,
"project_name": c.ProjectConfig.Name,
"project_path": c.ProjectConfig.Path,
}
}

func (c *ProjectContext) SetContextValue(key string, value interface{}) {
c.mu.Lock()
defer c.mu.Unlock()
c.contextVals[key] = value
}

func (c *ProjectContext) ContextValues() map[string]interface{} {
c.mu.RLock()
defer c.mu.RUnlock()
return c.contextVals
}

func (c *ProjectContext) SetFrom(d ProjectContexter) {
m := d.ProjectContext()
for k, v := range m {
c.SetContextValue(k, v)
c.ContextValues.SetValue(k, v)
}
}

Expand Down
Loading

0 comments on commit 48d8ba5

Please sign in to comment.