diff --git a/.golangci.yml b/.golangci.yml index 079e952252ba..374c9204ed1b 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -25,3 +25,9 @@ run: skip-files: - ".+_test.go" - "corpx/faker.go" + +issues: + exclude: + - "Set is deprecated: use context-based WithConfigValue instead" + - "SetDefaultIdentitySchemaFromRaw is deprecated: Use context-based WithDefaultIdentitySchemaFromRaw instead" + - "SetDefaultIdentitySchema is deprecated: Use context-based WithDefaultIdentitySchema instead" diff --git a/cipher/cipher_test.go b/cipher/cipher_test.go index 8cdb0ed0e2ac..eb8ba7e1ba7b 100644 --- a/cipher/cipher_test.go +++ b/cipher/cipher_test.go @@ -9,6 +9,8 @@ import ( "fmt" "testing" + "github.com/ory/x/configx" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -18,10 +20,11 @@ import ( "github.com/ory/kratos/internal" ) +var goodSecret = []string{"secret-thirty-two-character-long"} + func TestCipher(t *testing.T) { ctx := context.Background() - cfg, reg := internal.NewFastRegistryWithMocks(t) - goodSecret := []string{"secret-thirty-two-character-long"} + _, reg := internal.NewFastRegistryWithMocks(t, configx.WithValue(config.ViperKeySecretsDefault, goodSecret)) ciphers := []cipher.Cipher{ cipher.NewCryptAES(reg), @@ -30,82 +33,71 @@ func TestCipher(t *testing.T) { for _, c := range ciphers { t.Run(fmt.Sprintf("cipher=%T", c), func(t *testing.T) { + t.Parallel() t.Run("case=all_work", func(t *testing.T) { - cfg.MustSet(ctx, config.ViperKeySecretsCipher, goodSecret) - testAllWork(t, c, cfg) + t.Parallel() + + testAllWork(ctx, t, c) }) t.Run("case=encryption_failed", func(t *testing.T) { - // unset secret - err := cfg.Set(ctx, config.ViperKeySecretsCipher, []string{}) - require.NoError(t, err) + t.Parallel() + + ctx := config.WithConfigValue(ctx, config.ViperKeySecretsCipher, []string{""}) // secret have to be set - _, err = c.Encrypt(context.Background(), []byte("not-empty")) + _, err := c.Encrypt(ctx, []byte("not-empty")) require.Error(t, err) + var hErr *herodot.DefaultError + require.ErrorAs(t, err, &hErr) + assert.Equal(t, "Unable to encrypt message because no cipher secrets were configured.", hErr.Reason()) - // unset secret - err = cfg.Set(ctx, config.ViperKeySecretsCipher, []string{"bad-length"}) - require.NoError(t, err) + ctx = config.WithConfigValue(ctx, config.ViperKeySecretsCipher, []string{"bad-length"}) // bad secret length - _, err = c.Encrypt(context.Background(), []byte("not-empty")) - if e, ok := err.(*herodot.DefaultError); ok { - t.Logf("reason contains: %s", e.Reason()) - } - t.Logf("err type %T contains: %s", err, err.Error()) - require.Error(t, err) + _, err = c.Encrypt(ctx, []byte("not-empty")) + require.ErrorAs(t, err, &hErr) + assert.Equal(t, "Unable to encrypt message because no cipher secrets were configured.", hErr.Reason()) }) t.Run("case=decryption_failed", func(t *testing.T) { - // set secret - err := cfg.Set(ctx, config.ViperKeySecretsCipher, goodSecret) - require.NoError(t, err) + t.Parallel() - // - _, err = c.Decrypt(context.Background(), hex.EncodeToString([]byte("bad-data"))) + _, err := c.Decrypt(ctx, hex.EncodeToString([]byte("bad-data"))) require.Error(t, err) - _, err = c.Decrypt(context.Background(), "not-empty") + _, err = c.Decrypt(ctx, "not-empty") require.Error(t, err) - // unset secret - err = cfg.Set(ctx, config.ViperKeySecretsCipher, []string{}) - require.NoError(t, err) - - _, err = c.Decrypt(context.Background(), "not-empty") + _, err = c.Decrypt(config.WithConfigValue(ctx, config.ViperKeySecretsCipher, []string{""}), "not-empty") require.Error(t, err) }) }) } + c := cipher.NewNoop(reg) t.Run(fmt.Sprintf("cipher=%T", c), func(t *testing.T) { - cfg.MustSet(ctx, config.ViperKeySecretsCipher, goodSecret) - testAllWork(t, c, cfg) + t.Parallel() + testAllWork(ctx, t, c) }) } -func testAllWork(t *testing.T, c cipher.Cipher, cfg *config.Config) { - ctx := context.Background() - - goodSecret := []string{"secret-thirty-two-character-long"} - cfg.MustSet(ctx, config.ViperKeySecretsCipher, goodSecret) - +func testAllWork(ctx context.Context, t *testing.T, c cipher.Cipher) { message := "my secret message!" - encryptedSecret, err := c.Encrypt(context.Background(), []byte(message)) + encryptedSecret, err := c.Encrypt(ctx, []byte(message)) require.NoError(t, err) - decryptedSecret, err := c.Decrypt(context.Background(), encryptedSecret) + decryptedSecret, err := c.Decrypt(ctx, encryptedSecret) require.NoError(t, err, "encrypted", encryptedSecret) assert.Equal(t, message, string(decryptedSecret)) // data to encrypt return blank result - _, err = c.Encrypt(context.Background(), []byte("")) + _, err = c.Encrypt(ctx, []byte("")) require.NoError(t, err) // empty encrypted data return blank - _, err = c.Decrypt(context.Background(), "") + _, err = c.Decrypt(ctx, "") require.NoError(t, err) } diff --git a/cmd/courier/watch_test.go b/cmd/courier/watch_test.go index 48fd6515f53b..b521e9119a97 100644 --- a/cmd/courier/watch_test.go +++ b/cmd/courier/watch_test.go @@ -13,6 +13,7 @@ import ( "github.com/stretchr/testify/require" "github.com/ory/kratos/internal" + "github.com/ory/x/configx" ) func TestStartCourier(t *testing.T) { @@ -27,10 +28,9 @@ func TestStartCourier(t *testing.T) { t.Run("case=with metrics", func(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) - _, r := internal.NewFastRegistryWithMocks(t) port, err := freeport.GetFreePort() require.NoError(t, err) - r.Config().Set(ctx, "expose-metrics-port", port) + _, r := internal.NewFastRegistryWithMocks(t, configx.WithValue("expose-metrics-port", port)) go StartCourier(ctx, r) time.Sleep(time.Second) res, err := http.Get("http://" + r.Config().MetricsListenOn(ctx) + "/metrics/prometheus") diff --git a/cmd/hashers/argon2/root.go b/cmd/hashers/argon2/root.go index c5cb76581590..2282f0404d4a 100644 --- a/cmd/hashers/argon2/root.go +++ b/cmd/hashers/argon2/root.go @@ -9,6 +9,8 @@ import ( "reflect" "strings" + "github.com/ory/x/contextx" + "github.com/spf13/cobra" "github.com/spf13/pflag" @@ -70,6 +72,7 @@ func configProvider(cmd *cobra.Command, flagConf *argon2Config) (*argon2Config, cmd.Context(), l, cmd.ErrOrStderr(), + &contextx.Default{}, configx.WithFlags(cmd.Flags()), configx.SkipValidation(), configx.WithContext(cmd.Context()), diff --git a/courier/template/load_template_test.go b/courier/template/load_template_test.go index e6b043aa0c54..1fd245497ca9 100644 --- a/courier/template/load_template_test.go +++ b/courier/template/load_template_test.go @@ -182,7 +182,7 @@ func TestLoadTextTemplate(t *testing.T) { }) t.Run("case=disallowed resources", func(t *testing.T) { - require.NoError(t, reg.Config().GetProvider(ctx).Set(config.ViperKeyClientHTTPNoPrivateIPRanges, true)) + require.NoError(t, reg.Config().Set(ctx, config.ViperKeyClientHTTPNoPrivateIPRanges, true)) reg.HTTPClient(ctx).RetryMax = 1 reg.HTTPClient(ctx).RetryWaitMax = time.Millisecond diff --git a/driver/config/config.go b/driver/config/config.go index 05d7ddef52a7..9f3c1b38938b 100644 --- a/driver/config/config.go +++ b/driver/config/config.go @@ -367,13 +367,13 @@ func (s Schemas) FindSchemaByID(id string) (*Schema, error) { return nil, errors.Errorf("unable to find identity schema with id: %s", id) } -func MustNew(t testing.TB, l *logrusx.Logger, stdOutOrErr io.Writer, opts ...configx.OptionModifier) *Config { - p, err := New(context.TODO(), l, stdOutOrErr, opts...) +func MustNew(t testing.TB, l *logrusx.Logger, stdOutOrErr io.Writer, ctxer contextx.Contextualizer, opts ...configx.OptionModifier) *Config { + p, err := New(context.TODO(), l, stdOutOrErr, ctxer, opts...) require.NoError(t, err) return p } -func New(ctx context.Context, l *logrusx.Logger, stdOutOrErr io.Writer, opts ...configx.OptionModifier) (*Config, error) { +func New(ctx context.Context, l *logrusx.Logger, stdOutOrErr io.Writer, ctxer contextx.Contextualizer, opts ...configx.OptionModifier) (*Config, error) { var c *Config opts = append([]configx.OptionModifier{ @@ -402,7 +402,7 @@ func New(ctx context.Context, l *logrusx.Logger, stdOutOrErr io.Writer, opts ... l.UseConfig(p) - c = NewCustom(l, p, stdOutOrErr, &contextx.Default{}) + c = NewCustom(l, p, stdOutOrErr, ctxer) if !p.SkipValidation() { if err := c.validateIdentitySchemas(ctx); err != nil { @@ -518,12 +518,14 @@ func (p *Config) cors(ctx context.Context, prefix string) (cors.Options, bool) { }) } +// Deprecatd: use context-based WithConfigValue instead func (p *Config) Set(ctx context.Context, key string, value interface{}) error { - return p.GetProvider(ctx).Set(key, value) + return p.p.Set(key, value) } +// Deprecated: use context-based WithConfigValue instead func (p *Config) MustSet(ctx context.Context, key string, value interface{}) { - if err := p.GetProvider(ctx).Set(key, value); err != nil { + if err := p.p.Set(key, value); err != nil { p.l.WithError(err).Fatalf("Unable to set \"%s\" to \"%s\".", key, value) } } @@ -859,7 +861,7 @@ func (p *Config) SecretsCipher(ctx context.Context) [][32]byte { result := make([][32]byte, len(cleanSecrets)) for n, s := range secrets { for k, v := range []byte(s) { - result[n][k] = byte(v) + result[n][k] = v } } return result diff --git a/driver/config/config_test.go b/driver/config/config_test.go index 6cb37f100850..dc276eb3a171 100644 --- a/driver/config/config_test.go +++ b/driver/config/config_test.go @@ -18,6 +18,8 @@ import ( "testing" "time" + "github.com/ory/x/contextx" + "github.com/ory/x/httpx" "github.com/ory/x/randx" @@ -51,6 +53,7 @@ func TestViperProvider(t *testing.T) { t.Run("suite=loaders", func(t *testing.T) { p := config.MustNew(t, logrusx.New("", ""), os.Stderr, + &contextx.Default{}, configx.WithConfigFiles("stub/.kratos.yaml"), configx.WithContext(ctx), ) @@ -89,6 +92,7 @@ func TestViperProvider(t *testing.T) { pWithFragments := config.MustNew(t, logrusx.New("", ""), os.Stderr, + &contextx.Default{}, configx.WithValues(map[string]interface{}{ config.ViperKeySelfServiceLoginUI: "http://test.kratos.ory.sh/#/login", config.ViperKeySelfServiceSettingsURL: "http://test.kratos.ory.sh/#/settings", @@ -105,6 +109,7 @@ func TestViperProvider(t *testing.T) { pWithRelativeFragments := config.MustNew(t, logrusx.New("", ""), os.Stderr, + &contextx.Default{}, configx.WithValues(map[string]interface{}{ config.ViperKeySelfServiceLoginUI: "/login", config.ViperKeySelfServiceSettingsURL: "/settings", @@ -130,6 +135,7 @@ func TestViperProvider(t *testing.T) { pWithIncorrectUrls := config.MustNew(t, logger, os.Stderr, + &contextx.Default{}, configx.WithValues(map[string]interface{}{ config.ViperKeySelfServiceLoginUI: v, }), @@ -161,6 +167,7 @@ func TestViperProvider(t *testing.T) { t.Run("group=identity", func(t *testing.T) { c := config.MustNew(t, logrusx.New("", ""), os.Stderr, + &contextx.Default{}, configx.WithConfigFiles("stub/.kratos.mock.identities.yaml"), configx.SkipValidation()) @@ -198,7 +205,7 @@ func TestViperProvider(t *testing.T) { }, p.SecretsSession(ctx)) var cipherExpected [32]byte for k, v := range []byte("secret-thirty-two-character-long") { - cipherExpected[k] = byte(v) + cipherExpected[k] = v } assert.Equal(t, [][32]byte{ cipherExpected, @@ -400,7 +407,7 @@ func TestViperProvider(t *testing.T) { func TestBcrypt(t *testing.T) { t.Parallel() ctx := context.Background() - p := config.MustNew(t, logrusx.New("", ""), os.Stderr, configx.SkipValidation()) + p := config.MustNew(t, logrusx.New("", ""), os.Stderr, &contextx.Default{}, configx.SkipValidation()) require.NoError(t, p.Set(ctx, config.ViperKeyHasherBcryptCost, 4)) require.NoError(t, p.Set(ctx, "dev", false)) @@ -418,7 +425,7 @@ func TestProviderBaseURLs(t *testing.T) { machineHostname = "127.0.0.1" } - p := config.MustNew(t, logrusx.New("", ""), os.Stderr, configx.SkipValidation()) + p := config.MustNew(t, logrusx.New("", ""), os.Stderr, &contextx.Default{}, configx.SkipValidation()) assert.Equal(t, "https://"+machineHostname+":4433/", p.SelfPublicURL(ctx).String()) assert.Equal(t, "https://"+machineHostname+":4434/", p.SelfAdminURL(ctx).String()) @@ -446,7 +453,7 @@ func TestProviderSelfServiceLinkMethodBaseURL(t *testing.T) { machineHostname = "127.0.0.1" } - p := config.MustNew(t, logrusx.New("", ""), os.Stderr, configx.SkipValidation()) + p := config.MustNew(t, logrusx.New("", ""), os.Stderr, &contextx.Default{}, configx.SkipValidation()) assert.Equal(t, "https://"+machineHostname+":4433/", p.SelfServiceLinkMethodBaseURL(ctx).String()) p.MustSet(ctx, config.ViperKeyLinkBaseURL, "https://example.org/bar") @@ -456,7 +463,7 @@ func TestProviderSelfServiceLinkMethodBaseURL(t *testing.T) { func TestViperProvider_Secrets(t *testing.T) { t.Parallel() ctx := context.Background() - p := config.MustNew(t, logrusx.New("", ""), os.Stderr, configx.SkipValidation()) + p := config.MustNew(t, logrusx.New("", ""), os.Stderr, &contextx.Default{}, configx.SkipValidation()) def := p.SecretsDefault(ctx) assert.NotEmpty(t, def) @@ -479,24 +486,25 @@ func TestViperProvider_Defaults(t *testing.T) { }{ { init: func() *config.Config { - return config.MustNew(t, l, os.Stderr, configx.SkipValidation()) + return config.MustNew(t, l, os.Stderr, &contextx.Default{}, configx.SkipValidation()) }, }, { init: func() *config.Config { return config.MustNew(t, l, os.Stderr, + &contextx.Default{}, configx.WithConfigFiles("stub/.defaults.yml"), configx.SkipValidation()) }, }, { init: func() *config.Config { - return config.MustNew(t, l, os.Stderr, configx.WithConfigFiles("stub/.defaults-password.yml"), configx.SkipValidation()) + return config.MustNew(t, l, os.Stderr, &contextx.Default{}, configx.WithConfigFiles("stub/.defaults-password.yml"), configx.SkipValidation()) }, }, { init: func() *config.Config { - return config.MustNew(t, l, os.Stderr, configx.WithConfigFiles("../../test/e2e/profiles/recovery/.kratos.yml"), configx.SkipValidation()) + return config.MustNew(t, l, os.Stderr, &contextx.Default{}, configx.WithConfigFiles("../../test/e2e/profiles/recovery/.kratos.yml"), configx.SkipValidation()) }, expect: func(t *testing.T, p *config.Config) { assert.True(t, p.SelfServiceFlowRecoveryEnabled(ctx)) @@ -512,7 +520,7 @@ func TestViperProvider_Defaults(t *testing.T) { }, { init: func() *config.Config { - return config.MustNew(t, l, os.Stderr, configx.WithConfigFiles("../../test/e2e/profiles/verification/.kratos.yml"), configx.SkipValidation()) + return config.MustNew(t, l, os.Stderr, &contextx.Default{}, configx.WithConfigFiles("../../test/e2e/profiles/verification/.kratos.yml"), configx.SkipValidation()) }, expect: func(t *testing.T, p *config.Config) { assert.False(t, p.SelfServiceFlowRecoveryEnabled(ctx)) @@ -528,7 +536,7 @@ func TestViperProvider_Defaults(t *testing.T) { }, { init: func() *config.Config { - return config.MustNew(t, l, os.Stderr, configx.WithConfigFiles("../../test/e2e/profiles/oidc/.kratos.yml"), configx.SkipValidation()) + return config.MustNew(t, l, os.Stderr, &contextx.Default{}, configx.WithConfigFiles("../../test/e2e/profiles/oidc/.kratos.yml"), configx.SkipValidation()) }, expect: func(t *testing.T, p *config.Config) { assert.False(t, p.SelfServiceFlowRecoveryEnabled(ctx)) @@ -543,7 +551,7 @@ func TestViperProvider_Defaults(t *testing.T) { }, { init: func() *config.Config { - return config.MustNew(t, l, os.Stderr, configx.WithConfigFiles("stub/.kratos.notify-unknown-recipients.yml"), configx.SkipValidation()) + return config.MustNew(t, l, os.Stderr, &contextx.Default{}, configx.WithConfigFiles("stub/.kratos.notify-unknown-recipients.yml"), configx.SkipValidation()) }, expect: func(t *testing.T, p *config.Config) { assert.True(t, p.SelfServiceFlowRecoveryNotifyUnknownRecipients(ctx)) @@ -572,7 +580,7 @@ func TestViperProvider_Defaults(t *testing.T) { } t.Run("suite=ui_url", func(t *testing.T) { - p := config.MustNew(t, l, os.Stderr, configx.SkipValidation()) + p := config.MustNew(t, l, os.Stderr, &contextx.Default{}, configx.SkipValidation()) assert.Equal(t, "https://www.ory.sh/kratos/docs/fallback/login", p.SelfServiceFlowLoginUI(ctx).String()) assert.Equal(t, "https://www.ory.sh/kratos/docs/fallback/settings", p.SelfServiceFlowSettingsUI(ctx).String()) assert.Equal(t, "https://www.ory.sh/kratos/docs/fallback/registration", p.SelfServiceFlowRegistrationUI(ctx).String()) @@ -585,7 +593,7 @@ func TestViperProvider_ReturnTo(t *testing.T) { t.Parallel() ctx := context.Background() l := logrusx.New("", "") - p := config.MustNew(t, l, os.Stderr, configx.SkipValidation()) + p := config.MustNew(t, l, os.Stderr, &contextx.Default{}, configx.SkipValidation()) p.MustSet(ctx, config.ViperKeySelfServiceBrowserDefaultReturnTo, "https://www.ory.sh/") assert.Equal(t, "https://www.ory.sh/", p.SelfServiceFlowVerificationReturnTo(ctx, urlx.ParseOrPanic("https://www.ory.sh/")).String()) @@ -602,7 +610,7 @@ func TestSession(t *testing.T) { t.Parallel() ctx := context.Background() l := logrusx.New("", "") - p := config.MustNew(t, l, os.Stderr, configx.SkipValidation()) + p := config.MustNew(t, l, os.Stderr, &contextx.Default{}, configx.SkipValidation()) assert.Equal(t, "ory_kratos_session", p.SessionName(ctx)) p.MustSet(ctx, config.ViperKeySessionName, "ory_session") @@ -629,7 +637,7 @@ func TestCookies(t *testing.T) { t.Parallel() ctx := context.Background() l := logrusx.New("", "") - p := config.MustNew(t, l, os.Stderr, configx.SkipValidation()) + p := config.MustNew(t, l, os.Stderr, &contextx.Default{}, configx.SkipValidation()) t.Run("path", func(t *testing.T) { assert.Equal(t, "/", p.CookiePath(ctx)) @@ -676,14 +684,14 @@ func TestViperProvider_DSN(t *testing.T) { ctx := context.Background() t.Run("case=dsn: memory", func(t *testing.T) { - p := config.MustNew(t, logrusx.New("", ""), os.Stderr, configx.SkipValidation()) + p := config.MustNew(t, logrusx.New("", ""), os.Stderr, &contextx.Default{}, configx.SkipValidation()) p.MustSet(ctx, config.ViperKeyDSN, "memory") assert.Equal(t, config.DefaultSQLiteMemoryDSN, p.DSN(ctx)) }) t.Run("case=dsn: not memory", func(t *testing.T) { - p := config.MustNew(t, logrusx.New("", ""), os.Stderr, configx.SkipValidation()) + p := config.MustNew(t, logrusx.New("", ""), os.Stderr, &contextx.Default{}, configx.SkipValidation()) dsn := "sqlite://foo.db?_fk=true" p.MustSet(ctx, config.ViperKeyDSN, dsn) @@ -698,7 +706,7 @@ func TestViperProvider_DSN(t *testing.T) { l := logrusx.New("", "", logrusx.WithExitFunc(func(i int) { exitCode = i })) - p := config.MustNew(t, l, os.Stderr, configx.SkipValidation()) + p := config.MustNew(t, l, os.Stderr, &contextx.Default{}, configx.SkipValidation()) assert.Equal(t, dsn, p.DSN(ctx)) assert.NotEqual(t, 0, exitCode) @@ -714,7 +722,7 @@ func TestViperProvider_ParseURIOrFail(t *testing.T) { l := logrusx.New("", "", logrusx.WithExitFunc(func(i int) { exitCode = i })) - p := config.MustNew(t, l, os.Stderr, configx.SkipValidation()) + p := config.MustNew(t, l, os.Stderr, &contextx.Default{}, configx.SkipValidation()) require.Zero(t, exitCode) const testKey = "testKeyNotUsedInTheRealSchema" @@ -768,7 +776,7 @@ func TestViperProvider_HaveIBeenPwned(t *testing.T) { t.Parallel() ctx := context.Background() - p := config.MustNew(t, logrusx.New("", ""), os.Stderr, configx.SkipValidation()) + p := config.MustNew(t, logrusx.New("", ""), os.Stderr, &contextx.Default{}, configx.SkipValidation()) t.Run("case=hipb: host", func(t *testing.T) { p.MustSet(ctx, config.ViperKeyPasswordHaveIBeenPwnedHost, "foo.bar") assert.Equal(t, "foo.bar", p.PasswordPolicyConfig(ctx).HaveIBeenPwnedHost) @@ -806,7 +814,7 @@ func newTestConfig(t *testing.T) (_ *config.Config, _ *test.Hook, exited *bool) exited = new(bool) l.Logger.Hooks.Add(h) l.Logger.ExitFunc = func(code int) { *exited = true } - config := config.MustNew(t, l, os.Stderr, configx.SkipValidation()) + config := config.MustNew(t, l, os.Stderr, &contextx.Default{}, configx.SkipValidation()) return config, h, exited } @@ -972,7 +980,7 @@ func TestIdentitySchemaValidation(t *testing.T) { l := logrusx.New("kratos-"+tmpConfig.Name(), "test") hook := test.NewLocal(l.Logger) - conf, err := config.New(ctx, l, os.Stderr, configx.WithConfigFiles(tmpConfig.Name())) + conf, err := config.New(ctx, l, os.Stderr, &contextx.Default{}, configx.WithConfigFiles(tmpConfig.Name())) assert.NoError(t, err) // clean the hooks since it will throw an event on first boot @@ -986,7 +994,7 @@ func TestIdentitySchemaValidation(t *testing.T) { t.Run("case=skip invalid schema validation", func(t *testing.T) { ctx := ctx - _, err := config.New(ctx, logrusx.New("", ""), os.Stderr, + _, err := config.New(ctx, logrusx.New("", ""), os.Stderr, &contextx.Default{}, configx.WithConfigFiles("stub/.kratos.invalid.identities.yaml"), configx.SkipValidation()) assert.NoError(t, err) @@ -995,7 +1003,7 @@ func TestIdentitySchemaValidation(t *testing.T) { t.Run("case=invalid schema should throw error", func(t *testing.T) { ctx := ctx var stdErr bytes.Buffer - _, err := config.New(ctx, logrusx.New("", ""), &stdErr, + _, err := config.New(ctx, logrusx.New("", ""), &stdErr, &contextx.Default{}, configx.WithConfigFiles("stub/.kratos.invalid.identities.yaml")) assert.Error(t, err) assert.Contains(t, err.Error(), "minimum 1 properties allowed, but found 0") @@ -1013,7 +1021,7 @@ func TestIdentitySchemaValidation(t *testing.T) { err := make(chan error, 1) go func(err chan error) { - _, e := config.New(ctx, logrusx.New("", ""), os.Stderr, + _, e := config.New(ctx, logrusx.New("", ""), os.Stderr, &contextx.Default{}, configx.WithConfigFiles("stub/.kratos.mock.identities.yaml")) err <- e }(err) @@ -1068,7 +1076,7 @@ func TestPasswordless(t *testing.T) { t.Parallel() ctx := context.Background() - conf, err := config.New(ctx, logrusx.New("", ""), os.Stderr, + conf, err := config.New(ctx, logrusx.New("", ""), os.Stderr, &contextx.Default{}, configx.SkipValidation(), configx.WithValue(config.ViperKeyWebAuthnPasswordless, true)) require.NoError(t, err) @@ -1083,7 +1091,7 @@ func TestPasswordlessCode(t *testing.T) { ctx := context.Background() - conf, err := config.New(ctx, logrusx.New("", ""), os.Stderr, + conf, err := config.New(ctx, logrusx.New("", ""), os.Stderr, &contextx.Default{}, configx.SkipValidation(), configx.WithValue(config.ViperKeySelfServiceStrategyConfig+".code", map[string]interface{}{ "passwordless_enabled": true, @@ -1100,7 +1108,7 @@ func TestChangeMinPasswordLength(t *testing.T) { t.Run("case=must fail on minimum password length below enforced minimum", func(t *testing.T) { ctx := context.Background() - _, err := config.New(ctx, logrusx.New("", ""), os.Stderr, + _, err := config.New(ctx, logrusx.New("", ""), os.Stderr, &contextx.Default{}, configx.WithConfigFiles("stub/.kratos.yaml"), configx.WithValue(config.ViperKeyPasswordMinLength, 5)) @@ -1110,7 +1118,7 @@ func TestChangeMinPasswordLength(t *testing.T) { t.Run("case=must not fail on minimum password length above enforced minimum", func(t *testing.T) { ctx := context.Background() - _, err := config.New(ctx, logrusx.New("", ""), os.Stderr, + _, err := config.New(ctx, logrusx.New("", ""), os.Stderr, &contextx.Default{}, configx.WithConfigFiles("stub/.kratos.yaml"), configx.WithValue(config.ViperKeyPasswordMinLength, 9)) @@ -1123,14 +1131,14 @@ func TestCourierEmailHTTP(t *testing.T) { ctx := context.Background() t.Run("case=configs set", func(t *testing.T) { - conf, _ := config.New(ctx, logrusx.New("", ""), os.Stderr, + conf, _ := config.New(ctx, logrusx.New("", ""), os.Stderr, &contextx.Default{}, configx.WithConfigFiles("stub/.kratos.courier.email.http.yaml"), configx.SkipValidation()) assert.Equal(t, "http", conf.CourierEmailStrategy(ctx)) snapshotx.SnapshotT(t, conf.CourierEmailRequestConfig(ctx)) }) t.Run("case=defaults", func(t *testing.T) { - conf, _ := config.New(ctx, logrusx.New("", ""), os.Stderr, configx.SkipValidation()) + conf, _ := config.New(ctx, logrusx.New("", ""), os.Stderr, &contextx.Default{}, configx.SkipValidation()) assert.Equal(t, "smtp", conf.CourierEmailStrategy(ctx)) }) @@ -1140,7 +1148,7 @@ func TestCourierChannels(t *testing.T) { t.Parallel() ctx := context.Background() t.Run("case=configs set", func(t *testing.T) { - conf, _ := config.New(ctx, logrusx.New("", ""), os.Stderr, configx.WithConfigFiles("stub/.kratos.courier.channels.yaml"), configx.SkipValidation()) + conf, _ := config.New(ctx, logrusx.New("", ""), os.Stderr, &contextx.Default{}, configx.WithConfigFiles("stub/.kratos.courier.channels.yaml"), configx.SkipValidation()) channelConfig, err := conf.CourierChannels(ctx) require.NoError(t, err) @@ -1152,7 +1160,7 @@ func TestCourierChannels(t *testing.T) { }) t.Run("case=defaults", func(t *testing.T) { - conf, _ := config.New(ctx, logrusx.New("", ""), os.Stderr, configx.SkipValidation()) + conf, _ := config.New(ctx, logrusx.New("", ""), os.Stderr, &contextx.Default{}, configx.SkipValidation()) channelConfig, err := conf.CourierChannels(ctx) require.NoError(t, err) @@ -1171,7 +1179,7 @@ func TestCourierChannels(t *testing.T) { "smtp://username:pass%2Fword@email-smtp.eu-west-3.amazonaws.com:587/", } { t.Run("case="+tc, func(t *testing.T) { - conf, err := config.New(ctx, logrusx.New("", ""), os.Stderr, configx.WithValue(config.ViperKeyCourierSMTPURL, tc), configx.SkipValidation()) + conf, err := config.New(ctx, logrusx.New("", ""), os.Stderr, &contextx.Default{}, configx.WithValue(config.ViperKeyCourierSMTPURL, tc), configx.SkipValidation()) require.NoError(t, err) cs, err := conf.CourierChannels(ctx) require.NoError(t, err) @@ -1187,13 +1195,13 @@ func TestCourierMessageTTL(t *testing.T) { ctx := context.Background() t.Run("case=configs set", func(t *testing.T) { - conf, _ := config.New(ctx, logrusx.New("", ""), os.Stderr, + conf, _ := config.New(ctx, logrusx.New("", ""), os.Stderr, &contextx.Default{}, configx.WithConfigFiles("stub/.kratos.courier.message_retries.yaml"), configx.SkipValidation()) assert.Equal(t, conf.CourierMessageRetries(ctx), 10) }) t.Run("case=defaults", func(t *testing.T) { - conf, _ := config.New(ctx, logrusx.New("", ""), os.Stderr, configx.SkipValidation()) + conf, _ := config.New(ctx, logrusx.New("", ""), os.Stderr, &contextx.Default{}, configx.SkipValidation()) assert.Equal(t, conf.CourierMessageRetries(ctx), 5) }) } @@ -1203,7 +1211,7 @@ func TestOAuth2Provider(t *testing.T) { ctx := context.Background() t.Run("case=configs set", func(t *testing.T) { - conf, _ := config.New(ctx, logrusx.New("", ""), os.Stderr, + conf, _ := config.New(ctx, logrusx.New("", ""), os.Stderr, &contextx.Default{}, configx.WithConfigFiles("stub/.kratos.oauth2_provider.yaml"), configx.SkipValidation()) assert.Equal(t, "https://oauth2_provider/", conf.OAuth2ProviderURL(ctx).String()) assert.Equal(t, http.Header{"Authorization": {"Basic"}}, conf.OAuth2ProviderHeader(ctx)) @@ -1211,7 +1219,7 @@ func TestOAuth2Provider(t *testing.T) { }) t.Run("case=defaults", func(t *testing.T) { - conf, _ := config.New(ctx, logrusx.New("", ""), os.Stderr, configx.SkipValidation()) + conf, _ := config.New(ctx, logrusx.New("", ""), os.Stderr, &contextx.Default{}, configx.SkipValidation()) assert.Empty(t, conf.OAuth2ProviderURL(ctx)) assert.Empty(t, conf.OAuth2ProviderHeader(ctx)) assert.False(t, conf.OAuth2ProviderOverrideReturnTo(ctx)) @@ -1223,7 +1231,7 @@ func TestWebauthn(t *testing.T) { ctx := context.Background() t.Run("case=multiple origins", func(t *testing.T) { - conf, err := config.New(ctx, logrusx.New("", ""), os.Stderr, + conf, err := config.New(ctx, logrusx.New("", ""), os.Stderr, &contextx.Default{}, configx.WithConfigFiles("stub/.kratos.webauthn.origins.yaml")) require.NoError(t, err) webAuthnConfig := conf.WebAuthnConfig(ctx) @@ -1236,7 +1244,7 @@ func TestWebauthn(t *testing.T) { }) t.Run("case=one origin", func(t *testing.T) { - conf, err := config.New(ctx, logrusx.New("", ""), os.Stderr, + conf, err := config.New(ctx, logrusx.New("", ""), os.Stderr, &contextx.Default{}, configx.WithConfigFiles("stub/.kratos.webauthn.origin.yaml")) require.NoError(t, err) webAuthnConfig := conf.WebAuthnConfig(ctx) @@ -1247,7 +1255,7 @@ func TestWebauthn(t *testing.T) { }) t.Run("case=id as origin", func(t *testing.T) { - conf, err := config.New(ctx, logrusx.New("", ""), os.Stderr, + conf, err := config.New(ctx, logrusx.New("", ""), os.Stderr, &contextx.Default{}, configx.WithConfigFiles("stub/.kratos.yaml")) require.NoError(t, err) webAuthnConfig := conf.WebAuthnConfig(ctx) @@ -1258,7 +1266,7 @@ func TestWebauthn(t *testing.T) { }) t.Run("case=invalid", func(t *testing.T) { - _, err := config.New(ctx, logrusx.New("", ""), os.Stderr, + _, err := config.New(ctx, logrusx.New("", ""), os.Stderr, &contextx.Default{}, configx.WithConfigFiles("stub/.kratos.webauthn.invalid.yaml")) assert.Error(t, err) }) @@ -1269,19 +1277,19 @@ func TestCourierTemplatesConfig(t *testing.T) { ctx := context.Background() t.Run("case=partial template update allowed", func(t *testing.T) { - _, err := config.New(ctx, logrusx.New("", ""), os.Stderr, + _, err := config.New(ctx, logrusx.New("", ""), os.Stderr, &contextx.Default{}, configx.WithConfigFiles("stub/.kratos.courier.remote.partial.templates.yaml")) assert.NoError(t, err) }) t.Run("case=load remote template with fallback template overrides path", func(t *testing.T) { - _, err := config.New(ctx, logrusx.New("", ""), os.Stderr, + _, err := config.New(ctx, logrusx.New("", ""), os.Stderr, &contextx.Default{}, configx.WithConfigFiles("stub/.kratos.courier.remote.templates.yaml")) assert.NoError(t, err) }) t.Run("case=courier template helper", func(t *testing.T) { - c, err := config.New(ctx, logrusx.New("", ""), os.Stderr, + c, err := config.New(ctx, logrusx.New("", ""), os.Stderr, &contextx.Default{}, configx.WithConfigFiles("stub/.kratos.courier.remote.templates.yaml")) assert.NoError(t, err) @@ -1323,7 +1331,7 @@ func TestCleanup(t *testing.T) { t.Parallel() ctx := context.Background() - p := config.MustNew(t, logrusx.New("", ""), os.Stderr, + p := config.MustNew(t, logrusx.New("", ""), os.Stderr, &contextx.Default{}, configx.WithConfigFiles("stub/.kratos.yaml")) t.Run("group=cleanup config", func(t *testing.T) { diff --git a/driver/config/test_config.go b/driver/config/test_config.go new file mode 100644 index 000000000000..459ae15ac89c --- /dev/null +++ b/driver/config/test_config.go @@ -0,0 +1,75 @@ +// Copyright © 2024 Ory Corp +// SPDX-License-Identifier: Apache-2.0 + +package config + +import ( + "context" + "strings" + + "github.com/knadh/koanf/maps" + + "github.com/ory/kratos/embedx" + "github.com/ory/x/configx" + "github.com/ory/x/contextx" +) + +type ( + TestConfigProvider struct { + contextx.Contextualizer + Options []configx.OptionModifier + } + contextKey int + mapProvider map[string]any +) + +func (t *TestConfigProvider) NewProvider(ctx context.Context, opts ...configx.OptionModifier) (*configx.Provider, error) { + return configx.New(ctx, []byte(embedx.ConfigSchema), append(t.Options, opts...)...) +} + +func (t *TestConfigProvider) Config(ctx context.Context, config *configx.Provider) *configx.Provider { + config = t.Contextualizer.Config(ctx, config) + values, ok := ctx.Value(contextConfigKey).(mapProvider) + if !ok { + return config + } + config, err := t.NewProvider(ctx, configx.WithValues(values)) + if err != nil { + // This is not production code. The provider is only used in tests. + panic(err) + } + return config +} + +const contextConfigKey contextKey = 1 + +var ( + _ contextx.Contextualizer = (*TestConfigProvider)(nil) +) + +func WithConfigValue(ctx context.Context, key string, value any) context.Context { + return WithConfigValues(ctx, map[string]any{key: value}) +} + +func WithConfigValues(ctx context.Context, newValues map[string]any) context.Context { + values, ok := ctx.Value(contextConfigKey).(mapProvider) + if !ok { + values = make(mapProvider) + } + expandedValues := make([]map[string]any, 0, len(newValues)) + for k, v := range newValues { + parts := strings.Split(k, ".") + val := map[string]any{parts[len(parts)-1]: v} + if len(parts) > 1 { + for i := len(parts) - 2; i >= 0; i-- { + val = map[string]any{parts[i]: val} + } + } + expandedValues = append(expandedValues, val) + } + for _, v := range expandedValues { + maps.Merge(v, values) + } + + return context.WithValue(ctx, contextConfigKey, values) +} diff --git a/driver/factory.go b/driver/factory.go index e3470d3cffd9..da0dd5601e2b 100644 --- a/driver/factory.go +++ b/driver/factory.go @@ -38,7 +38,7 @@ func NewWithoutInit(ctx context.Context, stdOutOrErr io.Writer, sl *servicelocat c := newOptions(dOpts).config if c == nil { var err error - c, err = config.New(ctx, l, stdOutOrErr, opts...) + c, err = config.New(ctx, l, stdOutOrErr, sl.Contextualizer(), opts...) if err != nil { l.WithError(err).Error("Unable to instantiate configuration.") return nil, err diff --git a/driver/registry_default_test.go b/driver/registry_default_test.go index 009dd76173d8..020517a41159 100644 --- a/driver/registry_default_test.go +++ b/driver/registry_default_test.go @@ -10,6 +10,8 @@ import ( "os" "testing" + "github.com/ory/x/contextx" + "github.com/ory/kratos/selfservice/flow/recovery" "github.com/ory/kratos/selfservice/flow/verification" @@ -34,26 +36,27 @@ func TestDriverDefault_Hooks(t *testing.T) { t.Parallel() ctx := context.Background() + _, reg := internal.NewVeryFastRegistryWithoutDB(t) + t.Run("type=verification", func(t *testing.T) { t.Parallel() // BEFORE hooks for _, tc := range []struct { uc string - prep func(conf *config.Config) + config map[string]any expect func(reg *driver.RegistryDefault) []verification.PreHookExecutor }{ { uc: "No hooks configured", - prep: func(conf *config.Config) {}, expect: func(reg *driver.RegistryDefault) []verification.PreHookExecutor { return nil }, }, { uc: "Two web_hooks are configured", - prep: func(conf *config.Config) { - conf.MustSet(ctx, config.ViperKeySelfServiceVerificationBeforeHooks, []map[string]interface{}{ - {"hook": "web_hook", "config": map[string]interface{}{"url": "foo", "method": "POST", "headers": map[string]string{"X-Custom-Header": "test"}}}, - {"hook": "web_hook", "config": map[string]interface{}{"url": "bar", "method": "GET", "headers": map[string]string{"X-Custom-Header": "test"}}}, - }) + config: map[string]any{ + config.ViperKeySelfServiceVerificationBeforeHooks: []map[string]any{ + {"hook": "web_hook", "config": map[string]any{"url": "foo", "method": "POST", "headers": map[string]string{"X-Custom-Header": "test"}}}, + {"hook": "web_hook", "config": map[string]any{"url": "bar", "method": "GET", "headers": map[string]string{"X-Custom-Header": "test"}}}, + }, }, expect: func(reg *driver.RegistryDefault) []verification.PreHookExecutor { return []verification.PreHookExecutor{ @@ -64,8 +67,9 @@ func TestDriverDefault_Hooks(t *testing.T) { }, } { t.Run(fmt.Sprintf("before/uc=%s", tc.uc), func(t *testing.T) { - conf, reg := internal.NewVeryFastRegistryWithoutDB(t) - tc.prep(conf) + t.Parallel() + + ctx := config.WithConfigValues(ctx, tc.config) h := reg.PreVerificationHooks(ctx) @@ -79,6 +83,7 @@ func TestDriverDefault_Hooks(t *testing.T) { for _, tc := range []struct { uc string prep func(conf *config.Config) + config map[string]any expect func(reg *driver.RegistryDefault) []verification.PostHookExecutor }{ { @@ -88,11 +93,11 @@ func TestDriverDefault_Hooks(t *testing.T) { }, { uc: "Multiple web_hooks configured", - prep: func(conf *config.Config) { - conf.MustSet(ctx, config.ViperKeySelfServiceVerificationAfter+".hooks", []map[string]interface{}{ - {"hook": "web_hook", "config": map[string]interface{}{"url": "foo", "method": "POST", "headers": map[string]string{"X-Custom-Header": "test"}}}, - {"hook": "web_hook", "config": map[string]interface{}{"url": "bar", "method": "GET", "headers": map[string]string{"X-Custom-Header": "test"}}}, - }) + config: map[string]any{ + config.ViperKeySelfServiceVerificationAfter + ".hooks": []map[string]any{ + {"hook": "web_hook", "config": map[string]any{"url": "foo", "method": "POST", "headers": map[string]string{"X-Custom-Header": "test"}}}, + {"hook": "web_hook", "config": map[string]any{"url": "bar", "method": "GET", "headers": map[string]string{"X-Custom-Header": "test"}}}, + }, }, expect: func(reg *driver.RegistryDefault) []verification.PostHookExecutor { return []verification.PostHookExecutor{ @@ -103,8 +108,9 @@ func TestDriverDefault_Hooks(t *testing.T) { }, } { t.Run(fmt.Sprintf("after/uc=%s", tc.uc), func(t *testing.T) { - conf, reg := internal.NewVeryFastRegistryWithoutDB(t) - tc.prep(conf) + t.Parallel() + + ctx := config.WithConfigValues(ctx, tc.config) h := reg.PostVerificationHooks(ctx) @@ -120,21 +126,20 @@ func TestDriverDefault_Hooks(t *testing.T) { // BEFORE hooks for _, tc := range []struct { uc string - prep func(conf *config.Config) + config map[string]any expect func(reg *driver.RegistryDefault) []recovery.PreHookExecutor }{ { uc: "No hooks configured", - prep: func(conf *config.Config) {}, expect: func(reg *driver.RegistryDefault) []recovery.PreHookExecutor { return nil }, }, { uc: "Two web_hooks are configured", - prep: func(conf *config.Config) { - conf.MustSet(ctx, config.ViperKeySelfServiceRecoveryBeforeHooks, []map[string]interface{}{ - {"hook": "web_hook", "config": map[string]interface{}{"url": "foo", "method": "POST", "headers": map[string]string{"X-Custom-Header": "test"}}}, - {"hook": "web_hook", "config": map[string]interface{}{"url": "bar", "method": "GET", "headers": map[string]string{"X-Custom-Header": "test"}}}, - }) + config: map[string]any{ + config.ViperKeySelfServiceRecoveryBeforeHooks: []map[string]any{ + {"hook": "web_hook", "config": map[string]any{"url": "foo", "method": "POST", "headers": map[string]string{"X-Custom-Header": "test"}}}, + {"hook": "web_hook", "config": map[string]any{"url": "bar", "method": "GET", "headers": map[string]string{"X-Custom-Header": "test"}}}, + }, }, expect: func(reg *driver.RegistryDefault) []recovery.PreHookExecutor { return []recovery.PreHookExecutor{ @@ -145,8 +150,9 @@ func TestDriverDefault_Hooks(t *testing.T) { }, } { t.Run(fmt.Sprintf("before/uc=%s", tc.uc), func(t *testing.T) { - conf, reg := internal.NewVeryFastRegistryWithoutDB(t) - tc.prep(conf) + t.Parallel() + + ctx := config.WithConfigValues(ctx, tc.config) h := reg.PreRecoveryHooks(ctx) @@ -159,21 +165,20 @@ func TestDriverDefault_Hooks(t *testing.T) { // AFTER hooks for _, tc := range []struct { uc string - prep func(conf *config.Config) + config map[string]any expect func(reg *driver.RegistryDefault) []recovery.PostHookExecutor }{ { uc: "No hooks configured", - prep: func(conf *config.Config) {}, expect: func(reg *driver.RegistryDefault) []recovery.PostHookExecutor { return nil }, }, { uc: "Multiple web_hooks configured", - prep: func(conf *config.Config) { - conf.MustSet(ctx, config.ViperKeySelfServiceRecoveryAfter+".hooks", []map[string]interface{}{ - {"hook": "web_hook", "config": map[string]interface{}{"url": "foo", "method": "POST", "headers": map[string]string{"X-Custom-Header": "test"}}}, - {"hook": "web_hook", "config": map[string]interface{}{"url": "bar", "method": "GET", "headers": map[string]string{"X-Custom-Header": "test"}}}, - }) + config: map[string]any{ + config.ViperKeySelfServiceRecoveryAfter + ".hooks": []map[string]any{ + {"hook": "web_hook", "config": map[string]any{"url": "foo", "method": "POST", "headers": map[string]string{"X-Custom-Header": "test"}}}, + {"hook": "web_hook", "config": map[string]any{"url": "bar", "method": "GET", "headers": map[string]string{"X-Custom-Header": "test"}}}, + }, }, expect: func(reg *driver.RegistryDefault) []recovery.PostHookExecutor { return []recovery.PostHookExecutor{ @@ -184,8 +189,9 @@ func TestDriverDefault_Hooks(t *testing.T) { }, } { t.Run(fmt.Sprintf("after/uc=%s", tc.uc), func(t *testing.T) { - conf, reg := internal.NewVeryFastRegistryWithoutDB(t) - tc.prep(conf) + t.Parallel() + + ctx := config.WithConfigValues(ctx, tc.config) h := reg.PostRecoveryHooks(ctx) @@ -201,12 +207,11 @@ func TestDriverDefault_Hooks(t *testing.T) { // BEFORE hooks for _, tc := range []struct { uc string - prep func(conf *config.Config) + config map[string]any expect func(reg *driver.RegistryDefault) []registration.PreHookExecutor }{ { - uc: "No hooks configured", - prep: func(conf *config.Config) {}, + uc: "No hooks configured", expect: func(reg *driver.RegistryDefault) []registration.PreHookExecutor { return []registration.PreHookExecutor{ hook.NewTwoStepRegistration(reg), @@ -215,11 +220,11 @@ func TestDriverDefault_Hooks(t *testing.T) { }, { uc: "Two web_hooks are configured", - prep: func(conf *config.Config) { - conf.MustSet(ctx, config.ViperKeySelfServiceRegistrationBeforeHooks, []map[string]interface{}{ - {"hook": "web_hook", "config": map[string]interface{}{"url": "foo", "method": "POST", "headers": map[string]string{"X-Custom-Header": "test"}}}, - {"hook": "web_hook", "config": map[string]interface{}{"url": "bar", "method": "GET", "headers": map[string]string{"X-Custom-Header": "test"}}}, - }) + config: map[string]any{ + config.ViperKeySelfServiceRegistrationBeforeHooks: []map[string]any{ + {"hook": "web_hook", "config": map[string]any{"url": "foo", "method": "POST", "headers": map[string]string{"X-Custom-Header": "test"}}}, + {"hook": "web_hook", "config": map[string]any{"url": "bar", "method": "GET", "headers": map[string]string{"X-Custom-Header": "test"}}}, + }, }, expect: func(reg *driver.RegistryDefault) []registration.PreHookExecutor { return []registration.PreHookExecutor{ @@ -231,8 +236,9 @@ func TestDriverDefault_Hooks(t *testing.T) { }, } { t.Run(fmt.Sprintf("before/uc=%s", tc.uc), func(t *testing.T) { - conf, reg := internal.NewVeryFastRegistryWithoutDB(t) - tc.prep(conf) + t.Parallel() + + ctx := config.WithConfigValues(ctx, tc.config) h := reg.PreRegistrationHooks(ctx) @@ -245,21 +251,20 @@ func TestDriverDefault_Hooks(t *testing.T) { // AFTER hooks for _, tc := range []struct { uc string - prep func(conf *config.Config) + config map[string]any expect func(reg *driver.RegistryDefault) []registration.PostHookPostPersistExecutor }{ { uc: "No hooks configured", - prep: func(conf *config.Config) {}, expect: func(reg *driver.RegistryDefault) []registration.PostHookPostPersistExecutor { return nil }, }, { uc: "Only session hook configured for password strategy", - prep: func(conf *config.Config) { - conf.MustSet(ctx, config.ViperKeySelfServiceVerificationEnabled, true) - conf.MustSet(ctx, config.ViperKeySelfServiceRegistrationAfter+".password.hooks", []map[string]interface{}{ + config: map[string]any{ + config.ViperKeySelfServiceVerificationEnabled: true, + config.ViperKeySelfServiceRegistrationAfter + ".password.hooks": []map[string]any{ {"hook": "session"}, - }) + }, }, expect: func(reg *driver.RegistryDefault) []registration.PostHookPostPersistExecutor { return []registration.PostHookPostPersistExecutor{ @@ -270,12 +275,12 @@ func TestDriverDefault_Hooks(t *testing.T) { }, { uc: "A session hook and a web_hook are configured for password strategy", - prep: func(conf *config.Config) { - conf.MustSet(ctx, config.ViperKeySelfServiceVerificationEnabled, true) - conf.MustSet(ctx, config.ViperKeySelfServiceRegistrationAfter+".password.hooks", []map[string]interface{}{ - {"hook": "web_hook", "config": map[string]interface{}{"headers": map[string]string{"X-Custom-Header": "test"}, "url": "foo", "method": "POST", "body": "bar"}}, + config: map[string]any{ + config.ViperKeySelfServiceVerificationEnabled: true, + config.ViperKeySelfServiceRegistrationAfter + ".password.hooks": []map[string]any{ + {"hook": "web_hook", "config": map[string]any{"headers": map[string]string{"X-Custom-Header": "test"}, "url": "foo", "method": "POST", "body": "bar"}}, {"hook": "session"}, - }) + }, }, expect: func(reg *driver.RegistryDefault) []registration.PostHookPostPersistExecutor { return []registration.PostHookPostPersistExecutor{ @@ -287,11 +292,11 @@ func TestDriverDefault_Hooks(t *testing.T) { }, { uc: "Two web_hooks are configured on a global level", - prep: func(conf *config.Config) { - conf.MustSet(ctx, config.ViperKeySelfServiceRegistrationAfter+".hooks", []map[string]interface{}{ - {"hook": "web_hook", "config": map[string]interface{}{"url": "foo", "method": "POST", "headers": map[string]string{"X-Custom-Header": "test"}}}, - {"hook": "web_hook", "config": map[string]interface{}{"url": "bar", "method": "GET", "headers": map[string]string{"X-Custom-Header": "test"}}}, - }) + config: map[string]any{ + config.ViperKeySelfServiceRegistrationAfter + ".hooks": []map[string]any{ + {"hook": "web_hook", "config": map[string]any{"url": "foo", "method": "POST", "headers": map[string]string{"X-Custom-Header": "test"}}}, + {"hook": "web_hook", "config": map[string]any{"url": "bar", "method": "GET", "headers": map[string]string{"X-Custom-Header": "test"}}}, + }, }, expect: func(reg *driver.RegistryDefault) []registration.PostHookPostPersistExecutor { return []registration.PostHookPostPersistExecutor{ @@ -302,15 +307,15 @@ func TestDriverDefault_Hooks(t *testing.T) { }, { uc: "Hooks are configured on a global level, as well as on a strategy level", - prep: func(conf *config.Config) { - conf.MustSet(ctx, config.ViperKeySelfServiceRegistrationAfter+".password.hooks", []map[string]interface{}{ - {"hook": "web_hook", "config": map[string]interface{}{"url": "foo", "method": "GET", "headers": map[string]string{"X-Custom-Header": "test"}}}, + config: map[string]any{ + config.ViperKeySelfServiceRegistrationAfter + ".password.hooks": []map[string]any{ + {"hook": "web_hook", "config": map[string]any{"url": "foo", "method": "GET", "headers": map[string]string{"X-Custom-Header": "test"}}}, {"hook": "session"}, - }) - conf.MustSet(ctx, config.ViperKeySelfServiceRegistrationAfter+".hooks", []map[string]interface{}{ - {"hook": "web_hook", "config": map[string]interface{}{"url": "bar", "method": "POST", "headers": map[string]string{"X-Custom-Header": "test"}}}, - }) - conf.MustSet(ctx, config.ViperKeySelfServiceVerificationEnabled, true) + }, + config.ViperKeySelfServiceRegistrationAfter + ".hooks": []map[string]any{ + {"hook": "web_hook", "config": map[string]any{"url": "bar", "method": "POST", "headers": map[string]string{"X-Custom-Header": "test"}}}, + }, + config.ViperKeySelfServiceVerificationEnabled: true, }, expect: func(reg *driver.RegistryDefault) []registration.PostHookPostPersistExecutor { return []registration.PostHookPostPersistExecutor{ @@ -322,10 +327,10 @@ func TestDriverDefault_Hooks(t *testing.T) { }, { uc: "show_verification_ui is configured", - prep: func(conf *config.Config) { - conf.MustSet(ctx, config.ViperKeySelfServiceRegistrationAfter+".hooks", []map[string]interface{}{ + config: map[string]any{ + config.ViperKeySelfServiceRegistrationAfter + ".hooks": []map[string]any{ {"hook": "show_verification_ui"}, - }) + }, }, expect: func(reg *driver.RegistryDefault) []registration.PostHookPostPersistExecutor { return []registration.PostHookPostPersistExecutor{ @@ -335,8 +340,9 @@ func TestDriverDefault_Hooks(t *testing.T) { }, } { t.Run(fmt.Sprintf("after/uc=%s", tc.uc), func(t *testing.T) { - conf, reg := internal.NewVeryFastRegistryWithoutDB(t) - tc.prep(conf) + t.Parallel() + + ctx := config.WithConfigValues(ctx, tc.config) h := reg.PostRegistrationPostPersistHooks(ctx, identity.CredentialsTypePassword) @@ -352,21 +358,20 @@ func TestDriverDefault_Hooks(t *testing.T) { // BEFORE hooks for _, tc := range []struct { uc string - prep func(conf *config.Config) + config map[string]any expect func(reg *driver.RegistryDefault) []login.PreHookExecutor }{ { uc: "No hooks configured", - prep: func(conf *config.Config) {}, expect: func(reg *driver.RegistryDefault) []login.PreHookExecutor { return nil }, }, { uc: "Two web_hooks are configured", - prep: func(conf *config.Config) { - conf.MustSet(ctx, config.ViperKeySelfServiceLoginBeforeHooks, []map[string]interface{}{ - {"hook": "web_hook", "config": map[string]interface{}{"url": "foo", "method": "POST", "headers": map[string]string{"X-Custom-Header": "test"}}}, - {"hook": "web_hook", "config": map[string]interface{}{"url": "bar", "method": "GET", "headers": map[string]string{"X-Custom-Header": "test"}}}, - }) + config: map[string]any{ + config.ViperKeySelfServiceLoginBeforeHooks: []map[string]any{ + {"hook": "web_hook", "config": map[string]any{"url": "foo", "method": "POST", "headers": map[string]string{"X-Custom-Header": "test"}}}, + {"hook": "web_hook", "config": map[string]any{"url": "bar", "method": "GET", "headers": map[string]string{"X-Custom-Header": "test"}}}, + }, }, expect: func(reg *driver.RegistryDefault) []login.PreHookExecutor { return []login.PreHookExecutor{ @@ -377,8 +382,9 @@ func TestDriverDefault_Hooks(t *testing.T) { }, } { t.Run(fmt.Sprintf("before/uc=%s", tc.uc), func(t *testing.T) { - conf, reg := internal.NewVeryFastRegistryWithoutDB(t) - tc.prep(conf) + t.Parallel() + + ctx := config.WithConfigValues(ctx, tc.config) h := reg.PreLoginHooks(ctx) @@ -391,20 +397,19 @@ func TestDriverDefault_Hooks(t *testing.T) { // AFTER hooks for _, tc := range []struct { uc string - prep func(conf *config.Config) + config map[string]any expect func(reg *driver.RegistryDefault) []login.PostHookExecutor }{ { uc: "No hooks configured", - prep: func(conf *config.Config) {}, expect: func(reg *driver.RegistryDefault) []login.PostHookExecutor { return nil }, }, { uc: "Only revoke_active_sessions hook configured for password strategy", - prep: func(conf *config.Config) { - conf.MustSet(ctx, config.ViperKeySelfServiceLoginAfter+".password.hooks", []map[string]interface{}{ + config: map[string]any{ + config.ViperKeySelfServiceLoginAfter + ".password.hooks": []map[string]any{ {"hook": "revoke_active_sessions"}, - }) + }, }, expect: func(reg *driver.RegistryDefault) []login.PostHookExecutor { return []login.PostHookExecutor{ @@ -414,10 +419,10 @@ func TestDriverDefault_Hooks(t *testing.T) { }, { uc: "Only require_verified_address hook configured for password strategy", - prep: func(conf *config.Config) { - conf.MustSet(ctx, config.ViperKeySelfServiceLoginAfter+".password.hooks", []map[string]interface{}{ + config: map[string]any{ + config.ViperKeySelfServiceLoginAfter + ".password.hooks": []map[string]any{ {"hook": "require_verified_address"}, - }) + }, }, expect: func(reg *driver.RegistryDefault) []login.PostHookExecutor { return []login.PostHookExecutor{ @@ -427,12 +432,12 @@ func TestDriverDefault_Hooks(t *testing.T) { }, { uc: "A revoke_active_sessions hook, require_verified_address hook and a web_hook are configured for password strategy", - prep: func(conf *config.Config) { - conf.MustSet(ctx, config.ViperKeySelfServiceLoginAfter+".password.hooks", []map[string]interface{}{ - {"hook": "web_hook", "config": map[string]interface{}{"headers": map[string]string{"X-Custom-Header": "test"}, "url": "foo", "method": "POST", "body": "bar"}}, + config: map[string]any{ + config.ViperKeySelfServiceLoginAfter + ".password.hooks": []map[string]any{ + {"hook": "web_hook", "config": map[string]any{"headers": map[string]string{"X-Custom-Header": "test"}, "url": "foo", "method": "POST", "body": "bar"}}, {"hook": "require_verified_address"}, {"hook": "revoke_active_sessions"}, - }) + }, }, expect: func(reg *driver.RegistryDefault) []login.PostHookExecutor { return []login.PostHookExecutor{ @@ -444,11 +449,11 @@ func TestDriverDefault_Hooks(t *testing.T) { }, { uc: "Two web_hooks are configured on a global level", - prep: func(conf *config.Config) { - conf.MustSet(ctx, config.ViperKeySelfServiceLoginAfter+".hooks", []map[string]interface{}{ - {"hook": "web_hook", "config": map[string]interface{}{"url": "foo", "method": "POST", "headers": map[string]string{"X-Custom-Header": "test"}}}, - {"hook": "web_hook", "config": map[string]interface{}{"url": "bar", "method": "GET", "headers": map[string]string{"X-Custom-Header": "test"}}}, - }) + config: map[string]any{ + config.ViperKeySelfServiceLoginAfter + ".hooks": []map[string]any{ + {"hook": "web_hook", "config": map[string]any{"url": "foo", "method": "POST", "headers": map[string]string{"X-Custom-Header": "test"}}}, + {"hook": "web_hook", "config": map[string]any{"url": "bar", "method": "GET", "headers": map[string]string{"X-Custom-Header": "test"}}}, + }, }, expect: func(reg *driver.RegistryDefault) []login.PostHookExecutor { return []login.PostHookExecutor{ @@ -459,15 +464,15 @@ func TestDriverDefault_Hooks(t *testing.T) { }, { uc: "Hooks are configured on a global level, as well as on a strategy level", - prep: func(conf *config.Config) { - conf.MustSet(ctx, config.ViperKeySelfServiceLoginAfter+".password.hooks", []map[string]interface{}{ - {"hook": "web_hook", "config": map[string]interface{}{"url": "foo", "method": "GET", "headers": map[string]string{"X-Custom-Header": "test"}}}, + config: map[string]any{ + config.ViperKeySelfServiceLoginAfter + ".password.hooks": []map[string]any{ + {"hook": "web_hook", "config": map[string]any{"url": "foo", "method": "GET", "headers": map[string]string{"X-Custom-Header": "test"}}}, {"hook": "revoke_active_sessions"}, {"hook": "require_verified_address"}, - }) - conf.MustSet(ctx, config.ViperKeySelfServiceLoginAfter+".hooks", []map[string]interface{}{ - {"hook": "web_hook", "config": map[string]interface{}{"url": "foo", "method": "POST", "headers": map[string]string{"X-Custom-Header": "test"}}}, - }) + }, + config.ViperKeySelfServiceLoginAfter + ".hooks": []map[string]any{ + {"hook": "web_hook", "config": map[string]any{"url": "foo", "method": "POST", "headers": map[string]string{"X-Custom-Header": "test"}}}, + }, }, expect: func(reg *driver.RegistryDefault) []login.PostHookExecutor { return []login.PostHookExecutor{ @@ -479,8 +484,9 @@ func TestDriverDefault_Hooks(t *testing.T) { }, } { t.Run(fmt.Sprintf("after/uc=%s", tc.uc), func(t *testing.T) { - conf, reg := internal.NewVeryFastRegistryWithoutDB(t) - tc.prep(conf) + t.Parallel() + + ctx := config.WithConfigValues(ctx, tc.config) h := reg.PostLoginHooks(ctx, identity.CredentialsTypePassword) @@ -496,21 +502,20 @@ func TestDriverDefault_Hooks(t *testing.T) { // BEFORE hooks for _, tc := range []struct { uc string - prep func(conf *config.Config) + config map[string]any expect func(reg *driver.RegistryDefault) []settings.PreHookExecutor }{ { uc: "No hooks configured", - prep: func(conf *config.Config) {}, expect: func(reg *driver.RegistryDefault) []settings.PreHookExecutor { return nil }, }, { uc: "Two web_hooks are configured", - prep: func(conf *config.Config) { - conf.MustSet(ctx, config.ViperKeySelfServiceSettingsBeforeHooks, []map[string]interface{}{ - {"hook": "web_hook", "config": map[string]interface{}{"url": "foo", "method": "POST", "headers": map[string]string{"X-Custom-Header": "test"}}}, - {"hook": "web_hook", "config": map[string]interface{}{"url": "bar", "method": "GET", "headers": map[string]string{"X-Custom-Header": "test"}}}, - }) + config: map[string]any{ + config.ViperKeySelfServiceSettingsBeforeHooks: []map[string]any{ + {"hook": "web_hook", "config": map[string]any{"url": "foo", "method": "POST", "headers": map[string]string{"X-Custom-Header": "test"}}}, + {"hook": "web_hook", "config": map[string]any{"url": "bar", "method": "GET", "headers": map[string]string{"X-Custom-Header": "test"}}}, + }, }, expect: func(reg *driver.RegistryDefault) []settings.PreHookExecutor { return []settings.PreHookExecutor{ @@ -521,8 +526,9 @@ func TestDriverDefault_Hooks(t *testing.T) { }, } { t.Run(fmt.Sprintf("before/uc=%s", tc.uc), func(t *testing.T) { - conf, reg := internal.NewVeryFastRegistryWithoutDB(t) - tc.prep(conf) + t.Parallel() + + ctx := config.WithConfigValues(ctx, tc.config) h := reg.PreSettingsHooks(ctx) @@ -535,18 +541,17 @@ func TestDriverDefault_Hooks(t *testing.T) { // AFTER hooks for _, tc := range []struct { uc string - prep func(conf *config.Config) + config map[string]any expect func(reg *driver.RegistryDefault) []settings.PostHookPostPersistExecutor }{ { uc: "No hooks configured", - prep: func(conf *config.Config) {}, expect: func(reg *driver.RegistryDefault) []settings.PostHookPostPersistExecutor { return nil }, }, { uc: "Only verify hook configured for the strategy", - prep: func(conf *config.Config) { - conf.MustSet(ctx, config.ViperKeySelfServiceVerificationEnabled, true) + config: map[string]any{ + config.ViperKeySelfServiceVerificationEnabled: true, // I think this is a bug as there is a hook named verify defined for both profile and password // strategies. Instead of using it, the code makes use of the property used above and which // is defined in an entirely different flow (verification). @@ -559,11 +564,11 @@ func TestDriverDefault_Hooks(t *testing.T) { }, { uc: "A verify hook and a web_hook are configured for profile strategy", - prep: func(conf *config.Config) { - conf.MustSet(ctx, config.ViperKeySelfServiceSettingsAfter+".profile.hooks", []map[string]interface{}{ - {"hook": "web_hook", "config": map[string]interface{}{"headers": []map[string]string{{"X-Custom-Header": "test"}}, "url": "foo", "method": "POST", "body": "bar"}}, - }) - conf.MustSet(ctx, config.ViperKeySelfServiceVerificationEnabled, true) + config: map[string]any{ + config.ViperKeySelfServiceSettingsAfter + ".profile.hooks": []map[string]any{ + {"hook": "web_hook", "config": map[string]any{"headers": []map[string]string{{"X-Custom-Header": "test"}}, "url": "foo", "method": "POST", "body": "bar"}}, + }, + config.ViperKeySelfServiceVerificationEnabled: true, }, expect: func(reg *driver.RegistryDefault) []settings.PostHookPostPersistExecutor { return []settings.PostHookPostPersistExecutor{ @@ -574,11 +579,11 @@ func TestDriverDefault_Hooks(t *testing.T) { }, { uc: "Two web_hooks are configured on a global level", - prep: func(conf *config.Config) { - conf.MustSet(ctx, config.ViperKeySelfServiceSettingsAfter+".hooks", []map[string]interface{}{ - {"hook": "web_hook", "config": map[string]interface{}{"url": "foo", "method": "POST", "headers": map[string]string{"X-Custom-Header": "test"}}}, - {"hook": "web_hook", "config": map[string]interface{}{"url": "bar", "method": "GET", "headers": map[string]string{"X-Custom-Header": "test"}}}, - }) + config: map[string]any{ + config.ViperKeySelfServiceSettingsAfter + ".hooks": []map[string]any{ + {"hook": "web_hook", "config": map[string]any{"url": "foo", "method": "POST", "headers": map[string]string{"X-Custom-Header": "test"}}}, + {"hook": "web_hook", "config": map[string]any{"url": "bar", "method": "GET", "headers": map[string]string{"X-Custom-Header": "test"}}}, + }, }, expect: func(reg *driver.RegistryDefault) []settings.PostHookPostPersistExecutor { return []settings.PostHookPostPersistExecutor{ @@ -589,14 +594,14 @@ func TestDriverDefault_Hooks(t *testing.T) { }, { uc: "Hooks are configured on a global level, as well as on a strategy level", - prep: func(conf *config.Config) { - conf.MustSet(ctx, config.ViperKeySelfServiceVerificationEnabled, true) - conf.MustSet(ctx, config.ViperKeySelfServiceSettingsAfter+".profile.hooks", []map[string]interface{}{ - {"hook": "web_hook", "config": map[string]interface{}{"url": "foo", "method": "GET", "headers": map[string]string{"X-Custom-Header": "test"}}}, - }) - conf.MustSet(ctx, config.ViperKeySelfServiceSettingsAfter+".hooks", []map[string]interface{}{ - {"hook": "web_hook", "config": map[string]interface{}{"url": "foo", "method": "POST", "headers": map[string]string{"X-Custom-Header": "test"}}}, - }) + config: map[string]any{ + config.ViperKeySelfServiceVerificationEnabled: true, + config.ViperKeySelfServiceSettingsAfter + ".profile.hooks": []map[string]any{ + {"hook": "web_hook", "config": map[string]any{"url": "foo", "method": "GET", "headers": map[string]string{"X-Custom-Header": "test"}}}, + }, + config.ViperKeySelfServiceSettingsAfter + ".hooks": []map[string]any{ + {"hook": "web_hook", "config": map[string]any{"url": "foo", "method": "POST", "headers": map[string]string{"X-Custom-Header": "test"}}}, + }, }, expect: func(reg *driver.RegistryDefault) []settings.PostHookPostPersistExecutor { return []settings.PostHookPostPersistExecutor{ @@ -607,8 +612,9 @@ func TestDriverDefault_Hooks(t *testing.T) { }, } { t.Run(fmt.Sprintf("after/uc=%s", tc.uc), func(t *testing.T) { - conf, reg := internal.NewVeryFastRegistryWithoutDB(t) - tc.prep(conf) + t.Parallel() + + ctx := config.WithConfigValues(ctx, tc.config) h := reg.PostSettingsPostPersistHooks(ctx, "profile") @@ -623,62 +629,64 @@ func TestDriverDefault_Hooks(t *testing.T) { func TestDriverDefault_Strategies(t *testing.T) { t.Parallel() ctx := context.Background() + _, reg := internal.NewVeryFastRegistryWithoutDB(t) + t.Run("case=registration", func(t *testing.T) { t.Parallel() for _, tc := range []struct { name string - prep func(conf *config.Config) + config map[string]any expect []string }{ { name: "no strategies", - prep: func(conf *config.Config) { - conf.MustSet(ctx, config.ViperKeySelfServiceStrategyConfig+".password.enabled", false) - conf.MustSet(ctx, config.ViperKeySelfServiceStrategyConfig+".code.enabled", false) + config: map[string]any{ + config.ViperKeySelfServiceStrategyConfig + ".password.enabled": false, + config.ViperKeySelfServiceStrategyConfig + ".code.enabled": false, }, expect: []string{"profile"}, }, { name: "only password", - prep: func(conf *config.Config) { - conf.MustSet(ctx, config.ViperKeySelfServiceStrategyConfig+".password.enabled", true) - conf.MustSet(ctx, config.ViperKeySelfServiceStrategyConfig+".code.enabled", false) + config: map[string]any{ + config.ViperKeySelfServiceStrategyConfig + ".password.enabled": true, + config.ViperKeySelfServiceStrategyConfig + ".code.enabled": false, }, expect: []string{"password", "profile"}, }, { name: "oidc and password", - prep: func(conf *config.Config) { - conf.MustSet(ctx, config.ViperKeySelfServiceStrategyConfig+".oidc.enabled", true) - conf.MustSet(ctx, config.ViperKeySelfServiceStrategyConfig+".password.enabled", true) - conf.MustSet(ctx, config.ViperKeySelfServiceStrategyConfig+".code.enabled", false) + config: map[string]any{ + config.ViperKeySelfServiceStrategyConfig + ".oidc.enabled": true, + config.ViperKeySelfServiceStrategyConfig + ".password.enabled": true, + config.ViperKeySelfServiceStrategyConfig + ".code.enabled": false, }, expect: []string{"password", "oidc", "profile"}, }, { name: "oidc, password and totp", - prep: func(conf *config.Config) { - conf.MustSet(ctx, config.ViperKeySelfServiceStrategyConfig+".oidc.enabled", true) - conf.MustSet(ctx, config.ViperKeySelfServiceStrategyConfig+".password.enabled", true) - conf.MustSet(ctx, config.ViperKeySelfServiceStrategyConfig+".totp.enabled", true) - conf.MustSet(ctx, config.ViperKeySelfServiceStrategyConfig+".code.enabled", false) + config: map[string]any{ + config.ViperKeySelfServiceStrategyConfig + ".oidc.enabled": true, + config.ViperKeySelfServiceStrategyConfig + ".password.enabled": true, + config.ViperKeySelfServiceStrategyConfig + ".totp.enabled": true, + config.ViperKeySelfServiceStrategyConfig + ".code.enabled": false, }, expect: []string{"password", "oidc", "profile"}, }, { name: "password and code", - prep: func(conf *config.Config) { - conf.MustSet(ctx, config.ViperKeySelfServiceStrategyConfig+".password.enabled", true) - conf.MustSet(ctx, config.ViperKeySelfServiceStrategyConfig+".code.enabled", true) + config: map[string]any{ + config.ViperKeySelfServiceStrategyConfig + ".password.enabled": true, + config.ViperKeySelfServiceStrategyConfig + ".code.enabled": true, }, expect: []string{"password", "profile", "code"}, }, } { t.Run(fmt.Sprintf("subcase=%s", tc.name), func(t *testing.T) { - conf, reg := internal.NewVeryFastRegistryWithoutDB(t) - tc.prep(conf) + t.Parallel() - s := reg.RegistrationStrategies(context.Background()) + ctx := config.WithConfigValues(ctx, tc.config) + s := reg.RegistrationStrategies(ctx) require.Len(t, s, len(tc.expect)) for k, e := range tc.expect { assert.Equal(t, e, s[k].ID().String()) @@ -689,68 +697,69 @@ func TestDriverDefault_Strategies(t *testing.T) { t.Run("case=login", func(t *testing.T) { t.Parallel() + for _, tc := range []struct { name string - prep func(conf *config.Config) + config map[string]any expect []string }{ { name: "no strategies", - prep: func(conf *config.Config) { - conf.MustSet(ctx, config.ViperKeySelfServiceStrategyConfig+".password.enabled", false) - conf.MustSet(ctx, config.ViperKeySelfServiceStrategyConfig+".code.enabled", false) + config: map[string]any{ + config.ViperKeySelfServiceStrategyConfig + ".password.enabled": false, + config.ViperKeySelfServiceStrategyConfig + ".code.enabled": false, }, }, { name: "only password", - prep: func(conf *config.Config) { - conf.MustSet(ctx, config.ViperKeySelfServiceStrategyConfig+".password.enabled", true) - conf.MustSet(ctx, config.ViperKeySelfServiceStrategyConfig+".code.enabled", false) + config: map[string]any{ + config.ViperKeySelfServiceStrategyConfig + ".password.enabled": true, + config.ViperKeySelfServiceStrategyConfig + ".code.enabled": false, }, expect: []string{"password"}, }, { name: "oidc and password", - prep: func(conf *config.Config) { - conf.MustSet(ctx, config.ViperKeySelfServiceStrategyConfig+".oidc.enabled", true) - conf.MustSet(ctx, config.ViperKeySelfServiceStrategyConfig+".password.enabled", true) - conf.MustSet(ctx, config.ViperKeySelfServiceStrategyConfig+".code.enabled", false) + config: map[string]any{ + config.ViperKeySelfServiceStrategyConfig + ".oidc.enabled": true, + config.ViperKeySelfServiceStrategyConfig + ".password.enabled": true, + config.ViperKeySelfServiceStrategyConfig + ".code.enabled": false, }, expect: []string{"password", "oidc"}, }, { name: "oidc, password and totp", - prep: func(conf *config.Config) { - conf.MustSet(ctx, config.ViperKeySelfServiceStrategyConfig+".oidc.enabled", true) - conf.MustSet(ctx, config.ViperKeySelfServiceStrategyConfig+".password.enabled", true) - conf.MustSet(ctx, config.ViperKeySelfServiceStrategyConfig+".totp.enabled", true) - conf.MustSet(ctx, config.ViperKeySelfServiceStrategyConfig+".code.enabled", false) + config: map[string]any{ + config.ViperKeySelfServiceStrategyConfig + ".oidc.enabled": true, + config.ViperKeySelfServiceStrategyConfig + ".password.enabled": true, + config.ViperKeySelfServiceStrategyConfig + ".totp.enabled": true, + config.ViperKeySelfServiceStrategyConfig + ".code.enabled": false, }, expect: []string{"password", "oidc", "totp"}, }, { name: "password and code", - prep: func(conf *config.Config) { - conf.MustSet(ctx, config.ViperKeySelfServiceStrategyConfig+".password.enabled", true) - conf.MustSet(ctx, config.ViperKeySelfServiceStrategyConfig+".code.enabled", true) + config: map[string]any{ + config.ViperKeySelfServiceStrategyConfig + ".password.enabled": true, + config.ViperKeySelfServiceStrategyConfig + ".code.enabled": true, }, expect: []string{"password", "code"}, }, { name: "code is enabled if passwordless_enabled is true", - prep: func(conf *config.Config) { - conf.MustSet(ctx, config.ViperKeySelfServiceStrategyConfig+".password.enabled", false) - conf.MustSet(ctx, config.ViperKeySelfServiceStrategyConfig+".code.enabled", false) - conf.MustSet(ctx, config.ViperKeySelfServiceStrategyConfig+".code.passwordless_enabled", true) + config: map[string]any{ + config.ViperKeySelfServiceStrategyConfig + ".password.enabled": false, + config.ViperKeySelfServiceStrategyConfig + ".code.enabled": false, + config.ViperKeySelfServiceStrategyConfig + ".code.passwordless_enabled": true, }, expect: []string{"code"}, }, } { t.Run(fmt.Sprintf("run=%s", tc.name), func(t *testing.T) { - conf, reg := internal.NewVeryFastRegistryWithoutDB(t) - tc.prep(conf) + t.Parallel() - s := reg.LoginStrategies(context.Background()) + ctx := config.WithConfigValues(ctx, tc.config) + s := reg.LoginStrategies(ctx) require.Len(t, s, len(tc.expect)) for k, e := range tc.expect { assert.Equal(t, e, s[k].ID().String()) @@ -762,27 +771,28 @@ func TestDriverDefault_Strategies(t *testing.T) { t.Run("case=recovery", func(t *testing.T) { t.Parallel() for k, tc := range []struct { - prep func(conf *config.Config) + config map[string]any expect []string }{ { - prep: func(conf *config.Config) { - conf.MustSet(ctx, config.ViperKeySelfServiceStrategyConfig+".code.enabled", false) - conf.MustSet(ctx, config.ViperKeySelfServiceStrategyConfig+".link.enabled", false) + config: map[string]any{ + config.ViperKeySelfServiceStrategyConfig + ".code.enabled": false, + config.ViperKeySelfServiceStrategyConfig + ".link.enabled": false, }, }, { - prep: func(conf *config.Config) { - conf.MustSet(ctx, config.ViperKeySelfServiceStrategyConfig+".code.enabled", true) - conf.MustSet(ctx, config.ViperKeySelfServiceStrategyConfig+".link.enabled", true) + config: map[string]any{ + config.ViperKeySelfServiceStrategyConfig + ".code.enabled": true, + config.ViperKeySelfServiceStrategyConfig + ".link.enabled": true, }, expect: []string{"code", "link"}, }, } { t.Run(fmt.Sprintf("run=%d", k), func(t *testing.T) { - conf, reg := internal.NewVeryFastRegistryWithoutDB(t) - tc.prep(conf) + t.Parallel() + + ctx := config.WithConfigValues(ctx, tc.config) - s := reg.RecoveryStrategies(context.Background()) + s := reg.RecoveryStrategies(ctx) require.Len(t, s, len(tc.expect)) for k, e := range tc.expect { assert.Equal(t, e, s[k].RecoveryStrategyID()) @@ -796,81 +806,55 @@ func TestDriverDefault_Strategies(t *testing.T) { l := logrusx.New("", "") for k, tc := range []struct { - prep func(t *testing.T) *config.Config - expect []string + configOptions []configx.OptionModifier + expect []string }{ { - prep: func(t *testing.T) *config.Config { - c := config.MustNew(t, l, - os.Stderr, - configx.WithValues(map[string]interface{}{ - config.ViperKeyDSN: config.DefaultSQLiteMemoryDSN, - config.ViperKeySelfServiceStrategyConfig + ".password.enabled": false, - config.ViperKeySelfServiceStrategyConfig + ".oidc.enabled": false, - config.ViperKeySelfServiceStrategyConfig + ".profile.enabled": false, - }), - configx.SkipValidation()) - return c - }, - }, - { - prep: func(t *testing.T) *config.Config { - c := config.MustNew(t, l, - os.Stderr, - configx.WithValues(map[string]interface{}{ - config.ViperKeyDSN: config.DefaultSQLiteMemoryDSN, - config.ViperKeySelfServiceStrategyConfig + ".profile.enabled": true, - config.ViperKeySelfServiceStrategyConfig + ".password.enabled": false, - }), - configx.SkipValidation()) - return c - }, + configOptions: []configx.OptionModifier{configx.WithValues(map[string]any{ + config.ViperKeyDSN: config.DefaultSQLiteMemoryDSN, + config.ViperKeySelfServiceStrategyConfig + ".password.enabled": false, + config.ViperKeySelfServiceStrategyConfig + ".oidc.enabled": false, + config.ViperKeySelfServiceStrategyConfig + ".profile.enabled": false, + })}, + }, + { + configOptions: []configx.OptionModifier{configx.WithValues(map[string]any{ + config.ViperKeyDSN: config.DefaultSQLiteMemoryDSN, + config.ViperKeySelfServiceStrategyConfig + ".profile.enabled": true, + config.ViperKeySelfServiceStrategyConfig + ".password.enabled": false, + })}, expect: []string{"profile"}, }, { - prep: func(t *testing.T) *config.Config { - c := config.MustNew(t, l, - os.Stderr, - configx.WithValues(map[string]interface{}{ - config.ViperKeyDSN: config.DefaultSQLiteMemoryDSN, - config.ViperKeySelfServiceStrategyConfig + ".profile.enabled": true, - config.ViperKeySelfServiceStrategyConfig + ".password.enabled": false, - config.ViperKeySelfServiceStrategyConfig + ".totp.enabled": true, - }), - configx.SkipValidation()) - return c - }, + configOptions: []configx.OptionModifier{configx.WithValues(map[string]any{ + config.ViperKeyDSN: config.DefaultSQLiteMemoryDSN, + config.ViperKeySelfServiceStrategyConfig + ".profile.enabled": true, + config.ViperKeySelfServiceStrategyConfig + ".password.enabled": false, + config.ViperKeySelfServiceStrategyConfig + ".totp.enabled": true, + })}, expect: []string{"profile", "totp"}, }, { - prep: func(t *testing.T) *config.Config { - return config.MustNew(t, l, - os.Stderr, - configx.WithValues(map[string]interface{}{ - config.ViperKeyDSN: config.DefaultSQLiteMemoryDSN, - }), - configx.SkipValidation()) - }, + configOptions: []configx.OptionModifier{configx.WithValues(map[string]any{ + config.ViperKeyDSN: config.DefaultSQLiteMemoryDSN, + })}, expect: []string{"password", "profile"}, }, { - prep: func(t *testing.T) *config.Config { - return config.MustNew(t, l, - os.Stderr, - configx.WithConfigFiles("../test/e2e/profiles/verification/.kratos.yml"), - configx.WithValue(config.ViperKeyDSN, config.DefaultSQLiteMemoryDSN), - configx.SkipValidation()) + configOptions: []configx.OptionModifier{ + configx.WithConfigFiles("../test/e2e/profiles/verification/.kratos.yml"), + configx.WithValue(config.ViperKeyDSN, config.DefaultSQLiteMemoryDSN), }, expect: []string{"password", "profile"}, }, } { t.Run(fmt.Sprintf("run=%d", k), func(t *testing.T) { - conf := tc.prep(t) + conf := config.MustNew(t, l, os.Stderr, &contextx.Default{}, append(tc.configOptions, configx.SkipValidation())...) - reg, err := driver.NewRegistryFromDSN(ctx, conf, logrusx.New("", "")) + reg, err := driver.NewRegistryFromDSN(ctx, conf, l) require.NoError(t, err) - s := reg.SettingsStrategies(context.Background()) + s := reg.SettingsStrategies(ctx) require.Len(t, s, len(tc.expect)) for k, e := range tc.expect { @@ -924,12 +908,16 @@ func TestDefaultRegistry_AllStrategies(t *testing.T) { func TestGetActiveRecoveryStrategy(t *testing.T) { t.Parallel() - conf, reg := internal.NewVeryFastRegistryWithoutDB(t) + ctx := context.Background() + _, reg := internal.NewVeryFastRegistryWithoutDB(t) + t.Run("returns error if active strategy is disabled", func(t *testing.T) { - conf.Set(context.Background(), "selfservice.methods.code.enabled", false) - conf.Set(context.Background(), config.ViperKeySelfServiceRecoveryUse, "code") + ctx := config.WithConfigValues(ctx, map[string]any{ + "selfservice.methods.code.enabled": false, + config.ViperKeySelfServiceRecoveryUse: "code", + }) - _, err := reg.GetActiveRecoveryStrategy(context.Background()) + _, err := reg.GetActiveRecoveryStrategy(ctx) require.Error(t, err) }) @@ -938,10 +926,12 @@ func TestGetActiveRecoveryStrategy(t *testing.T) { "code", "link", } { t.Run(fmt.Sprintf("strategy=%s", sID), func(t *testing.T) { - conf.Set(context.Background(), fmt.Sprintf("selfservice.methods.%s.enabled", sID), true) - conf.Set(context.Background(), config.ViperKeySelfServiceRecoveryUse, sID) + ctx := config.WithConfigValues(ctx, map[string]any{ + fmt.Sprintf("selfservice.methods.%s.enabled", sID): true, + config.ViperKeySelfServiceRecoveryUse: sID, + }) - s, err := reg.GetActiveRecoveryStrategy(context.Background()) + s, err := reg.GetActiveRecoveryStrategy(ctx) require.NoError(t, err) require.Equal(t, sID, s.RecoveryStrategyID()) }) @@ -951,12 +941,14 @@ func TestGetActiveRecoveryStrategy(t *testing.T) { func TestGetActiveVerificationStrategy(t *testing.T) { t.Parallel() - conf, reg := internal.NewVeryFastRegistryWithoutDB(t) + ctx := context.Background() + _, reg := internal.NewVeryFastRegistryWithoutDB(t) t.Run("returns error if active strategy is disabled", func(t *testing.T) { - conf.Set(context.Background(), "selfservice.methods.code.enabled", false) - conf.Set(context.Background(), config.ViperKeySelfServiceVerificationUse, "code") - - _, err := reg.GetActiveVerificationStrategy(context.Background()) + ctx := config.WithConfigValues(ctx, map[string]any{ + "selfservice.methods.code.enabled": false, + config.ViperKeySelfServiceVerificationUse: "code", + }) + _, err := reg.GetActiveVerificationStrategy(ctx) require.Error(t, err) }) @@ -965,10 +957,12 @@ func TestGetActiveVerificationStrategy(t *testing.T) { "code", "link", } { t.Run(fmt.Sprintf("strategy=%s", sID), func(t *testing.T) { - conf.Set(context.Background(), fmt.Sprintf("selfservice.methods.%s.enabled", sID), true) - conf.Set(context.Background(), config.ViperKeySelfServiceVerificationUse, sID) + ctx := config.WithConfigValues(ctx, map[string]any{ + fmt.Sprintf("selfservice.methods.%s.enabled", sID): true, + config.ViperKeySelfServiceVerificationUse: sID, + }) - s, err := reg.GetActiveVerificationStrategy(context.Background()) + s, err := reg.GetActiveVerificationStrategy(ctx) require.NoError(t, err) require.Equal(t, sID, s.VerificationStrategyID()) }) diff --git a/go.mod b/go.mod index 537942ebacbd..67e7a524c134 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module github.com/ory/kratos -go 1.21 +go 1.22 replace ( github.com/go-sql-driver/mysql => github.com/go-sql-driver/mysql v1.7.2-0.20231005084435-37980127edfb diff --git a/hydra/hydra_test.go b/hydra/hydra_test.go index b2e252be5b18..d022ae6021cf 100644 --- a/hydra/hydra_test.go +++ b/hydra/hydra_test.go @@ -13,6 +13,7 @@ import ( "github.com/ory/kratos/driver/config" "github.com/ory/kratos/hydra" "github.com/ory/x/configx" + "github.com/ory/x/contextx" "github.com/ory/x/logrusx" "github.com/ory/x/sqlxx" "github.com/ory/x/urlx" @@ -25,11 +26,12 @@ func requestFromChallenge(s string) *http.Request { func TestGetLoginChallengeID(t *testing.T) { uuidChallenge := "b346a452-e8fb-4828-8ef8-a4dbc98dc23a" blobChallenge := "1337deadbeefcafe" - defaultConfig := config.MustNew(t, logrusx.New("", ""), os.Stderr, configx.SkipValidation()) + defaultConfig := config.MustNew(t, logrusx.New("", ""), os.Stderr, &contextx.Default{}, configx.SkipValidation()) configWithHydra := config.MustNew( t, logrusx.New("", ""), os.Stderr, + &contextx.Default{}, configx.SkipValidation(), configx.WithValues(map[string]interface{}{ config.ViperKeyOAuth2ProviderURL: "https://hydra", diff --git a/internal/driver.go b/internal/driver.go index a6f1f13d7954..3499a83b5b9b 100644 --- a/internal/driver.go +++ b/internal/driver.go @@ -9,9 +9,10 @@ import ( "runtime" "testing" + "github.com/ory/x/contextx" + "github.com/sirupsen/logrus" - "github.com/ory/x/contextx" "github.com/ory/x/jsonnetsecure" "github.com/gofrs/uuid" @@ -36,9 +37,8 @@ func init() { }) } -func NewConfigurationWithDefaults(t testing.TB) *config.Config { - c := config.MustNew(t, logrusx.New("", ""), - os.Stderr, +func NewConfigurationWithDefaults(t testing.TB, opts ...configx.OptionModifier) *config.Config { + configOpts := append([]configx.OptionModifier{ configx.WithValues(map[string]interface{}{ "log.level": "error", config.ViperKeyDSN: dbal.NewSQLiteTestDatabase(t), @@ -53,14 +53,19 @@ func NewConfigurationWithDefaults(t testing.TB) *config.Config { config.ViperKeySecretsCipher: []string{"secret-thirty-two-character-long"}, }), configx.SkipValidation(), + }, opts...) + c := config.MustNew(t, logrusx.New("", ""), + os.Stderr, + &config.TestConfigProvider{Contextualizer: &contextx.Default{}, Options: configOpts}, + configOpts..., ) return c } // NewFastRegistryWithMocks returns a registry with several mocks and an SQLite in memory database that make testing // easier and way faster. This suite does not work for e2e or advanced integration tests. -func NewFastRegistryWithMocks(t *testing.T) (*config.Config, *driver.RegistryDefault) { - conf, reg := NewRegistryDefaultWithDSN(t, "") +func NewFastRegistryWithMocks(t *testing.T, opts ...configx.OptionModifier) (*config.Config, *driver.RegistryDefault) { + conf, reg := NewRegistryDefaultWithDSN(t, "", opts...) reg.WithCSRFTokenGenerator(x.FakeCSRFTokenGenerator) reg.WithCSRFHandler(x.NewFakeCSRFHandler("")) reg.WithHooks(map[string]func(config.SelfServiceHook) interface{}{ @@ -76,16 +81,17 @@ func NewFastRegistryWithMocks(t *testing.T) (*config.Config, *driver.RegistryDef } // NewRegistryDefaultWithDSN returns a more standard registry without mocks. Good for e2e and advanced integration testing! -func NewRegistryDefaultWithDSN(t testing.TB, dsn string) (*config.Config, *driver.RegistryDefault) { +func NewRegistryDefaultWithDSN(t testing.TB, dsn string, opts ...configx.OptionModifier) (*config.Config, *driver.RegistryDefault) { ctx := context.Background() - c := NewConfigurationWithDefaults(t) - c.MustSet(ctx, config.ViperKeyDSN, stringsx.Coalesce(dsn, dbal.NewSQLiteTestDatabase(t))) + c := NewConfigurationWithDefaults(t, append(opts, configx.WithValues(map[string]interface{}{ + config.ViperKeyDSN: stringsx.Coalesce(dsn, dbal.NewSQLiteTestDatabase(t)), + "dev": true, + }))...) reg, err := driver.NewRegistryFromDSN(ctx, c, logrusx.New("", "", logrusx.ForceLevel(logrus.ErrorLevel))) require.NoError(t, err) - reg.Config().MustSet(ctx, "dev", true) pool := jsonnetsecure.NewProcessPool(runtime.GOMAXPROCS(0)) t.Cleanup(pool.Close) - require.NoError(t, reg.Init(context.Background(), &contextx.Default{}, driver.SkipNetworkInit, driver.WithDisabledMigrationLogging(), driver.WithJsonnetPool(pool))) + require.NoError(t, reg.Init(context.Background(), &config.TestConfigProvider{Contextualizer: &contextx.Default{}}, driver.SkipNetworkInit, driver.WithDisabledMigrationLogging(), driver.WithJsonnetPool(pool))) require.NoError(t, reg.Persister().MigrateUp(context.Background())) // always migrate up actual, err := reg.Persister().DetermineNetwork(context.Background()) diff --git a/internal/testhelpers/config.go b/internal/testhelpers/config.go index 2a24709c0745..8e17a6ab3a12 100644 --- a/internal/testhelpers/config.go +++ b/internal/testhelpers/config.go @@ -8,11 +8,10 @@ import ( "encoding/base64" "testing" - "github.com/ory/kratos/driver/config" - "github.com/spf13/pflag" "github.com/stretchr/testify/require" + "github.com/ory/kratos/driver/config" "github.com/ory/x/configx" "github.com/ory/x/randx" ) @@ -24,6 +23,20 @@ func UseConfigFile(t *testing.T, path string) *pflag.FlagSet { return flags } +func DefaultIdentitySchemaConfig(url string) map[string]any { + return map[string]any{ + config.ViperKeyDefaultIdentitySchemaID: "default", + config.ViperKeyIdentitySchemas: config.Schemas{ + {ID: "default", URL: url}, + }, + } +} + +func WithDefaultIdentitySchema(ctx context.Context, url string) context.Context { + return config.WithConfigValues(ctx, DefaultIdentitySchemaConfig(url)) +} + +// Deprecated: Use context-based WithDefaultIdentitySchema instead func SetDefaultIdentitySchema(conf *config.Config, url string) func() { schemaUrl, _ := conf.DefaultIdentityTraitsSchemaURL(context.Background()) conf.MustSet(context.Background(), config.ViperKeyDefaultIdentitySchemaID, "default") @@ -37,13 +50,29 @@ func SetDefaultIdentitySchema(conf *config.Config, url string) func() { } } -// UseIdentitySchema registeres an identity schema in the config with a random ID and returns the ID +// WithAddIdentitySchema registers an identity schema in the config with a random ID and returns the ID +// +// It also registers a test cleanup function, to reset the schemas to the original values, after the test finishes +func WithAddIdentitySchema(ctx context.Context, t *testing.T, conf *config.Config, url string) (context.Context, string) { + id := randx.MustString(16, randx.Alpha) + schemas, err := conf.IdentityTraitsSchemas(ctx) + require.NoError(t, err) + + return config.WithConfigValue(ctx, config.ViperKeyIdentitySchemas, append(schemas, config.Schema{ + ID: id, + URL: url, + })), id +} + +// UseIdentitySchema registers an identity schema in the config with a random ID and returns the ID // -// It also registeres a test cleanup function, to reset the schemas to the original values, after the test finishes +// It also registers a test cleanup function, to reset the schemas to the original values, after the test finishes +// Deprecated: Use context-based WithAddIdentitySchema instead func UseIdentitySchema(t *testing.T, conf *config.Config, url string) (id string) { id = randx.MustString(16, randx.Alpha) schemas, err := conf.IdentityTraitsSchemas(context.Background()) require.NoError(t, err) + conf.MustSet(context.Background(), config.ViperKeyIdentitySchemas, append(schemas, config.Schema{ ID: id, URL: url, @@ -54,7 +83,12 @@ func UseIdentitySchema(t *testing.T, conf *config.Config, url string) (id string return id } -// SetDefaultIdentitySchemaFromRaw allows setting the default identity schema from a raw JSON string. +// WithDefaultIdentitySchemaFromRaw allows setting the default identity schema from a raw JSON string. +func WithDefaultIdentitySchemaFromRaw(ctx context.Context, schema []byte) context.Context { + return WithDefaultIdentitySchema(ctx, "base64://"+base64.URLEncoding.EncodeToString(schema)) +} + +// Deprecated: Use context-based WithDefaultIdentitySchemaFromRaw instead func SetDefaultIdentitySchemaFromRaw(conf *config.Config, schema []byte) { conf.MustSet(context.Background(), config.ViperKeyDefaultIdentitySchemaID, "default") conf.MustSet(context.Background(), config.ViperKeyIdentitySchemas, config.Schemas{ diff --git a/internal/testhelpers/network.go b/internal/testhelpers/network.go index 888f46b583f5..10978dba6b05 100644 --- a/internal/testhelpers/network.go +++ b/internal/testhelpers/network.go @@ -20,7 +20,7 @@ func NewNetworkUnlessExisting(t *testing.T, ctx context.Context, p persistence.P } n := networkx.NewNetwork() - require.NoError(t, p.GetConnection(context.Background()).Create(n)) + require.NoError(t, p.GetConnection(ctx).Create(n)) return n.ID, p.WithNetworkID(n.ID) } diff --git a/persistence/sql/persister_hmac_test.go b/persistence/sql/persister_hmac_test.go index c7adcdce3a1e..7b8cc8575368 100644 --- a/persistence/sql/persister_hmac_test.go +++ b/persistence/sql/persister_hmac_test.go @@ -64,7 +64,7 @@ var _ persisterDependencies = &logRegistryOnly{} func TestPersisterHMAC(t *testing.T) { ctx := context.Background() - conf := config.MustNew(t, logrusx.New("", ""), os.Stderr, configx.SkipValidation()) + conf := config.MustNew(t, logrusx.New("", ""), os.Stderr, &contextx.Default{}, configx.SkipValidation()) conf.MustSet(ctx, config.ViperKeySecretsDefault, []string{"foobarbaz"}) c, err := pop.NewConnection(&pop.ConnectionDetails{URL: "sqlite://foo?mode=memory"}) require.NoError(t, err) diff --git a/session/test/persistence.go b/session/test/persistence.go index 8e8cbfeb18b2..0db6964468d8 100644 --- a/session/test/persistence.go +++ b/session/test/persistence.go @@ -42,7 +42,7 @@ func TestPersister(ctx context.Context, conf *config.Config, p interface { return func(t *testing.T) { _, p := testhelpers.NewNetworkUnlessExisting(t, ctx, p) - testhelpers.SetDefaultIdentitySchema(conf, "file://./stub/identity.schema.json") + ctx := testhelpers.WithDefaultIdentitySchema(ctx, "file://./stub/identity.schema.json") t.Run("case=not found", func(t *testing.T) { _, err := p.GetSession(ctx, x.NewUUID(), session.ExpandNothing) @@ -611,10 +611,7 @@ func TestPersister(ctx context.Context, conf *config.Config, p interface { }) t.Run("extend session lifespan but min time is not yet reached", func(t *testing.T) { - conf.MustSet(ctx, config.ViperKeySessionRefreshMinTimeLeft, time.Hour*2) - t.Cleanup(func() { - conf.MustSet(ctx, config.ViperKeySessionRefreshMinTimeLeft, nil) - }) + ctx := config.WithConfigValues(ctx, map[string]any{config.ViperKeySessionRefreshMinTimeLeft: 2 * time.Hour}) var expected session.Session require.NoError(t, faker.FakeData(&expected)) @@ -629,23 +626,19 @@ func TestPersister(ctx context.Context, conf *config.Config, p interface { }) t.Run("extend session lifespan", func(t *testing.T) { - conf.MustSet(ctx, config.ViperKeySessionRefreshMinTimeLeft, time.Hour) - t.Cleanup(func() { - conf.MustSet(ctx, config.ViperKeySessionRefreshMinTimeLeft, nil) - }) + ctx := config.WithConfigValues(ctx, map[string]any{config.ViperKeySessionRefreshMinTimeLeft: 2 * time.Hour}) - conf.MustSet(ctx, config.ViperKeySessionRefreshMinTimeLeft, time.Hour*2) var expected session.Session require.NoError(t, faker.FakeData(&expected)) expected.ExpiresAt = time.Now().Add(time.Hour).UTC() require.NoError(t, p.CreateIdentity(ctx, expected.Identity)) require.NoError(t, p.UpsertSession(ctx, &expected)) - expectedExpiry := expected.Refresh(ctx, conf).ExpiresAt.Round(time.Minute) + expectedExpiry := expected.Refresh(ctx, conf).ExpiresAt require.NoError(t, p.ExtendSession(ctx, expected.ID)) actual, err := p.GetSession(ctx, expected.ID, session.ExpandNothing) require.NoError(t, err) - assert.Equal(t, expectedExpiry, actual.ExpiresAt.Round(time.Minute)) + assert.GreaterOrEqual(t, 10*time.Second, expectedExpiry.Sub(actual.ExpiresAt).Abs()) }) t.Run("extend session lifespan on CockroachDB", func(t *testing.T) { @@ -653,23 +646,19 @@ func TestPersister(ctx context.Context, conf *config.Config, p interface { t.Skip("Skipping test because driver is not CockroachDB") } - conf.MustSet(ctx, config.ViperKeySessionRefreshMinTimeLeft, time.Hour) - t.Cleanup(func() { - conf.MustSet(ctx, config.ViperKeySessionRefreshMinTimeLeft, nil) - }) + ctx := config.WithConfigValue(ctx, config.ViperKeySessionRefreshMinTimeLeft, 2*time.Hour) - conf.MustSet(ctx, config.ViperKeySessionRefreshMinTimeLeft, time.Hour*2) var expected session.Session require.NoError(t, faker.FakeData(&expected)) expected.ExpiresAt = time.Now().Add(time.Hour).UTC() require.NoError(t, p.CreateIdentity(ctx, expected.Identity)) require.NoError(t, p.UpsertSession(ctx, &expected)) - expectedExpiry := expected.Refresh(ctx, conf).ExpiresAt.Round(time.Minute) + expectedExpiry := expected.Refresh(ctx, conf).ExpiresAt - var foundExpectedCockroachError bool + foundExpectedCockroachError := false g := errgroup.Group{} - for i := 0; i < 10; i++ { + for range 10 { g.Go(func() error { err := p.ExtendSession(ctx, expected.ID) if errors.Is(err, sqlcon.ErrNoRows) { @@ -683,7 +672,7 @@ func TestPersister(ctx context.Context, conf *config.Config, p interface { actual, err := p.GetSession(ctx, expected.ID, session.ExpandNothing) require.NoError(t, err) - assert.Equal(t, expectedExpiry, actual.ExpiresAt.Round(time.Minute)) + assert.LessOrEqual(t, expectedExpiry.Sub(actual.ExpiresAt).Abs(), 10*time.Second) assert.True(t, foundExpectedCockroachError, "We expect to find a not found error caused by ... FOR UPDATE SKIP LOCKED") }) } diff --git a/x/redir_test.go b/x/redir_test.go index bbb8417f8b91..1c4a191b429b 100644 --- a/x/redir_test.go +++ b/x/redir_test.go @@ -4,7 +4,6 @@ package x_test import ( - "context" "fmt" "io" "net/http" @@ -12,6 +11,8 @@ import ( "strings" "testing" + "github.com/ory/x/configx" + "github.com/julienschmidt/httprouter" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -22,17 +23,16 @@ import ( ) func TestRedirectToPublicAdminRoute(t *testing.T) { - ctx := context.Background() - conf, reg := internal.NewFastRegistryWithMocks(t) pub := x.NewRouterPublic() adm := x.NewRouterAdmin() adminTS := httptest.NewServer(adm) pubTS := httptest.NewServer(pub) t.Cleanup(pubTS.Close) t.Cleanup(adminTS.Close) - - conf.MustSet(ctx, config.ViperKeyAdminBaseURL, adminTS.URL) - conf.MustSet(ctx, config.ViperKeyPublicBaseURL, pubTS.URL) + _, reg := internal.NewFastRegistryWithMocks(t, configx.WithValues(map[string]any{ + config.ViperKeyAdminBaseURL: adminTS.URL, + config.ViperKeyPublicBaseURL: pubTS.URL, + })) pub.POST("/privileged", x.RedirectToAdminRoute(reg)) pub.POST("/admin/privileged", x.RedirectToAdminRoute(reg))