Skip to content

Commit

Permalink
rollback downloader changes - regressed
Browse files Browse the repository at this point in the history
  • Loading branch information
rusq committed Mar 17, 2023
1 parent 2924b0e commit ff21836
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 58 deletions.
67 changes: 32 additions & 35 deletions downloader/downloader.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ import (
"path/filepath"
"runtime/trace"
"sync"
"sync/atomic"

"errors"

Expand All @@ -23,10 +22,10 @@ import (
)

const (
retries = 3 // default number of retries if download fails
numWorkers = 4 // number of download processes
maxEvtPerSec = 5000 // default API limit, in events per second.
downloadBufSz = 100 // default download channel buffer.
defRetries = 3 // default number of retries if download fails
defNumWorkers = 4 // number of download processes
defLimit = 5000 // default API limit, in events per second.
defFileBufSz = 100 // default download channel buffer.
)

// Client is the instance of the downloader.
Expand All @@ -39,9 +38,10 @@ type Client struct {
retries int
workers int

mu sync.Mutex // mutex prevents race condition when starting/stopping
fileRequests chan fileRequest
wg *sync.WaitGroup
started atomic.Bool
started bool

nameFn FilenameFunc
}
Expand All @@ -55,8 +55,6 @@ var Filename FilenameFunc = stdFilenameFn

// Downloader is the file downloader interface. It exists primarily for mocking
// in tests.
//
//go:generate mockgen -destination ../internal/mocks/mock_downloader/mock_downloader.go github.com/rusq/slackdump/v2/downloader Downloader
type Downloader interface {
// GetFile retreives a given file from its private download URL
GetFile(downloadURL string, writer io.Writer) error
Expand All @@ -78,7 +76,7 @@ func Limiter(l *rate.Limiter) Option {
func Retries(n int) Option {
return func(c *Client) {
if n <= 0 {
n = retries
n = defRetries
}
c.retries = n
}
Expand All @@ -88,7 +86,7 @@ func Retries(n int) Option {
func Workers(n int) Option {
return func(c *Client) {
if n <= 0 {
n = numWorkers
n = defNumWorkers
}
c.workers = n
}
Expand Down Expand Up @@ -124,9 +122,9 @@ func New(client Downloader, fs fsadapter.FS, opts ...Option) *Client {
c := &Client{
client: client,
fs: fs,
limiter: rate.NewLimiter(maxEvtPerSec, 1),
retries: retries,
workers: numWorkers,
limiter: rate.NewLimiter(defLimit, 1),
retries: defRetries,
workers: defNumWorkers,
nameFn: Filename,
}
for _, opt := range opts {
Expand All @@ -148,22 +146,25 @@ type fileRequest struct {
// Start starts an async file downloader. If the downloader is already
// started, it does nothing.
func (c *Client) Start(ctx context.Context) {
if c.started.Load() {
c.mu.Lock()
defer c.mu.Unlock()

if c.started {
// already started
return
}
req := make(chan fileRequest, downloadBufSz)
req := make(chan fileRequest, defFileBufSz)

c.fileRequests = req
c.wg = c.startWorkers(ctx, req)
c.started.Store(true)
c.started = true
}

// startWorkers starts download workers. It returns a sync.WaitGroup. If the
// req channel is closed, workers will stop, and wg.Wait() completes.
func (c *Client) startWorkers(ctx context.Context, req <-chan fileRequest) *sync.WaitGroup {
if c.workers == 0 {
c.workers = numWorkers
c.workers = defNumWorkers
}
var wg sync.WaitGroup
// create workers
Expand All @@ -190,17 +191,13 @@ func (c *Client) worker(ctx context.Context, reqC <-chan fileRequest) {
if !moar {
return
}
if req.File.Mode == "hidden_by_limit" {
c.l().Printf("skipping %q because it's hidden by Slack due to 90 days limit", c.nameFn(req.File))
continue
}
c.l().Debugf("saving %q to %s, size: %d", c.nameFn(req.File), req.Directory, req.File.Size)
n, err := c.saveFile(ctx, req.Directory, req.File)
if err != nil {
c.l().Printf("error saving %q to %q: %s", c.nameFn(req.File), req.Directory, err)
break
}
c.l().Debugf("file %q saved to %s: %d bytes written", c.nameFn(req.File), req.Directory, n)
c.l().Printf("file %q saved to %s: %d bytes written", c.nameFn(req.File), req.Directory, n)
}
}
}
Expand All @@ -220,13 +217,7 @@ func (c *Client) AsyncDownloader(ctx context.Context, dir string, fileDlQueue <-
req := make(chan fileRequest)
go func() {
defer close(req)
select {
case <-ctx.Done():
return
case f, more := <-fileDlQueue:
if !more {
return
}
for f := range fileDlQueue {
req <- fileRequest{Directory: dir, File: f}
}
}()
Expand Down Expand Up @@ -257,14 +248,13 @@ func (c *Client) saveFile(ctx context.Context, dir string, sf *slack.File) (int6
tf.Close()
os.Remove(tf.Name())
}()
tfReset := func() error { _, err := tf.Seek(0, io.SeekStart); return err }

if err := network.WithRetry(ctx, c.limiter, c.retries, func() error {
region := trace.StartRegion(ctx, "GetFile")
defer region.End()

if err := c.client.GetFile(sf.URLPrivateDownload, tf); err != nil {
if err := tfReset(); err != nil {
if _, err := tf.Seek(0, io.SeekStart); err != nil {
c.l().Debugf("seek error: %s", err)
}
return fmt.Errorf("download to %q failed, [src=%s]: %w", filePath, sf.URLPrivateDownload, err)
Expand All @@ -276,7 +266,7 @@ func (c *Client) saveFile(ctx context.Context, dir string, sf *slack.File) (int6

// at this point, temporary file position would be at EOF, we need to reset
// it prior to copying.
if err := tfReset(); err != nil {
if _, err := tf.Seek(0, io.SeekStart); err != nil {
return 0, err
}

Expand All @@ -300,7 +290,10 @@ func stdFilenameFn(f *slack.File) string {

// Stop waits for all transfers to finish, and stops the downloader.
func (c *Client) Stop() {
if !c.started.Load() {
c.mu.Lock()
defer c.mu.Unlock()

if !c.started {
return
}

Expand All @@ -311,7 +304,7 @@ func (c *Client) Stop() {

c.fileRequests = nil
c.wg = nil
c.started.Store(false)
c.started = false
}

var ErrNotStarted = errors.New("downloader not started")
Expand All @@ -322,7 +315,11 @@ var ErrNotStarted = errors.New("downloader not started")
// is full, will block until it becomes empty. It returns the filepath within the
// filesystem.
func (c *Client) DownloadFile(dir string, f slack.File) (string, error) {
if !c.started.Load() {
c.mu.Lock()
started := c.started
c.mu.Unlock()

if !started {
return "", ErrNotStarted
}
c.fileRequests <- fileRequest{Directory: dir, File: &f}
Expand Down
46 changes: 23 additions & 23 deletions downloader/downloader_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,10 @@ func TestSession_SaveFileTo(t *testing.T) {
{
"ok",
fields{
l: rate.NewLimiter(maxEvtPerSec, 1),
l: rate.NewLimiter(defLimit, 1),
fs: fsadapter.NewDirectory(tmpdir),
retries: retries,
workers: numWorkers,
retries: defRetries,
workers: defNumWorkers,
nameFn: Filename,
},
args{
Expand All @@ -79,10 +79,10 @@ func TestSession_SaveFileTo(t *testing.T) {
{
"getfile rekt",
fields{
l: rate.NewLimiter(maxEvtPerSec, 1),
l: rate.NewLimiter(defLimit, 1),
fs: fsadapter.NewDirectory(tmpdir),
retries: retries,
workers: numWorkers,
retries: defRetries,
workers: defNumWorkers,
nameFn: Filename,
},
args{
Expand Down Expand Up @@ -152,10 +152,10 @@ func TestSession_saveFile(t *testing.T) {
{
"ok",
fields{
l: rate.NewLimiter(maxEvtPerSec, 1),
l: rate.NewLimiter(defLimit, 1),
fs: fsadapter.NewDirectory(tmpdir),
retries: retries,
workers: numWorkers,
retries: defRetries,
workers: defNumWorkers,
nameFn: Filename,
},
args{
Expand All @@ -175,10 +175,10 @@ func TestSession_saveFile(t *testing.T) {
{
"getfile rekt",
fields{
l: rate.NewLimiter(maxEvtPerSec, 1),
l: rate.NewLimiter(defLimit, 1),
fs: fsadapter.NewDirectory(tmpdir),
retries: retries,
workers: numWorkers,
retries: defRetries,
workers: defNumWorkers,
nameFn: Filename,
},
args{
Expand Down Expand Up @@ -243,7 +243,7 @@ func Test_filename(t *testing.T) {
}

func TestSession_newFileDownloader(t *testing.T) {
tl := rate.NewLimiter(maxEvtPerSec, 1)
tl := rate.NewLimiter(defLimit, 1)
tmpdir, err := os.MkdirTemp("", "")
if err != nil {
t.Fatal(err)
Expand Down Expand Up @@ -283,16 +283,16 @@ func TestSession_newFileDownloader(t *testing.T) {
}

func TestSession_worker(t *testing.T) {
tl := rate.NewLimiter(maxEvtPerSec, 1)
tl := rate.NewLimiter(defLimit, 1)
tmpdir := t.TempDir()

newClient := func(mc *mock_downloader.MockDownloader) *Client {
return &Client{
client: mc,
fs: fsadapter.NewDirectory(tmpdir),
limiter: tl,
retries: retries,
workers: numWorkers,
retries: defRetries,
workers: defNumWorkers,
nameFn: Filename,
}
}
Expand Down Expand Up @@ -359,14 +359,14 @@ func TestClient_startWorkers(t *testing.T) {
client: dc,
fs: fsadapter.NewDirectory(t.TempDir()),
limiter: rate.NewLimiter(5000, 1),
workers: numWorkers,
workers: defNumWorkers,
nameFn: Filename,
}

dc.EXPECT().GetFile(gomock.Any(), gomock.Any()).Times(qSz).Return(nil)

fileQueue := makeFileReqQ(qSz, t.TempDir())
fileChan := slice2chan(fileQueue, downloadBufSz)
fileChan := slice2chan(fileQueue, defFileBufSz)
wg := cl.startWorkers(context.Background(), fileChan)

wg.Wait()
Expand All @@ -393,7 +393,7 @@ func TestClient_Start(t *testing.T) {
c.Start(context.Background())
defer c.Stop()

assert.True(t, c.started.Load())
assert.True(t, c.started)
assert.NotNil(t, c.wg)
assert.NotNil(t, c.fileRequests)
})
Expand All @@ -404,17 +404,17 @@ func TestClient_Stop(t *testing.T) {
t.Run("ensure stopped", func(t *testing.T) {
c := clientWithMock(t, tmpdir)
c.Start(context.Background())
assert.True(t, c.started.Load())
assert.True(t, c.started)

c.Stop()
assert.False(t, c.started.Load())
assert.False(t, c.started)
assert.Nil(t, c.fileRequests)
assert.Nil(t, c.wg)
})
t.Run("stop on stopped downloader does nothing", func(t *testing.T) {
c := clientWithMock(t, tmpdir)
c.Stop()
assert.False(t, c.started.Load())
assert.False(t, c.started)
assert.Nil(t, c.fileRequests)
assert.Nil(t, c.wg)
})
Expand All @@ -427,7 +427,7 @@ func clientWithMock(t *testing.T, dir string) *Client {
client: dc,
fs: fsadapter.NewDirectory(dir),
limiter: rate.NewLimiter(5000, 1),
workers: numWorkers,
workers: defNumWorkers,
nameFn: Filename,
}
return c
Expand Down

0 comments on commit ff21836

Please sign in to comment.