From 7ae2895766682379ef475909688ca309ea1843a7 Mon Sep 17 00:00:00 2001 From: Rustam Gilyazov <16064414+rusq@users.noreply.github.com> Date: Sun, 12 Mar 2023 17:07:16 +1000 Subject: [PATCH] use atomic.Bool in downloader --- downloader/downloader.go | 39 +++++++++++++--------------- downloader/downloader_test.go | 8 +++--- internal/chunk/processor/standard.go | 7 +++-- 3 files changed, 27 insertions(+), 27 deletions(-) diff --git a/downloader/downloader.go b/downloader/downloader.go index 1a8332a5..0c4bcf5b 100644 --- a/downloader/downloader.go +++ b/downloader/downloader.go @@ -10,6 +10,7 @@ import ( "path/filepath" "runtime/trace" "sync" + "sync/atomic" "errors" @@ -38,10 +39,9 @@ type Client struct { retries int workers int - mu sync.Mutex // mutex prevents race condition when starting/stopping fileRequests chan fileRequest wg *sync.WaitGroup - started bool + started atomic.Bool nameFn FilenameFunc } @@ -148,10 +148,7 @@ type fileRequest struct { // Start starts an async file downloader. If the downloader is already // started, it does nothing. func (c *Client) Start(ctx context.Context) { - c.mu.Lock() - defer c.mu.Unlock() - - if c.started { + if c.started.Load() { // already started return } @@ -159,7 +156,7 @@ func (c *Client) Start(ctx context.Context) { c.fileRequests = req c.wg = c.startWorkers(ctx, req) - c.started = true + c.started.Store(true) } // startWorkers starts download workers. It returns a sync.WaitGroup. If the @@ -203,7 +200,7 @@ func (c *Client) worker(ctx context.Context, reqC <-chan fileRequest) { c.l().Printf("error saving %q to %q: %s", c.nameFn(req.File), req.Directory, err) break } - c.l().Printf("file %q saved to %s: %d bytes written", c.nameFn(req.File), req.Directory, n) + c.l().Debugf("file %q saved to %s: %d bytes written", c.nameFn(req.File), req.Directory, n) } } } @@ -223,7 +220,13 @@ func (c *Client) AsyncDownloader(ctx context.Context, dir string, fileDlQueue <- req := make(chan fileRequest) go func() { defer close(req) - for f := range fileDlQueue { + select { + case <-ctx.Done(): + return + case f, more := <-fileDlQueue: + if !more { + return + } req <- fileRequest{Directory: dir, File: f} } }() @@ -254,13 +257,14 @@ func (c *Client) saveFile(ctx context.Context, dir string, sf *slack.File) (int6 tf.Close() os.Remove(tf.Name()) }() + tfReset := func() error { _, err := tf.Seek(0, io.SeekStart); return err } if err := network.WithRetry(ctx, c.limiter, c.retries, func() error { region := trace.StartRegion(ctx, "GetFile") defer region.End() if err := c.client.GetFile(sf.URLPrivateDownload, tf); err != nil { - if _, err := tf.Seek(0, io.SeekStart); err != nil { + if err := tfReset(); err != nil { c.l().Debugf("seek error: %s", err) } return fmt.Errorf("download to %q failed, [src=%s]: %w", filePath, sf.URLPrivateDownload, err) @@ -272,7 +276,7 @@ func (c *Client) saveFile(ctx context.Context, dir string, sf *slack.File) (int6 // at this point, temporary file position would be at EOF, we need to reset // it prior to copying. - if _, err := tf.Seek(0, io.SeekStart); err != nil { + if err := tfReset(); err != nil { return 0, err } @@ -296,10 +300,7 @@ func stdFilenameFn(f *slack.File) string { // Stop waits for all transfers to finish, and stops the downloader. func (c *Client) Stop() { - c.mu.Lock() - defer c.mu.Unlock() - - if !c.started { + if !c.started.Load() { return } @@ -310,7 +311,7 @@ func (c *Client) Stop() { c.fileRequests = nil c.wg = nil - c.started = false + c.started.Store(false) } var ErrNotStarted = errors.New("downloader not started") @@ -321,11 +322,7 @@ var ErrNotStarted = errors.New("downloader not started") // is full, will block until it becomes empty. It returns the filepath within the // filesystem. func (c *Client) DownloadFile(dir string, f slack.File) (string, error) { - c.mu.Lock() - started := c.started - c.mu.Unlock() - - if !started { + if !c.started.Load() { return "", ErrNotStarted } c.fileRequests <- fileRequest{Directory: dir, File: &f} diff --git a/downloader/downloader_test.go b/downloader/downloader_test.go index 2774127a..ae1398d0 100644 --- a/downloader/downloader_test.go +++ b/downloader/downloader_test.go @@ -393,7 +393,7 @@ func TestClient_Start(t *testing.T) { c.Start(context.Background()) defer c.Stop() - assert.True(t, c.started) + assert.True(t, c.started.Load()) assert.NotNil(t, c.wg) assert.NotNil(t, c.fileRequests) }) @@ -404,17 +404,17 @@ func TestClient_Stop(t *testing.T) { t.Run("ensure stopped", func(t *testing.T) { c := clientWithMock(t, tmpdir) c.Start(context.Background()) - assert.True(t, c.started) + assert.True(t, c.started.Load()) c.Stop() - assert.False(t, c.started) + assert.False(t, c.started.Load()) assert.Nil(t, c.fileRequests) assert.Nil(t, c.wg) }) t.Run("stop on stopped downloader does nothing", func(t *testing.T) { c := clientWithMock(t, tmpdir) c.Stop() - assert.False(t, c.started) + assert.False(t, c.started.Load()) assert.Nil(t, c.fileRequests) assert.Nil(t, c.wg) }) diff --git a/internal/chunk/processor/standard.go b/internal/chunk/processor/standard.go index c116e8ff..24745df2 100644 --- a/internal/chunk/processor/standard.go +++ b/internal/chunk/processor/standard.go @@ -46,6 +46,10 @@ func (s *Standard) Files(ctx context.Context, channelID string, parent slack.Mes // ignore files if requested return nil } + st, err := s.Recorder.State() + if err != nil { + return err + } // custom file processor, because we need to donwload those files for i := range ff { if ff[i].Mode == "hidden_by_limit" { @@ -57,8 +61,7 @@ func (s *Standard) Files(ctx context.Context, channelID string, parent slack.Mes if err != nil { return err } - s, _ := s.Recorder.State() - s.AddFile(channelID, ff[i].ID, filename) + st.AddFile(channelID, ff[i].ID, filename) } return nil }