Skip to content

Commit

Permalink
feat: add remote IP filter to allow a connection from remote kms (#692)
Browse files Browse the repository at this point in the history
* feat: add kms remote IP filter

* fix: side-effect for testing by allowing any connection if remote address is empty

* chore: add some tests

* feat: apply denyAll if empty allow address, and null object pattern for "unix" connection type

* chore: increase test coverage
  • Loading branch information
jaeseung-bae committed Aug 17, 2023
1 parent 5a8209b commit 0e64a96
Show file tree
Hide file tree
Showing 14 changed files with 269 additions and 10 deletions.
2 changes: 1 addition & 1 deletion cmd/ostracon/commands/show_validator.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ func showValidator(cmd *cobra.Command, args []string, config *cfg.Config) error
if err != nil {
return err
}
pv, err = node.CreateAndStartPrivValidatorSocketClient(config.PrivValidatorListenAddr, chainID, logger)
pv, err = node.CreateAndStartPrivValidatorSocketClient(config, chainID, logger)
if err != nil {
return err
}
Expand Down
2 changes: 2 additions & 0 deletions cmd/ostracon/commands/show_validator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package commands
import (
"bytes"
"os"
"strings"
"sync"
"testing"

Expand Down Expand Up @@ -79,6 +80,7 @@ func TestShowValidatorWithKMS(t *testing.T) {
}
privval.WithMockKMS(t, dir, chainID, func(addr string, privKey crypto.PrivKey) {
config.PrivValidatorListenAddr = addr
config.PrivValidatorRemoteAddr = addr[:strings.Index(addr, ":")]
require.NoFileExists(t, config.PrivValidatorKeyFile())
output, err := captureStdout(func() {
err := showValidator(ShowValidatorCmd, nil, config)
Expand Down
6 changes: 6 additions & 0 deletions config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -242,8 +242,14 @@ type BaseConfig struct { //nolint: maligned

// TCP or UNIX socket address for Ostracon to listen on for
// connections from an external PrivValidator process
// example) tcp://0.0.0.0:26659
PrivValidatorListenAddr string `mapstructure:"priv_validator_laddr"`

// Validator's remote address(without port) to allow a connection
// ostracon only allow a connection from this address
// example) 10.0.0.7
PrivValidatorRemoteAddr string `mapstructure:"priv_validator_raddr"`

// A JSON file containing the private key to use for p2p authenticated encryption
NodeKey string `mapstructure:"node_key_file"`

Expand Down
6 changes: 6 additions & 0 deletions config/toml.go
Original file line number Diff line number Diff line change
Expand Up @@ -156,8 +156,14 @@ priv_validator_state_file = "{{ js .BaseConfig.PrivValidatorState }}"
# TCP or UNIX socket address for Ostracon to listen on for
# connections from an external PrivValidator process
# example) tcp://0.0.0.0:26659
priv_validator_laddr = "{{ .BaseConfig.PrivValidatorListenAddr }}"
# Validator's remote address to allow a connection
# ostracon only allow a connection from this address
# example) 127.0.0.1
priv_validator_raddr = "127.0.0.1"
# Path to the JSON file containing the private key to use for node authentication in the p2p protocol
node_key_file = "{{ js .BaseConfig.NodeKey }}"
Expand Down
10 changes: 3 additions & 7 deletions node/node.go
Original file line number Diff line number Diff line change
Expand Up @@ -794,7 +794,7 @@ func NewNode(config *cfg.Config,
// external signing process.
if config.PrivValidatorListenAddr != "" {
// FIXME: we should start services inside OnStart
privValidator, err = CreateAndStartPrivValidatorSocketClient(config.PrivValidatorListenAddr, genDoc.ChainID, logger)
privValidator, err = CreateAndStartPrivValidatorSocketClient(config, genDoc.ChainID, logger)
if err != nil {
return nil, fmt.Errorf("error with private validator socket client: %w", err)
}
Expand Down Expand Up @@ -1523,12 +1523,8 @@ func saveGenesisDoc(db dbm.DB, genDoc *types.GenesisDoc) error {
return nil
}

func CreateAndStartPrivValidatorSocketClient(
listenAddr,
chainID string,
logger log.Logger,
) (types.PrivValidator, error) {
pve, err := privval.NewSignerListener(listenAddr, logger)
func CreateAndStartPrivValidatorSocketClient(config *cfg.Config, chainID string, logger log.Logger) (types.PrivValidator, error) {
pve, err := privval.NewSignerListener(logger, config.PrivValidatorListenAddr, config.PrivValidatorRemoteAddr)
if err != nil {
return nil, fmt.Errorf("failed to start private validator: %w", err)
}
Expand Down
6 changes: 6 additions & 0 deletions node/node_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -160,11 +160,17 @@ func TestNodeSetAppVersion(t *testing.T) {
}

func TestNodeSetPrivValTCP(t *testing.T) {
address := testFreeAddr(t)
addr := "tcp://" + testFreeAddr(t)

config := cfg.ResetTestRoot("node_priv_val_tcp_test")
defer os.RemoveAll(config.RootDir)
config.BaseConfig.PrivValidatorListenAddr = addr
addrPart, _, err := net.SplitHostPort(address)
if err != nil {
return
}
config.BaseConfig.PrivValidatorRemoteAddr = addrPart

dialer := privval.DialTCPFn(addr, 100*time.Millisecond, ed25519.GenPrivKey())
dialerEndpoint := privval.NewSignerDialerEndpoint(
Expand Down
8 changes: 8 additions & 0 deletions privval/internal/conn_filter.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
package internal

import "net"

type ConnectionFilter interface {
Filter(addr net.Addr) net.Addr
String() string
}
44 changes: 44 additions & 0 deletions privval/internal/ip_filter.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
package internal

import (
"fmt"
"github.com/Finschia/ostracon/libs/log"
"net"
)

type IpFilter struct {
allowAddr string
log log.Logger
}

func NewIpFilter(addr string, l log.Logger) *IpFilter {
return &IpFilter{
allowAddr: addr,
log: l,
}
}

func (f *IpFilter) Filter(addr net.Addr) net.Addr {
if f.isAllowedAddr(addr) {
return addr
}
return nil
}

func (f *IpFilter) String() string {
return f.allowAddr
}

func (f *IpFilter) isAllowedAddr(addr net.Addr) bool {
if len(f.allowAddr) == 0 {
return false
}
hostAddr, _, err := net.SplitHostPort(addr.String())
if err != nil {
if f.log != nil {
f.log.Error(fmt.Sprintf("IpFilter: can't split host and port from addr.String()=%s", addr.String()))
}
return false
}
return f.allowAddr == hostAddr
}
91 changes: 91 additions & 0 deletions privval/internal/ip_filter_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
package internal

import (
"github.com/stretchr/testify/assert"
"net"
"testing"
)

type addrStub struct {
address string
}

func (a addrStub) Network() string {
return ""
}

func (a addrStub) String() string {
return a.address
}

func TestFilterRemoteConnectionByIP(t *testing.T) {
type fields struct {
allowIP string
remoteAddr net.Addr
expected net.Addr
}
tests := []struct {
name string
fields fields
}{
{
"should allow correct ip",
struct {
allowIP string
remoteAddr net.Addr
expected net.Addr
}{"127.0.0.1", addrStub{"127.0.0.1:45678"}, addrStub{"127.0.0.1:45678"}},
},
{
"should not allow different ip",
struct {
allowIP string
remoteAddr net.Addr
expected net.Addr
}{"127.0.0.1", addrStub{"10.0.0.2:45678"}, nil},
},
{
"should works for IPv6 with correct ip",
struct {
allowIP string
remoteAddr net.Addr
expected net.Addr
}{"2001:db8::1", addrStub{"[2001:db8::1]:80"}, addrStub{"[2001:db8::1]:80"}},
},
{
"should works for IPv6 with incorrect ip",
struct {
allowIP string
remoteAddr net.Addr
expected net.Addr
}{"2001:db8::2", addrStub{"[2001:db8::1]:80"}, nil},
},
{
"empty allowIP should deny all",
struct {
allowIP string
remoteAddr net.Addr
expected net.Addr
}{"", addrStub{"127.0.0.1:45678"}, nil},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cut := NewIpFilter(tt.fields.allowIP, nil)
assert.Equalf(t, tt.fields.expected, cut.Filter(tt.fields.remoteAddr), tt.name)
})
}
}

func TestIpFilterShouldSetAllowAddress(t *testing.T) {
expected := "192.168.0.1"

cut := NewIpFilter(expected, nil)

assert.Equal(t, expected, cut.allowAddr)
}

func TestIpFilterStringShouldReturnsIP(t *testing.T) {
expected := "127.0.0.1"
assert.Equal(t, expected, NewIpFilter(expected, nil).String())
}
19 changes: 19 additions & 0 deletions privval/internal/null_object_filter.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
package internal

import "net"

// NullObject is null object pattern. It does nothing
type NullObject struct {
}

func NewNullObject() *NullObject {
return &NullObject{}
}

func (n NullObject) Filter(addr net.Addr) net.Addr {
return addr
}

func (n NullObject) String() string {
return "NullObject"
}
40 changes: 40 additions & 0 deletions privval/internal/null_object_filter_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
package internal

import (
"github.com/stretchr/testify/assert"
"net"
"reflect"
"testing"
)

func TestNullObject_filter(t *testing.T) {
stubInput := addrStub{}
tests := []struct {
name string
addr net.Addr
want net.Addr
}{
{
name: "null object does nothing, returns what it receives",
addr: stubInput,
want: stubInput,
},
{
name: "null object does nothing, returns nil it receives nil",
addr: nil,
want: nil,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
n := NewNullObject()
if got := n.Filter(tt.addr); !reflect.DeepEqual(got, tt.want) {
t.Errorf("Filter() = %v, want %v", got, tt.want)
}
})
}
}

func TestNullObjectString(t *testing.T) {
assert.Equal(t, "NullObject", NewNullObject().String())
}
28 changes: 28 additions & 0 deletions privval/signer_listener_endpoint.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package privval

import (
"fmt"
"github.com/Finschia/ostracon/privval/internal"
"net"
"time"

Expand All @@ -24,6 +25,19 @@ func SignerListenerEndpointTimeoutReadWrite(timeout time.Duration) SignerListene
return func(sl *SignerListenerEndpoint) { sl.signerEndpoint.timeoutReadWrite = timeout }
}

// SignerListenerEndpointAllowAddress sets the address to allow
// connections from the only allowed address
//
func SignerListenerEndpointAllowAddress(protocol string, addr string) SignerListenerEndpointOption {
return func(sl *SignerListenerEndpoint) {
if protocol == "tcp" || len(protocol) == 0 {
sl.connFilter = internal.NewIpFilter(addr, sl.Logger)
return
}
sl.connFilter = internal.NewNullObject()
}
}

// SignerListenerEndpoint listens for an external process to dial in and keeps
// the connection alive by dropping and reconnecting.
//
Expand All @@ -41,6 +55,7 @@ type SignerListenerEndpoint struct {
pingInterval time.Duration

instanceMtx tmsync.Mutex // Ensures instance public methods access, i.e. SendRequest
connFilter internal.ConnectionFilter
}

// NewSignerListenerEndpoint returns an instance of SignerListenerEndpoint.
Expand Down Expand Up @@ -186,6 +201,12 @@ func (sl *SignerListenerEndpoint) serviceLoop() {
{
conn, err := sl.acceptNewConnection()
if err == nil {
remoteAddr := conn.RemoteAddr()
if sl.filter(remoteAddr) == nil {
sl.Logger.Info(fmt.Sprintf("SignerListener: deny a connection request from remote address=%s, expected=%s", remoteAddr, sl.connFilter))
conn.Close()
continue
}
sl.Logger.Info("SignerListener: Connected")

// We have a good connection, wait for someone that needs one otherwise cancellation
Expand All @@ -207,6 +228,13 @@ func (sl *SignerListenerEndpoint) serviceLoop() {
}
}

func (sl *SignerListenerEndpoint) filter(addr net.Addr) net.Addr {
if sl.connFilter == nil {
return addr
}
return sl.connFilter.Filter(addr)
}

func (sl *SignerListenerEndpoint) pingLoop() {
for {
select {
Expand Down
13 changes: 13 additions & 0 deletions privval/signer_listener_endpoint_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package privval

import (
"github.com/Finschia/ostracon/privval/internal"
"net"
"testing"
"time"
Expand Down Expand Up @@ -213,3 +214,15 @@ func getMockEndpoints(

return listenerEndpoint, dialerEndpoint
}

func TestSignerListenerEndpointAllowAddressSetIpFilterForTCP(t *testing.T) {
cut := NewSignerListenerEndpoint(nil, nil, SignerListenerEndpointAllowAddress("tcp", "127.0.0.1"))
_, ok := cut.connFilter.(*internal.IpFilter)
assert.True(t, ok)
}

func TestSignerListenerEndpointAllowAddressSetNullObjectFilterForUDS(t *testing.T) {
cut := NewSignerListenerEndpoint(nil, nil, SignerListenerEndpointAllowAddress("unix", "/mnt/uds/sock01"))
_, ok := cut.connFilter.(*internal.NullObject)
assert.True(t, ok)
}
Loading

0 comments on commit 0e64a96

Please sign in to comment.