Skip to content

Commit

Permalink
use atomic.Bool in downloader
Browse files Browse the repository at this point in the history
  • Loading branch information
rusq committed Mar 12, 2023
1 parent f58c7b4 commit 7ae2895
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 27 deletions.
39 changes: 18 additions & 21 deletions downloader/downloader.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"path/filepath"
"runtime/trace"
"sync"
"sync/atomic"

"errors"

Expand Down Expand Up @@ -38,10 +39,9 @@ type Client struct {
retries int
workers int

mu sync.Mutex // mutex prevents race condition when starting/stopping
fileRequests chan fileRequest
wg *sync.WaitGroup
started bool
started atomic.Bool

nameFn FilenameFunc
}
Expand Down Expand Up @@ -148,18 +148,15 @@ type fileRequest struct {
// Start starts an async file downloader. If the downloader is already
// started, it does nothing.
func (c *Client) Start(ctx context.Context) {
c.mu.Lock()
defer c.mu.Unlock()

if c.started {
if c.started.Load() {
// already started
return
}
req := make(chan fileRequest, downloadBufSz)

c.fileRequests = req
c.wg = c.startWorkers(ctx, req)
c.started = true
c.started.Store(true)
}

// startWorkers starts download workers. It returns a sync.WaitGroup. If the
Expand Down Expand Up @@ -203,7 +200,7 @@ func (c *Client) worker(ctx context.Context, reqC <-chan fileRequest) {
c.l().Printf("error saving %q to %q: %s", c.nameFn(req.File), req.Directory, err)
break
}
c.l().Printf("file %q saved to %s: %d bytes written", c.nameFn(req.File), req.Directory, n)
c.l().Debugf("file %q saved to %s: %d bytes written", c.nameFn(req.File), req.Directory, n)
}
}
}
Expand All @@ -223,7 +220,13 @@ func (c *Client) AsyncDownloader(ctx context.Context, dir string, fileDlQueue <-
req := make(chan fileRequest)
go func() {
defer close(req)
for f := range fileDlQueue {
select {
case <-ctx.Done():
return
case f, more := <-fileDlQueue:
if !more {
return
}
req <- fileRequest{Directory: dir, File: f}
}
}()
Expand Down Expand Up @@ -254,13 +257,14 @@ func (c *Client) saveFile(ctx context.Context, dir string, sf *slack.File) (int6
tf.Close()
os.Remove(tf.Name())
}()
tfReset := func() error { _, err := tf.Seek(0, io.SeekStart); return err }

if err := network.WithRetry(ctx, c.limiter, c.retries, func() error {
region := trace.StartRegion(ctx, "GetFile")
defer region.End()

if err := c.client.GetFile(sf.URLPrivateDownload, tf); err != nil {
if _, err := tf.Seek(0, io.SeekStart); err != nil {
if err := tfReset(); err != nil {
c.l().Debugf("seek error: %s", err)
}
return fmt.Errorf("download to %q failed, [src=%s]: %w", filePath, sf.URLPrivateDownload, err)
Expand All @@ -272,7 +276,7 @@ func (c *Client) saveFile(ctx context.Context, dir string, sf *slack.File) (int6

// at this point, temporary file position would be at EOF, we need to reset
// it prior to copying.
if _, err := tf.Seek(0, io.SeekStart); err != nil {
if err := tfReset(); err != nil {
return 0, err
}

Expand All @@ -296,10 +300,7 @@ func stdFilenameFn(f *slack.File) string {

// Stop waits for all transfers to finish, and stops the downloader.
func (c *Client) Stop() {
c.mu.Lock()
defer c.mu.Unlock()

if !c.started {
if !c.started.Load() {
return
}

Expand All @@ -310,7 +311,7 @@ func (c *Client) Stop() {

c.fileRequests = nil
c.wg = nil
c.started = false
c.started.Store(false)
}

var ErrNotStarted = errors.New("downloader not started")
Expand All @@ -321,11 +322,7 @@ var ErrNotStarted = errors.New("downloader not started")
// is full, will block until it becomes empty. It returns the filepath within the
// filesystem.
func (c *Client) DownloadFile(dir string, f slack.File) (string, error) {
c.mu.Lock()
started := c.started
c.mu.Unlock()

if !started {
if !c.started.Load() {
return "", ErrNotStarted
}
c.fileRequests <- fileRequest{Directory: dir, File: &f}
Expand Down
8 changes: 4 additions & 4 deletions downloader/downloader_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -393,7 +393,7 @@ func TestClient_Start(t *testing.T) {
c.Start(context.Background())
defer c.Stop()

assert.True(t, c.started)
assert.True(t, c.started.Load())
assert.NotNil(t, c.wg)
assert.NotNil(t, c.fileRequests)
})
Expand All @@ -404,17 +404,17 @@ func TestClient_Stop(t *testing.T) {
t.Run("ensure stopped", func(t *testing.T) {
c := clientWithMock(t, tmpdir)
c.Start(context.Background())
assert.True(t, c.started)
assert.True(t, c.started.Load())

c.Stop()
assert.False(t, c.started)
assert.False(t, c.started.Load())
assert.Nil(t, c.fileRequests)
assert.Nil(t, c.wg)
})
t.Run("stop on stopped downloader does nothing", func(t *testing.T) {
c := clientWithMock(t, tmpdir)
c.Stop()
assert.False(t, c.started)
assert.False(t, c.started.Load())
assert.Nil(t, c.fileRequests)
assert.Nil(t, c.wg)
})
Expand Down
7 changes: 5 additions & 2 deletions internal/chunk/processor/standard.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,10 @@ func (s *Standard) Files(ctx context.Context, channelID string, parent slack.Mes
// ignore files if requested
return nil
}
st, err := s.Recorder.State()
if err != nil {
return err
}
// custom file processor, because we need to donwload those files
for i := range ff {
if ff[i].Mode == "hidden_by_limit" {
Expand All @@ -57,8 +61,7 @@ func (s *Standard) Files(ctx context.Context, channelID string, parent slack.Mes
if err != nil {
return err
}
s, _ := s.Recorder.State()
s.AddFile(channelID, ff[i].ID, filename)
st.AddFile(channelID, ff[i].ID, filename)
}
return nil
}
Expand Down

0 comments on commit 7ae2895

Please sign in to comment.