From 9d84158642223749a5c2458af989b7b34db74cc9 Mon Sep 17 00:00:00 2001 From: Nicola Murino Date: Sat, 24 Jun 2023 19:48:46 +0200 Subject: [PATCH] ssh: fix call to Fatalf from a non-test goroutine also fix some redundant type declarations --- ssh/agent/client_test.go | 7 +- ssh/agent/server_test.go | 8 +- ssh/benchmark_test.go | 7 +- ssh/common_test.go | 10 +-- ssh/handshake_test.go | 32 ++++---- ssh/mux_test.go | 148 +++++++++++++++++++++++++----------- ssh/session_test.go | 23 +++++- ssh/test/multi_auth_test.go | 10 +-- 8 files changed, 165 insertions(+), 80 deletions(-) diff --git a/ssh/agent/client_test.go b/ssh/agent/client_test.go index 8ffaca7491..0b2f042fb4 100644 --- a/ssh/agent/client_test.go +++ b/ssh/agent/client_test.go @@ -363,10 +363,12 @@ func TestAuth(t *testing.T) { return nil, errors.New("pubkey rejected") } + errorCh := make(chan error, 1) go func() { conn, _, _, err := ssh.NewServerConn(a, &serverConf) + errorCh <- err if err != nil { - t.Fatalf("Server: %v", err) + return } conn.Close() }() @@ -380,6 +382,9 @@ func TestAuth(t *testing.T) { t.Fatalf("NewClientConn: %v", err) } conn.Close() + if err := <-errorCh; err != nil { + t.Fatalf("Server: %v", err) + } } func TestLockOpenSSHAgent(t *testing.T) { diff --git a/ssh/agent/server_test.go b/ssh/agent/server_test.go index 038018ebb1..e5846d1a98 100644 --- a/ssh/agent/server_test.go +++ b/ssh/agent/server_test.go @@ -51,10 +51,12 @@ func TestSetupForwardAgent(t *testing.T) { } serverConf.AddHostKey(testSigners["rsa"]) incoming := make(chan *ssh.ServerConn, 1) + errorCh := make(chan error, 1) go func() { conn, _, _, err := ssh.NewServerConn(a, &serverConf) + errorCh <- err if err != nil { - t.Fatalf("Server: %v", err) + return } incoming <- conn }() @@ -71,7 +73,9 @@ func TestSetupForwardAgent(t *testing.T) { if err := ForwardToRemote(client, socket); err != nil { t.Fatalf("SetupForwardAgent: %v", err) } - + if err := <-errorCh; err != nil { + t.Fatalf("Server: %v", err) + } server := <-incoming ch, reqs, err := server.OpenChannel(channelType, nil) if err != nil { diff --git a/ssh/benchmark_test.go b/ssh/benchmark_test.go index a13235d743..b356330b46 100644 --- a/ssh/benchmark_test.go +++ b/ssh/benchmark_test.go @@ -6,6 +6,7 @@ package ssh import ( "errors" + "fmt" "io" "net" "testing" @@ -90,16 +91,16 @@ func BenchmarkEndToEnd(b *testing.B) { go func() { newCh, err := server.Accept() if err != nil { - b.Fatalf("Client: %v", err) + panic(fmt.Sprintf("Client: %v", err)) } ch, incoming, err := newCh.Accept() if err != nil { - b.Fatalf("Accept: %v", err) + panic(fmt.Sprintf("Accept: %v", err)) } go DiscardRequests(incoming) for i := 0; i < b.N; i++ { if _, err := io.ReadFull(ch, output); err != nil { - b.Fatalf("ReadFull: %v", err) + panic(fmt.Sprintf("ReadFull: %v", err)) } } ch.Close() diff --git a/ssh/common_test.go b/ssh/common_test.go index 96744dcf0f..a7beee8e88 100644 --- a/ssh/common_test.go +++ b/ssh/common_test.go @@ -82,11 +82,11 @@ func TestFindAgreedAlgorithms(t *testing.T) { } cases := []testcase{ - testcase{ + { name: "standard", }, - testcase{ + { name: "no common hostkey", serverIn: kexInitMsg{ ServerHostKeyAlgos: []string{"hostkey2"}, @@ -94,7 +94,7 @@ func TestFindAgreedAlgorithms(t *testing.T) { wantErr: true, }, - testcase{ + { name: "no common kex", serverIn: kexInitMsg{ KexAlgos: []string{"kex2"}, @@ -102,7 +102,7 @@ func TestFindAgreedAlgorithms(t *testing.T) { wantErr: true, }, - testcase{ + { name: "no common cipher", serverIn: kexInitMsg{ CiphersClientServer: []string{"cipher2"}, @@ -110,7 +110,7 @@ func TestFindAgreedAlgorithms(t *testing.T) { wantErr: true, }, - testcase{ + { name: "client decides cipher", serverIn: kexInitMsg{ CiphersClientServer: []string{"cipher1", "cipher2"}, diff --git a/ssh/handshake_test.go b/ssh/handshake_test.go index f190cbfa91..879143a6b5 100644 --- a/ssh/handshake_test.go +++ b/ssh/handshake_test.go @@ -148,6 +148,7 @@ func TestHandshakeBasic(t *testing.T) { clientDone := make(chan int, 0) gotHalf := make(chan int, 0) const N = 20 + errorCh := make(chan error, 1) go func() { defer close(clientDone) @@ -158,7 +159,9 @@ func TestHandshakeBasic(t *testing.T) { for i := 0; i < N; i++ { p := []byte{msgRequestSuccess, byte(i)} if err := trC.writePacket(p); err != nil { - t.Fatalf("sendPacket: %v", err) + errorCh <- err + trC.Close() + return } if (i % 10) == 5 { <-gotHalf @@ -177,16 +180,15 @@ func TestHandshakeBasic(t *testing.T) { checker.waitCall <- 1 } } + errorCh <- nil }() // Server checks that client messages come in cleanly i := 0 - err = nil for ; i < N; i++ { - var p []byte - p, err = trS.readPacket() - if err != nil { - break + p, err := trS.readPacket() + if err != nil && err != io.EOF { + t.Fatalf("server error: %v", err) } if (i % 10) == 5 { gotHalf <- 1 @@ -198,8 +200,8 @@ func TestHandshakeBasic(t *testing.T) { } } <-clientDone - if err != nil && err != io.EOF { - t.Fatalf("server error: %v", err) + if err := <-errorCh; err != nil { + t.Fatalf("sendPacket: %v", err) } if i != N { t.Errorf("received %d messages, want 10.", i) @@ -345,16 +347,16 @@ func TestHandshakeAutoRekeyRead(t *testing.T) { // While we read out the packet, a key change will be // initiated. - done := make(chan int, 1) + errorCh := make(chan error, 1) go func() { - defer close(done) - if _, err := trC.readPacket(); err != nil { - t.Fatalf("readPacket(client): %v", err) - } - + _, err := trC.readPacket() + errorCh <- err }() - <-done + if err := <-errorCh; err != nil { + t.Fatalf("readPacket(client): %v", err) + } + <-sync.called } diff --git a/ssh/mux_test.go b/ssh/mux_test.go index 393017c08c..bcec7ae0ca 100644 --- a/ssh/mux_test.go +++ b/ssh/mux_test.go @@ -5,6 +5,8 @@ package ssh import ( + "errors" + "fmt" "io" "sync" "testing" @@ -26,18 +28,23 @@ func channelPair(t *testing.T) (*channel, *channel, *mux) { c, s := muxPair() res := make(chan *channel, 1) + errorCh := make(chan error, 1) go func() { newCh, ok := <-s.incomingChannels if !ok { - t.Fatalf("No incoming channel") + errorCh <- errors.New("no incoming channel") + return } if newCh.ChannelType() != "chan" { - t.Fatalf("got type %q want chan", newCh.ChannelType()) + errorCh <- fmt.Errorf("got type %q want chan", newCh.ChannelType()) + return } ch, _, err := newCh.Accept() if err != nil { - t.Fatalf("Accept %v", err) + errorCh <- fmt.Errorf("accept: %w", err) + return } + close(errorCh) res <- ch.(*channel) }() @@ -45,6 +52,9 @@ func channelPair(t *testing.T) (*channel, *channel, *mux) { if err != nil { t.Fatalf("OpenChannel: %v", err) } + if err := <-errorCh; err != nil { + t.Fatal(err) + } return <-res, ch, c } @@ -57,7 +67,7 @@ func TestMuxChannelExtendedThreadSafety(t *testing.T) { defer reader.Close() defer mux.Close() - var wr, rd sync.WaitGroup + var wr sync.WaitGroup magic := "hello world" wr.Add(2) @@ -70,25 +80,32 @@ func TestMuxChannelExtendedThreadSafety(t *testing.T) { wr.Done() }() - rd.Add(2) + errs := make(chan error, 2) go func() { c, err := io.ReadAll(reader) if string(c) != magic { - t.Fatalf("stdout read got %q, want %q (error %s)", c, magic, err) + errs <- fmt.Errorf("stdout read got %q, want %q (error %w)", c, magic, err) + return } - rd.Done() + errs <- nil }() go func() { c, err := io.ReadAll(reader.Stderr()) if string(c) != magic { - t.Fatalf("stderr read got %q, want %q (error %s)", c, magic, err) + errs <- fmt.Errorf("stderr read got %q, want %q (error %w)", c, magic, err) + return } - rd.Done() + errs <- nil }() wr.Wait() writer.CloseWrite() - rd.Wait() + for i := 0; i < 2; i++ { + err := <-errs + if err != nil { + t.Fatal(err) + } + } } func TestMuxReadWrite(t *testing.T) { @@ -99,15 +116,12 @@ func TestMuxReadWrite(t *testing.T) { magic := "hello world" magicExt := "hello stderr" + errs := make(chan error, 2) go func() { _, err := s.Write([]byte(magic)) - if err != nil { - t.Fatalf("Write: %v", err) - } + errs <- err _, err = s.Extended(1).Write([]byte(magicExt)) - if err != nil { - t.Fatalf("Write: %v", err) - } + errs <- err }() var buf [1024]byte @@ -129,6 +143,11 @@ func TestMuxReadWrite(t *testing.T) { if got != magicExt { t.Fatalf("server: got %q want %q", got, magic) } + for i := 0; i < 2; i++ { + if err := <-errs; err != nil { + t.Fatalf("write error: %v", err) + } + } } func TestMuxChannelOverflow(t *testing.T) { @@ -212,15 +231,19 @@ func TestMuxReject(t *testing.T) { defer server.Close() defer client.Close() + errorCh := make(chan error, 1) go func() { ch, ok := <-server.incomingChannels if !ok { - t.Fatalf("Accept") + errorCh <- errors.New("cannot accept channel") + return } if ch.ChannelType() != "ch" || string(ch.ExtraData()) != "extra" { - t.Fatalf("unexpected channel: %q, %q", ch.ChannelType(), ch.ExtraData()) + errorCh <- fmt.Errorf("unexpected channel: %q, %q", ch.ChannelType(), ch.ExtraData()) + return } ch.Reject(RejectionReason(42), "message") + close(errorCh) }() ch, err := client.openChannel("ch", []byte("extra")) @@ -228,6 +251,10 @@ func TestMuxReject(t *testing.T) { t.Fatal("openChannel not rejected") } + if err := <-errorCh; err != nil { + t.Fatal(err) + } + ocf, ok := err.(*OpenChannelError) if !ok { t.Errorf("got %#v want *OpenChannelError", err) @@ -294,7 +321,7 @@ func TestMuxUnknownChannelRequests(t *testing.T) { defer serverPipe.Close() defer client.Close() - kDone := make(chan struct{}) + kDone := make(chan error, 1) go func() { // Ignore unknown channel messages that don't want a reply. err := serverPipe.writePacket(Marshal(channelRequestMsg{ @@ -304,7 +331,8 @@ func TestMuxUnknownChannelRequests(t *testing.T) { RequestSpecificData: []byte{}, })) if err != nil { - t.Fatalf("send: %v", err) + kDone <- fmt.Errorf("send: %w", err) + return } // Send a keepalive, which should get a channel failure message @@ -316,44 +344,53 @@ func TestMuxUnknownChannelRequests(t *testing.T) { RequestSpecificData: []byte{}, })) if err != nil { - t.Fatalf("send: %v", err) + kDone <- fmt.Errorf("send: %w", err) + return } packet, err := serverPipe.readPacket() if err != nil { - t.Fatalf("read packet: %v", err) + kDone <- fmt.Errorf("read packet: %w", err) + return } decoded, err := decode(packet) if err != nil { - t.Fatalf("decode failed: %v", err) + kDone <- fmt.Errorf("decode failed: %w", err) + return } switch msg := decoded.(type) { case *channelRequestFailureMsg: if msg.PeersID != 2 { - t.Fatalf("received response to wrong message: %v", msg) + kDone <- fmt.Errorf("received response to wrong message: %v", msg) + return + } default: - t.Fatalf("unexpected channel message: %v", msg) + kDone <- fmt.Errorf("unexpected channel message: %v", msg) + return } - kDone <- struct{}{} + kDone <- nil // Receive and respond to the keepalive to confirm the mux is // still processing requests. packet, err = serverPipe.readPacket() if err != nil { - t.Fatalf("read packet: %v", err) + kDone <- fmt.Errorf("read packet: %w", err) + return } if packet[0] != msgGlobalRequest { - t.Fatalf("expected global request") + kDone <- errors.New("expected global request") + return } err = serverPipe.writePacket(Marshal(globalRequestFailureMsg{ Data: []byte{}, })) if err != nil { - t.Fatalf("failed to send failure msg: %v", err) + kDone <- fmt.Errorf("failed to send failure msg: %w", err) + return } close(kDone) @@ -362,7 +399,10 @@ func TestMuxUnknownChannelRequests(t *testing.T) { // Wait for the server to send the keepalive message and receive back a // response. select { - case <-kDone: + case err := <-kDone: + if err != nil { + t.Fatal(err) + } case <-time.After(10 * time.Second): t.Fatalf("server never received ack") } @@ -373,7 +413,10 @@ func TestMuxUnknownChannelRequests(t *testing.T) { } select { - case <-kDone: + case err := <-kDone: + if err != nil { + t.Fatal(err) + } case <-time.After(10 * time.Second): t.Fatalf("server never shut down") } @@ -385,20 +428,23 @@ func TestMuxClosedChannel(t *testing.T) { defer serverPipe.Close() defer client.Close() - kDone := make(chan struct{}) + kDone := make(chan error, 1) go func() { // Open the channel. packet, err := serverPipe.readPacket() if err != nil { - t.Fatalf("read packet: %v", err) + kDone <- fmt.Errorf("read packet: %w", err) + return } if packet[0] != msgChannelOpen { - t.Fatalf("expected chan open") + kDone <- errors.New("expected chan open") + return } var openMsg channelOpenMsg if err := Unmarshal(packet, &openMsg); err != nil { - t.Fatalf("unmarshal: %v", err) + kDone <- fmt.Errorf("unmarshal: %w", err) + return } // Send back the opened channel confirmation. @@ -409,7 +455,8 @@ func TestMuxClosedChannel(t *testing.T) { MaxPacketSize: channelMaxPacket, })) if err != nil { - t.Fatalf("send: %v", err) + kDone <- fmt.Errorf("send: %w", err) + return } // Close the channel. @@ -417,7 +464,8 @@ func TestMuxClosedChannel(t *testing.T) { PeersID: openMsg.PeersID, })) if err != nil { - t.Fatalf("send: %v", err) + kDone <- fmt.Errorf("send: %w", err) + return } // Send a keepalive message on the channel we just closed. @@ -428,43 +476,51 @@ func TestMuxClosedChannel(t *testing.T) { RequestSpecificData: []byte{}, })) if err != nil { - t.Fatalf("send: %v", err) + kDone <- fmt.Errorf("send: %w", err) + return } // Receive the channel closed response. packet, err = serverPipe.readPacket() if err != nil { - t.Fatalf("read packet: %v", err) + kDone <- fmt.Errorf("read packet: %w", err) + return } if packet[0] != msgChannelClose { - t.Fatalf("expected channel close") + kDone <- errors.New("expected channel close") + return } // Receive the keepalive response failure. packet, err = serverPipe.readPacket() if err != nil { - t.Fatalf("read packet: %v", err) + kDone <- fmt.Errorf("read packet: %w", err) + return } if packet[0] != msgChannelFailure { - t.Fatalf("expected channel close") + kDone <- errors.New("expected channel failure") + return } - kDone <- struct{}{} + kDone <- nil // Receive and respond to the keepalive to confirm the mux is // still processing requests. packet, err = serverPipe.readPacket() if err != nil { - t.Fatalf("read packet: %v", err) + kDone <- fmt.Errorf("read packet: %w", err) + return } if packet[0] != msgGlobalRequest { - t.Fatalf("expected global request") + kDone <- errors.New("expected global request") + return } err = serverPipe.writePacket(Marshal(globalRequestFailureMsg{ Data: []byte{}, })) if err != nil { - t.Fatalf("failed to send failure msg: %v", err) + kDone <- fmt.Errorf("failed to send failure msg: %w", err) + return } close(kDone) diff --git a/ssh/session_test.go b/ssh/session_test.go index c4b9f0ea5b..8525cb5e95 100644 --- a/ssh/session_test.go +++ b/ssh/session_test.go @@ -27,6 +27,7 @@ func dial(handler serverType, t *testing.T) *Client { t.Fatalf("netPipe: %v", err) } + errorCh := make(chan error, 1) go func() { defer c1.Close() conf := ServerConfig{ @@ -35,8 +36,9 @@ func dial(handler serverType, t *testing.T) *Client { conf.AddHostKey(testSigners["rsa"]) conn, chans, reqs, err := NewServerConn(c1, &conf) + errorCh <- err if err != nil { - t.Fatalf("Unable to handshake: %v", err) + return } go DiscardRequests(reqs) @@ -69,6 +71,9 @@ func dial(handler serverType, t *testing.T) *Client { if err != nil { t.Fatalf("unable to dial remote side: %v", err) } + if err := <-errorCh; err != nil { + t.Fatalf("Unable to handshake: %v", err) + } return NewClient(conn, chans, reqs) } @@ -647,10 +652,12 @@ func TestSessionID(t *testing.T) { User: "user", } + srvErrCh := make(chan error, 1) go func() { conn, chans, reqs, err := NewServerConn(c1, serverConf) + srvErrCh <- err if err != nil { - t.Fatalf("server handshake: %v", err) + return } serverID <- conn.SessionID() go DiscardRequests(reqs) @@ -659,10 +666,12 @@ func TestSessionID(t *testing.T) { } }() + cliErrCh := make(chan error, 1) go func() { conn, chans, reqs, err := NewClientConn(c2, "", clientConf) + cliErrCh <- err if err != nil { - t.Fatalf("client handshake: %v", err) + return } clientID <- conn.SessionID() go DiscardRequests(reqs) @@ -671,6 +680,14 @@ func TestSessionID(t *testing.T) { } }() + if err := <-srvErrCh; err != nil { + t.Fatalf("server handshake: %v", err) + } + + if err := <-cliErrCh; err != nil { + t.Fatalf("client handshake: %v", err) + } + s := <-serverID c := <-clientID if bytes.Compare(s, c) != 0 { diff --git a/ssh/test/multi_auth_test.go b/ssh/test/multi_auth_test.go index 6c253a7547..403d7363ab 100644 --- a/ssh/test/multi_auth_test.go +++ b/ssh/test/multi_auth_test.go @@ -77,27 +77,27 @@ func (ctx *multiAuthTestCtx) kbdIntCb(user, instruction string, questions []stri func TestMultiAuth(t *testing.T) { testCases := []multiAuthTestCase{ // Test password,publickey authentication, assert that password callback is called 1 time - multiAuthTestCase{ + { authMethods: []string{"password", "publickey"}, expectedPasswordCbs: 1, }, // Test keyboard-interactive,publickey authentication, assert that keyboard-interactive callback is called 1 time - multiAuthTestCase{ + { authMethods: []string{"keyboard-interactive", "publickey"}, expectedKbdIntCbs: 1, }, // Test publickey,password authentication, assert that password callback is called 1 time - multiAuthTestCase{ + { authMethods: []string{"publickey", "password"}, expectedPasswordCbs: 1, }, // Test publickey,keyboard-interactive authentication, assert that keyboard-interactive callback is called 1 time - multiAuthTestCase{ + { authMethods: []string{"publickey", "keyboard-interactive"}, expectedKbdIntCbs: 1, }, // Test password,password authentication, assert that password callback is called 2 times - multiAuthTestCase{ + { authMethods: []string{"password", "password"}, expectedPasswordCbs: 2, },