diff --git a/cmd/geth/attach_test.go b/cmd/geth/attach_test.go new file mode 100644 index 000000000000..7c5f951750fb --- /dev/null +++ b/cmd/geth/attach_test.go @@ -0,0 +1,83 @@ +// Copyright 2022 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package main + +import ( + "fmt" + "net" + "net/http" + "sync/atomic" + "testing" +) + +type testHandler struct { + body func(http.ResponseWriter, *http.Request) +} + +func (t *testHandler) ServeHTTP(out http.ResponseWriter, in *http.Request) { + t.body(out, in) +} + +// TestAttachWithHeaders tests that 'geth attach' with custom headers works, i.e +// that custom headers are forwarded to the target. +func TestAttachWithHeaders(t *testing.T) { + t.Parallel() + ln, err := net.Listen("tcp", "localhost:0") + if err != nil { + t.Fatal(err) + } + port := ln.Addr().(*net.TCPAddr).Port + testReceiveHeaders(t, ln, "attach", "-H", "first: one", "-H", "second: two", fmt.Sprintf("http://localhost:%d", port)) + // This way to do it fails due to flag ordering: + // + // testReceiveHeaders(t, ln, "-H", "first: one", "-H", "second: two", "attach", fmt.Sprintf("http://localhost:%d", port)) + // This is fixed in a follow-up PR. +} + +// TestAttachWithHeaders tests that 'geth db --remotedb' with custom headers works, i.e +// that custom headers are forwarded to the target. +func TestRemoteDbWithHeaders(t *testing.T) { + t.Parallel() + ln, err := net.Listen("tcp", "localhost:0") + if err != nil { + t.Fatal(err) + } + port := ln.Addr().(*net.TCPAddr).Port + testReceiveHeaders(t, ln, "db", "metadata", "--remotedb", fmt.Sprintf("http://localhost:%d", port), "-H", "first: one", "-H", "second: two") +} + +func testReceiveHeaders(t *testing.T, ln net.Listener, gethArgs ...string) { + var ok uint32 + server := &http.Server{ + Addr: "localhost:0", + Handler: &testHandler{func(w http.ResponseWriter, r *http.Request) { + // We expect two headers + if have, want := r.Header.Get("first"), "one"; have != want { + t.Fatalf("missing header, have %v want %v", have, want) + } + if have, want := r.Header.Get("second"), "two"; have != want { + t.Fatalf("missing header, have %v want %v", have, want) + } + atomic.StoreUint32(&ok, 1) + }}} + go server.Serve(ln) + defer server.Close() + runGeth(t, gethArgs...).WaitExit() + if atomic.LoadUint32(&ok) != 1 { + t.Fatal("Test fail, expected invocation to succeed") + } +} diff --git a/cmd/geth/consolecmd.go b/cmd/geth/consolecmd.go index 87bbe24b977a..83c6b66a8a60 100644 --- a/cmd/geth/consolecmd.go +++ b/cmd/geth/consolecmd.go @@ -23,8 +23,6 @@ import ( "github.com/ethereum/go-ethereum/cmd/utils" "github.com/ethereum/go-ethereum/console" "github.com/ethereum/go-ethereum/internal/flags" - "github.com/ethereum/go-ethereum/node" - "github.com/ethereum/go-ethereum/rpc" "github.com/urfave/cli/v2" ) @@ -47,7 +45,7 @@ See https://geth.ethereum.org/docs/interface/javascript-console.`, Name: "attach", Usage: "Start an interactive JavaScript environment (connect to node)", ArgsUsage: "[endpoint]", - Flags: flags.Merge([]cli.Flag{utils.DataDirFlag}, consoleFlags), + Flags: flags.Merge([]cli.Flag{utils.DataDirFlag, utils.HttpHeaderFlag}, consoleFlags), Description: ` The Geth console is an interactive shell for the JavaScript runtime environment which exposes a node admin interface as well as the Ðapp JavaScript API. @@ -118,14 +116,13 @@ func remoteConsole(ctx *cli.Context) error { if ctx.Args().Len() > 1 { utils.Fatalf("invalid command-line: too many arguments") } - endpoint := ctx.Args().First() if endpoint == "" { cfg := defaultNodeConfig() utils.SetDataDir(ctx, &cfg) endpoint = cfg.IPCEndpoint() } - client, err := dialRPC(endpoint) + client, err := utils.DialRPCWithHeaders(endpoint, ctx.StringSlice(utils.HttpHeaderFlag.Name)) if err != nil { utils.Fatalf("Unable to attach to remote geth: %v", err) } @@ -164,17 +161,3 @@ func ephemeralConsole(ctx *cli.Context) error { geth --exec "%s" console`, b.String()) return nil } - -// dialRPC returns a RPC client which connects to the given endpoint. -// The check for empty endpoint implements the defaulting logic -// for "geth attach" with no argument. -func dialRPC(endpoint string) (*rpc.Client, error) { - if endpoint == "" { - endpoint = node.DefaultIPCEndpoint(clientIdentifier) - } else if strings.HasPrefix(endpoint, "rpc:") || strings.HasPrefix(endpoint, "ipc:") { - // Backwards compatibility with geth < 1.5 which required - // these prefixes. - endpoint = endpoint[4:] - } - return rpc.Dial(endpoint) -} diff --git a/cmd/utils/flags.go b/cmd/utils/flags.go index 745b9f088eb3..ca6ded475668 100644 --- a/cmd/utils/flags.go +++ b/cmd/utils/flags.go @@ -18,10 +18,13 @@ package utils import ( + "context" "crypto/ecdsa" + "errors" "fmt" "math" "math/big" + "net/http" "os" "path/filepath" godebug "runtime/debug" @@ -976,6 +979,13 @@ var ( Value: metrics.DefaultConfig.InfluxDBOrganization, Category: flags.MetricsCategory, } + + HttpHeaderFlag = &cli.StringSliceFlag{ + Name: "header", + Aliases: []string{"H"}, + Usage: "Pass custom headers to the RPC server wheng using --" + RemoteDBFlag.Name + " or the geth attach console.", + Category: flags.NetworkingCategory, + } ) var ( @@ -995,6 +1005,7 @@ var ( DataDirFlag, AncientFlag, RemoteDBFlag, + HttpHeaderFlag, } ) @@ -2125,8 +2136,12 @@ func MakeChainDatabase(ctx *cli.Context, stack *node.Node, readonly bool) ethdb. ) switch { case ctx.IsSet(RemoteDBFlag.Name): - log.Info("Using remote db", "url", ctx.String(RemoteDBFlag.Name)) - chainDb, err = remotedb.New(ctx.String(RemoteDBFlag.Name)) + log.Info("Using remote db", "url", ctx.String(RemoteDBFlag.Name), "headers", len(ctx.StringSlice(HttpHeaderFlag.Name))) + client, err := DialRPCWithHeaders(ctx.String(RemoteDBFlag.Name), ctx.StringSlice(HttpHeaderFlag.Name)) + if err != nil { + break + } + chainDb = remotedb.New(client) case ctx.String(SyncModeFlag.Name) == "light": chainDb, err = stack.OpenDatabase("lightchaindata", cache, handles, "", readonly) default: @@ -2148,6 +2163,30 @@ func IsNetworkPreset(ctx *cli.Context) bool { return false } +func DialRPCWithHeaders(endpoint string, headers []string) (*rpc.Client, error) { + if endpoint == "" { + return nil, errors.New("endpoint must be specified") + } + if strings.HasPrefix(endpoint, "rpc:") || strings.HasPrefix(endpoint, "ipc:") { + // Backwards compatibility with geth < 1.5 which required + // these prefixes. + endpoint = endpoint[4:] + } + var opts []rpc.ClientOption + if len(headers) > 0 { + var customHeaders = make(http.Header) + for _, h := range headers { + kv := strings.Split(h, ":") + if len(kv) != 2 { + return nil, fmt.Errorf("invalid http header directive: %q", h) + } + customHeaders.Add(kv[0], kv[1]) + } + opts = append(opts, rpc.WithHeaders(customHeaders)) + } + return rpc.DialOptions(context.Background(), endpoint, opts...) +} + func MakeGenesis(ctx *cli.Context) *core.Genesis { var genesis *core.Genesis switch { diff --git a/ethdb/remotedb/remotedb.go b/ethdb/remotedb/remotedb.go index 59a570bb5e96..9ce657d78026 100644 --- a/ethdb/remotedb/remotedb.go +++ b/ethdb/remotedb/remotedb.go @@ -22,9 +22,6 @@ package remotedb import ( - "errors" - "strings" - "github.com/ethereum/go-ethereum/common/hexutil" "github.com/ethereum/go-ethereum/ethdb" "github.com/ethereum/go-ethereum/rpc" @@ -150,24 +147,8 @@ func (db *Database) Close() error { return nil } -func dialRPC(endpoint string) (*rpc.Client, error) { - if endpoint == "" { - return nil, errors.New("endpoint must be specified") - } - if strings.HasPrefix(endpoint, "rpc:") || strings.HasPrefix(endpoint, "ipc:") { - // Backwards compatibility with geth < 1.5 which required - // these prefixes. - endpoint = endpoint[4:] - } - return rpc.Dial(endpoint) -} - -func New(endpoint string) (ethdb.Database, error) { - client, err := dialRPC(endpoint) - if err != nil { - return nil, err - } +func New(client *rpc.Client) ethdb.Database { return &Database{ remote: client, - }, nil + } }