Skip to content

Commit

Permalink
[bug] - add context timeout to ssh verification (#3161)
Browse files Browse the repository at this point in the history
* add context timeout to ssh verification

* fix test
  • Loading branch information
ahrav authored Aug 2, 2024
1 parent 2961322 commit c549b5b
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 10 deletions.
4 changes: 2 additions & 2 deletions pkg/detectors/privatekey/privatekey.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ func (s Scanner) FromData(ctx context.Context, verify bool, data []byte) (result
wg.Add(1)
go func() {
defer wg.Done()
user, err := verifyGitHubUser(parsedKey)
user, err := verifyGitHubUser(ctx, parsedKey)
if err != nil && !errors.Is(err, errPermissionDenied) {
verificationErrors.Add(err)
}
Expand All @@ -122,7 +122,7 @@ func (s Scanner) FromData(ctx context.Context, verify bool, data []byte) (result
wg.Add(1)
go func() {
defer wg.Done()
user, err := verifyGitLabUser(parsedKey)
user, err := verifyGitLabUser(ctx, parsedKey)
if err != nil && !errors.Is(err, errPermissionDenied) {
verificationErrors.Add(err)
}
Expand Down
30 changes: 24 additions & 6 deletions pkg/detectors/privatekey/ssh.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package privatekey

import (
"bytes"
"context"
"errors"
"fmt"
"net"
Expand All @@ -26,7 +27,7 @@ var gitlabFingerprints = map[string]string{
"SHA256:ROQFvPThGrW4RuWLoL9tq9I9zJ42fK4XywyRtbOz/EQ": "RSA",
}

func firstResponseFromSSH(parsedKey any, username, hostport string) (string, error) {
func firstResponseFromSSH(ctx context.Context, parsedKey any, username, hostport string) (string, error) {
signer, err := ssh.NewSignerFromKey(parsedKey)
if err != nil {
return "", err
Expand Down Expand Up @@ -58,7 +59,7 @@ func firstResponseFromSSH(parsedKey any, username, hostport string) (string, err
},
}

client, err := ssh.Dial("tcp", hostport, config)
client, err := sshDialWithContext(ctx, "tcp", hostport, config)
if err != nil {
if strings.Contains(err.Error(), "unable to authenticate") {
return "", errPermissionDenied
Expand All @@ -85,10 +86,27 @@ func firstResponseFromSSH(parsedKey any, username, hostport string) (string, err
return output.String(), err
}

func sshDialWithContext(ctx context.Context, network, addr string, config *ssh.ClientConfig) (*ssh.Client, error) {
d := net.Dialer{Timeout: config.Timeout}
conn, err := d.DialContext(ctx, network, addr)
if err != nil {
return nil, fmt.Errorf("error dialing %s: %w", addr, err)
}

ncc, chans, reqs, err := ssh.NewClientConn(conn, addr, config)
if err != nil {
conn.Close()
return nil, fmt.Errorf("error creating SSH connection to %s: %w", addr, err)
}

client := ssh.NewClient(ncc, chans, reqs)
return client, nil
}

var errPermissionDenied = errors.New("permission denied")

func verifyGitHubUser(parsedKey any) (*string, error) {
output, err := firstResponseFromSSH(parsedKey, "git", "github.com:22")
func verifyGitHubUser(ctx context.Context, parsedKey any) (*string, error) {
output, err := firstResponseFromSSH(ctx, parsedKey, "git", "github.com:22")
if err != nil {
return nil, err
}
Expand All @@ -105,8 +123,8 @@ func verifyGitHubUser(parsedKey any) (*string, error) {
return nil, nil
}

func verifyGitLabUser(parsedKey any) (*string, error) {
output, err := firstResponseFromSSH(parsedKey, "git", "gitlab.com:22")
func verifyGitLabUser(ctx context.Context, parsedKey any) (*string, error) {
output, err := firstResponseFromSSH(ctx, parsedKey, "git", "gitlab.com:22")
if err != nil {
return nil, err
}
Expand Down
5 changes: 3 additions & 2 deletions pkg/detectors/privatekey/ssh_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@ import (
"testing"
"time"

"github.com/trufflesecurity/trufflehog/v3/pkg/common"
"golang.org/x/crypto/ssh"

"github.com/trufflesecurity/trufflehog/v3/pkg/common"
)

func TestFirstResponseFromSSH(t *testing.T) {
Expand All @@ -23,7 +24,7 @@ func TestFirstResponseFromSSH(t *testing.T) {
t.Fatalf("could not parse test secret: %s", err)
}

output, err := firstResponseFromSSH(parsedKey, "git", "github.com:22")
output, err := firstResponseFromSSH(ctx, parsedKey, "git", "github.com:22")
if err != nil {
t.Fail()
}
Expand Down

0 comments on commit c549b5b

Please sign in to comment.