Skip to content

Commit

Permalink
Refactor func type to single method interface
Browse files Browse the repository at this point in the history
- This will allow us to use some kind of stateful waiters later on.
  • Loading branch information
dnnrly committed Jun 24, 2022
1 parent 2c66588 commit ea89ab6
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 26 deletions.
30 changes: 20 additions & 10 deletions waitfor.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,35 +3,45 @@ package waitfor
import (
"context"
"fmt"
"google.golang.org/grpc/credentials/insecure"
"net"
"net/http"
"time"

"google.golang.org/grpc/credentials/insecure"

"golang.org/x/sync/errgroup"
"google.golang.org/grpc"

"github.com/spf13/afero"
)

type Waiter interface {
Wait(name string, target *TargetConfig) error
}

// WaiterFunc is used to implement waiting for a specific type of target.
// The name is used in the error and target is the actual destination being tested.
type WaiterFunc func(name string, target *TargetConfig) error

func (w WaiterFunc) Wait(name string, target *TargetConfig) error {
return w(name, target)
}

type Logger func(string, ...interface{})

// NullLogger can be used in place of a real logging function
var NullLogger = func(f string, a ...interface{}) {}

// SupportedWaiters is a mapping of known protocol names to waiter implementations
var SupportedWaiters = map[string]WaiterFunc{
"http": HTTPWaiter,
"tcp": TCPWaiter,
"grpc": GRPCWaiter,
var SupportedWaiters = map[string]Waiter{
"http": WaiterFunc(HTTPWaiter),
"tcp": WaiterFunc(TCPWaiter),
"grpc": WaiterFunc(GRPCWaiter),
}

// WaitOn implements waiting for many targets, using the location of config file provided with named targets to wait until
// all of those targets are responding as expected
func WaitOn(config *Config, logger Logger, targets []string, waiters map[string]WaiterFunc) error {
func WaitOn(config *Config, logger Logger, targets []string, waiters map[string]Waiter) error {

for _, target := range targets {
if !config.GotTarget(target) {
Expand Down Expand Up @@ -80,7 +90,7 @@ func OpenConfig(configFile, defaultTimeout, defaultHTTPTimeout string, fs afero.
return config, nil
}

func waitOnTargets(logger Logger, targets map[string]TargetConfig, waiters map[string]WaiterFunc) error {
func waitOnTargets(logger Logger, targets map[string]TargetConfig, waiters map[string]Waiter) error {
var eg errgroup.Group

for name, target := range targets {
Expand Down Expand Up @@ -108,14 +118,14 @@ func waitOnTargets(logger Logger, targets map[string]TargetConfig, waiters map[s
return nil
}

func waitOnSingleTarget(name string, logger Logger, target TargetConfig, waiter WaiterFunc) error {
func waitOnSingleTarget(name string, logger Logger, target TargetConfig, waiter Waiter) error {
end := time.Now().Add(target.Timeout)

err := waiter(name, &target)
err := waiter.Wait(name, &target)
for err != nil && end.After(time.Now()) {
logger("error while waiting for %s: %v", name, err)
time.Sleep(time.Second)
err = waiter(name, &target)
err = waiter.Wait(name, &target)
}

if err != nil {
Expand Down
33 changes: 17 additions & 16 deletions waitfor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@ package waitfor
import (
"errors"
"fmt"
"github.com/phayes/freeport"
"google.golang.org/grpc"
"net"
"testing"
"time"

"github.com/phayes/freeport"
"google.golang.org/grpc"

"github.com/spf13/afero"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
Expand Down Expand Up @@ -80,12 +81,12 @@ func TestOpenConfig_defaultHTTPTimeoutCanBeSet(t *testing.T) {
}

func TestWaitOn_errorsInvalidTarget(t *testing.T) {
err := WaitOn(NewConfig(), NullLogger, []string{"localhost"}, map[string]WaiterFunc{})
err := WaitOn(NewConfig(), NullLogger, []string{"localhost"}, map[string]Waiter{})
assert.Error(t, err)
}

func TestRun_errorsOnParseFailure(t *testing.T) {
err := WaitOn(NewConfig(), NullLogger, []string{"http://localhost"}, map[string]WaiterFunc{})
err := WaitOn(NewConfig(), NullLogger, []string{"http://localhost"}, map[string]Waiter{})
assert.Error(t, err)
}

Expand All @@ -97,7 +98,7 @@ func TestWaitOnSingleTarget_succeedsImmediately(t *testing.T) {
"name",
doLog,
TargetConfig{Timeout: time.Second * 2},
func(name string, target *TargetConfig) error { return nil },
WaiterFunc(func(name string, target *TargetConfig) error { return nil }),
)

assert.NoError(t, err)
Expand All @@ -115,12 +116,12 @@ func TestWaitOnSingleTarget_succeedsAfterWaiting(t *testing.T) {
"name",
doLog,
TargetConfig{Timeout: time.Second * 2},
func(name string, target *TargetConfig) error {
WaiterFunc(func(name string, target *TargetConfig) error {
if waitUntil.After(time.Now()) {
return fmt.Errorf("there was an error")
}
return nil
},
}),
)

assert.NoError(t, err)
Expand All @@ -136,9 +137,9 @@ func TestWaitOnSingleTarget_failsIfTimerExpires(t *testing.T) {
"name",
doLog,
TargetConfig{Timeout: time.Second * 2},
func(name string, target *TargetConfig) error {
WaiterFunc(func(name string, target *TargetConfig) error {
return fmt.Errorf("")
},
}),
)

assert.Error(t, err)
Expand All @@ -149,7 +150,7 @@ func TestWaitOnTargets_failsForUnknownType(t *testing.T) {
err := waitOnTargets(
NullLogger,
map[string]TargetConfig{"unkown": {Type: "unknown type"}},
map[string]WaiterFunc{"type": func(string, *TargetConfig) error { return errors.New("") }},
map[string]Waiter{"type": WaiterFunc(func(string, *TargetConfig) error { return errors.New("") })},
)

require.Error(t, err)
Expand All @@ -162,9 +163,9 @@ func TestWaitOnTargets_selectsCorrectWaiter(t *testing.T) {
map[string]TargetConfig{
"type 1": {Type: "t1"},
},
map[string]WaiterFunc{
"t1": func(string, *TargetConfig) error { return nil },
"t2": func(string, *TargetConfig) error { return errors.New("an error") },
map[string]Waiter{
"t1": WaiterFunc(func(string, *TargetConfig) error { return nil }),
"t2": WaiterFunc(func(string, *TargetConfig) error { return errors.New("an error") }),
},
)

Expand All @@ -178,9 +179,9 @@ func TestWaitOnTargets_failsWhenWaiterFails(t *testing.T) {
"type 1": {Type: "t1"},
"type 2": {Type: "t2"},
},
map[string]WaiterFunc{
"t1": func(string, *TargetConfig) error { return nil },
"t2": func(string, *TargetConfig) error { return errors.New("an error") },
map[string]Waiter{
"t1": WaiterFunc(func(string, *TargetConfig) error { return nil }),
"t2": WaiterFunc(func(string, *TargetConfig) error { return errors.New("an error") }),
},
)

Expand Down

0 comments on commit ea89ab6

Please sign in to comment.