Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cancelable ReadDirContext #565

Merged
merged 2 commits into from
Nov 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 38 additions & 29 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package sftp

import (
"bytes"
"context"
"encoding/binary"
"errors"
"fmt"
Expand Down Expand Up @@ -321,19 +322,27 @@ func (c *Client) Walk(root string) *fs.Walker {
return fs.WalkFS(root, c)
}

// ReadDir reads the directory named by dirname and returns a list of
// directory entries.
// ReadDir reads the directory named by p
// and returns a list of directory entries.
func (c *Client) ReadDir(p string) ([]os.FileInfo, error) {
handle, err := c.opendir(p)
return c.ReadDirContext(context.Background(), p)
}

// ReadDirContext reads the directory named by p
// and returns a list of directory entries.
// The passed context can be used to cancel the operation
// returning all entries listed up to the cancellation.
func (c *Client) ReadDirContext(ctx context.Context, p string) ([]os.FileInfo, error) {
handle, err := c.opendir(ctx, p)
if err != nil {
return nil, err
}
defer c.close(handle) // this has to defer earlier than the lock below
var attrs []os.FileInfo
var entries []os.FileInfo
var done = false
for !done {
id := c.nextID()
typ, data, err1 := c.sendPacket(nil, &sshFxpReaddirPacket{
typ, data, err1 := c.sendPacket(ctx, nil, &sshFxpReaddirPacket{
ID: id,
Handle: handle,
})
Expand All @@ -358,7 +367,7 @@ func (c *Client) ReadDir(p string) ([]os.FileInfo, error) {
if filename == "." || filename == ".." {
continue
}
attrs = append(attrs, fileInfoFromStat(attr, path.Base(filename)))
entries = append(entries, fileInfoFromStat(attr, path.Base(filename)))
}
case sshFxpStatus:
// TODO(dfc) scope warning!
Expand All @@ -371,12 +380,12 @@ func (c *Client) ReadDir(p string) ([]os.FileInfo, error) {
if err == io.EOF {
err = nil
}
return attrs, err
return entries, err
}

func (c *Client) opendir(path string) (string, error) {
func (c *Client) opendir(ctx context.Context, path string) (string, error) {
id := c.nextID()
typ, data, err := c.sendPacket(nil, &sshFxpOpendirPacket{
typ, data, err := c.sendPacket(ctx, nil, &sshFxpOpendirPacket{
ID: id,
Path: path,
})
Expand Down Expand Up @@ -412,7 +421,7 @@ func (c *Client) Stat(p string) (os.FileInfo, error) {
// If 'p' is a symbolic link, the returned FileInfo structure describes the symbolic link.
func (c *Client) Lstat(p string) (os.FileInfo, error) {
id := c.nextID()
typ, data, err := c.sendPacket(nil, &sshFxpLstatPacket{
typ, data, err := c.sendPacket(context.Background(), nil, &sshFxpLstatPacket{
ID: id,
Path: p,
})
Expand All @@ -437,7 +446,7 @@ func (c *Client) Lstat(p string) (os.FileInfo, error) {
// ReadLink reads the target of a symbolic link.
func (c *Client) ReadLink(p string) (string, error) {
id := c.nextID()
typ, data, err := c.sendPacket(nil, &sshFxpReadlinkPacket{
typ, data, err := c.sendPacket(context.Background(), nil, &sshFxpReadlinkPacket{
ID: id,
Path: p,
})
Expand Down Expand Up @@ -466,7 +475,7 @@ func (c *Client) ReadLink(p string) (string, error) {
// Link creates a hard link at 'newname', pointing at the same inode as 'oldname'
func (c *Client) Link(oldname, newname string) error {
id := c.nextID()
typ, data, err := c.sendPacket(nil, &sshFxpHardlinkPacket{
typ, data, err := c.sendPacket(context.Background(), nil, &sshFxpHardlinkPacket{
ID: id,
Oldpath: oldname,
Newpath: newname,
Expand All @@ -485,7 +494,7 @@ func (c *Client) Link(oldname, newname string) error {
// Symlink creates a symbolic link at 'newname', pointing at target 'oldname'
func (c *Client) Symlink(oldname, newname string) error {
id := c.nextID()
typ, data, err := c.sendPacket(nil, &sshFxpSymlinkPacket{
typ, data, err := c.sendPacket(context.Background(), nil, &sshFxpSymlinkPacket{
ID: id,
Linkpath: newname,
Targetpath: oldname,
Expand All @@ -503,7 +512,7 @@ func (c *Client) Symlink(oldname, newname string) error {

func (c *Client) setfstat(handle string, flags uint32, attrs interface{}) error {
id := c.nextID()
typ, data, err := c.sendPacket(nil, &sshFxpFsetstatPacket{
typ, data, err := c.sendPacket(context.Background(), nil, &sshFxpFsetstatPacket{
ID: id,
Handle: handle,
Flags: flags,
Expand All @@ -523,7 +532,7 @@ func (c *Client) setfstat(handle string, flags uint32, attrs interface{}) error
// setstat is a convience wrapper to allow for changing of various parts of the file descriptor.
func (c *Client) setstat(path string, flags uint32, attrs interface{}) error {
id := c.nextID()
typ, data, err := c.sendPacket(nil, &sshFxpSetstatPacket{
typ, data, err := c.sendPacket(context.Background(), nil, &sshFxpSetstatPacket{
ID: id,
Path: path,
Flags: flags,
Expand Down Expand Up @@ -593,7 +602,7 @@ func (c *Client) OpenFile(path string, f int) (*File, error) {

func (c *Client) open(path string, pflags uint32) (*File, error) {
id := c.nextID()
typ, data, err := c.sendPacket(nil, &sshFxpOpenPacket{
typ, data, err := c.sendPacket(context.Background(), nil, &sshFxpOpenPacket{
ID: id,
Path: path,
Pflags: pflags,
Expand Down Expand Up @@ -621,7 +630,7 @@ func (c *Client) open(path string, pflags uint32) (*File, error) {
// immediately after this request has been sent.
func (c *Client) close(handle string) error {
id := c.nextID()
typ, data, err := c.sendPacket(nil, &sshFxpClosePacket{
typ, data, err := c.sendPacket(context.Background(), nil, &sshFxpClosePacket{
ID: id,
Handle: handle,
})
Expand All @@ -638,7 +647,7 @@ func (c *Client) close(handle string) error {

func (c *Client) stat(path string) (*FileStat, error) {
id := c.nextID()
typ, data, err := c.sendPacket(nil, &sshFxpStatPacket{
typ, data, err := c.sendPacket(context.Background(), nil, &sshFxpStatPacket{
ID: id,
Path: path,
})
Expand All @@ -662,7 +671,7 @@ func (c *Client) stat(path string) (*FileStat, error) {

func (c *Client) fstat(handle string) (*FileStat, error) {
id := c.nextID()
typ, data, err := c.sendPacket(nil, &sshFxpFstatPacket{
typ, data, err := c.sendPacket(context.Background(), nil, &sshFxpFstatPacket{
ID: id,
Handle: handle,
})
Expand Down Expand Up @@ -691,7 +700,7 @@ func (c *Client) fstat(handle string) (*FileStat, error) {
func (c *Client) StatVFS(path string) (*StatVFS, error) {
// send the StatVFS packet to the server
id := c.nextID()
typ, data, err := c.sendPacket(nil, &sshFxpStatvfsPacket{
typ, data, err := c.sendPacket(context.Background(), nil, &sshFxpStatvfsPacket{
ID: id,
Path: path,
})
Expand Down Expand Up @@ -746,7 +755,7 @@ func (c *Client) Remove(path string) error {

func (c *Client) removeFile(path string) error {
id := c.nextID()
typ, data, err := c.sendPacket(nil, &sshFxpRemovePacket{
typ, data, err := c.sendPacket(context.Background(), nil, &sshFxpRemovePacket{
ID: id,
Filename: path,
})
Expand All @@ -764,7 +773,7 @@ func (c *Client) removeFile(path string) error {
// RemoveDirectory removes a directory path.
func (c *Client) RemoveDirectory(path string) error {
id := c.nextID()
typ, data, err := c.sendPacket(nil, &sshFxpRmdirPacket{
typ, data, err := c.sendPacket(context.Background(), nil, &sshFxpRmdirPacket{
ID: id,
Path: path,
})
Expand All @@ -782,7 +791,7 @@ func (c *Client) RemoveDirectory(path string) error {
// Rename renames a file.
func (c *Client) Rename(oldname, newname string) error {
id := c.nextID()
typ, data, err := c.sendPacket(nil, &sshFxpRenamePacket{
typ, data, err := c.sendPacket(context.Background(), nil, &sshFxpRenamePacket{
ID: id,
Oldpath: oldname,
Newpath: newname,
Expand All @@ -802,7 +811,7 @@ func (c *Client) Rename(oldname, newname string) error {
// which will replace newname if it already exists.
func (c *Client) PosixRename(oldname, newname string) error {
id := c.nextID()
typ, data, err := c.sendPacket(nil, &sshFxpPosixRenamePacket{
typ, data, err := c.sendPacket(context.Background(), nil, &sshFxpPosixRenamePacket{
ID: id,
Oldpath: oldname,
Newpath: newname,
Expand All @@ -824,7 +833,7 @@ func (c *Client) PosixRename(oldname, newname string) error {
// or relative pathnames without a leading slash into absolute paths.
func (c *Client) RealPath(path string) (string, error) {
id := c.nextID()
typ, data, err := c.sendPacket(nil, &sshFxpRealpathPacket{
typ, data, err := c.sendPacket(context.Background(), nil, &sshFxpRealpathPacket{
ID: id,
Path: path,
})
Expand Down Expand Up @@ -861,7 +870,7 @@ func (c *Client) Getwd() (string, error) {
// parent folder does not exist (the method cannot create complete paths).
func (c *Client) Mkdir(path string) error {
id := c.nextID()
typ, data, err := c.sendPacket(nil, &sshFxpMkdirPacket{
typ, data, err := c.sendPacket(context.Background(), nil, &sshFxpMkdirPacket{
ID: id,
Path: path,
})
Expand Down Expand Up @@ -1007,7 +1016,7 @@ func (f *File) Read(b []byte) (int, error) {
func (f *File) readChunkAt(ch chan result, b []byte, off int64) (n int, err error) {
for err == nil && n < len(b) {
id := f.c.nextID()
typ, data, err := f.c.sendPacket(ch, &sshFxpReadPacket{
typ, data, err := f.c.sendPacket(context.Background(), ch, &sshFxpReadPacket{
ID: id,
Handle: f.handle,
Offset: uint64(off) + uint64(n),
Expand Down Expand Up @@ -1475,7 +1484,7 @@ func (f *File) Write(b []byte) (int, error) {
}

func (f *File) writeChunkAt(ch chan result, b []byte, off int64) (int, error) {
typ, data, err := f.c.sendPacket(ch, &sshFxpWritePacket{
typ, data, err := f.c.sendPacket(context.Background(), ch, &sshFxpWritePacket{
ID: f.c.nextID(),
Handle: f.handle,
Offset: uint64(off),
Expand Down Expand Up @@ -1933,7 +1942,7 @@ func (f *File) Chmod(mode os.FileMode) error {
// Sync requires the server to support the [email protected] extension.
func (f *File) Sync() error {
id := f.c.nextID()
typ, data, err := f.c.sendPacket(nil, &sshFxpFsyncPacket{
typ, data, err := f.c.sendPacket(context.Background(), nil, &sshFxpFsyncPacket{
ID: id,
Handle: f.handle,
})
Expand Down
12 changes: 9 additions & 3 deletions conn.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package sftp

import (
"context"
"encoding"
"fmt"
"io"
Expand Down Expand Up @@ -128,14 +129,19 @@ type idmarshaler interface {
encoding.BinaryMarshaler
}

func (c *clientConn) sendPacket(ch chan result, p idmarshaler) (byte, []byte, error) {
func (c *clientConn) sendPacket(ctx context.Context, ch chan result, p idmarshaler) (byte, []byte, error) {
if cap(ch) < 1 {
ch = make(chan result, 1)
}

c.dispatchRequest(ch, p)
s := <-ch
return s.typ, s.data, s.err

select {
case <-ctx.Done():
return 0, nil, ctx.Err()
case s := <-ch:
return s.typ, s.data, s.err
}
}

// dispatchRequest should ideally only be called by race-detection tests outside of this file,
Expand Down
3 changes: 2 additions & 1 deletion server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package sftp

import (
"bytes"
"context"
"errors"
"io"
"os"
Expand Down Expand Up @@ -76,7 +77,7 @@ func TestInvalidExtendedPacket(t *testing.T) {
defer server.Close()

badPacket := sshFxpTestBadExtendedPacket{client.nextID(), "thisDoesn'tExist", "foobar"}
typ, data, err := client.clientConn.sendPacket(nil, badPacket)
typ, data, err := client.clientConn.sendPacket(context.Background(), nil, badPacket)
if err != nil {
t.Fatalf("unexpected error from sendPacket: %s", err)
}
Expand Down