Skip to content

Commit

Permalink
[ws-proxy] SSH gateway support full channel type
Browse files Browse the repository at this point in the history
  • Loading branch information
iQQBot authored and roboquat committed May 28, 2022
1 parent 2823606 commit 09d34a5
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 38 deletions.
34 changes: 17 additions & 17 deletions components/ws-proxy/pkg/sshproxy/forward.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,29 +14,29 @@ import (
"golang.org/x/net/context"
)

func (s *Server) ChannelForward(ctx context.Context, session *Session, client *ssh.Client, newChannel ssh.NewChannel) {
workspaceChan, workspaceReqs, err := client.OpenChannel(newChannel.ChannelType(), newChannel.ExtraData())
func (s *Server) ChannelForward(ctx context.Context, session *Session, targetConn ssh.Conn, originChannel ssh.NewChannel) {
targetChan, targetReqs, err := targetConn.OpenChannel(originChannel.ChannelType(), originChannel.ExtraData())
if err != nil {
log.WithFields(log.OWI("", session.WorkspaceID, session.InstanceID)).Error("open workspace channel error")
newChannel.Reject(ssh.ConnectionFailed, "open workspace channel error")
log.WithFields(log.OWI("", session.WorkspaceID, session.InstanceID)).Error("open target channel error")
originChannel.Reject(ssh.ConnectionFailed, "open target channel error")
return
}
defer workspaceChan.Close()
defer targetChan.Close()

clientChan, clientReqs, err := newChannel.Accept()
originChan, originReqs, err := originChannel.Accept()
if err != nil {
log.WithFields(log.OWI("", session.WorkspaceID, session.InstanceID)).Error("accept new channel failed")
log.WithFields(log.OWI("", session.WorkspaceID, session.InstanceID)).Error("accept origin channel failed")
return
}
if newChannel.ChannelType() == "session" {
clientChan = startHeartbeatingChannel(clientChan, s.Heartbeater, session.InstanceID)
if originChannel.ChannelType() == "session" {
originChan = startHeartbeatingChannel(originChan, s.Heartbeater, session.InstanceID)
}
defer clientChan.Close()
defer originChan.Close()

maskedReqs := make(chan *ssh.Request, 1)

go func() {
for req := range clientReqs {
for req := range originReqs {
switch req.Type {
case "pty-req", "shell":
log.WithFields(log.OWI("", session.WorkspaceID, session.InstanceID)).Debugf("forwarding %s request", req.Type)
Expand All @@ -51,13 +51,13 @@ func (s *Server) ChannelForward(ctx context.Context, session *Session, client *s
}()

go func() {
io.Copy(workspaceChan, clientChan)
workspaceChan.CloseWrite()
io.Copy(targetChan, originChan)
targetChan.CloseWrite()
}()

go func() {
io.Copy(clientChan, workspaceChan)
clientChan.CloseWrite()
io.Copy(originChan, targetChan)
originChan.CloseWrite()
}()

wg := sync.WaitGroup{}
Expand All @@ -82,8 +82,8 @@ func (s *Server) ChannelForward(ctx context.Context, session *Session, client *s
}

wg.Add(2)
go forward(maskedReqs, workspaceChan)
go forward(workspaceReqs, clientChan)
go forward(maskedReqs, targetChan)
go forward(targetReqs, originChan)

wg.Wait()
log.WithFields(log.OWI("", session.WorkspaceID, session.InstanceID)).Debug("session forward stop")
Expand Down
72 changes: 51 additions & 21 deletions components/ws-proxy/pkg/sshproxy/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ package sshproxy

import (
"context"
"fmt"
"net"
"strings"
"time"
Expand Down Expand Up @@ -168,20 +167,29 @@ func ReportSSHAttemptMetrics(err error) {
SSHAttemptTotal.WithLabelValues("failed", errorType).Inc()
}

func (s *Server) RequestForward(reqs <-chan *ssh.Request, targetConn ssh.Conn) {
for req := range reqs {
result, payload, err := targetConn.SendRequest(req.Type, req.WantReply, req.Payload)
if err != nil {
continue
}
_ = req.Reply(result, payload)
}
}

func (s *Server) HandleConn(c net.Conn) {
sshConn, chans, reqs, err := ssh.NewServerConn(c, s.sshConfig)
clientConn, clientChans, clientReqs, err := ssh.NewServerConn(c, s.sshConfig)
if err != nil {
c.Close()
ReportSSHAttemptMetrics(err)
return
}
defer sshConn.Close()
defer clientConn.Close()

go ssh.DiscardRequests(reqs)
if sshConn.Permissions == nil || sshConn.Permissions.Extensions == nil || sshConn.Permissions.Extensions["workspaceId"] == "" {
if clientConn.Permissions == nil || clientConn.Permissions.Extensions == nil || clientConn.Permissions.Extensions["workspaceId"] == "" {
return
}
workspaceId := sshConn.Permissions.Extensions["workspaceId"]
workspaceId := clientConn.Permissions.Extensions["workspaceId"]
wsInfo := s.workspaceInfoProvider.WorkspaceInfo(workspaceId)
if wsInfo == nil {
ReportSSHAttemptMetrics(ErrWorkspaceNotFound)
Expand All @@ -199,7 +207,7 @@ func (s *Server) HandleConn(c net.Conn) {
cancel()

session := &Session{
Conn: sshConn,
Conn: clientConn,
WorkspaceID: workspaceId,
InstanceID: wsInfo.InstanceID,
WorkspacePrivateKey: key,
Expand All @@ -214,7 +222,7 @@ func (s *Server) HandleConn(c net.Conn) {
}
defer conn.Close()

clientConn, clientChans, clientReqs, err := ssh.NewClientConn(conn, remoteAddr, &ssh.ClientConfig{
workspaceConn, workspaceChans, workspaceReqs, err := ssh.NewClientConn(conn, remoteAddr, &ssh.ClientConfig{
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
User: GitpodUsername,
Auth: []ssh.AuthMethod{
Expand All @@ -231,29 +239,51 @@ func (s *Server) HandleConn(c net.Conn) {
return
}
s.Heartbeater.SendHeartbeat(wsInfo.InstanceID, false)
client := ssh.NewClient(clientConn, clientChans, clientReqs)
ctx, cancel = context.WithCancel(context.Background())

s.TrackSSHConnection(wsInfo, "connect", nil)
SSHConnectionCount.Inc()
ReportSSHAttemptMetrics(nil)

forwardRequests := func(reqs <-chan *ssh.Request, targetConn ssh.Conn) {
for req := range reqs {
result, payload, err := targetConn.SendRequest(req.Type, req.WantReply, req.Payload)
if err != nil {
continue
}
_ = req.Reply(result, payload)
}
}
// client -> workspace global request forward
go forwardRequests(clientReqs, workspaceConn)
// workspce -> client global request forward
go forwardRequests(workspaceReqs, clientConn)

go func() {
client.Wait()
cancel()
defer SSHConnectionCount.Dec()
for newChannel := range workspaceChans {
go s.ChannelForward(ctx, session, clientConn, newChannel)
}
}()

for newChannel := range chans {
switch newChannel.ChannelType() {
case "session", "direct-tcpip":
go s.ChannelForward(ctx, session, client, newChannel)
case "tcpip-forward":
newChannel.Reject(ssh.UnknownChannelType, "Gitpod SSH Gateway cannot remote forward ports")
default:
newChannel.Reject(ssh.UnknownChannelType, fmt.Sprintf("Gitpod SSH Gateway cannot handle %s channel types", newChannel.ChannelType()))
go func() {
for newChannel := range clientChans {
go s.ChannelForward(ctx, session, workspaceConn, newChannel)
}
}
}()

go func() {
clientConn.Wait()
cancel()
}()
go func() {
workspaceConn.Wait()
cancel()
}()
<-ctx.Done()
SSHConnectionCount.Dec()
workspaceConn.Close()
clientConn.Close()
cancel()
}

func (s *Server) Authenticator(workspaceId, ownerToken string) (*p.WorkspaceInfo, error) {
Expand Down

0 comments on commit 09d34a5

Please sign in to comment.