diff --git a/ssh/agent/client_test.go b/ssh/agent/client_test.go index c27eaa969f..fdc8000654 100644 --- a/ssh/agent/client_test.go +++ b/ssh/agent/client_test.go @@ -369,7 +369,8 @@ func TestAuth(t *testing.T) { go func() { conn, _, _, err := ssh.NewServerConn(a, &serverConf) if err != nil { - t.Fatalf("Server: %v", err) + t.Errorf("NewServerConn error: %v", err) + return } conn.Close() }() diff --git a/ssh/agent/server_test.go b/ssh/agent/server_test.go index 038018ebb1..0af85457e4 100644 --- a/ssh/agent/server_test.go +++ b/ssh/agent/server_test.go @@ -53,10 +53,11 @@ func TestSetupForwardAgent(t *testing.T) { incoming := make(chan *ssh.ServerConn, 1) go func() { conn, _, _, err := ssh.NewServerConn(a, &serverConf) + incoming <- conn if err != nil { - t.Fatalf("Server: %v", err) + t.Errorf("NewServerConn error: %v", err) + return } - incoming <- conn }() conf := ssh.ClientConfig{ @@ -71,8 +72,10 @@ func TestSetupForwardAgent(t *testing.T) { if err := ForwardToRemote(client, socket); err != nil { t.Fatalf("SetupForwardAgent: %v", err) } - server := <-incoming + if server == nil { + t.Fatal("Unable to get server") + } ch, reqs, err := server.OpenChannel(channelType, nil) if err != nil { t.Fatalf("OpenChannel(%q): %v", channelType, err) diff --git a/ssh/benchmark_test.go b/ssh/benchmark_test.go index a13235d743..b356330b46 100644 --- a/ssh/benchmark_test.go +++ b/ssh/benchmark_test.go @@ -6,6 +6,7 @@ package ssh import ( "errors" + "fmt" "io" "net" "testing" @@ -90,16 +91,16 @@ func BenchmarkEndToEnd(b *testing.B) { go func() { newCh, err := server.Accept() if err != nil { - b.Fatalf("Client: %v", err) + panic(fmt.Sprintf("Client: %v", err)) } ch, incoming, err := newCh.Accept() if err != nil { - b.Fatalf("Accept: %v", err) + panic(fmt.Sprintf("Accept: %v", err)) } go DiscardRequests(incoming) for i := 0; i < b.N; i++ { if _, err := io.ReadFull(ch, output); err != nil { - b.Fatalf("ReadFull: %v", err) + panic(fmt.Sprintf("ReadFull: %v", err)) } } ch.Close() diff --git a/ssh/common_test.go b/ssh/common_test.go index 96744dcf0f..a7beee8e88 100644 --- a/ssh/common_test.go +++ b/ssh/common_test.go @@ -82,11 +82,11 @@ func TestFindAgreedAlgorithms(t *testing.T) { } cases := []testcase{ - testcase{ + { name: "standard", }, - testcase{ + { name: "no common hostkey", serverIn: kexInitMsg{ ServerHostKeyAlgos: []string{"hostkey2"}, @@ -94,7 +94,7 @@ func TestFindAgreedAlgorithms(t *testing.T) { wantErr: true, }, - testcase{ + { name: "no common kex", serverIn: kexInitMsg{ KexAlgos: []string{"kex2"}, @@ -102,7 +102,7 @@ func TestFindAgreedAlgorithms(t *testing.T) { wantErr: true, }, - testcase{ + { name: "no common cipher", serverIn: kexInitMsg{ CiphersClientServer: []string{"cipher2"}, @@ -110,7 +110,7 @@ func TestFindAgreedAlgorithms(t *testing.T) { wantErr: true, }, - testcase{ + { name: "client decides cipher", serverIn: kexInitMsg{ CiphersClientServer: []string{"cipher1", "cipher2"}, diff --git a/ssh/handshake_test.go b/ssh/handshake_test.go index f190cbfa91..879143a6b5 100644 --- a/ssh/handshake_test.go +++ b/ssh/handshake_test.go @@ -148,6 +148,7 @@ func TestHandshakeBasic(t *testing.T) { clientDone := make(chan int, 0) gotHalf := make(chan int, 0) const N = 20 + errorCh := make(chan error, 1) go func() { defer close(clientDone) @@ -158,7 +159,9 @@ func TestHandshakeBasic(t *testing.T) { for i := 0; i < N; i++ { p := []byte{msgRequestSuccess, byte(i)} if err := trC.writePacket(p); err != nil { - t.Fatalf("sendPacket: %v", err) + errorCh <- err + trC.Close() + return } if (i % 10) == 5 { <-gotHalf @@ -177,16 +180,15 @@ func TestHandshakeBasic(t *testing.T) { checker.waitCall <- 1 } } + errorCh <- nil }() // Server checks that client messages come in cleanly i := 0 - err = nil for ; i < N; i++ { - var p []byte - p, err = trS.readPacket() - if err != nil { - break + p, err := trS.readPacket() + if err != nil && err != io.EOF { + t.Fatalf("server error: %v", err) } if (i % 10) == 5 { gotHalf <- 1 @@ -198,8 +200,8 @@ func TestHandshakeBasic(t *testing.T) { } } <-clientDone - if err != nil && err != io.EOF { - t.Fatalf("server error: %v", err) + if err := <-errorCh; err != nil { + t.Fatalf("sendPacket: %v", err) } if i != N { t.Errorf("received %d messages, want 10.", i) @@ -345,16 +347,16 @@ func TestHandshakeAutoRekeyRead(t *testing.T) { // While we read out the packet, a key change will be // initiated. - done := make(chan int, 1) + errorCh := make(chan error, 1) go func() { - defer close(done) - if _, err := trC.readPacket(); err != nil { - t.Fatalf("readPacket(client): %v", err) - } - + _, err := trC.readPacket() + errorCh <- err }() - <-done + if err := <-errorCh; err != nil { + t.Fatalf("readPacket(client): %v", err) + } + <-sync.called } diff --git a/ssh/mux_test.go b/ssh/mux_test.go index 393017c08c..1db3be54a0 100644 --- a/ssh/mux_test.go +++ b/ssh/mux_test.go @@ -5,6 +5,8 @@ package ssh import ( + "errors" + "fmt" "io" "sync" "testing" @@ -29,14 +31,21 @@ func channelPair(t *testing.T) (*channel, *channel, *mux) { go func() { newCh, ok := <-s.incomingChannels if !ok { - t.Fatalf("No incoming channel") + t.Error("no incoming channel") + close(res) + return } if newCh.ChannelType() != "chan" { - t.Fatalf("got type %q want chan", newCh.ChannelType()) + t.Errorf("got type %q want chan", newCh.ChannelType()) + newCh.Reject(Prohibited, fmt.Sprintf("got type %q want chan", newCh.ChannelType())) + close(res) + return } ch, _, err := newCh.Accept() if err != nil { - t.Fatalf("Accept %v", err) + t.Errorf("accept: %v", err) + close(res) + return } res <- ch.(*channel) }() @@ -45,8 +54,12 @@ func channelPair(t *testing.T) (*channel, *channel, *mux) { if err != nil { t.Fatalf("OpenChannel: %v", err) } + w := <-res + if w == nil { + t.Fatal("unable to get write channel") + } - return <-res, ch, c + return w, ch, c } // Test that stderr and stdout can be addressed from different @@ -74,14 +87,14 @@ func TestMuxChannelExtendedThreadSafety(t *testing.T) { go func() { c, err := io.ReadAll(reader) if string(c) != magic { - t.Fatalf("stdout read got %q, want %q (error %s)", c, magic, err) + t.Errorf("stdout read got %q, want %q (error %s)", c, magic, err) } rd.Done() }() go func() { c, err := io.ReadAll(reader.Stderr()) if string(c) != magic { - t.Fatalf("stderr read got %q, want %q (error %s)", c, magic, err) + t.Errorf("stderr read got %q, want %q (error %s)", c, magic, err) } rd.Done() }() @@ -102,11 +115,13 @@ func TestMuxReadWrite(t *testing.T) { go func() { _, err := s.Write([]byte(magic)) if err != nil { - t.Fatalf("Write: %v", err) + t.Errorf("Write: %v", err) + return } _, err = s.Extended(1).Write([]byte(magicExt)) if err != nil { - t.Fatalf("Write: %v", err) + t.Errorf("Write: %v", err) + return } }() @@ -215,10 +230,13 @@ func TestMuxReject(t *testing.T) { go func() { ch, ok := <-server.incomingChannels if !ok { - t.Fatalf("Accept") + t.Error("cannot accept channel") + return } if ch.ChannelType() != "ch" || string(ch.ExtraData()) != "extra" { - t.Fatalf("unexpected channel: %q, %q", ch.ChannelType(), ch.ExtraData()) + t.Errorf("unexpected channel: %q, %q", ch.ChannelType(), ch.ExtraData()) + ch.Reject(RejectionReason(UnknownChannelType), UnknownChannelType.String()) + return } ch.Reject(RejectionReason(42), "message") }() @@ -294,7 +312,7 @@ func TestMuxUnknownChannelRequests(t *testing.T) { defer serverPipe.Close() defer client.Close() - kDone := make(chan struct{}) + kDone := make(chan error, 1) go func() { // Ignore unknown channel messages that don't want a reply. err := serverPipe.writePacket(Marshal(channelRequestMsg{ @@ -304,7 +322,8 @@ func TestMuxUnknownChannelRequests(t *testing.T) { RequestSpecificData: []byte{}, })) if err != nil { - t.Fatalf("send: %v", err) + kDone <- fmt.Errorf("send: %w", err) + return } // Send a keepalive, which should get a channel failure message @@ -316,44 +335,53 @@ func TestMuxUnknownChannelRequests(t *testing.T) { RequestSpecificData: []byte{}, })) if err != nil { - t.Fatalf("send: %v", err) + kDone <- fmt.Errorf("send: %w", err) + return } packet, err := serverPipe.readPacket() if err != nil { - t.Fatalf("read packet: %v", err) + kDone <- fmt.Errorf("read packet: %w", err) + return } decoded, err := decode(packet) if err != nil { - t.Fatalf("decode failed: %v", err) + kDone <- fmt.Errorf("decode failed: %w", err) + return } switch msg := decoded.(type) { case *channelRequestFailureMsg: if msg.PeersID != 2 { - t.Fatalf("received response to wrong message: %v", msg) + kDone <- fmt.Errorf("received response to wrong message: %v", msg) + return + } default: - t.Fatalf("unexpected channel message: %v", msg) + kDone <- fmt.Errorf("unexpected channel message: %v", msg) + return } - kDone <- struct{}{} + kDone <- nil // Receive and respond to the keepalive to confirm the mux is // still processing requests. packet, err = serverPipe.readPacket() if err != nil { - t.Fatalf("read packet: %v", err) + kDone <- fmt.Errorf("read packet: %w", err) + return } if packet[0] != msgGlobalRequest { - t.Fatalf("expected global request") + kDone <- errors.New("expected global request") + return } err = serverPipe.writePacket(Marshal(globalRequestFailureMsg{ Data: []byte{}, })) if err != nil { - t.Fatalf("failed to send failure msg: %v", err) + kDone <- fmt.Errorf("failed to send failure msg: %w", err) + return } close(kDone) @@ -362,7 +390,10 @@ func TestMuxUnknownChannelRequests(t *testing.T) { // Wait for the server to send the keepalive message and receive back a // response. select { - case <-kDone: + case err := <-kDone: + if err != nil { + t.Fatal(err) + } case <-time.After(10 * time.Second): t.Fatalf("server never received ack") } @@ -373,7 +404,10 @@ func TestMuxUnknownChannelRequests(t *testing.T) { } select { - case <-kDone: + case err := <-kDone: + if err != nil { + t.Fatal(err) + } case <-time.After(10 * time.Second): t.Fatalf("server never shut down") } @@ -385,20 +419,23 @@ func TestMuxClosedChannel(t *testing.T) { defer serverPipe.Close() defer client.Close() - kDone := make(chan struct{}) + kDone := make(chan error, 1) go func() { // Open the channel. packet, err := serverPipe.readPacket() if err != nil { - t.Fatalf("read packet: %v", err) + kDone <- fmt.Errorf("read packet: %w", err) + return } if packet[0] != msgChannelOpen { - t.Fatalf("expected chan open") + kDone <- errors.New("expected chan open") + return } var openMsg channelOpenMsg if err := Unmarshal(packet, &openMsg); err != nil { - t.Fatalf("unmarshal: %v", err) + kDone <- fmt.Errorf("unmarshal: %w", err) + return } // Send back the opened channel confirmation. @@ -409,7 +446,8 @@ func TestMuxClosedChannel(t *testing.T) { MaxPacketSize: channelMaxPacket, })) if err != nil { - t.Fatalf("send: %v", err) + kDone <- fmt.Errorf("send: %w", err) + return } // Close the channel. @@ -417,7 +455,8 @@ func TestMuxClosedChannel(t *testing.T) { PeersID: openMsg.PeersID, })) if err != nil { - t.Fatalf("send: %v", err) + kDone <- fmt.Errorf("send: %w", err) + return } // Send a keepalive message on the channel we just closed. @@ -428,43 +467,51 @@ func TestMuxClosedChannel(t *testing.T) { RequestSpecificData: []byte{}, })) if err != nil { - t.Fatalf("send: %v", err) + kDone <- fmt.Errorf("send: %w", err) + return } // Receive the channel closed response. packet, err = serverPipe.readPacket() if err != nil { - t.Fatalf("read packet: %v", err) + kDone <- fmt.Errorf("read packet: %w", err) + return } if packet[0] != msgChannelClose { - t.Fatalf("expected channel close") + kDone <- errors.New("expected channel close") + return } // Receive the keepalive response failure. packet, err = serverPipe.readPacket() if err != nil { - t.Fatalf("read packet: %v", err) + kDone <- fmt.Errorf("read packet: %w", err) + return } if packet[0] != msgChannelFailure { - t.Fatalf("expected channel close") + kDone <- errors.New("expected channel failure") + return } - kDone <- struct{}{} + kDone <- nil // Receive and respond to the keepalive to confirm the mux is // still processing requests. packet, err = serverPipe.readPacket() if err != nil { - t.Fatalf("read packet: %v", err) + kDone <- fmt.Errorf("read packet: %w", err) + return } if packet[0] != msgGlobalRequest { - t.Fatalf("expected global request") + kDone <- errors.New("expected global request") + return } err = serverPipe.writePacket(Marshal(globalRequestFailureMsg{ Data: []byte{}, })) if err != nil { - t.Fatalf("failed to send failure msg: %v", err) + kDone <- fmt.Errorf("failed to send failure msg: %w", err) + return } close(kDone) diff --git a/ssh/session_test.go b/ssh/session_test.go index c4b9f0ea5b..521677f9b1 100644 --- a/ssh/session_test.go +++ b/ssh/session_test.go @@ -36,7 +36,8 @@ func dial(handler serverType, t *testing.T) *Client { conn, chans, reqs, err := NewServerConn(c1, &conf) if err != nil { - t.Fatalf("Unable to handshake: %v", err) + t.Errorf("Unable to handshake: %v", err) + return } go DiscardRequests(reqs) @@ -647,10 +648,12 @@ func TestSessionID(t *testing.T) { User: "user", } + srvErrCh := make(chan error, 1) go func() { conn, chans, reqs, err := NewServerConn(c1, serverConf) + srvErrCh <- err if err != nil { - t.Fatalf("server handshake: %v", err) + return } serverID <- conn.SessionID() go DiscardRequests(reqs) @@ -659,10 +662,12 @@ func TestSessionID(t *testing.T) { } }() + cliErrCh := make(chan error, 1) go func() { conn, chans, reqs, err := NewClientConn(c2, "", clientConf) + cliErrCh <- err if err != nil { - t.Fatalf("client handshake: %v", err) + return } clientID <- conn.SessionID() go DiscardRequests(reqs) @@ -671,6 +676,14 @@ func TestSessionID(t *testing.T) { } }() + if err := <-srvErrCh; err != nil { + t.Fatalf("server handshake: %v", err) + } + + if err := <-cliErrCh; err != nil { + t.Fatalf("client handshake: %v", err) + } + s := <-serverID c := <-clientID if bytes.Compare(s, c) != 0 { diff --git a/ssh/test/multi_auth_test.go b/ssh/test/multi_auth_test.go index 6c253a7547..403d7363ab 100644 --- a/ssh/test/multi_auth_test.go +++ b/ssh/test/multi_auth_test.go @@ -77,27 +77,27 @@ func (ctx *multiAuthTestCtx) kbdIntCb(user, instruction string, questions []stri func TestMultiAuth(t *testing.T) { testCases := []multiAuthTestCase{ // Test password,publickey authentication, assert that password callback is called 1 time - multiAuthTestCase{ + { authMethods: []string{"password", "publickey"}, expectedPasswordCbs: 1, }, // Test keyboard-interactive,publickey authentication, assert that keyboard-interactive callback is called 1 time - multiAuthTestCase{ + { authMethods: []string{"keyboard-interactive", "publickey"}, expectedKbdIntCbs: 1, }, // Test publickey,password authentication, assert that password callback is called 1 time - multiAuthTestCase{ + { authMethods: []string{"publickey", "password"}, expectedPasswordCbs: 1, }, // Test publickey,keyboard-interactive authentication, assert that keyboard-interactive callback is called 1 time - multiAuthTestCase{ + { authMethods: []string{"publickey", "keyboard-interactive"}, expectedKbdIntCbs: 1, }, // Test password,password authentication, assert that password callback is called 2 times - multiAuthTestCase{ + { authMethods: []string{"password", "password"}, expectedPasswordCbs: 2, },