Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Default validator support #525

Merged
merged 7 commits into from
Mar 1, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 65 additions & 21 deletions validation.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,10 @@ type validation struct {
// mx protects the validator map
mx sync.Mutex
// topicVals tracks per topic validators
topicVals map[string]*topicVal
topicVals map[string]*validatorImpl

// defaultVals tracks default validators applicable to all topics
defaultVals []*validatorImpl

// validateQ is the front-end to the validation pipeline
validateQ chan *validateReq
Expand All @@ -84,13 +87,13 @@ type validation struct {

// validation requests
type validateReq struct {
vals []*topicVal
vals []*validatorImpl
src peer.ID
msg *Message
}

// representation of topic validators
type topicVal struct {
type validatorImpl struct {
topic string
validate ValidatorEx
validateTimeout time.Duration
Expand All @@ -117,7 +120,7 @@ type rmValReq struct {
// newValidation creates a new validation pipeline
func newValidation() *validation {
return &validation{
topicVals: make(map[string]*topicVal),
topicVals: make(map[string]*validatorImpl),
validateQ: make(chan *validateReq, defaultValidateQueueSize),
validateThrottle: make(chan struct{}, defaultValidateThrottle),
validateWorkers: runtime.NumCPU(),
Expand All @@ -136,17 +139,28 @@ func (v *validation) Start(p *PubSub) {

// AddValidator adds a new validator
func (v *validation) AddValidator(req *addValReq) {
val, err := v.makeValidator(req)
if err != nil {
req.resp <- err
return
}

v.mx.Lock()
defer v.mx.Unlock()

topic := req.topic
topic := val.topic

_, ok := v.topicVals[topic]
if ok {
req.resp <- fmt.Errorf("duplicate validator for topic %s", topic)
return
}

v.topicVals[topic] = val
req.resp <- nil
}

func (v *validation) makeValidator(req *addValReq) (*validatorImpl, error) {
makeValidatorEx := func(v Validator) ValidatorEx {
return func(ctx context.Context, p peer.ID, msg *Message) ValidationResult {
if v(ctx, p, msg) {
Expand All @@ -170,12 +184,15 @@ func (v *validation) AddValidator(req *addValReq) {
validator = v

default:
req.resp <- fmt.Errorf("unknown validator type for topic %s; must be an instance of Validator or ValidatorEx", topic)
return
topic := req.topic
if req.topic == "" {
topic = "(default)"
}
return nil, fmt.Errorf("unknown validator type for topic %s; must be an instance of Validator or ValidatorEx", topic)
}

val := &topicVal{
topic: topic,
val := &validatorImpl{
topic: req.topic,
validate: validator,
validateTimeout: 0,
validateThrottle: make(chan struct{}, defaultValidateConcurrency),
Expand All @@ -190,8 +207,7 @@ func (v *validation) AddValidator(req *addValReq) {
val.validateThrottle = make(chan struct{}, req.throttle)
}

v.topicVals[topic] = val
req.resp <- nil
return val, nil
}

// RemoveValidator removes an existing validator
Expand Down Expand Up @@ -244,18 +260,21 @@ func (v *validation) Push(src peer.ID, msg *Message) bool {
}

// getValidators returns all validators that apply to a given message
func (v *validation) getValidators(msg *Message) []*topicVal {
func (v *validation) getValidators(msg *Message) []*validatorImpl {
v.mx.Lock()
defer v.mx.Unlock()

var vals []*validatorImpl
vals = append(vals, v.defaultVals...)

topic := msg.GetTopic()

val, ok := v.topicVals[topic]
if !ok {
return nil
return vals
}

return []*topicVal{val}
return append(vals, val)
}

// validateWorker is an active goroutine performing inline validation
Expand All @@ -271,7 +290,7 @@ func (v *validation) validateWorker() {
}

// validate performs validation and only sends the message if all validators succeed
func (v *validation) validate(vals []*topicVal, src peer.ID, msg *Message, synchronous bool) error {
func (v *validation) validate(vals []*validatorImpl, src peer.ID, msg *Message, synchronous bool) error {
// If signature verification is enabled, but signing is disabled,
// the Signature is required to be nil upon receiving the message in PubSub.pushMsg.
if msg.Signature != nil {
Expand All @@ -292,7 +311,7 @@ func (v *validation) validate(vals []*topicVal, src peer.ID, msg *Message, synch
v.tracer.ValidateMessage(msg)
}

var inline, async []*topicVal
var inline, async []*validatorImpl
for _, val := range vals {
if val.validateInline || synchronous {
inline = append(inline, val)
Expand Down Expand Up @@ -360,7 +379,7 @@ func (v *validation) validateSignature(msg *Message) bool {
return true
}

func (v *validation) doValidateTopic(vals []*topicVal, src peer.ID, msg *Message, r ValidationResult) {
func (v *validation) doValidateTopic(vals []*validatorImpl, src peer.ID, msg *Message, r ValidationResult) {
result := v.validateTopic(vals, src, msg)

if result == ValidationAccept && r != ValidationAccept {
Expand Down Expand Up @@ -388,7 +407,7 @@ func (v *validation) doValidateTopic(vals []*topicVal, src peer.ID, msg *Message
}
}

func (v *validation) validateTopic(vals []*topicVal, src peer.ID, msg *Message) ValidationResult {
func (v *validation) validateTopic(vals []*validatorImpl, src peer.ID, msg *Message) ValidationResult {
if len(vals) == 1 {
return v.validateSingleTopic(vals[0], src, msg)
}
Expand All @@ -404,7 +423,7 @@ func (v *validation) validateTopic(vals []*topicVal, src peer.ID, msg *Message)

select {
case val.validateThrottle <- struct{}{}:
go func(val *topicVal) {
go func(val *validatorImpl) {
rch <- val.validateMsg(ctx, src, msg)
<-val.validateThrottle
}(val)
Expand Down Expand Up @@ -438,7 +457,7 @@ loop:
}

// fast path for single topic validation that avoids the extra goroutine
func (v *validation) validateSingleTopic(val *topicVal, src peer.ID, msg *Message) ValidationResult {
func (v *validation) validateSingleTopic(val *validatorImpl, src peer.ID, msg *Message) ValidationResult {
select {
case val.validateThrottle <- struct{}{}:
res := val.validateMsg(v.p.ctx, src, msg)
Expand All @@ -451,7 +470,7 @@ func (v *validation) validateSingleTopic(val *topicVal, src peer.ID, msg *Messag
}
}

func (val *topicVal) validateMsg(ctx context.Context, src peer.ID, msg *Message) ValidationResult {
func (val *validatorImpl) validateMsg(ctx context.Context, src peer.ID, msg *Message) ValidationResult {
start := time.Now()
defer func() {
log.Debugf("validation done; took %s", time.Since(start))
Expand Down Expand Up @@ -479,6 +498,31 @@ func (val *topicVal) validateMsg(ctx context.Context, src peer.ID, msg *Message)
}

// / Options
// WithDefaultValidator adds a validator that applies to all topics by default; it can be used
// more than once and add multiple validators. Having a defult validator does not inhibit registering
// a per topic validator.
func WithDefaultValidator(val interface{}, opts ...ValidatorOpt) Option {
return func(ps *PubSub) error {
addVal := &addValReq{
validate: val,
}

for _, opt := range opts {
err := opt(addVal)
if err != nil {
return err
}
}

val, err := ps.val.makeValidator(addVal)
if err != nil {
return err
}

ps.val.defaultVals = append(ps.val.defaultVals, val)
return nil
}
}

// WithValidateQueueSize sets the buffer of validate queue. Defaults to 32.
// When queue is full, validation is throttled and new messages are dropped.
Expand Down
95 changes: 95 additions & 0 deletions validation_builtin.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
package pubsub

import (
"context"
"encoding/binary"
"sync"

"github.com/libp2p/go-libp2p/core/peer"
)

// PeerMetadataStore is an interface for storing and retrieving per peer metadata
type PeerMetadataStore interface {
// Get retrieves the metadata associated with a peer;
// It should return nil if there is no metadata associated with the peer and not an error.
Get(context.Context, peer.ID) ([]byte, error)
// Put sets the metadata associated with a peer.
Put(context.Context, peer.ID, []byte) error
}

// BasicSeqnoValidator is a basic validator, usable as a default validator, that ignores replayed
// messages outside the seen cache window. The validator uses the message seqno as a peer-specific
// nonce to decide whether the message should be propagated, comparing to the maximal nonce store
// in the peer metadata store. This is useful to ensure that there can be no infinitely propagating
// messages in the network regardless of the seen cache span and network diameter.
// It requires that pubsub is instantiated with a strict message signing policy and that seqnos
// are not disabled, ie it doesn't support anonymous mode.
vyzo marked this conversation as resolved.
Show resolved Hide resolved
type BasicSeqnoValidator struct {
mx sync.RWMutex
meta PeerMetadataStore
}

// NewBasicSeqnoValidator constructs a BasicSeqnoValidator using the givven PeerMetadataStore.
func NewBasicSeqnoValidator(meta PeerMetadataStore) *BasicSeqnoValidator {
return &BasicSeqnoValidator{
meta: meta,
}
}

func (v *BasicSeqnoValidator) validate(ctx context.Context, _ peer.ID, m *Message) ValidationResult {
p := m.GetFrom()

v.mx.RLock()
nonceBytes, err := v.meta.Get(ctx, p)
v.mx.RUnlock()

if err != nil {
log.Warn("error retrieving peer nonce: %s", err)
return ValidationIgnore
}

var nonce uint64
if len(nonceBytes) > 0 {
nonce = binary.BigEndian.Uint64(nonceBytes)
}

var seqno uint64
seqnoBytes := m.GetSeqno()
if len(seqnoBytes) > 0 {
seqno = binary.BigEndian.Uint64(seqnoBytes)
}

// compare against the largest seen nonce
if seqno <= nonce {
return ValidationIgnore
}

// get the nonce and compare again with an exclusive lock before commiting (cf concurrent validation)
v.mx.Lock()
defer v.mx.Unlock()

nonceBytes, err = v.meta.Get(ctx, p)
if err != nil {
log.Warn("error retrieving peer nonce: %s", err)
return ValidationIgnore
}

if len(nonceBytes) > 0 {
nonce = binary.BigEndian.Uint64(nonceBytes)
}

if seqno <= nonce {
return ValidationIgnore
}

// update the nonce
nonceBytes = make([]byte, 8)
binary.BigEndian.PutUint64(nonceBytes, seqno)

err = v.meta.Put(ctx, p, nonceBytes)
if err != nil {
log.Warn("error storing peer nonce: %s", err)
}

return ValidationAccept
}