Skip to content

Commit

Permalink
fix(sftp): return scanner from (*sftpReader).initScanner
Browse files Browse the repository at this point in the history
This prevents a race condition between two calls to ReadBatch clobbering
each other.
  • Loading branch information
ooesili committed Jan 9, 2025
1 parent a2838f1 commit 9cf4382
Showing 1 changed file with 12 additions and 12 deletions.
24 changes: 12 additions & 12 deletions internal/impl/sftp/input.go
Original file line number Diff line number Diff line change
Expand Up @@ -209,18 +209,19 @@ func (s *sftpReader) ReadBatch(ctx context.Context) (service.MessageBatch, servi
}

func (s *sftpReader) tryReadBatch(ctx context.Context) (service.MessageBatch, service.AckFunc, error) {
if err := s.initScanner(ctx); err != nil {
scanner, err := s.initScanner(ctx)
if err != nil {
return nil, nil, err
}

s.stateLock.Lock()
defer s.stateLock.Unlock()

parts, codecAckFn, err := s.scanner.NextBatch(ctx)
parts, codecAckFn, err := scanner.NextBatch(ctx)
if err != nil {
if ctx.Err() != nil {
return nil, nil, ctx.Err()
}
s.stateLock.Lock()
defer s.stateLock.Unlock()

_ = s.scanner.Close(ctx)
s.scanner = nil
s.currentPath = ""
Expand All @@ -237,12 +238,12 @@ func (s *sftpReader) tryReadBatch(ctx context.Context) (service.MessageBatch, se
return parts, codecAckFn, nil
}

func (s *sftpReader) initScanner(ctx context.Context) error {
func (s *sftpReader) initScanner(ctx context.Context) (codec.DeprecatedFallbackStream, error) {
s.stateLock.Lock()
scanner := s.scanner
s.stateLock.Unlock()
if scanner != nil {
return nil
return scanner, nil
}

var file *sftp.File
Expand All @@ -252,10 +253,10 @@ func (s *sftpReader) initScanner(ctx context.Context) error {
var err error
path, ok, err = s.pathProvider.Next(ctx)
if err != nil {
return fmt.Errorf("finding next file path: %w", err)
return nil, fmt.Errorf("finding next file path: %w", err)
}
if !ok {
return service.ErrEndOfInput
return nil, service.ErrEndOfInput
}

file, err = s.client.Open(path)
Expand All @@ -278,14 +279,14 @@ func (s *sftpReader) initScanner(ctx context.Context) error {
scanner, err := s.scannerCtor.Create(file, s.newCodecAckFn(path), details)
if err != nil {
_ = file.Close()
return fmt.Errorf("creating scanner: %w", err)
return nil, fmt.Errorf("creating scanner: %w", err)
}

s.stateLock.Lock()
s.scanner = scanner
s.currentPath = path
s.stateLock.Unlock()
return nil
return scanner, nil
}
}

Expand All @@ -307,7 +308,6 @@ func (s *sftpReader) newCodecAckFn(path string) service.AckFunc {
}
}

time.Sleep(time.Millisecond * 100)
return nil
}
}
Expand Down

0 comments on commit 9cf4382

Please sign in to comment.