Skip to content

Commit

Permalink
backport logger and network changes
Browse files Browse the repository at this point in the history
  • Loading branch information
rusq committed Apr 7, 2024
1 parent 5e57b38 commit 3cb1506
Show file tree
Hide file tree
Showing 6 changed files with 147 additions and 77 deletions.
2 changes: 0 additions & 2 deletions export/export.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ import (
"golang.org/x/sync/errgroup"

"github.com/rusq/slackdump/v2"
"github.com/rusq/slackdump/v2/internal/network"
"github.com/rusq/slackdump/v2/internal/structures"
"github.com/rusq/slackdump/v2/internal/structures/files/dl"
"github.com/rusq/slackdump/v2/logger"
Expand All @@ -38,7 +37,6 @@ func New(sd *slackdump.Session, fs fsadapter.FS, cfg Options) *Export {
if cfg.Logger == nil {
cfg.Logger = logger.Default
}
network.SetLogger(cfg.Logger)

se := &Export{
fs: fs,
Expand Down
50 changes: 30 additions & 20 deletions internal/network/network.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ var (
// The wait time for a transient error depends on the current retry
// attempt number and is calculated as: (attempt+2)^3 seconds, capped at
// maxAllowedWaitTime.
maxAllowedWaitTime = 5 * time.Minute
lg logger.Interface = logger.Default
maxAllowedWaitTime = 5 * time.Minute

// waitFn returns the amount of time to wait before retrying depending on
// the current attempt. This variable exists to reduce the test time.
waitFn = cubicWait
Expand All @@ -38,17 +38,36 @@ var (

// ErrRetryFailed is returned if number of retry attempts exceeded the retry attempts limit and
// function wasn't able to complete without errors.
var ErrRetryFailed = errors.New("callback was unable to complete without errors within the allowed number of retries")
type ErrRetryFailed struct {
Err error
}

func (e *ErrRetryFailed) Error() string {
return fmt.Sprintf("callback was unable to complete without errors within the allowed number of retries: %s", e.Err)
}

func (e *ErrRetryFailed) Unwrap() error {
return e.Err
}

func (e *ErrRetryFailed) Is(target error) bool {
_, ok := target.(*ErrRetryFailed)
return ok
}

// WithRetry will run the callback function fn. If the function returns
// slack.RateLimitedError, it will delay, and then call it again up to
// maxAttempts times. It will return an error if it runs out of attempts.
func WithRetry(ctx context.Context, lim *rate.Limiter, maxAttempts int, fn func() error) error {
tracelogf(ctx, "info", "maxAttempts=%d", maxAttempts)
var ok bool
if maxAttempts == 0 {
maxAttempts = defNumAttempts
}

var lastErr error
for attempt := 0; attempt < maxAttempts; attempt++ {
// calling wait to ensure that we don't exceed the rate limit
var err error
trace.WithRegion(ctx, "WithRetry.wait", func() {
err = lim.Wait(ctx)
Expand All @@ -59,9 +78,11 @@ func WithRetry(ctx context.Context, lim *rate.Limiter, maxAttempts int, fn func(

cbErr := fn()
if cbErr == nil {
tracelogf(ctx, "info", "success")
ok = true
break
}
lastErr = cbErr

tracelogf(ctx, "error", "WithRetry: %[1]s (%[1]T) after %[2]d attempts", cbErr, attempt+1)
var (
Expand All @@ -71,22 +92,22 @@ func WithRetry(ctx context.Context, lim *rate.Limiter, maxAttempts int, fn func(
)
switch {
case errors.As(cbErr, &rle):
tracelogf(ctx, "info", "got rate limited, sleeping %s", rle.RetryAfter)
tracelogf(ctx, "info", "got rate limited, sleeping %s (%s)", rle.RetryAfter, cbErr)
time.Sleep(rle.RetryAfter)
continue
case errors.As(cbErr, &sce):
if isRecoverable(sce.Code) {
// possibly transient error
delay := waitFn(attempt)
tracelogf(ctx, "info", "got server error %d, sleeping %s", sce.Code, delay)
tracelogf(ctx, "info", "got server error %d, sleeping %s (%s)", sce.Code, delay, cbErr)
time.Sleep(delay)
continue
}
case errors.As(cbErr, &ne):
if ne.Op == "read" || ne.Op == "write" {
// possibly transient error
delay := netWaitFn(attempt)
tracelogf(ctx, "info", "got network error %s, sleeping %s", ne.Op, delay)
tracelogf(ctx, "info", "got network error %s on %q, sleeping %s", cbErr, ne.Op, delay)
time.Sleep(delay)
continue
}
Expand All @@ -95,7 +116,7 @@ func WithRetry(ctx context.Context, lim *rate.Limiter, maxAttempts int, fn func(
return fmt.Errorf("callback error: %w", cbErr)
}
if !ok {
return ErrRetryFailed
return &ErrRetryFailed{Err: lastErr}
}
return nil
}
Expand All @@ -109,7 +130,7 @@ func isRecoverable(statusCode int) bool {
// where x is the current attempt number. The maximum wait time is capped at 5
// minutes.
func cubicWait(attempt int) time.Duration {
x := attempt + 2 // this is to ensure that we sleep at least 8 seconds.
x := attempt + 1 // this is to ensure that we sleep at least a second.
delay := time.Duration(x*x*x) * time.Second
if delay > maxAllowedWaitTime {
return maxAllowedWaitTime
Expand All @@ -128,22 +149,11 @@ func expWait(attempt int) time.Duration {
func tracelogf(ctx context.Context, category string, fmt string, a ...any) {
mu.RLock()
defer mu.RUnlock()

lg := logger.FromContext(ctx)
trace.Logf(ctx, category, fmt, a...)
lg.Debugf(fmt, a...)
}

// SetLogger sets the package logger.
func SetLogger(l logger.Interface) {
mu.Lock()
defer mu.Unlock()
if l == nil {
l = logger.Default
return
}
lg = l
}

// SetMaxAllowedWaitTime sets the maximum time to wait for a transient error.
func SetMaxAllowedWaitTime(d time.Duration) {
mu.Lock()
Expand Down
112 changes: 63 additions & 49 deletions internal/network/network_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"time"

"github.com/slack-go/slack"
"github.com/stretchr/testify/assert"
"golang.org/x/time/rate"
)

Expand Down Expand Up @@ -121,7 +122,7 @@ func Test_withRetry(t *testing.T) {
true,
calcRunDuration(testRateLimit, 2),
},
{"rate limiter test 4 limited attempts, 100 ms each",
{"rate limiter test 4 lmited attempts, 100 ms each",

Check failure on line 125 in internal/network/network_test.go

View workflow job for this annotation

GitHub Actions / Check for spelling errors

lmited ==> limited
args{
context.Background(),
rate.NewLimiter(10.0, 1),
Expand Down Expand Up @@ -177,36 +178,70 @@ func Test_withRetry(t *testing.T) {
}
})
}
}
t.Run("500 error handling", func(t *testing.T) {
waitFn = func(attempt int) time.Duration { return 50 * time.Millisecond }
defer func() {
waitFn = cubicWait
}()

var codes = []int{500, 502, 503, 504, 598}
for _, code := range codes {
var thisCode = code
// This test is to ensure that we handle 500 errors correctly.
t.Run(fmt.Sprintf("%d error", code), func(t *testing.T) {

const (
testRetryCount = 1
waitThreshold = 100 * time.Millisecond
)

// Create a test server that returns a 500 error.
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(thisCode)
}))
defer ts.Close()

// Create a new client with the test server as the endpoint.
client := slack.New("token", slack.OptionAPIURL(ts.URL+"/"))

start := time.Now()
// Call the client with a retry.
err := WithRetry(context.Background(), rate.NewLimiter(1, 1), testRetryCount, func() error {
_, err := client.GetConversationHistory(&slack.GetConversationHistoryParameters{})
if err == nil {
return errors.New("expected error, got nil")
}
return err
})
if err == nil {
t.Fatal("expected error, got nil")
}

func Test500ErrorHandling(t *testing.T) {
waitFn = func(attempt int) time.Duration { return 50 * time.Millisecond }
defer func() {
waitFn = cubicWait
}()
dur := time.Since(start)
if dur < waitFn(testRetryCount-1)-waitThreshold || waitFn(testRetryCount-1)+waitThreshold < dur {
t.Errorf("expected duration to be around %s, got %s", waitFn(testRetryCount), dur)
}

var codes = []int{500, 502, 503, 504, 598}
for _, code := range codes {
var thisCode = code
// This test is to ensure that we handle 500 errors correctly.
t.Run(fmt.Sprintf("%d error", code), func(t *testing.T) {
})
}
t.Run("404 error", func(t *testing.T) {
t.Parallel()

const (
testRetryCount = 1
waitThreshold = 100 * time.Millisecond
)

// Create a test server that returns a 500 error.
// Create a test server that returns a 404 error.
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(thisCode)
w.WriteHeader(404)
}))
defer ts.Close()

// Create a new client with the test server as the endpoint.
client := slack.New("token", slack.OptionAPIURL(ts.URL+"/"))

start := time.Now()
// Call the client with a retry.
start := time.Now()
err := WithRetry(context.Background(), rate.NewLimiter(1, 1), testRetryCount, func() error {
_, err := client.GetConversationHistory(&slack.GetConversationHistoryParameters{})
if err == nil {
Expand All @@ -217,46 +252,25 @@ func Test500ErrorHandling(t *testing.T) {
if err == nil {
t.Fatal("expected error, got nil")
}

dur := time.Since(start)
if dur < waitFn(testRetryCount-1)-waitThreshold || waitFn(testRetryCount-1)+waitThreshold < dur {
t.Errorf("expected duration to be around %s, got %s", waitFn(testRetryCount), dur)
if dur > 500*time.Millisecond { // 404 error should not be retried
t.Errorf("expected no sleep, but slept for %s", dur)
}

})
}
t.Run("404 error", func(t *testing.T) {
})
t.Run("meaningful error message", func(t *testing.T) {
t.Parallel()

const (
testRetryCount = 1
)

// Create a test server that returns a 404 error.
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(404)
}))
defer ts.Close()

// Create a new client with the test server as the endpoint.
client := slack.New("token", slack.OptionAPIURL(ts.URL+"/"))

// Call the client with a retry.
start := time.Now()
err := WithRetry(context.Background(), rate.NewLimiter(1, 1), testRetryCount, func() error {
_, err := client.GetConversationHistory(&slack.GetConversationHistoryParameters{})
if err == nil {
return errors.New("expected error, got nil")
}
return err
})
var errFunc = func() error {
return slack.StatusCodeError{Code: 500, Status: "Internal Server Error"}
}
err := WithRetry(context.Background(), rate.NewLimiter(1, 1), 1, errFunc)
if err == nil {
t.Fatal("expected error, got nil")
}
dur := time.Since(start)
if dur > 500*time.Millisecond { // 404 error should not be retried
t.Errorf("expected no sleep, but slept for %s", dur)
}
assert.ErrorContains(t, err, "Internal Server Error")
assert.ErrorIs(t, err, &ErrRetryFailed{})
var sce slack.StatusCodeError
assert.ErrorAs(t, errors.Unwrap(err), &sce)
})
}

Expand Down
42 changes: 38 additions & 4 deletions logger/logger.go
Original file line number Diff line number Diff line change
@@ -1,23 +1,57 @@
package logger

import (
"io"
"context"
"log"
"os"

"github.com/rusq/dlog"
)

// Interface is the interface for a logger.
type Interface interface {
Debug(...any)
Debugf(fmt string, a ...any)
Debugln(...any)
Print(...any)
Printf(fmt string, a ...any)
Println(...any)
IsDebug() bool
}

// Default is the default logger. It logs to stderr and debug logging can be
// enabled by setting the DEBUG environment variable to 1. For example:
//
// DEBUG=1 slackdump
var Default = dlog.New(log.Default().Writer(), "", log.LstdFlags, os.Getenv("DEBUG") == "1")

// note: previously ioutil.Discard which is not deprecated in favord of io.Discard
// so this is valid only from go1.16
var Silent = dlog.New(io.Discard, "", log.LstdFlags, false)
// Silent is a logger that does not log anything.
var Silent = silent{}

// Silent is a logger that does not log anything.
type silent struct{}

func (s silent) Debug(...any) {}
func (s silent) Debugf(fmt string, a ...any) {}
func (s silent) Debugln(...any) {}
func (s silent) Print(...any) {}
func (s silent) Printf(fmt string, a ...any) {}
func (s silent) Println(...any) {}
func (s silent) IsDebug() bool { return false }

type logCtx uint8

const (
logCtxKey logCtx = iota
)

func NewContext(ctx context.Context, l Interface) context.Context {
return context.WithValue(ctx, logCtxKey, l)
}

func FromContext(ctx context.Context) Interface {
if l, ok := ctx.Value(logCtxKey).(Interface); ok {
return l
}
return Default
}
16 changes: 16 additions & 0 deletions logger/logger_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
package logger

import "testing"

func BenchmarkSlientPrintf(b *testing.B) {
var l = Silent
for i := 0; i < b.N; i++ {
l.Printf("hello world, %s, %d", "foo", i)
}
// This benchmark compares the performance of the Silent logger when
// using io.Discard, and when using a no-op function.
// io.Discard: BenchmarkSlientPrintf-16 93075956 12.92 ns/op 8 B/op 0 allocs/op
// no-op func: BenchmarkSlientPrintf-16 1000000000 0.2364 ns/op 0 B/op 0 allocs/op
//
// Oh, look! We have an WINNER. The no-op function wins, no surprises.
}
2 changes: 0 additions & 2 deletions slackdump.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,6 @@ func NewWithOptions(ctx context.Context, authProvider auth.Provider, opts Option
fs: fsadapter.NewDirectory("."), // default is to save attachments to the current directory.
}

network.SetLogger(sd.l())

if err := os.MkdirAll(opts.CacheDir, 0700); err != nil {
return nil, fmt.Errorf("failed to create the cache directory: %s", err)
}
Expand Down

0 comments on commit 3cb1506

Please sign in to comment.