From 9cf438259e225d8f49e52a818fe69c30971b928c Mon Sep 17 00:00:00 2001 From: Wesley Merkel Date: Thu, 9 Jan 2025 09:49:21 -0700 Subject: [PATCH] fix(sftp): return scanner from (*sftpReader).initScanner This prevents a race condition between two calls to ReadBatch clobbering each other. --- internal/impl/sftp/input.go | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/internal/impl/sftp/input.go b/internal/impl/sftp/input.go index 8e3c9f17d..f685c1761 100644 --- a/internal/impl/sftp/input.go +++ b/internal/impl/sftp/input.go @@ -209,18 +209,19 @@ func (s *sftpReader) ReadBatch(ctx context.Context) (service.MessageBatch, servi } func (s *sftpReader) tryReadBatch(ctx context.Context) (service.MessageBatch, service.AckFunc, error) { - if err := s.initScanner(ctx); err != nil { + scanner, err := s.initScanner(ctx) + if err != nil { return nil, nil, err } - s.stateLock.Lock() - defer s.stateLock.Unlock() - - parts, codecAckFn, err := s.scanner.NextBatch(ctx) + parts, codecAckFn, err := scanner.NextBatch(ctx) if err != nil { if ctx.Err() != nil { return nil, nil, ctx.Err() } + s.stateLock.Lock() + defer s.stateLock.Unlock() + _ = s.scanner.Close(ctx) s.scanner = nil s.currentPath = "" @@ -237,12 +238,12 @@ func (s *sftpReader) tryReadBatch(ctx context.Context) (service.MessageBatch, se return parts, codecAckFn, nil } -func (s *sftpReader) initScanner(ctx context.Context) error { +func (s *sftpReader) initScanner(ctx context.Context) (codec.DeprecatedFallbackStream, error) { s.stateLock.Lock() scanner := s.scanner s.stateLock.Unlock() if scanner != nil { - return nil + return scanner, nil } var file *sftp.File @@ -252,10 +253,10 @@ func (s *sftpReader) initScanner(ctx context.Context) error { var err error path, ok, err = s.pathProvider.Next(ctx) if err != nil { - return fmt.Errorf("finding next file path: %w", err) + return nil, fmt.Errorf("finding next file path: %w", err) } if !ok { - return service.ErrEndOfInput + return nil, service.ErrEndOfInput } file, err = s.client.Open(path) @@ -278,14 +279,14 @@ func (s *sftpReader) initScanner(ctx context.Context) error { scanner, err := s.scannerCtor.Create(file, s.newCodecAckFn(path), details) if err != nil { _ = file.Close() - return fmt.Errorf("creating scanner: %w", err) + return nil, fmt.Errorf("creating scanner: %w", err) } s.stateLock.Lock() s.scanner = scanner s.currentPath = path s.stateLock.Unlock() - return nil + return scanner, nil } } @@ -307,7 +308,6 @@ func (s *sftpReader) newCodecAckFn(path string) service.AckFunc { } } - time.Sleep(time.Millisecond * 100) return nil } }