From 041f37ab0298110004e888e4c6b80c9069257bd9 Mon Sep 17 00:00:00 2001 From: "jaeseung.bae" Date: Tue, 8 Aug 2023 11:43:41 +0900 Subject: [PATCH 1/5] feat: add kms remote IP filter --- cmd/ostracon/commands/show_validator.go | 2 +- config/config.go | 6 +++ config/toml.go | 6 +++ node/node.go | 10 ++--- privval/signer_listener_endpoint.go | 22 ++++++++++ privval/signer_listener_endpoint_test.go | 55 ++++++++++++++++++++++++ privval/utils.go | 4 +- 7 files changed, 95 insertions(+), 10 deletions(-) diff --git a/cmd/ostracon/commands/show_validator.go b/cmd/ostracon/commands/show_validator.go index e19914b60..43b31d041 100644 --- a/cmd/ostracon/commands/show_validator.go +++ b/cmd/ostracon/commands/show_validator.go @@ -31,7 +31,7 @@ func showValidator(cmd *cobra.Command, args []string, config *cfg.Config) error if err != nil { return err } - pv, err = node.CreateAndStartPrivValidatorSocketClient(config.PrivValidatorListenAddr, chainID, logger) + pv, err = node.CreateAndStartPrivValidatorSocketClient(config, chainID, logger) if err != nil { return err } diff --git a/config/config.go b/config/config.go index a39a65ad9..a6be7ec27 100644 --- a/config/config.go +++ b/config/config.go @@ -242,8 +242,14 @@ type BaseConfig struct { //nolint: maligned // TCP or UNIX socket address for Ostracon to listen on for // connections from an external PrivValidator process + // example) 0.0.0.0:26659 PrivValidatorListenAddr string `mapstructure:"priv_validator_laddr"` + // Validator's remote address(without port) to allow a connection + // ostracon only allow a connection from this address + // example) 10.0.0.7 + PrivValidatorRemoteAddr string `mapstructure:"priv_validator_raddr"` + // A JSON file containing the private key to use for p2p authenticated encryption NodeKey string `mapstructure:"node_key_file"` diff --git a/config/toml.go b/config/toml.go index 2f52e4c2e..2c2df7021 100644 --- a/config/toml.go +++ b/config/toml.go @@ -156,8 +156,14 @@ priv_validator_state_file = "{{ js .BaseConfig.PrivValidatorState }}" # TCP or UNIX socket address for Ostracon to listen on for # connections from an external PrivValidator process +# example) 0.0.0.0:26659 priv_validator_laddr = "{{ .BaseConfig.PrivValidatorListenAddr }}" +# Validator's remote address to allow a connection +# ostracon only allow a connection from this address +# example) 10.0.0.7 +priv_validator_raddr = "{{ .BaseConfig.PrivValidatorRemoteAddr }}" + # Path to the JSON file containing the private key to use for node authentication in the p2p protocol node_key_file = "{{ js .BaseConfig.NodeKey }}" diff --git a/node/node.go b/node/node.go index 5db7dc43f..7bad342bd 100644 --- a/node/node.go +++ b/node/node.go @@ -794,7 +794,7 @@ func NewNode(config *cfg.Config, // external signing process. if config.PrivValidatorListenAddr != "" { // FIXME: we should start services inside OnStart - privValidator, err = CreateAndStartPrivValidatorSocketClient(config.PrivValidatorListenAddr, genDoc.ChainID, logger) + privValidator, err = CreateAndStartPrivValidatorSocketClient(config, genDoc.ChainID, logger) if err != nil { return nil, fmt.Errorf("error with private validator socket client: %w", err) } @@ -1523,12 +1523,8 @@ func saveGenesisDoc(db dbm.DB, genDoc *types.GenesisDoc) error { return nil } -func CreateAndStartPrivValidatorSocketClient( - listenAddr, - chainID string, - logger log.Logger, -) (types.PrivValidator, error) { - pve, err := privval.NewSignerListener(listenAddr, logger) +func CreateAndStartPrivValidatorSocketClient(config *cfg.Config, chainID string, logger log.Logger) (types.PrivValidator, error) { + pve, err := privval.NewSignerListener(logger, config.PrivValidatorListenAddr, config.PrivValidatorRemoteAddr) if err != nil { return nil, fmt.Errorf("failed to start private validator: %w", err) } diff --git a/privval/signer_listener_endpoint.go b/privval/signer_listener_endpoint.go index 030b4de4d..c9f06fa3d 100644 --- a/privval/signer_listener_endpoint.go +++ b/privval/signer_listener_endpoint.go @@ -3,6 +3,7 @@ package privval import ( "fmt" "net" + "strings" "time" privvalproto "github.com/tendermint/tendermint/proto/tendermint/privval" @@ -24,6 +25,13 @@ func SignerListenerEndpointTimeoutReadWrite(timeout time.Duration) SignerListene return func(sl *SignerListenerEndpoint) { sl.signerEndpoint.timeoutReadWrite = timeout } } +// SignerListenerEndpointAllowAddress sets the address to allow +// connections from the only allowed address +// +func SignerListenerEndpointAllowAddress(addr string) SignerListenerEndpointOption { + return func(sl *SignerListenerEndpoint) { sl.allowAddr = addr } +} + // SignerListenerEndpoint listens for an external process to dial in and keeps // the connection alive by dropping and reconnecting. // @@ -41,6 +49,8 @@ type SignerListenerEndpoint struct { pingInterval time.Duration instanceMtx tmsync.Mutex // Ensures instance public methods access, i.e. SendRequest + + allowAddr string } // NewSignerListenerEndpoint returns an instance of SignerListenerEndpoint. @@ -185,6 +195,13 @@ func (sl *SignerListenerEndpoint) serviceLoop() { case <-sl.connectRequestCh: { conn, err := sl.acceptNewConnection() + remoteAddr := conn.RemoteAddr() + if !sl.isAllowedAddr(remoteAddr) { + sl.Logger.Info(fmt.Sprintf("deny a connection request from remote address=%s", remoteAddr)) + conn.Close() + continue + } + if err == nil { sl.Logger.Info("SignerListener: Connected") @@ -207,6 +224,11 @@ func (sl *SignerListenerEndpoint) serviceLoop() { } } +func (sl *SignerListenerEndpoint) isAllowedAddr(addr net.Addr) bool { + addrOnly := addr.String()[:strings.Index(addr.String(), ":")] + return sl.allowAddr == addrOnly +} + func (sl *SignerListenerEndpoint) pingLoop() { for { select { diff --git a/privval/signer_listener_endpoint_test.go b/privval/signer_listener_endpoint_test.go index 27a707b74..dcadaa320 100644 --- a/privval/signer_listener_endpoint_test.go +++ b/privval/signer_listener_endpoint_test.go @@ -145,6 +145,61 @@ func TestRetryConnToRemoteSigner(t *testing.T) { } } +type addrStub struct { + address string +} + +func (a addrStub) Network() string { + return "" +} + +func (a addrStub) String() string { + return a.address +} + +func TestFilterRemoteConnectionByIP(t *testing.T) { + type fields struct { + allowIP string + remoteAddr net.Addr + expected bool + } + tests := []struct { + name string + fields fields + }{ + { + "should allow correct ip", + struct { + allowIP string + remoteAddr net.Addr + expected bool + }{"127.0.0.1", addrStub{"127.0.0.1:45678"}, true}, + }, + { + "should not allow different ip", + struct { + allowIP string + remoteAddr net.Addr + expected bool + }{"127.0.0.1", addrStub{"10.0.0.2:45678"}, false}, + }, + { + "empty allowIP should deny all", + struct { + allowIP string + remoteAddr net.Addr + expected bool + }{"", addrStub{"127.0.0.1:45678"}, false}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + sl := &SignerListenerEndpoint{allowAddr: tt.fields.allowIP} + assert.Equalf(t, tt.fields.expected, sl.isAllowedAddr(tt.fields.remoteAddr), tt.name) + }) + } +} + func newSignerListenerEndpoint(logger log.Logger, addr string, timeoutReadWrite time.Duration) *SignerListenerEndpoint { proto, address := tmnet.ProtocolAndAddress(addr) diff --git a/privval/utils.go b/privval/utils.go index fe52ec7e2..cca41c050 100644 --- a/privval/utils.go +++ b/privval/utils.go @@ -26,7 +26,7 @@ func IsConnTimeout(err error) bool { } // NewSignerListener creates a new SignerListenerEndpoint using the corresponding listen address -func NewSignerListener(listenAddr string, logger log.Logger) (*SignerListenerEndpoint, error) { +func NewSignerListener(logger log.Logger, listenAddr, remoteAddr string) (*SignerListenerEndpoint, error) { var listener net.Listener protocol, address := tmnet.ProtocolAndAddress(listenAddr) @@ -47,7 +47,7 @@ func NewSignerListener(listenAddr string, logger log.Logger) (*SignerListenerEnd ) } - pve := NewSignerListenerEndpoint(logger.With("module", "privval"), listener) + pve := NewSignerListenerEndpoint(logger.With("module", "privval"), listener, SignerListenerEndpointAllowAddress(remoteAddr)) return pve, nil } From 59838a016933fe7c68570439a96f475044f28eb6 Mon Sep 17 00:00:00 2001 From: "jaeseung.bae" Date: Tue, 8 Aug 2023 18:04:59 +0900 Subject: [PATCH 2/5] fix: side-effect for testing by allowing any connection if remote address is empty --- cmd/ostracon/commands/show_validator_test.go | 2 ++ privval/signer_listener_endpoint.go | 25 ++++++++++++-------- privval/signer_listener_endpoint_test.go | 4 ++-- 3 files changed, 19 insertions(+), 12 deletions(-) diff --git a/cmd/ostracon/commands/show_validator_test.go b/cmd/ostracon/commands/show_validator_test.go index 9b4c43bf7..5f708d428 100644 --- a/cmd/ostracon/commands/show_validator_test.go +++ b/cmd/ostracon/commands/show_validator_test.go @@ -3,6 +3,7 @@ package commands import ( "bytes" "os" + "strings" "sync" "testing" @@ -79,6 +80,7 @@ func TestShowValidatorWithKMS(t *testing.T) { } privval.WithMockKMS(t, dir, chainID, func(addr string, privKey crypto.PrivKey) { config.PrivValidatorListenAddr = addr + config.PrivValidatorRemoteAddr = addr[:strings.Index(addr, ":")] require.NoFileExists(t, config.PrivValidatorKeyFile()) output, err := captureStdout(func() { err := showValidator(ShowValidatorCmd, nil, config) diff --git a/privval/signer_listener_endpoint.go b/privval/signer_listener_endpoint.go index c9f06fa3d..981cdaedb 100644 --- a/privval/signer_listener_endpoint.go +++ b/privval/signer_listener_endpoint.go @@ -50,7 +50,7 @@ type SignerListenerEndpoint struct { instanceMtx tmsync.Mutex // Ensures instance public methods access, i.e. SendRequest - allowAddr string + allowAddr string // empty value allows all } // NewSignerListenerEndpoint returns an instance of SignerListenerEndpoint. @@ -195,14 +195,13 @@ func (sl *SignerListenerEndpoint) serviceLoop() { case <-sl.connectRequestCh: { conn, err := sl.acceptNewConnection() - remoteAddr := conn.RemoteAddr() - if !sl.isAllowedAddr(remoteAddr) { - sl.Logger.Info(fmt.Sprintf("deny a connection request from remote address=%s", remoteAddr)) - conn.Close() - continue - } - if err == nil { + remoteAddr := conn.RemoteAddr() + if !sl.isAllowedAddr(remoteAddr) { + sl.Logger.Info(fmt.Sprintf("deny a connection request from remote address=%s", remoteAddr)) + conn.Close() + continue + } sl.Logger.Info("SignerListener: Connected") // We have a good connection, wait for someone that needs one otherwise cancellation @@ -225,8 +224,14 @@ func (sl *SignerListenerEndpoint) serviceLoop() { } func (sl *SignerListenerEndpoint) isAllowedAddr(addr net.Addr) bool { - addrOnly := addr.String()[:strings.Index(addr.String(), ":")] - return sl.allowAddr == addrOnly + if len(sl.allowAddr) == 0 { + return true + } + if strings.Contains(addr.String(), ":") { + addrOnly := addr.String()[:strings.Index(addr.String(), ":")] + return sl.allowAddr == addrOnly + } + return sl.allowAddr == addr.String() } func (sl *SignerListenerEndpoint) pingLoop() { diff --git a/privval/signer_listener_endpoint_test.go b/privval/signer_listener_endpoint_test.go index dcadaa320..4bd620898 100644 --- a/privval/signer_listener_endpoint_test.go +++ b/privval/signer_listener_endpoint_test.go @@ -184,12 +184,12 @@ func TestFilterRemoteConnectionByIP(t *testing.T) { }{"127.0.0.1", addrStub{"10.0.0.2:45678"}, false}, }, { - "empty allowIP should deny all", + "empty allowIP should allow all", struct { allowIP string remoteAddr net.Addr expected bool - }{"", addrStub{"127.0.0.1:45678"}, false}, + }{"", addrStub{"127.0.0.1:45678"}, true}, }, } for _, tt := range tests { From 548562c18b0d8967831112fc941e454d4c5765d7 Mon Sep 17 00:00:00 2001 From: "jaeseung.bae" Date: Tue, 8 Aug 2023 18:28:33 +0900 Subject: [PATCH 3/5] chore: add some tests --- privval/signer_listener_endpoint_test.go | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/privval/signer_listener_endpoint_test.go b/privval/signer_listener_endpoint_test.go index 4bd620898..b118c28f2 100644 --- a/privval/signer_listener_endpoint_test.go +++ b/privval/signer_listener_endpoint_test.go @@ -174,6 +174,13 @@ func TestFilterRemoteConnectionByIP(t *testing.T) { remoteAddr net.Addr expected bool }{"127.0.0.1", addrStub{"127.0.0.1:45678"}, true}, + }, { + "should allow correct ip without port", + struct { + allowIP string + remoteAddr net.Addr + expected bool + }{"127.0.0.1", addrStub{"127.0.0.1"}, true}, }, { "should not allow different ip", @@ -200,6 +207,14 @@ func TestFilterRemoteConnectionByIP(t *testing.T) { } } +func TestSignerListenerEndpointAllowAddress(t *testing.T) { + expected := "192.168.0.1" + + cut := NewSignerListenerEndpoint(nil, nil, SignerListenerEndpointAllowAddress(expected)) + + assert.Equal(t, expected, cut.allowAddr) +} + func newSignerListenerEndpoint(logger log.Logger, addr string, timeoutReadWrite time.Duration) *SignerListenerEndpoint { proto, address := tmnet.ProtocolAndAddress(addr) From 8bdabf47c9ea5b0d8bc1493367479208193a6c30 Mon Sep 17 00:00:00 2001 From: "jaeseung.bae" Date: Fri, 11 Aug 2023 15:14:01 +0900 Subject: [PATCH 4/5] feat: apply denyAll if empty allow address, and null object pattern for "unix" connection type --- node/node_test.go | 6 ++ privval/internal/conn_filter.go | 8 ++ privval/internal/ip_filter.go | 44 +++++++++++ privval/internal/ip_filter_test.go | 86 +++++++++++++++++++++ privval/internal/null_object_filter.go | 19 +++++ privval/internal/null_object_filter_test.go | 35 +++++++++ privval/signer_listener_endpoint.go | 31 ++++---- privval/signer_listener_endpoint_test.go | 70 ----------------- privval/utils.go | 2 +- 9 files changed, 215 insertions(+), 86 deletions(-) create mode 100644 privval/internal/conn_filter.go create mode 100644 privval/internal/ip_filter.go create mode 100644 privval/internal/ip_filter_test.go create mode 100644 privval/internal/null_object_filter.go create mode 100644 privval/internal/null_object_filter_test.go diff --git a/node/node_test.go b/node/node_test.go index aba8fa313..f7d0e4f16 100644 --- a/node/node_test.go +++ b/node/node_test.go @@ -160,11 +160,17 @@ func TestNodeSetAppVersion(t *testing.T) { } func TestNodeSetPrivValTCP(t *testing.T) { + address := testFreeAddr(t) addr := "tcp://" + testFreeAddr(t) config := cfg.ResetTestRoot("node_priv_val_tcp_test") defer os.RemoveAll(config.RootDir) config.BaseConfig.PrivValidatorListenAddr = addr + addrPart, _, err := net.SplitHostPort(address) + if err != nil { + return + } + config.BaseConfig.PrivValidatorRemoteAddr = addrPart dialer := privval.DialTCPFn(addr, 100*time.Millisecond, ed25519.GenPrivKey()) dialerEndpoint := privval.NewSignerDialerEndpoint( diff --git a/privval/internal/conn_filter.go b/privval/internal/conn_filter.go new file mode 100644 index 000000000..405cb8b2b --- /dev/null +++ b/privval/internal/conn_filter.go @@ -0,0 +1,8 @@ +package internal + +import "net" + +type ConnectionFilter interface { + Filter(addr net.Addr) net.Addr + String() string +} diff --git a/privval/internal/ip_filter.go b/privval/internal/ip_filter.go new file mode 100644 index 000000000..e9f7cceb8 --- /dev/null +++ b/privval/internal/ip_filter.go @@ -0,0 +1,44 @@ +package internal + +import ( + "fmt" + "github.com/Finschia/ostracon/libs/log" + "net" +) + +type IpFilter struct { + allowAddr string + log log.Logger +} + +func NewIpFilter(addr string, l log.Logger) *IpFilter { + return &IpFilter{ + allowAddr: addr, + log: l, + } +} + +func (f *IpFilter) Filter(addr net.Addr) net.Addr { + if f.isAllowedAddr(addr) { + return addr + } + return nil +} + +func (f *IpFilter) String() string { + return f.allowAddr +} + +func (f *IpFilter) isAllowedAddr(addr net.Addr) bool { + if len(f.allowAddr) == 0 { + return false + } + hostAddr, _, err := net.SplitHostPort(addr.String()) + if err != nil { + if f.log != nil { + f.log.Error(fmt.Sprintf("IpFilter: can't split host and port from addr.String()=%s", addr.String())) + } + return false + } + return f.allowAddr == hostAddr +} diff --git a/privval/internal/ip_filter_test.go b/privval/internal/ip_filter_test.go new file mode 100644 index 000000000..b8b8dad8a --- /dev/null +++ b/privval/internal/ip_filter_test.go @@ -0,0 +1,86 @@ +package internal + +import ( + "github.com/stretchr/testify/assert" + "net" + "testing" +) + +type addrStub struct { + address string +} + +func (a addrStub) Network() string { + return "" +} + +func (a addrStub) String() string { + return a.address +} + +func TestFilterRemoteConnectionByIP(t *testing.T) { + type fields struct { + allowIP string + remoteAddr net.Addr + expected bool + } + tests := []struct { + name string + fields fields + }{ + { + "should allow correct ip", + struct { + allowIP string + remoteAddr net.Addr + expected bool + }{"127.0.0.1", addrStub{"127.0.0.1:45678"}, true}, + }, + { + "should not allow different ip", + struct { + allowIP string + remoteAddr net.Addr + expected bool + }{"127.0.0.1", addrStub{"10.0.0.2:45678"}, false}, + }, + { + "should works for IPv6 with correct ip", + struct { + allowIP string + remoteAddr net.Addr + expected bool + }{"2001:db8::1", addrStub{"[2001:db8::1]:80"}, true}, + }, + { + "should works for IPv6 with incorrect ip", + struct { + allowIP string + remoteAddr net.Addr + expected bool + }{"2001:db8::2", addrStub{"[2001:db8::1]:80"}, false}, + }, + { + "empty allowIP should deny all", + struct { + allowIP string + remoteAddr net.Addr + expected bool + }{"", addrStub{"127.0.0.1:45678"}, false}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cut := NewIpFilter(tt.fields.allowIP, nil) + assert.Equalf(t, tt.fields.expected, cut.isAllowedAddr(tt.fields.remoteAddr), tt.name) + }) + } +} + +func TestIpFilterShouldSetAllowAddress(t *testing.T) { + expected := "192.168.0.1" + + cut := NewIpFilter(expected, nil) + + assert.Equal(t, expected, cut.allowAddr) +} diff --git a/privval/internal/null_object_filter.go b/privval/internal/null_object_filter.go new file mode 100644 index 000000000..f3ee3367e --- /dev/null +++ b/privval/internal/null_object_filter.go @@ -0,0 +1,19 @@ +package internal + +import "net" + +// NullObject is null object pattern. It does nothing +type NullObject struct { +} + +func NewNullObject() NullObject { + return NullObject{} +} + +func (n NullObject) Filter(addr net.Addr) net.Addr { + return addr +} + +func (n NullObject) String() string { + return "NullObject" +} diff --git a/privval/internal/null_object_filter_test.go b/privval/internal/null_object_filter_test.go new file mode 100644 index 000000000..44dcbf691 --- /dev/null +++ b/privval/internal/null_object_filter_test.go @@ -0,0 +1,35 @@ +package internal + +import ( + "net" + "reflect" + "testing" +) + +func TestNullObject_filter(t *testing.T) { + stubInput := addrStub{} + tests := []struct { + name string + addr net.Addr + want net.Addr + }{ + { + name: "null object does nothing, returns what it receives", + addr: stubInput, + want: stubInput, + }, + { + name: "null object does nothing, returns nil it receives nil", + addr: nil, + want: nil, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + n := &NullObject{} + if got := n.Filter(tt.addr); !reflect.DeepEqual(got, tt.want) { + t.Errorf("Filter() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/privval/signer_listener_endpoint.go b/privval/signer_listener_endpoint.go index 981cdaedb..c2c96642d 100644 --- a/privval/signer_listener_endpoint.go +++ b/privval/signer_listener_endpoint.go @@ -2,8 +2,8 @@ package privval import ( "fmt" + "github.com/Finschia/ostracon/privval/internal" "net" - "strings" "time" privvalproto "github.com/tendermint/tendermint/proto/tendermint/privval" @@ -28,8 +28,14 @@ func SignerListenerEndpointTimeoutReadWrite(timeout time.Duration) SignerListene // SignerListenerEndpointAllowAddress sets the address to allow // connections from the only allowed address // -func SignerListenerEndpointAllowAddress(addr string) SignerListenerEndpointOption { - return func(sl *SignerListenerEndpoint) { sl.allowAddr = addr } +func SignerListenerEndpointAllowAddress(protocol string, addr string) SignerListenerEndpointOption { + return func(sl *SignerListenerEndpoint) { + if protocol == "tcp" || len(protocol) == 0 { + sl.connFilter = internal.NewIpFilter(addr, sl.Logger) + return + } + sl.connFilter = internal.NewNullObject() + } } // SignerListenerEndpoint listens for an external process to dial in and keeps @@ -49,8 +55,7 @@ type SignerListenerEndpoint struct { pingInterval time.Duration instanceMtx tmsync.Mutex // Ensures instance public methods access, i.e. SendRequest - - allowAddr string // empty value allows all + connFilter internal.ConnectionFilter } // NewSignerListenerEndpoint returns an instance of SignerListenerEndpoint. @@ -197,8 +202,8 @@ func (sl *SignerListenerEndpoint) serviceLoop() { conn, err := sl.acceptNewConnection() if err == nil { remoteAddr := conn.RemoteAddr() - if !sl.isAllowedAddr(remoteAddr) { - sl.Logger.Info(fmt.Sprintf("deny a connection request from remote address=%s", remoteAddr)) + if sl.filter(remoteAddr) == nil { + sl.Logger.Info(fmt.Sprintf("SignerListener: deny a connection request from remote address=%s, expected=%s", remoteAddr, sl.connFilter)) conn.Close() continue } @@ -223,15 +228,11 @@ func (sl *SignerListenerEndpoint) serviceLoop() { } } -func (sl *SignerListenerEndpoint) isAllowedAddr(addr net.Addr) bool { - if len(sl.allowAddr) == 0 { - return true - } - if strings.Contains(addr.String(), ":") { - addrOnly := addr.String()[:strings.Index(addr.String(), ":")] - return sl.allowAddr == addrOnly +func (sl *SignerListenerEndpoint) filter(addr net.Addr) net.Addr { + if sl.connFilter == nil { + return addr } - return sl.allowAddr == addr.String() + return sl.connFilter.Filter(addr) } func (sl *SignerListenerEndpoint) pingLoop() { diff --git a/privval/signer_listener_endpoint_test.go b/privval/signer_listener_endpoint_test.go index b118c28f2..27a707b74 100644 --- a/privval/signer_listener_endpoint_test.go +++ b/privval/signer_listener_endpoint_test.go @@ -145,76 +145,6 @@ func TestRetryConnToRemoteSigner(t *testing.T) { } } -type addrStub struct { - address string -} - -func (a addrStub) Network() string { - return "" -} - -func (a addrStub) String() string { - return a.address -} - -func TestFilterRemoteConnectionByIP(t *testing.T) { - type fields struct { - allowIP string - remoteAddr net.Addr - expected bool - } - tests := []struct { - name string - fields fields - }{ - { - "should allow correct ip", - struct { - allowIP string - remoteAddr net.Addr - expected bool - }{"127.0.0.1", addrStub{"127.0.0.1:45678"}, true}, - }, { - "should allow correct ip without port", - struct { - allowIP string - remoteAddr net.Addr - expected bool - }{"127.0.0.1", addrStub{"127.0.0.1"}, true}, - }, - { - "should not allow different ip", - struct { - allowIP string - remoteAddr net.Addr - expected bool - }{"127.0.0.1", addrStub{"10.0.0.2:45678"}, false}, - }, - { - "empty allowIP should allow all", - struct { - allowIP string - remoteAddr net.Addr - expected bool - }{"", addrStub{"127.0.0.1:45678"}, true}, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - sl := &SignerListenerEndpoint{allowAddr: tt.fields.allowIP} - assert.Equalf(t, tt.fields.expected, sl.isAllowedAddr(tt.fields.remoteAddr), tt.name) - }) - } -} - -func TestSignerListenerEndpointAllowAddress(t *testing.T) { - expected := "192.168.0.1" - - cut := NewSignerListenerEndpoint(nil, nil, SignerListenerEndpointAllowAddress(expected)) - - assert.Equal(t, expected, cut.allowAddr) -} - func newSignerListenerEndpoint(logger log.Logger, addr string, timeoutReadWrite time.Duration) *SignerListenerEndpoint { proto, address := tmnet.ProtocolAndAddress(addr) diff --git a/privval/utils.go b/privval/utils.go index cca41c050..34607235c 100644 --- a/privval/utils.go +++ b/privval/utils.go @@ -47,7 +47,7 @@ func NewSignerListener(logger log.Logger, listenAddr, remoteAddr string) (*Signe ) } - pve := NewSignerListenerEndpoint(logger.With("module", "privval"), listener, SignerListenerEndpointAllowAddress(remoteAddr)) + pve := NewSignerListenerEndpoint(logger.With("module", "privval"), listener, SignerListenerEndpointAllowAddress(protocol, remoteAddr)) return pve, nil } From 44a8045a1f380764e02fdf528081572ca98d7e32 Mon Sep 17 00:00:00 2001 From: "jaeseung.bae" Date: Fri, 11 Aug 2023 18:37:57 +0900 Subject: [PATCH 5/5] chore: increase test coverage --- config/config.go | 2 +- config/toml.go | 6 ++--- privval/internal/ip_filter_test.go | 29 ++++++++++++--------- privval/internal/null_object_filter.go | 4 +-- privval/internal/null_object_filter_test.go | 7 ++++- privval/signer_listener_endpoint_test.go | 13 +++++++++ 6 files changed, 42 insertions(+), 19 deletions(-) diff --git a/config/config.go b/config/config.go index a6be7ec27..052f01a85 100644 --- a/config/config.go +++ b/config/config.go @@ -242,7 +242,7 @@ type BaseConfig struct { //nolint: maligned // TCP or UNIX socket address for Ostracon to listen on for // connections from an external PrivValidator process - // example) 0.0.0.0:26659 + // example) tcp://0.0.0.0:26659 PrivValidatorListenAddr string `mapstructure:"priv_validator_laddr"` // Validator's remote address(without port) to allow a connection diff --git a/config/toml.go b/config/toml.go index 2c2df7021..34dc59204 100644 --- a/config/toml.go +++ b/config/toml.go @@ -156,13 +156,13 @@ priv_validator_state_file = "{{ js .BaseConfig.PrivValidatorState }}" # TCP or UNIX socket address for Ostracon to listen on for # connections from an external PrivValidator process -# example) 0.0.0.0:26659 +# example) tcp://0.0.0.0:26659 priv_validator_laddr = "{{ .BaseConfig.PrivValidatorListenAddr }}" # Validator's remote address to allow a connection # ostracon only allow a connection from this address -# example) 10.0.0.7 -priv_validator_raddr = "{{ .BaseConfig.PrivValidatorRemoteAddr }}" +# example) 127.0.0.1 +priv_validator_raddr = "127.0.0.1" # Path to the JSON file containing the private key to use for node authentication in the p2p protocol node_key_file = "{{ js .BaseConfig.NodeKey }}" diff --git a/privval/internal/ip_filter_test.go b/privval/internal/ip_filter_test.go index b8b8dad8a..6257fa7ca 100644 --- a/privval/internal/ip_filter_test.go +++ b/privval/internal/ip_filter_test.go @@ -22,7 +22,7 @@ func TestFilterRemoteConnectionByIP(t *testing.T) { type fields struct { allowIP string remoteAddr net.Addr - expected bool + expected net.Addr } tests := []struct { name string @@ -33,46 +33,46 @@ func TestFilterRemoteConnectionByIP(t *testing.T) { struct { allowIP string remoteAddr net.Addr - expected bool - }{"127.0.0.1", addrStub{"127.0.0.1:45678"}, true}, + expected net.Addr + }{"127.0.0.1", addrStub{"127.0.0.1:45678"}, addrStub{"127.0.0.1:45678"}}, }, { "should not allow different ip", struct { allowIP string remoteAddr net.Addr - expected bool - }{"127.0.0.1", addrStub{"10.0.0.2:45678"}, false}, + expected net.Addr + }{"127.0.0.1", addrStub{"10.0.0.2:45678"}, nil}, }, { "should works for IPv6 with correct ip", struct { allowIP string remoteAddr net.Addr - expected bool - }{"2001:db8::1", addrStub{"[2001:db8::1]:80"}, true}, + expected net.Addr + }{"2001:db8::1", addrStub{"[2001:db8::1]:80"}, addrStub{"[2001:db8::1]:80"}}, }, { "should works for IPv6 with incorrect ip", struct { allowIP string remoteAddr net.Addr - expected bool - }{"2001:db8::2", addrStub{"[2001:db8::1]:80"}, false}, + expected net.Addr + }{"2001:db8::2", addrStub{"[2001:db8::1]:80"}, nil}, }, { "empty allowIP should deny all", struct { allowIP string remoteAddr net.Addr - expected bool - }{"", addrStub{"127.0.0.1:45678"}, false}, + expected net.Addr + }{"", addrStub{"127.0.0.1:45678"}, nil}, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { cut := NewIpFilter(tt.fields.allowIP, nil) - assert.Equalf(t, tt.fields.expected, cut.isAllowedAddr(tt.fields.remoteAddr), tt.name) + assert.Equalf(t, tt.fields.expected, cut.Filter(tt.fields.remoteAddr), tt.name) }) } } @@ -84,3 +84,8 @@ func TestIpFilterShouldSetAllowAddress(t *testing.T) { assert.Equal(t, expected, cut.allowAddr) } + +func TestIpFilterStringShouldReturnsIP(t *testing.T) { + expected := "127.0.0.1" + assert.Equal(t, expected, NewIpFilter(expected, nil).String()) +} diff --git a/privval/internal/null_object_filter.go b/privval/internal/null_object_filter.go index f3ee3367e..9914df73d 100644 --- a/privval/internal/null_object_filter.go +++ b/privval/internal/null_object_filter.go @@ -6,8 +6,8 @@ import "net" type NullObject struct { } -func NewNullObject() NullObject { - return NullObject{} +func NewNullObject() *NullObject { + return &NullObject{} } func (n NullObject) Filter(addr net.Addr) net.Addr { diff --git a/privval/internal/null_object_filter_test.go b/privval/internal/null_object_filter_test.go index 44dcbf691..6df8177f8 100644 --- a/privval/internal/null_object_filter_test.go +++ b/privval/internal/null_object_filter_test.go @@ -1,6 +1,7 @@ package internal import ( + "github.com/stretchr/testify/assert" "net" "reflect" "testing" @@ -26,10 +27,14 @@ func TestNullObject_filter(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - n := &NullObject{} + n := NewNullObject() if got := n.Filter(tt.addr); !reflect.DeepEqual(got, tt.want) { t.Errorf("Filter() = %v, want %v", got, tt.want) } }) } } + +func TestNullObjectString(t *testing.T) { + assert.Equal(t, "NullObject", NewNullObject().String()) +} diff --git a/privval/signer_listener_endpoint_test.go b/privval/signer_listener_endpoint_test.go index 27a707b74..317bab82f 100644 --- a/privval/signer_listener_endpoint_test.go +++ b/privval/signer_listener_endpoint_test.go @@ -1,6 +1,7 @@ package privval import ( + "github.com/Finschia/ostracon/privval/internal" "net" "testing" "time" @@ -213,3 +214,15 @@ func getMockEndpoints( return listenerEndpoint, dialerEndpoint } + +func TestSignerListenerEndpointAllowAddressSetIpFilterForTCP(t *testing.T) { + cut := NewSignerListenerEndpoint(nil, nil, SignerListenerEndpointAllowAddress("tcp", "127.0.0.1")) + _, ok := cut.connFilter.(*internal.IpFilter) + assert.True(t, ok) +} + +func TestSignerListenerEndpointAllowAddressSetNullObjectFilterForUDS(t *testing.T) { + cut := NewSignerListenerEndpoint(nil, nil, SignerListenerEndpointAllowAddress("unix", "/mnt/uds/sock01")) + _, ok := cut.connFilter.(*internal.NullObject) + assert.True(t, ok) +}