Skip to content

Commit

Permalink
feat: add kms remote IP filter
Browse files Browse the repository at this point in the history
  • Loading branch information
jaeseung-bae committed Aug 8, 2023
1 parent d971aa4 commit 041f37a
Show file tree
Hide file tree
Showing 7 changed files with 95 additions and 10 deletions.
2 changes: 1 addition & 1 deletion cmd/ostracon/commands/show_validator.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ func showValidator(cmd *cobra.Command, args []string, config *cfg.Config) error
if err != nil {
return err
}
pv, err = node.CreateAndStartPrivValidatorSocketClient(config.PrivValidatorListenAddr, chainID, logger)
pv, err = node.CreateAndStartPrivValidatorSocketClient(config, chainID, logger)
if err != nil {
return err
}
Expand Down
6 changes: 6 additions & 0 deletions config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -242,8 +242,14 @@ type BaseConfig struct { //nolint: maligned

// TCP or UNIX socket address for Ostracon to listen on for
// connections from an external PrivValidator process
// example) 0.0.0.0:26659
PrivValidatorListenAddr string `mapstructure:"priv_validator_laddr"`

// Validator's remote address(without port) to allow a connection
// ostracon only allow a connection from this address
// example) 10.0.0.7
PrivValidatorRemoteAddr string `mapstructure:"priv_validator_raddr"`

// A JSON file containing the private key to use for p2p authenticated encryption
NodeKey string `mapstructure:"node_key_file"`

Expand Down
6 changes: 6 additions & 0 deletions config/toml.go
Original file line number Diff line number Diff line change
Expand Up @@ -156,8 +156,14 @@ priv_validator_state_file = "{{ js .BaseConfig.PrivValidatorState }}"
# TCP or UNIX socket address for Ostracon to listen on for
# connections from an external PrivValidator process
# example) 0.0.0.0:26659
priv_validator_laddr = "{{ .BaseConfig.PrivValidatorListenAddr }}"
# Validator's remote address to allow a connection
# ostracon only allow a connection from this address
# example) 10.0.0.7
priv_validator_raddr = "{{ .BaseConfig.PrivValidatorRemoteAddr }}"
# Path to the JSON file containing the private key to use for node authentication in the p2p protocol
node_key_file = "{{ js .BaseConfig.NodeKey }}"
Expand Down
10 changes: 3 additions & 7 deletions node/node.go
Original file line number Diff line number Diff line change
Expand Up @@ -794,7 +794,7 @@ func NewNode(config *cfg.Config,
// external signing process.
if config.PrivValidatorListenAddr != "" {
// FIXME: we should start services inside OnStart
privValidator, err = CreateAndStartPrivValidatorSocketClient(config.PrivValidatorListenAddr, genDoc.ChainID, logger)
privValidator, err = CreateAndStartPrivValidatorSocketClient(config, genDoc.ChainID, logger)
if err != nil {
return nil, fmt.Errorf("error with private validator socket client: %w", err)
}
Expand Down Expand Up @@ -1523,12 +1523,8 @@ func saveGenesisDoc(db dbm.DB, genDoc *types.GenesisDoc) error {
return nil
}

func CreateAndStartPrivValidatorSocketClient(
listenAddr,
chainID string,
logger log.Logger,
) (types.PrivValidator, error) {
pve, err := privval.NewSignerListener(listenAddr, logger)
func CreateAndStartPrivValidatorSocketClient(config *cfg.Config, chainID string, logger log.Logger) (types.PrivValidator, error) {
pve, err := privval.NewSignerListener(logger, config.PrivValidatorListenAddr, config.PrivValidatorRemoteAddr)
if err != nil {
return nil, fmt.Errorf("failed to start private validator: %w", err)
}
Expand Down
22 changes: 22 additions & 0 deletions privval/signer_listener_endpoint.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package privval
import (
"fmt"
"net"
"strings"
"time"

privvalproto "github.com/tendermint/tendermint/proto/tendermint/privval"
Expand All @@ -24,6 +25,13 @@ func SignerListenerEndpointTimeoutReadWrite(timeout time.Duration) SignerListene
return func(sl *SignerListenerEndpoint) { sl.signerEndpoint.timeoutReadWrite = timeout }
}

// SignerListenerEndpointAllowAddress sets the address to allow
// connections from the only allowed address
//
func SignerListenerEndpointAllowAddress(addr string) SignerListenerEndpointOption {
return func(sl *SignerListenerEndpoint) { sl.allowAddr = addr }
}

// SignerListenerEndpoint listens for an external process to dial in and keeps
// the connection alive by dropping and reconnecting.
//
Expand All @@ -41,6 +49,8 @@ type SignerListenerEndpoint struct {
pingInterval time.Duration

instanceMtx tmsync.Mutex // Ensures instance public methods access, i.e. SendRequest

allowAddr string
}

// NewSignerListenerEndpoint returns an instance of SignerListenerEndpoint.
Expand Down Expand Up @@ -185,6 +195,13 @@ func (sl *SignerListenerEndpoint) serviceLoop() {
case <-sl.connectRequestCh:
{
conn, err := sl.acceptNewConnection()
remoteAddr := conn.RemoteAddr()
if !sl.isAllowedAddr(remoteAddr) {
sl.Logger.Info(fmt.Sprintf("deny a connection request from remote address=%s", remoteAddr))
conn.Close()
continue
}

if err == nil {
sl.Logger.Info("SignerListener: Connected")

Expand All @@ -207,6 +224,11 @@ func (sl *SignerListenerEndpoint) serviceLoop() {
}
}

func (sl *SignerListenerEndpoint) isAllowedAddr(addr net.Addr) bool {
addrOnly := addr.String()[:strings.Index(addr.String(), ":")]
return sl.allowAddr == addrOnly
}

func (sl *SignerListenerEndpoint) pingLoop() {
for {
select {
Expand Down
55 changes: 55 additions & 0 deletions privval/signer_listener_endpoint_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,61 @@ func TestRetryConnToRemoteSigner(t *testing.T) {
}
}

type addrStub struct {
address string
}

func (a addrStub) Network() string {
return ""
}

func (a addrStub) String() string {
return a.address
}

func TestFilterRemoteConnectionByIP(t *testing.T) {
type fields struct {
allowIP string
remoteAddr net.Addr
expected bool
}
tests := []struct {
name string
fields fields
}{
{
"should allow correct ip",
struct {
allowIP string
remoteAddr net.Addr
expected bool
}{"127.0.0.1", addrStub{"127.0.0.1:45678"}, true},
},
{
"should not allow different ip",
struct {
allowIP string
remoteAddr net.Addr
expected bool
}{"127.0.0.1", addrStub{"10.0.0.2:45678"}, false},
},
{
"empty allowIP should deny all",
struct {
allowIP string
remoteAddr net.Addr
expected bool
}{"", addrStub{"127.0.0.1:45678"}, false},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
sl := &SignerListenerEndpoint{allowAddr: tt.fields.allowIP}
assert.Equalf(t, tt.fields.expected, sl.isAllowedAddr(tt.fields.remoteAddr), tt.name)
})
}
}

func newSignerListenerEndpoint(logger log.Logger, addr string, timeoutReadWrite time.Duration) *SignerListenerEndpoint {
proto, address := tmnet.ProtocolAndAddress(addr)

Expand Down
4 changes: 2 additions & 2 deletions privval/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ func IsConnTimeout(err error) bool {
}

// NewSignerListener creates a new SignerListenerEndpoint using the corresponding listen address
func NewSignerListener(listenAddr string, logger log.Logger) (*SignerListenerEndpoint, error) {
func NewSignerListener(logger log.Logger, listenAddr, remoteAddr string) (*SignerListenerEndpoint, error) {
var listener net.Listener

protocol, address := tmnet.ProtocolAndAddress(listenAddr)
Expand All @@ -47,7 +47,7 @@ func NewSignerListener(listenAddr string, logger log.Logger) (*SignerListenerEnd
)
}

pve := NewSignerListenerEndpoint(logger.With("module", "privval"), listener)
pve := NewSignerListenerEndpoint(logger.With("module", "privval"), listener, SignerListenerEndpointAllowAddress(remoteAddr))

return pve, nil
}
Expand Down

0 comments on commit 041f37a

Please sign in to comment.