Skip to content

Commit

Permalink
Update http.go
Browse files Browse the repository at this point in the history
- Better reverse proxy forwarding
  • Loading branch information
shoriwe committed May 28, 2022
1 parent 8954293 commit cee30c6
Showing 1 changed file with 52 additions and 33 deletions.
85 changes: 52 additions & 33 deletions internal/proxy/servers/reverse/http.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
package reverse

import (
"bufio"
"crypto/tls"
"context"
"fmt"
"github.com/gorilla/websocket"
"github.com/shoriwe/fullproxy/v3/internal/common"
Expand All @@ -26,7 +25,6 @@ type (
HTTP struct {
Targets map[string]*Target
Dial servers.DialFunc
WebSocketDialer *websocket.Dialer
IncomingSniffer, OutgoingSniffer io.Writer
}
)
Expand Down Expand Up @@ -60,9 +58,6 @@ func (H *HTTP) Handle(_ net.Conn) error {

func (H *HTTP) SetDial(dialFunc servers.DialFunc) {
H.Dial = dialFunc
H.WebSocketDialer = &websocket.Dialer{
NetDial: dialFunc,
}
}

func (H *HTTP) ServeHTTP(writer http.ResponseWriter, request *http.Request) {
Expand All @@ -88,6 +83,7 @@ func (H *HTTP) ServeHTTP(writer http.ResponseWriter, request *http.Request) {
// TODO: Do something with the error
return
}
defer request.Body.Close()
newRequest.Header = request.Header.Clone()
// Inject Headers in request
for key, values := range target.RequestHeader {
Expand Down Expand Up @@ -115,7 +111,13 @@ func (H *HTTP) ServeHTTP(writer http.ResponseWriter, request *http.Request) {
newRequest.Header.Del("Sec-Websocket-Key")
newRequest.Header.Del("Sec-Websocket-Version")
newRequest.Header.Del("Sec-Websocket-Extensions")
targetConnection, response, dialError := H.WebSocketDialer.Dial(
dialer := &websocket.Dialer{
NetDial: func(_, _ string) (net.Conn, error) {
return H.Dial(host.Network, host.Address)
},
TLSClientConfig: host.TLSConfig,
}
targetConnection, response, dialError := dialer.Dial(
u.String(),
newRequest.Header,
)
Expand Down Expand Up @@ -157,39 +159,56 @@ func (H *HTTP) ServeHTTP(writer http.ResponseWriter, request *http.Request) {
return
}
// Prepare client
serverConnection, connectionError := H.Dial(host.Network, host.Address)
if connectionError != nil {
// TODO: Do something with the error
return
}
defer serverConnection.Close()
if host.TLSConfig != nil {
serverConnection = tls.Client(serverConnection, host.TLSConfig)
}
server := &common.Sniffer{
WriteSniffer: H.OutgoingSniffer,
ReadSniffer: H.IncomingSniffer,
Connection: serverConnection,
}
// Send request to server
sendRequestError := newRequest.Write(server)
if sendRequestError != nil {
client := http.Client{
Transport: &http.Transport{
DialContext: func(_ context.Context, _, _ string) (net.Conn, error) {
return H.Dial(host.Network, host.Address)
},
TLSClientConfig: host.TLSConfig,
},
CheckRedirect: nil,
Jar: nil,
Timeout: 0,
}
newRequest.Body = &common.RequestSniffer{
HeaderDone: false,
Writer: H.OutgoingSniffer,
Request: request,
}
response, requestError := client.Do(newRequest)
if requestError != nil {
// TODO: Do something with the error
return
}
// Receive server response
serverResponse, readResponseError := http.ReadResponse(bufio.NewReader(server), newRequest)
if readResponseError != nil {
// TODO: Do something with the error
return
defer response.Body.Close()
newResponse := &http.Response{
Status: response.Status,
StatusCode: response.StatusCode,
Proto: response.Proto,
ProtoMajor: response.ProtoMajor,
ProtoMinor: response.ProtoMinor,
Header: response.Header.Clone(),
Body: &common.ResponseSniffer{
HeaderDone: false,
Writer: H.IncomingSniffer,
Response: response,
},
ContentLength: response.ContentLength,
TransferEncoding: response.TransferEncoding,
Close: response.Close,
Uncompressed: response.Uncompressed,
Trailer: response.Trailer.Clone(),
Request: response.Request,
TLS: response.TLS,
}
for key, values := range newResponse.Header {
writer.Header()[key] = values
}
defer serverResponse.Body.Close()
// Inject response headers
for key, values := range target.ResponseHeader {
writer.Header()[key] = values
}
writer.WriteHeader(serverResponse.StatusCode)
_, copyError := io.Copy(writer, serverResponse.Body)
writer.WriteHeader(newResponse.StatusCode)
_, copyError := io.Copy(writer, newResponse.Body)
if copyError != nil {
// TODO: Do something with the error
return
Expand Down

0 comments on commit cee30c6

Please sign in to comment.