Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add remote IP filter to allow a connection from remote kms #692

Merged
merged 5 commits into from
Aug 17, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cmd/ostracon/commands/show_validator.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ func showValidator(cmd *cobra.Command, args []string, config *cfg.Config) error
if err != nil {
return err
}
pv, err = node.CreateAndStartPrivValidatorSocketClient(config.PrivValidatorListenAddr, chainID, logger)
pv, err = node.CreateAndStartPrivValidatorSocketClient(config, chainID, logger)
if err != nil {
return err
}
Expand Down
2 changes: 2 additions & 0 deletions cmd/ostracon/commands/show_validator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package commands
import (
"bytes"
"os"
"strings"
"sync"
"testing"

Expand Down Expand Up @@ -79,6 +80,7 @@ func TestShowValidatorWithKMS(t *testing.T) {
}
privval.WithMockKMS(t, dir, chainID, func(addr string, privKey crypto.PrivKey) {
config.PrivValidatorListenAddr = addr
config.PrivValidatorRemoteAddr = addr[:strings.Index(addr, ":")]
require.NoFileExists(t, config.PrivValidatorKeyFile())
output, err := captureStdout(func() {
err := showValidator(ShowValidatorCmd, nil, config)
Expand Down
6 changes: 6 additions & 0 deletions config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -242,8 +242,14 @@ type BaseConfig struct { //nolint: maligned

// TCP or UNIX socket address for Ostracon to listen on for
// connections from an external PrivValidator process
// example) 0.0.0.0:26659
PrivValidatorListenAddr string `mapstructure:"priv_validator_laddr"`

// Validator's remote address(without port) to allow a connection
// ostracon only allow a connection from this address
// example) 10.0.0.7
PrivValidatorRemoteAddr string `mapstructure:"priv_validator_raddr"`
ulbqb marked this conversation as resolved.
Show resolved Hide resolved

// A JSON file containing the private key to use for p2p authenticated encryption
NodeKey string `mapstructure:"node_key_file"`

Expand Down
6 changes: 6 additions & 0 deletions config/toml.go
Original file line number Diff line number Diff line change
Expand Up @@ -156,8 +156,14 @@ priv_validator_state_file = "{{ js .BaseConfig.PrivValidatorState }}"

# TCP or UNIX socket address for Ostracon to listen on for
# connections from an external PrivValidator process
# example) 0.0.0.0:26659
priv_validator_laddr = "{{ .BaseConfig.PrivValidatorListenAddr }}"

# Validator's remote address to allow a connection
# ostracon only allow a connection from this address
# example) 10.0.0.7
priv_validator_raddr = "{{ .BaseConfig.PrivValidatorRemoteAddr }}"

# Path to the JSON file containing the private key to use for node authentication in the p2p protocol
node_key_file = "{{ js .BaseConfig.NodeKey }}"

Expand Down
10 changes: 3 additions & 7 deletions node/node.go
Original file line number Diff line number Diff line change
Expand Up @@ -794,7 +794,7 @@ func NewNode(config *cfg.Config,
// external signing process.
if config.PrivValidatorListenAddr != "" {
// FIXME: we should start services inside OnStart
privValidator, err = CreateAndStartPrivValidatorSocketClient(config.PrivValidatorListenAddr, genDoc.ChainID, logger)
privValidator, err = CreateAndStartPrivValidatorSocketClient(config, genDoc.ChainID, logger)
if err != nil {
return nil, fmt.Errorf("error with private validator socket client: %w", err)
}
Expand Down Expand Up @@ -1523,12 +1523,8 @@ func saveGenesisDoc(db dbm.DB, genDoc *types.GenesisDoc) error {
return nil
}

func CreateAndStartPrivValidatorSocketClient(
listenAddr,
chainID string,
logger log.Logger,
) (types.PrivValidator, error) {
pve, err := privval.NewSignerListener(listenAddr, logger)
func CreateAndStartPrivValidatorSocketClient(config *cfg.Config, chainID string, logger log.Logger) (types.PrivValidator, error) {
pve, err := privval.NewSignerListener(logger, config.PrivValidatorListenAddr, config.PrivValidatorRemoteAddr)
if err != nil {
return nil, fmt.Errorf("failed to start private validator: %w", err)
}
Expand Down
27 changes: 27 additions & 0 deletions privval/signer_listener_endpoint.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package privval
import (
"fmt"
"net"
"strings"
"time"

privvalproto "github.com/tendermint/tendermint/proto/tendermint/privval"
Expand All @@ -24,6 +25,13 @@ func SignerListenerEndpointTimeoutReadWrite(timeout time.Duration) SignerListene
return func(sl *SignerListenerEndpoint) { sl.signerEndpoint.timeoutReadWrite = timeout }
}

// SignerListenerEndpointAllowAddress sets the address to allow
// connections from the only allowed address
//
func SignerListenerEndpointAllowAddress(addr string) SignerListenerEndpointOption {
return func(sl *SignerListenerEndpoint) { sl.allowAddr = addr }
}

// SignerListenerEndpoint listens for an external process to dial in and keeps
// the connection alive by dropping and reconnecting.
//
Expand All @@ -41,6 +49,8 @@ type SignerListenerEndpoint struct {
pingInterval time.Duration

instanceMtx tmsync.Mutex // Ensures instance public methods access, i.e. SendRequest

allowAddr string // empty value allows all
}

// NewSignerListenerEndpoint returns an instance of SignerListenerEndpoint.
Expand Down Expand Up @@ -186,6 +196,12 @@ func (sl *SignerListenerEndpoint) serviceLoop() {
{
conn, err := sl.acceptNewConnection()
if err == nil {
remoteAddr := conn.RemoteAddr()
if !sl.isAllowedAddr(remoteAddr) {
sl.Logger.Info(fmt.Sprintf("deny a connection request from remote address=%s", remoteAddr))
jaeseung-bae marked this conversation as resolved.
Show resolved Hide resolved
conn.Close()
continue
}
sl.Logger.Info("SignerListener: Connected")

// We have a good connection, wait for someone that needs one otherwise cancellation
Expand All @@ -207,6 +223,17 @@ func (sl *SignerListenerEndpoint) serviceLoop() {
}
}

func (sl *SignerListenerEndpoint) isAllowedAddr(addr net.Addr) bool {
if len(sl.allowAddr) == 0 {
return true
}
if strings.Contains(addr.String(), ":") {
addrOnly := addr.String()[:strings.Index(addr.String(), ":")]
jaeseung-bae marked this conversation as resolved.
Show resolved Hide resolved
return sl.allowAddr == addrOnly
}
return sl.allowAddr == addr.String()
}

func (sl *SignerListenerEndpoint) pingLoop() {
for {
select {
Expand Down
70 changes: 70 additions & 0 deletions privval/signer_listener_endpoint_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,76 @@ func TestRetryConnToRemoteSigner(t *testing.T) {
}
}

type addrStub struct {
address string
}

func (a addrStub) Network() string {
return ""
}

func (a addrStub) String() string {
return a.address
}

func TestFilterRemoteConnectionByIP(t *testing.T) {
type fields struct {
allowIP string
remoteAddr net.Addr
expected bool
}
tests := []struct {
name string
fields fields
}{
{
"should allow correct ip",
struct {
allowIP string
remoteAddr net.Addr
expected bool
}{"127.0.0.1", addrStub{"127.0.0.1:45678"}, true},
}, {
"should allow correct ip without port",
struct {
allowIP string
remoteAddr net.Addr
expected bool
}{"127.0.0.1", addrStub{"127.0.0.1"}, true},
},
{
"should not allow different ip",
struct {
allowIP string
remoteAddr net.Addr
expected bool
}{"127.0.0.1", addrStub{"10.0.0.2:45678"}, false},
},
{
"empty allowIP should allow all",
struct {
allowIP string
remoteAddr net.Addr
expected bool
}{"", addrStub{"127.0.0.1:45678"}, true},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
sl := &SignerListenerEndpoint{allowAddr: tt.fields.allowIP}
assert.Equalf(t, tt.fields.expected, sl.isAllowedAddr(tt.fields.remoteAddr), tt.name)
})
}
}

func TestSignerListenerEndpointAllowAddress(t *testing.T) {
expected := "192.168.0.1"

cut := NewSignerListenerEndpoint(nil, nil, SignerListenerEndpointAllowAddress(expected))

assert.Equal(t, expected, cut.allowAddr)
}

func newSignerListenerEndpoint(logger log.Logger, addr string, timeoutReadWrite time.Duration) *SignerListenerEndpoint {
proto, address := tmnet.ProtocolAndAddress(addr)

Expand Down
4 changes: 2 additions & 2 deletions privval/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ func IsConnTimeout(err error) bool {
}

// NewSignerListener creates a new SignerListenerEndpoint using the corresponding listen address
func NewSignerListener(listenAddr string, logger log.Logger) (*SignerListenerEndpoint, error) {
func NewSignerListener(logger log.Logger, listenAddr, remoteAddr string) (*SignerListenerEndpoint, error) {
var listener net.Listener

protocol, address := tmnet.ProtocolAndAddress(listenAddr)
Expand All @@ -47,7 +47,7 @@ func NewSignerListener(listenAddr string, logger log.Logger) (*SignerListenerEnd
)
}

pve := NewSignerListenerEndpoint(logger.With("module", "privval"), listener)
pve := NewSignerListenerEndpoint(logger.With("module", "privval"), listener, SignerListenerEndpointAllowAddress(remoteAddr))

return pve, nil
}
Expand Down