diff --git a/CHANGELOG.md b/CHANGELOG.md index 04fb930f26..bc2eda0ac9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -34,6 +34,7 @@ All notable changes to this project will be documented in this file. - Fix an issue in `aws_sqs` with refreshing in-flight message leases which could prevent acks from processed. (@rockwotj) - Fix an issue with `postgres_cdc` with TOAST values not being propagated with `REPLICA IDENTITY FULL`. (@rockwotj) - Fix a initial snapshot streaming consistency issue with `postgres_cdc`. (@rockwotj) +- Fix bug in `sftp` input where the last file was not deleted when `watcher` and `delete_on_finish` were enabled (@ooesili) ## 4.44.0 - 2024-12-13 diff --git a/internal/impl/sftp/client_pool.go b/internal/impl/sftp/client_pool.go new file mode 100644 index 0000000000..9fa40dce19 --- /dev/null +++ b/internal/impl/sftp/client_pool.go @@ -0,0 +1,126 @@ +// Copyright 2024 Redpanda Data, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sftp + +import ( + "errors" + "io/fs" + "sync" + + "github.com/pkg/sftp" +) + +func newClientPool(newClient func() (*sftp.Client, error)) (*clientPool, error) { + client, err := newClient() + if err != nil { + return nil, err + } + return &clientPool{ + newClient: newClient, + client: client, + }, nil +} + +type clientPool struct { + newClient func() (*sftp.Client, error) + + lock sync.Mutex + client *sftp.Client + closed bool +} + +func (c *clientPool) Open(path string) (*sftp.File, error) { + return clientPoolDoReturning(c, func(client *sftp.Client) (*sftp.File, error) { + return client.Open(path) + }) +} + +func (c *clientPool) Glob(path string) ([]string, error) { + return clientPoolDoReturning(c, func(client *sftp.Client) ([]string, error) { + return client.Glob(path) + }) +} + +func (c *clientPool) Stat(path string) (fs.FileInfo, error) { + return clientPoolDoReturning(c, func(client *sftp.Client) (fs.FileInfo, error) { + return client.Stat(path) + }) +} + +func (c *clientPool) Remove(path string) error { + return clientPoolDo(c, func(client *sftp.Client) error { + return client.Remove(path) + }) +} + +func (c *clientPool) Close() error { + c.lock.Lock() + defer c.lock.Unlock() + + if c.closed { + return nil + } + c.closed = true + + if c.client != nil { + err := c.client.Close() + c.client = nil + return err + } + return nil +} + +func clientPoolDo(c *clientPool, fn func(*sftp.Client) error) error { + _, err := clientPoolDoReturning(c, func(client *sftp.Client) (struct{}, error) { + err := fn(client) + return struct{}{}, err + }) + return err +} + +func clientPoolDoReturning[T any](c *clientPool, fn func(*sftp.Client) (T, error)) (T, error) { + c.lock.Lock() + defer c.lock.Unlock() + + var zero T + + // In the case that the clientPool is used from an AckFn after the input is + // closed, we create temporary client to fulfil the operation, then + // immediately close it. + if c.closed { + client, err := c.newClient() + if err != nil { + return zero, err + } + result, err := fn(client) + _ = client.Close() + return result, err + } + + if c.client == nil { + client, err := c.newClient() + if err != nil { + return zero, err + } + c.client = client + } + + result, err := fn(c.client) + if errors.Is(err, sftp.ErrSSHFxConnectionLost) { + _ = c.client.Close() + c.client = nil + } + return result, err +} diff --git a/internal/impl/sftp/input.go b/internal/impl/sftp/input.go index e4fb3bed94..f685c1761e 100644 --- a/internal/impl/sftp/input.go +++ b/internal/impl/sftp/input.go @@ -121,13 +121,13 @@ type sftpReader struct { watcherPollInterval time.Duration watcherMinAge time.Duration - pathProvider pathProvider - // State - scannerMut sync.Mutex - client *sftp.Client + stateLock sync.Mutex scanner codec.DeprecatedFallbackStream currentPath string + + client *clientPool + pathProvider pathProvider } func newSFTPReaderFromParsed(conf *service.ParsedConfig, mgr *service.Resources) (s *sftpReader, err error) { @@ -173,106 +173,45 @@ func newSFTPReaderFromParsed(conf *service.ParsedConfig, mgr *service.Resources) return } -func (s *sftpReader) Connect(ctx context.Context) (err error) { - s.scannerMut.Lock() - defer s.scannerMut.Unlock() +func (s *sftpReader) Connect(ctx context.Context) error { + s.stateLock.Lock() + defer s.stateLock.Unlock() - if s.scanner != nil { - return nil - } - - if s.client == nil { - if s.client, err = s.creds.GetClient(s.mgr.FS(), s.address); err != nil { - return + client, err := newClientPool(func() (*sftp.Client, error) { + return s.creds.GetClient(s.mgr.FS(), s.address) + }) + if err != nil { + if errors.Is(err, sftp.ErrSSHFxConnectionLost) { + err = service.ErrNotConnected } + return err } + s.client = client if s.pathProvider == nil { - s.pathProvider = s.getFilePathProvider(ctx) - } - - var nextPath string - var file *sftp.File - for { - if nextPath, err = s.pathProvider.Next(ctx, s.client); err != nil { - if errors.Is(err, sftp.ErrSshFxConnectionLost) { - _ = s.client.Close() - s.client = nil - return - } - if errors.Is(err, errEndOfPaths) { - err = service.ErrEndOfInput - } - return - } - - if file, err = s.client.Open(nextPath); err != nil { - if errors.Is(err, sftp.ErrSshFxConnectionLost) { - _ = s.client.Close() - s.client = nil - } - - s.log.With("path", nextPath, "err", err.Error()).Warn("Unable to open previously identified file") - if os.IsNotExist(err) { - // If we failed to open the file because it no longer exists - // then we can "ack" the path as we're done with it. - _ = s.pathProvider.Ack(ctx, nextPath, nil) - } else { - // Otherwise we "nack" it with the error as we'll want to - // reprocess it again later. - _ = s.pathProvider.Ack(ctx, nextPath, err) - } - } else { - break - } + s.pathProvider = s.getFilePathProvider(client) } + return nil +} - details := service.NewScannerSourceDetails() - details.SetName(nextPath) - if s.scanner, err = s.scannerCtor.Create(file, func(ctx context.Context, aErr error) (outErr error) { - _ = s.pathProvider.Ack(ctx, nextPath, aErr) - if aErr != nil { - return nil - } - if s.deleteOnFinish { - s.scannerMut.Lock() - client := s.client - if client == nil { - if client, outErr = s.creds.GetClient(s.mgr.FS(), s.address); outErr != nil { - outErr = fmt.Errorf("obtain private client: %w", outErr) - } - defer func() { - _ = client.Close() - }() - } - if outErr == nil { - if outErr = client.Remove(nextPath); outErr != nil { - outErr = fmt.Errorf("remove %v: %w", nextPath, outErr) - } - } - s.scannerMut.Unlock() +func (s *sftpReader) ReadBatch(ctx context.Context) (service.MessageBatch, service.AckFunc, error) { + parts, codecAckFn, err := s.tryReadBatch(ctx) + if err != nil { + if errors.Is(err, sftp.ErrSSHFxConnectionLost) { + s.stateLock.Lock() + s.closeScanner(ctx) + s.stateLock.Unlock() + err = service.ErrNotConnected } - return - }, details); err != nil { - _ = file.Close() - _ = s.pathProvider.Ack(ctx, nextPath, err) - return err + return nil, nil, err } - s.currentPath = nextPath - - s.log.Debugf("Consuming from file '%v'", nextPath) - return + return parts, codecAckFn, nil } -func (s *sftpReader) ReadBatch(ctx context.Context) (service.MessageBatch, service.AckFunc, error) { - s.scannerMut.Lock() - scanner := s.scanner - client := s.client - currentPath := s.currentPath - s.scannerMut.Unlock() - - if scanner == nil || client == nil { - return nil, nil, service.ErrNotConnected +func (s *sftpReader) tryReadBatch(ctx context.Context) (service.MessageBatch, service.AckFunc, error) { + scanner, err := s.initScanner(ctx) + if err != nil { + return nil, nil, err } parts, codecAckFn, err := scanner.NextBatch(ctx) @@ -280,13 +219,12 @@ func (s *sftpReader) ReadBatch(ctx context.Context) (service.MessageBatch, servi if ctx.Err() != nil { return nil, nil, ctx.Err() } - _ = scanner.Close(ctx) - s.scannerMut.Lock() - if s.currentPath == currentPath { - s.scanner = nil - s.currentPath = "" - } - s.scannerMut.Unlock() + s.stateLock.Lock() + defer s.stateLock.Unlock() + + _ = s.scanner.Close(ctx) + s.scanner = nil + s.currentPath = "" if errors.Is(err, io.EOF) { err = service.ErrNotConnected } @@ -294,42 +232,109 @@ func (s *sftpReader) ReadBatch(ctx context.Context) (service.MessageBatch, servi } for _, part := range parts { - part.MetaSetMut("sftp_path", currentPath) + part.MetaSetMut("sftp_path", s.currentPath) } - return parts, func(ctx context.Context, res error) error { - return codecAckFn(ctx, res) - }, nil + return parts, codecAckFn, nil } -func (s *sftpReader) Close(ctx context.Context) error { - s.scannerMut.Lock() +func (s *sftpReader) initScanner(ctx context.Context) (codec.DeprecatedFallbackStream, error) { + s.stateLock.Lock() scanner := s.scanner - s.scanner = nil - client := s.client - s.client = nil - s.paths = nil - s.scannerMut.Unlock() - + s.stateLock.Unlock() if scanner != nil { - if err := scanner.Close(ctx); err != nil { - s.log.With("error", err).Warn("Failed to close consumed file") + return scanner, nil + } + + var file *sftp.File + var path string + for { + var ok bool + var err error + path, ok, err = s.pathProvider.Next(ctx) + if err != nil { + return nil, fmt.Errorf("finding next file path: %w", err) + } + if !ok { + return nil, service.ErrEndOfInput + } + + file, err = s.client.Open(path) + if err != nil { + s.log.With("path", path, "err", err.Error()).Warn("Unable to open previously identified file") + if os.IsNotExist(err) { + // If we failed to open the file because it no longer exists then we + // can "ack" the path as we're done with it. Otherwise we "nack" it + // with the error as we'll want to reprocess it again later. + err = nil + } + if ackErr := s.pathProvider.Ack(ctx, path, err); ackErr != nil { + s.log.With("error", ackErr).Warnf("Failed to acknowledge path: %s", path) + } + continue } + + details := service.NewScannerSourceDetails() + details.SetName(path) + scanner, err := s.scannerCtor.Create(file, s.newCodecAckFn(path), details) + if err != nil { + _ = file.Close() + return nil, fmt.Errorf("creating scanner: %w", err) + } + + s.stateLock.Lock() + s.scanner = scanner + s.currentPath = path + s.stateLock.Unlock() + return scanner, nil } - if client != nil { - if err := client.Close(); err != nil { - s.log.With("error", err).Error("Failed to close client") +} + +func (s *sftpReader) newCodecAckFn(path string) service.AckFunc { + return func(ctx context.Context, aErr error) error { + s.stateLock.Lock() + defer s.stateLock.Unlock() + + if err := s.pathProvider.Ack(ctx, path, aErr); err != nil { + s.log.With("error", err).Warnf("Failed to acknowledge path: %s", path) + } + if aErr != nil { + return nil } + + if s.deleteOnFinish { + if err := s.client.Remove(path); err != nil { + return fmt.Errorf("remove %v: %w", path, err) + } + } + + return nil } - return nil } -//------------------------------------------------------------------------------ +func (s *sftpReader) Close(ctx context.Context) error { + s.stateLock.Lock() + defer s.stateLock.Unlock() -var errEndOfPaths = errors.New("end of paths") + s.closeScanner(ctx) + if err := s.client.Close(); err != nil { + s.log.With("error", err).Error("Failed to close client") + } + return nil +} + +func (s *sftpReader) closeScanner(ctx context.Context) { + if s.scanner != nil { + if err := s.scanner.Close(ctx); err != nil { + s.log.With("error", err).Error("Failed to close scanner") + } + s.scanner = nil + s.currentPath = "" + } +} type pathProvider interface { - Next(context.Context, *sftp.Client) (string, error) + Next(context.Context) (string, bool, error) Ack(context.Context, string, error) error } @@ -337,13 +342,13 @@ type staticPathProvider struct { expandedPaths []string } -func (s *staticPathProvider) Next(ctx context.Context, client *sftp.Client) (string, error) { +func (s *staticPathProvider) Next(context.Context) (string, bool, error) { if len(s.expandedPaths) == 0 { - return "", errEndOfPaths + return "", false, nil } - nextPath := s.expandedPaths[0] + path := s.expandedPaths[0] s.expandedPaths = s.expandedPaths[1:] - return nextPath, nil + return path, true, nil } func (s *staticPathProvider) Ack(context.Context, string, error) error { @@ -351,6 +356,7 @@ func (s *staticPathProvider) Ack(context.Context, string, error) error { } type watcherPathProvider struct { + client *clientPool mgr *service.Resources cacheName string pollInterval time.Duration @@ -362,32 +368,41 @@ type watcherPathProvider struct { followUpPoll bool } -func (w *watcherPathProvider) Next(ctx context.Context, client *sftp.Client) (string, error) { - if len(w.expandedPaths) > 0 { - nextPath := w.expandedPaths[0] - w.expandedPaths = w.expandedPaths[1:] - return nextPath, nil - } +func (w *watcherPathProvider) Next(ctx context.Context) (string, bool, error) { + for { + if len(w.expandedPaths) > 0 { + nextPath := w.expandedPaths[0] + w.expandedPaths = w.expandedPaths[1:] + return nextPath, true, nil + } - if waitFor := time.Until(w.nextPoll); waitFor > 0 { - w.nextPoll = time.Now().Add(w.pollInterval) - select { - case <-time.After(waitFor): - case <-ctx.Done(): - return "", ctx.Err() + if waitFor := time.Until(w.nextPoll); w.nextPoll.IsZero() || waitFor > 0 { + w.nextPoll = time.Now().Add(w.pollInterval) + select { + case <-time.After(waitFor): + case <-ctx.Done(): + return "", false, ctx.Err() + } + } + + if err := w.findNewPaths(ctx); err != nil { + return "", false, fmt.Errorf("expanding new paths: %w", err) } + w.followUpPoll = true } +} +func (w *watcherPathProvider) findNewPaths(ctx context.Context) error { if cerr := w.mgr.AccessCache(ctx, w.cacheName, func(cache service.Cache) { for _, p := range w.targetPaths { - paths, err := client.Glob(p) + paths, err := w.client.Glob(p) if err != nil { w.mgr.Logger().With("error", err, "path", p).Warn("Failed to scan files from path") continue } for _, path := range paths { - info, err := client.Stat(path) + info, err := w.client.Stat(path) if err != nil { w.mgr.Logger().With("error", err, "path", path).Warn("Failed to stat path") continue @@ -414,10 +429,10 @@ func (w *watcherPathProvider) Next(ctx context.Context, client *sftp.Client) (st } } }); cerr != nil { - return "", fmt.Errorf("error obtaining cache: %v", cerr) + return fmt.Errorf("error obtaining cache: %v", cerr) } - w.followUpPoll = true - return w.Next(ctx, client) + + return nil } func (w *watcherPathProvider) Ack(ctx context.Context, name string, err error) (outErr error) { @@ -433,21 +448,22 @@ func (w *watcherPathProvider) Ack(ctx context.Context, name string, err error) ( return } -func (s *sftpReader) getFilePathProvider(_ context.Context) pathProvider { +func (s *sftpReader) getFilePathProvider(client *clientPool) pathProvider { if !s.watcherEnabled { - var filepaths []string - for _, p := range s.paths { - paths, err := s.client.Glob(p) + provider := &staticPathProvider{} + for _, path := range s.paths { + expandedPaths, err := client.Glob(path) if err != nil { - s.log.Warnf("Failed to scan files from path %v: %v", p, err) + s.log.Warnf("Failed to scan files from path %v: %v", path, err) continue } - filepaths = append(filepaths, paths...) + provider.expandedPaths = append(provider.expandedPaths, expandedPaths...) } - return &staticPathProvider{expandedPaths: filepaths} + return provider } return &watcherPathProvider{ + client: client, mgr: s.mgr, cacheName: s.watcherCache, pollInterval: s.watcherPollInterval, diff --git a/internal/impl/sftp/integration_test.go b/internal/impl/sftp/integration_test.go index be86760960..6b89bd2693 100644 --- a/internal/impl/sftp/integration_test.go +++ b/internal/impl/sftp/integration_test.go @@ -15,15 +15,22 @@ package sftp import ( + "context" + "errors" + "fmt" "io/fs" "os" + "strings" + "sync" "testing" "time" "github.com/ory/dockertest/v3" + "github.com/pkg/sftp" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/redpanda-data/benthos/v4/public/service" "github.com/redpanda-data/benthos/v4/public/service/integration" // Bring in memory cache. @@ -39,34 +46,7 @@ func TestIntegrationSFTP(t *testing.T) { integration.CheckSkip(t) t.Parallel() - pool, err := dockertest.NewPool("") - require.NoError(t, err) - - pool.MaxWait = time.Second * 30 - resource, err := pool.RunWithOptions(&dockertest.RunOptions{ - Repository: "atmoz/sftp", - Tag: "alpine", - Cmd: []string{ - // https://github.com/atmoz/sftp/issues/401 - "/bin/sh", "-c", "ulimit -n 65535 && exec /entrypoint " + sftpUsername + ":" + sftpPassword + ":1001:100:upload", - }, - }) - require.NoError(t, err) - t.Cleanup(func() { - assert.NoError(t, pool.Purge(resource)) - }) - - _ = resource.Expire(900) - - creds := credentials{ - Username: sftpUsername, - Password: sftpPassword, - } - - require.NoError(t, pool.Retry(func() error { - _, err = creds.GetClient(&osPT{}, "localhost:"+resource.GetPort("22/tcp")) - return err - })) + resource := setupDockerPool(t) t.Run("sftp", func(t *testing.T) { template := ` @@ -129,6 +109,133 @@ cache_resources: }) } +func TestIntegrationSFTPDeleteOnFinish(t *testing.T) { + integration.CheckSkip(t) + t.Parallel() + + resource := setupDockerPool(t) + + client, err := getClient(resource) + require.NoError(t, err) + + writeSFTPFile(t, client, "/upload/1.txt", "data-1") + writeSFTPFile(t, client, "/upload/2.txt", "data-2") + writeSFTPFile(t, client, "/upload/3.txt", "data-3") + + config := ` +output: + drop: {} + +input: + sftp: + address: localhost:$PORT + paths: + - /upload/*.txt + credentials: + username: foo + password: pass + delete_on_finish: true + watcher: + enabled: true + poll_interval: 100ms + cache: files_memory + +cache_resources: + - label: files_memory + memory: + default_ttl: 900s +` + config = strings.NewReplacer( + "$PORT", resource.GetPort("22/tcp"), + ).Replace(config) + + var receivedPathsMut sync.Mutex + var receivedPaths []string + + builder := service.NewStreamBuilder() + require.NoError(t, builder.SetYAML(config)) + require.NoError(t, builder.AddConsumerFunc(func(_ context.Context, msg *service.Message) error { + receivedPathsMut.Lock() + defer receivedPathsMut.Unlock() + path, ok := msg.MetaGet("sftp_path") + if !ok { + return errors.New("sftp_path metadata not found") + } + receivedPaths = append(receivedPaths, path) + return nil + })) + stream, err := builder.Build() + require.NoError(t, err) + + ctx, cancel := context.WithCancel(context.Background()) + runErr := make(chan error) + go func() { runErr <- stream.Run(ctx) }() + defer func() { + cancel() + err := <-runErr + if err != context.Canceled { + require.NoError(t, err, "stream.Run() failed") + } + }() + + require.EventuallyWithT(t, func(c *assert.CollectT) { + receivedPathsMut.Lock() + defer receivedPathsMut.Unlock() + assert.Len(c, receivedPaths, 3) + + files, err := client.Glob("/upload/*.txt") + assert.NoError(c, err) + assert.Empty(c, files) + }, time.Second*10, time.Millisecond*100) +} + +func setupDockerPool(t *testing.T) *dockertest.Resource { + t.Helper() + + pool, err := dockertest.NewPool("") + require.NoError(t, err) + + pool.MaxWait = time.Second * 30 + resource, err := pool.RunWithOptions(&dockertest.RunOptions{ + Repository: "atmoz/sftp", + Tag: "alpine", + Cmd: []string{ + // https://github.com/atmoz/sftp/issues/401 + "/bin/sh", "-c", "ulimit -n 65535 && exec /entrypoint " + sftpUsername + ":" + sftpPassword + ":1001:100:upload", + }, + }) + require.NoError(t, err) + t.Cleanup(func() { + assert.NoError(t, pool.Purge(resource)) + }) + + _ = resource.Expire(900) + + // wait for server to be ready to accept connections + require.NoError(t, pool.Retry(func() error { + _, err := getClient(resource) + return err + })) + + return resource +} +func getClient(resource *dockertest.Resource) (*sftp.Client, error) { + creds := credentials{ + Username: sftpUsername, + Password: sftpPassword, + } + return creds.GetClient(&osPT{}, "localhost:"+resource.GetPort("22/tcp")) +} + +func writeSFTPFile(t *testing.T, client *sftp.Client, path, data string) { + t.Helper() + file, err := client.Create(path) + require.NoError(t, err, "creating file") + defer file.Close() + _, err = fmt.Fprint(file, data, "writing file contents") + require.NoError(t, err) +} + type osPT struct{} func (o *osPT) Open(name string) (fs.File, error) {