Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enrich policy input #540

Merged
merged 7 commits into from
May 24, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 9 additions & 7 deletions act/registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ import (

type IRegistry interface {
Add(policy *sdkAct.Policy)
Apply(signals []sdkAct.Signal) []*sdkAct.Output
Apply(signals []sdkAct.Signal, hook sdkAct.Hook) []*sdkAct.Output
Run(output *sdkAct.Output, params ...sdkAct.Parameter) (any, *gerr.GatewayDError)
}

Expand Down Expand Up @@ -107,11 +107,11 @@ func (r *Registry) Add(policy *sdkAct.Policy) {
}

// Apply applies the signals to the registry and returns the outputs.
func (r *Registry) Apply(signals []sdkAct.Signal) []*sdkAct.Output {
func (r *Registry) Apply(signals []sdkAct.Signal, hook sdkAct.Hook) []*sdkAct.Output {
// If there are no signals, apply the default policy.
if len(signals) == 0 {
r.Logger.Debug().Msg("No signals provided, applying default signal")
return r.Apply([]sdkAct.Signal{*r.DefaultSignal})
return r.Apply([]sdkAct.Signal{*r.DefaultSignal}, hook)
}

// Separate terminal and non-terminal signals to find contradictions.
Expand Down Expand Up @@ -139,7 +139,7 @@ func (r *Registry) Apply(signals []sdkAct.Signal) []*sdkAct.Output {
}

// Apply the signal and append the output to the list of outputs.
output, err := r.apply(signal)
output, err := r.apply(signal, hook)
if err != nil {
r.Logger.Error().Err(err).Str("name", signal.Name).Msg("Error applying signal")
// If there is an error evaluating the policy, continue to the next signal.
Expand All @@ -155,14 +155,16 @@ func (r *Registry) Apply(signals []sdkAct.Signal) []*sdkAct.Output {
}

if len(outputs) == 0 && !evalErr {
return r.Apply([]sdkAct.Signal{*r.DefaultSignal})
return r.Apply([]sdkAct.Signal{*r.DefaultSignal}, hook)
}

return outputs
}

// apply applies the signal to the registry and returns the output.
func (r *Registry) apply(signal sdkAct.Signal) (*sdkAct.Output, *gerr.GatewayDError) {
func (r *Registry) apply(
signal sdkAct.Signal, hook sdkAct.Hook,
) (*sdkAct.Output, *gerr.GatewayDError) {
action, exists := r.Actions[signal.Name]
if !exists {
return nil, gerr.ErrActionNotMatched
Expand All @@ -178,12 +180,12 @@ func (r *Registry) apply(signal sdkAct.Signal) (*sdkAct.Output, *gerr.GatewayDEr
defer cancel()

// Evaluate the policy.
// TODO: Policy should be able to receive other parameters like server and client IPs, etc.
verdict, err := policy.Eval(
ctx, sdkAct.Input{
Name: signal.Name,
Policy: policy.Metadata,
Signal: signal.Metadata,
Hook: hook,
// Action dictates the sync mode, not the signal.
Sync: action.Sync,
},
Expand Down
71 changes: 65 additions & 6 deletions act/registry_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -196,9 +196,17 @@ func Test_Apply(t *testing.T) {
})
assert.NotNil(t, actRegistry)

outputs := actRegistry.Apply([]sdkAct.Signal{
*sdkAct.Passthrough(),
})
outputs := actRegistry.Apply(
[]sdkAct.Signal{
*sdkAct.Passthrough(),
},
sdkAct.Hook{
Name: "HOOK_NAME_ON_TRAFFIC_FROM_CLIENT",
Priority: 1000,
Params: map[string]any{},
Result: map[string]any{},
},
)
assert.NotNil(t, outputs)
assert.Len(t, outputs, 1)
assert.Equal(t, "passthrough", outputs[0].MatchedPolicy)
Expand All @@ -225,7 +233,15 @@ func Test_Apply_NoSignals(t *testing.T) {
})
assert.NotNil(t, actRegistry)

outputs := actRegistry.Apply([]sdkAct.Signal{})
outputs := actRegistry.Apply(
[]sdkAct.Signal{},
sdkAct.Hook{
Name: "HOOK_NAME_ON_TRAFFIC_FROM_CLIENT",
Priority: 1000,
Params: map[string]any{},
Result: map[string]any{},
},
)
assert.NotNil(t, outputs)
assert.Len(t, outputs, 1)
assert.Equal(t, "passthrough", outputs[0].MatchedPolicy)
Expand Down Expand Up @@ -272,7 +288,12 @@ func Test_Apply_ContradictorySignals(t *testing.T) {
assert.NotNil(t, actRegistry)

for _, s := range signals {
outputs := actRegistry.Apply(s)
outputs := actRegistry.Apply(s, sdkAct.Hook{
Name: "HOOK_NAME_ON_TRAFFIC_FROM_CLIENT",
Priority: 1000,
Params: map[string]any{},
Result: map[string]any{},
})
assert.NotNil(t, outputs)
assert.Len(t, outputs, 2)
assert.Equal(t, "terminate", outputs[0].MatchedPolicy)
Expand Down Expand Up @@ -318,6 +339,11 @@ func Test_Apply_ActionNotMatched(t *testing.T) {

outputs := actRegistry.Apply([]sdkAct.Signal{
{Name: "non-existent"},
}, sdkAct.Hook{
Name: "HOOK_NAME_ON_TRAFFIC_FROM_CLIENT",
Priority: 1000,
Params: map[string]any{},
Result: map[string]any{},
})
assert.NotNil(t, outputs)
assert.Len(t, outputs, 1)
Expand Down Expand Up @@ -351,6 +377,11 @@ func Test_Apply_PolicyNotMatched(t *testing.T) {

outputs := actRegistry.Apply([]sdkAct.Signal{
*sdkAct.Terminate(),
}, sdkAct.Hook{
Name: "HOOK_NAME_ON_TRAFFIC_FROM_CLIENT",
Priority: 1000,
Params: map[string]any{},
Result: map[string]any{},
})
assert.NotNil(t, outputs)
assert.Len(t, outputs, 1)
Expand Down Expand Up @@ -399,6 +430,11 @@ func Test_Apply_NonBoolPolicy(t *testing.T) {

outputs := actRegistry.Apply([]sdkAct.Signal{
*sdkAct.Passthrough(),
}, sdkAct.Hook{
Name: "HOOK_NAME_ON_TRAFFIC_FROM_CLIENT",
Priority: 1000,
Params: map[string]any{},
Result: map[string]any{},
})
assert.NotNil(t, outputs)
assert.Len(t, outputs, 1)
Expand Down Expand Up @@ -464,6 +500,11 @@ func Test_Run(t *testing.T) {

outputs := actRegistry.Apply([]sdkAct.Signal{
*sdkAct.Passthrough(),
}, sdkAct.Hook{
Name: "HOOK_NAME_ON_TRAFFIC_FROM_CLIENT",
Priority: 1000,
Params: map[string]any{},
Result: map[string]any{},
})
assert.NotNil(t, outputs)

Expand All @@ -489,6 +530,11 @@ func Test_Run_Terminate(t *testing.T) {

outputs := actRegistry.Apply([]sdkAct.Signal{
*sdkAct.Terminate(),
}, sdkAct.Hook{
Name: "HOOK_NAME_ON_TRAFFIC_FROM_CLIENT",
Priority: 1000,
Params: map[string]any{},
Result: map[string]any{},
})
assert.NotNil(t, outputs)
assert.Equal(t, "terminate", outputs[0].MatchedPolicy)
Expand Down Expand Up @@ -522,6 +568,11 @@ func Test_Run_Async(t *testing.T) {

outputs := actRegistry.Apply([]sdkAct.Signal{
*sdkAct.Log("info", "test", map[string]any{"async": true}),
}, sdkAct.Hook{
Name: "HOOK_NAME_ON_TRAFFIC_FROM_CLIENT",
Priority: 1000,
Params: map[string]any{},
Result: map[string]any{},
})
assert.NotNil(t, outputs)
assert.Equal(t, "log", outputs[0].MatchedPolicy)
Expand Down Expand Up @@ -647,7 +698,15 @@ func Test_Run_Timeout(t *testing.T) {
})
assert.NotNil(t, actRegistry)

outputs := actRegistry.Apply([]sdkAct.Signal{*signals[name]})
outputs := actRegistry.Apply(
[]sdkAct.Signal{*signals[name]},
sdkAct.Hook{
Name: "HOOK_NAME_ON_TRAFFIC_FROM_CLIENT",
Priority: 1000,
Params: map[string]any{},
Result: map[string]any{},
},
)
assert.NotNil(t, outputs)
assert.Equal(t, name, outputs[0].MatchedPolicy)
assert.Equal(t,
Expand Down
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ require (
github.com/codingsince1985/checksum v1.3.0
github.com/cybercyst/go-scaffold v0.0.0-20240404115540-744e601147cd
github.com/envoyproxy/protoc-gen-validate v1.0.4
github.com/gatewayd-io/gatewayd-plugin-sdk v0.2.13
github.com/gatewayd-io/gatewayd-plugin-sdk v0.2.14
github.com/getsentry/sentry-go v0.27.0
github.com/go-co-op/gocron v1.37.0
github.com/google/go-github/v53 v53.2.0
Expand Down
4 changes: 2 additions & 2 deletions go.sum

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

38 changes: 21 additions & 17 deletions network/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -856,31 +856,35 @@ func (pr *Proxy) shouldTerminate(result map[string]interface{}) (bool, map[strin
// The Terminal field is only present if the action wants to terminate the request,
// that is the `__terminal__` field is set in one of the outputs.
keys := maps.Keys(result)
if slices.Contains(keys, sdkAct.Terminal) {
var actionResult map[string]interface{}
for _, output := range outputs {
actRes, err := pr.PluginRegistry.ActRegistry.Run(
output, act.WithResult(result))
// If the action is async and we received a sentinel error,
// don't log the error.
if err != nil && !errors.Is(err, gerr.ErrAsyncAction) {
pr.Logger.Error().Err(err).Msg("Error running policy")
}
// The terminate action should return a map.
if v, ok := actRes.(map[string]interface{}); ok {
actionResult = v
}
terminate := slices.Contains(keys, sdkAct.Terminal) && cast.ToBool(result[sdkAct.Terminal])
actionResult := make(map[string]interface{})
for _, output := range outputs {
if !cast.ToBool(output.Verdict) {
pr.Logger.Debug().Msg(
"Skipping the action, because the verdict of the policy execution is false")
continue
}
actRes, err := pr.PluginRegistry.ActRegistry.Run(
output, act.WithResult(result))
// If the action is async and we received a sentinel error,
// don't log the error.
if err != nil && !errors.Is(err, gerr.ErrAsyncAction) {
pr.Logger.Error().Err(err).Msg("Error running policy")
}
// The terminate action should return a map.
if v, ok := actRes.(map[string]interface{}); ok {
actionResult = v
}
}
if terminate {
pr.Logger.Debug().Fields(
map[string]interface{}{
"function": "proxy.passthrough",
"reason": "terminate",
},
).Msg("Terminating request")
return cast.ToBool(result[sdkAct.Terminal]), actionResult
}

return false, result
return terminate, actionResult
}

// getPluginModifiedRequest is a function that retrieves the modified request
Expand Down
20 changes: 14 additions & 6 deletions plugin/plugin_registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
goplugin "github.com/hashicorp/go-plugin"
"github.com/mitchellh/mapstructure"
"github.com/rs/zerolog"
"github.com/spf13/cast"
"go.opentelemetry.io/otel"
"go.opentelemetry.io/otel/attribute"
"google.golang.org/grpc"
Expand Down Expand Up @@ -49,7 +50,7 @@ type IRegistry interface {
Shutdown()
LoadPlugins(ctx context.Context, plugins []config.Plugin, startTimeout time.Duration)
RegisterHooks(ctx context.Context, pluginID sdkPlugin.Identifier)
Apply(hookName string, result *v1.Struct) ([]*sdkAct.Output, bool)
Apply(hook sdkAct.Hook) ([]*sdkAct.Output, bool)

// Hook management
IHook
Expand Down Expand Up @@ -329,7 +330,14 @@ func (reg *Registry) Run(
continue
}

out, terminal := reg.Apply(hookName.String(), result)
out, terminal := reg.Apply(
sdkAct.Hook{
Name: hookName.String(),
Priority: uint(priority),
Params: params.AsMap(),
Result: result.AsMap(),
},
)
outputs = append(outputs, out...)

if terminal {
Expand All @@ -352,16 +360,16 @@ func (reg *Registry) Run(
}

// Apply applies policies to the result.
func (reg *Registry) Apply(hookName string, result *v1.Struct) ([]*sdkAct.Output, bool) {
func (reg *Registry) Apply(hook sdkAct.Hook) ([]*sdkAct.Output, bool) {
_, span := otel.Tracer(config.TracerName).Start(reg.ctx, "Apply")
defer span.End()

// Get signals from the result.
signals := getSignals(result.AsMap())
signals := getSignals(hook.Result)
// Apply policies to the signals.
// The outputs contain the verdicts of the policies and their metadata.
// And using this list, the caller can take further actions.
outputs := applyPolicies(hookName, signals, reg.Logger, reg.ActRegistry)
outputs := applyPolicies(hook, signals, reg.Logger, reg.ActRegistry)

// If no policies are found, return a default output.
// Note: this should never happen, as the default policy is always loaded.
Expand All @@ -373,7 +381,7 @@ func (reg *Registry) Apply(hookName string, result *v1.Struct) ([]*sdkAct.Output
// Check if any of the policies have a terminal action.
var terminal bool
for _, output := range outputs {
if output.Verdict != nil && output.Terminal {
if output.Verdict != nil && cast.ToBool(output.Verdict) && output.Terminal {
terminal = true
break
}
Expand Down
11 changes: 7 additions & 4 deletions plugin/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,10 @@ func getSignals(result map[string]any) []sdkAct.Signal {

// applyPolicies applies the policies to the signals and returns the outputs.
func applyPolicies(
hookName string, signals []sdkAct.Signal, logger zerolog.Logger, reg act.IRegistry,
hook sdkAct.Hook,
signals []sdkAct.Signal,
logger zerolog.Logger,
reg act.IRegistry,
) []*sdkAct.Output {
signalNames := []string{}
for _, signal := range signals {
Expand All @@ -85,15 +88,15 @@ func applyPolicies(

logger.Debug().Fields(
map[string]interface{}{
"hook": hookName,
"hook": hook.Name,
"signals": signalNames,
},
).Msg("Detected signals from the plugin hook")

outputs := reg.Apply(signals)
outputs := reg.Apply(signals, hook)
logger.Debug().Fields(
map[string]interface{}{
"hook": hookName,
"hook": hook.Name,
"outputs": outputs,
},
).Msg("Applied policies to signals")
Expand Down
11 changes: 10 additions & 1 deletion plugin/utils_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,16 @@ func Test_applyPolicies(t *testing.T) {
})

output := applyPolicies(
"onTrafficFromClient", []sdkAct.Signal{*sdkAct.Passthrough()}, logger, actRegistry)
sdkAct.Hook{
Name: "HOOK_NAME_ON_TRAFFIC_FROM_CLIENT",
Priority: 1000,
Params: map[string]any{},
Result: map[string]any{},
},
[]sdkAct.Signal{*sdkAct.Passthrough()},
logger,
actRegistry,
)
assert.Len(t, output, 1)
assert.Equal(t, "passthrough", output[0].MatchedPolicy)
assert.Nil(t, output[0].Metadata)
Expand Down