diff --git a/floodsub_test.go b/floodsub_test.go index 62552ceb..4534adaf 100644 --- a/floodsub_test.go +++ b/floodsub_test.go @@ -115,6 +115,18 @@ func getPubsubs(ctx context.Context, hs []host.Host, opts ...Option) []*PubSub { return psubs } +func getPubsubsWithOptionC(ctx context.Context, hs []host.Host, cons ...func(int) Option) []*PubSub { + var psubs []*PubSub + for _, h := range hs { + var opts []Option + for i, c := range cons { + opts = append(opts, c(i)) + } + psubs = append(psubs, getPubsub(ctx, h, opts...)) + } + return psubs +} + func assertReceive(t *testing.T, ch *Subscription, exp []byte) { select { case msg := <-ch.ch: @@ -175,7 +187,6 @@ func TestBasicFloodsub(t *testing.T) { } } } - } func TestMultihops(t *testing.T) { diff --git a/validation.go b/validation.go index 0c11fc24..1044d5d8 100644 --- a/validation.go +++ b/validation.go @@ -70,7 +70,10 @@ type validation struct { // mx protects the validator map mx sync.Mutex // topicVals tracks per topic validators - topicVals map[string]*topicVal + topicVals map[string]*validatorImpl + + // defaultVals tracks default validators applicable to all topics + defaultVals []*validatorImpl // validateQ is the front-end to the validation pipeline validateQ chan *validateReq @@ -84,13 +87,13 @@ type validation struct { // validation requests type validateReq struct { - vals []*topicVal + vals []*validatorImpl src peer.ID msg *Message } // representation of topic validators -type topicVal struct { +type validatorImpl struct { topic string validate ValidatorEx validateTimeout time.Duration @@ -117,7 +120,7 @@ type rmValReq struct { // newValidation creates a new validation pipeline func newValidation() *validation { return &validation{ - topicVals: make(map[string]*topicVal), + topicVals: make(map[string]*validatorImpl), validateQ: make(chan *validateReq, defaultValidateQueueSize), validateThrottle: make(chan struct{}, defaultValidateThrottle), validateWorkers: runtime.NumCPU(), @@ -136,10 +139,16 @@ func (v *validation) Start(p *PubSub) { // AddValidator adds a new validator func (v *validation) AddValidator(req *addValReq) { + val, err := v.makeValidator(req) + if err != nil { + req.resp <- err + return + } + v.mx.Lock() defer v.mx.Unlock() - topic := req.topic + topic := val.topic _, ok := v.topicVals[topic] if ok { @@ -147,6 +156,11 @@ func (v *validation) AddValidator(req *addValReq) { return } + v.topicVals[topic] = val + req.resp <- nil +} + +func (v *validation) makeValidator(req *addValReq) (*validatorImpl, error) { makeValidatorEx := func(v Validator) ValidatorEx { return func(ctx context.Context, p peer.ID, msg *Message) ValidationResult { if v(ctx, p, msg) { @@ -170,12 +184,15 @@ func (v *validation) AddValidator(req *addValReq) { validator = v default: - req.resp <- fmt.Errorf("unknown validator type for topic %s; must be an instance of Validator or ValidatorEx", topic) - return + topic := req.topic + if req.topic == "" { + topic = "(default)" + } + return nil, fmt.Errorf("unknown validator type for topic %s; must be an instance of Validator or ValidatorEx", topic) } - val := &topicVal{ - topic: topic, + val := &validatorImpl{ + topic: req.topic, validate: validator, validateTimeout: 0, validateThrottle: make(chan struct{}, defaultValidateConcurrency), @@ -190,8 +207,7 @@ func (v *validation) AddValidator(req *addValReq) { val.validateThrottle = make(chan struct{}, req.throttle) } - v.topicVals[topic] = val - req.resp <- nil + return val, nil } // RemoveValidator removes an existing validator @@ -244,18 +260,21 @@ func (v *validation) Push(src peer.ID, msg *Message) bool { } // getValidators returns all validators that apply to a given message -func (v *validation) getValidators(msg *Message) []*topicVal { +func (v *validation) getValidators(msg *Message) []*validatorImpl { v.mx.Lock() defer v.mx.Unlock() + var vals []*validatorImpl + vals = append(vals, v.defaultVals...) + topic := msg.GetTopic() val, ok := v.topicVals[topic] if !ok { - return nil + return vals } - return []*topicVal{val} + return append(vals, val) } // validateWorker is an active goroutine performing inline validation @@ -271,7 +290,7 @@ func (v *validation) validateWorker() { } // validate performs validation and only sends the message if all validators succeed -func (v *validation) validate(vals []*topicVal, src peer.ID, msg *Message, synchronous bool) error { +func (v *validation) validate(vals []*validatorImpl, src peer.ID, msg *Message, synchronous bool) error { // If signature verification is enabled, but signing is disabled, // the Signature is required to be nil upon receiving the message in PubSub.pushMsg. if msg.Signature != nil { @@ -292,7 +311,7 @@ func (v *validation) validate(vals []*topicVal, src peer.ID, msg *Message, synch v.tracer.ValidateMessage(msg) } - var inline, async []*topicVal + var inline, async []*validatorImpl for _, val := range vals { if val.validateInline || synchronous { inline = append(inline, val) @@ -360,7 +379,7 @@ func (v *validation) validateSignature(msg *Message) bool { return true } -func (v *validation) doValidateTopic(vals []*topicVal, src peer.ID, msg *Message, r ValidationResult) { +func (v *validation) doValidateTopic(vals []*validatorImpl, src peer.ID, msg *Message, r ValidationResult) { result := v.validateTopic(vals, src, msg) if result == ValidationAccept && r != ValidationAccept { @@ -388,7 +407,7 @@ func (v *validation) doValidateTopic(vals []*topicVal, src peer.ID, msg *Message } } -func (v *validation) validateTopic(vals []*topicVal, src peer.ID, msg *Message) ValidationResult { +func (v *validation) validateTopic(vals []*validatorImpl, src peer.ID, msg *Message) ValidationResult { if len(vals) == 1 { return v.validateSingleTopic(vals[0], src, msg) } @@ -404,7 +423,7 @@ func (v *validation) validateTopic(vals []*topicVal, src peer.ID, msg *Message) select { case val.validateThrottle <- struct{}{}: - go func(val *topicVal) { + go func(val *validatorImpl) { rch <- val.validateMsg(ctx, src, msg) <-val.validateThrottle }(val) @@ -438,7 +457,7 @@ loop: } // fast path for single topic validation that avoids the extra goroutine -func (v *validation) validateSingleTopic(val *topicVal, src peer.ID, msg *Message) ValidationResult { +func (v *validation) validateSingleTopic(val *validatorImpl, src peer.ID, msg *Message) ValidationResult { select { case val.validateThrottle <- struct{}{}: res := val.validateMsg(v.p.ctx, src, msg) @@ -451,7 +470,7 @@ func (v *validation) validateSingleTopic(val *topicVal, src peer.ID, msg *Messag } } -func (val *topicVal) validateMsg(ctx context.Context, src peer.ID, msg *Message) ValidationResult { +func (val *validatorImpl) validateMsg(ctx context.Context, src peer.ID, msg *Message) ValidationResult { start := time.Now() defer func() { log.Debugf("validation done; took %s", time.Since(start)) @@ -479,6 +498,31 @@ func (val *topicVal) validateMsg(ctx context.Context, src peer.ID, msg *Message) } // / Options +// WithDefaultValidator adds a validator that applies to all topics by default; it can be used +// more than once and add multiple validators. Having a defult validator does not inhibit registering +// a per topic validator. +func WithDefaultValidator(val interface{}, opts ...ValidatorOpt) Option { + return func(ps *PubSub) error { + addVal := &addValReq{ + validate: val, + } + + for _, opt := range opts { + err := opt(addVal) + if err != nil { + return err + } + } + + val, err := ps.val.makeValidator(addVal) + if err != nil { + return err + } + + ps.val.defaultVals = append(ps.val.defaultVals, val) + return nil + } +} // WithValidateQueueSize sets the buffer of validate queue. Defaults to 32. // When queue is full, validation is throttled and new messages are dropped. diff --git a/validation_builtin.go b/validation_builtin.go new file mode 100644 index 00000000..6a2e3ccd --- /dev/null +++ b/validation_builtin.go @@ -0,0 +1,101 @@ +package pubsub + +import ( + "context" + "encoding/binary" + "sync" + + "github.com/libp2p/go-libp2p/core/peer" +) + +// PeerMetadataStore is an interface for storing and retrieving per peer metadata +type PeerMetadataStore interface { + // Get retrieves the metadata associated with a peer; + // It should return nil if there is no metadata associated with the peer and not an error. + Get(context.Context, peer.ID) ([]byte, error) + // Put sets the metadata associated with a peer. + Put(context.Context, peer.ID, []byte) error +} + +// BasicSeqnoValidator is a basic validator, usable as a default validator, that ignores replayed +// messages outside the seen cache window. The validator uses the message seqno as a peer-specific +// nonce to decide whether the message should be propagated, comparing to the maximal nonce store +// in the peer metadata store. This is useful to ensure that there can be no infinitely propagating +// messages in the network regardless of the seen cache span and network diameter. +// It requires that pubsub is instantiated with a strict message signing policy and that seqnos +// are not disabled, ie it doesn't support anonymous mode. +// +// Warning: See https://github.com/libp2p/rust-libp2p/issues/3453 +// TL;DR: rust is currently violating the spec by issuing a random seqno, which creates an +// interoperability hazard. We expect this issue to be addressed in the not so distant future, +// but keep this in mind if you are in a mixed environment with (older) rust nodes. +type BasicSeqnoValidator struct { + mx sync.RWMutex + meta PeerMetadataStore +} + +// NewBasicSeqnoValidator constructs a BasicSeqnoValidator using the givven PeerMetadataStore. +func NewBasicSeqnoValidator(meta PeerMetadataStore) ValidatorEx { + val := &BasicSeqnoValidator{ + meta: meta, + } + return val.validate +} + +func (v *BasicSeqnoValidator) validate(ctx context.Context, _ peer.ID, m *Message) ValidationResult { + p := m.GetFrom() + + v.mx.RLock() + nonceBytes, err := v.meta.Get(ctx, p) + v.mx.RUnlock() + + if err != nil { + log.Warn("error retrieving peer nonce: %s", err) + return ValidationIgnore + } + + var nonce uint64 + if len(nonceBytes) > 0 { + nonce = binary.BigEndian.Uint64(nonceBytes) + } + + var seqno uint64 + seqnoBytes := m.GetSeqno() + if len(seqnoBytes) > 0 { + seqno = binary.BigEndian.Uint64(seqnoBytes) + } + + // compare against the largest seen nonce + if seqno <= nonce { + return ValidationIgnore + } + + // get the nonce and compare again with an exclusive lock before commiting (cf concurrent validation) + v.mx.Lock() + defer v.mx.Unlock() + + nonceBytes, err = v.meta.Get(ctx, p) + if err != nil { + log.Warn("error retrieving peer nonce: %s", err) + return ValidationIgnore + } + + if len(nonceBytes) > 0 { + nonce = binary.BigEndian.Uint64(nonceBytes) + } + + if seqno <= nonce { + return ValidationIgnore + } + + // update the nonce + nonceBytes = make([]byte, 8) + binary.BigEndian.PutUint64(nonceBytes, seqno) + + err = v.meta.Put(ctx, p, nonceBytes) + if err != nil { + log.Warn("error storing peer nonce: %s", err) + } + + return ValidationAccept +} diff --git a/validation_builtin_test.go b/validation_builtin_test.go new file mode 100644 index 00000000..df406f26 --- /dev/null +++ b/validation_builtin_test.go @@ -0,0 +1,278 @@ +package pubsub + +import ( + "bytes" + "context" + "encoding/binary" + "fmt" + "math/rand" + "sync" + "testing" + "time" + + pool "github.com/libp2p/go-buffer-pool" + "github.com/libp2p/go-libp2p/core/host" + "github.com/libp2p/go-libp2p/core/network" + "github.com/libp2p/go-libp2p/core/peer" + "github.com/libp2p/go-msgio" + "github.com/multiformats/go-varint" + + pb "github.com/libp2p/go-libp2p-pubsub/pb" +) + +var rng *rand.Rand + +func init() { + rng = rand.New(rand.NewSource(314159)) +} + +func TestBasicSeqnoValidator1(t *testing.T) { + testBasicSeqnoValidator(t, time.Minute) +} + +func TestBasicSeqnoValidator2(t *testing.T) { + testBasicSeqnoValidator(t, time.Nanosecond) +} + +func testBasicSeqnoValidator(t *testing.T, ttl time.Duration) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + hosts := getNetHosts(t, ctx, 20) + psubs := getPubsubsWithOptionC(ctx, hosts, + func(i int) Option { + return WithDefaultValidator(NewBasicSeqnoValidator(newMockPeerMetadataStore())) + }, + func(i int) Option { + return WithSeenMessagesTTL(ttl) + }, + ) + + var msgs []*Subscription + for _, ps := range psubs { + subch, err := ps.Subscribe("foobar") + if err != nil { + t.Fatal(err) + } + + msgs = append(msgs, subch) + } + + // connectAll(t, hosts) + sparseConnect(t, hosts) + + time.Sleep(time.Millisecond * 100) + + for i := 0; i < 100; i++ { + msg := []byte(fmt.Sprintf("%d the flooooooood %d", i, i)) + + owner := rng.Intn(len(psubs)) + + psubs[owner].Publish("foobar", msg) + + for _, sub := range msgs { + got, err := sub.Next(ctx) + if err != nil { + t.Fatal(sub.err) + } + if !bytes.Equal(msg, got.Data) { + t.Fatal("got wrong message!") + } + } + } +} + +func TestBasicSeqnoValidatorReplay(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + hosts := getNetHosts(t, ctx, 20) + psubs := getPubsubsWithOptionC(ctx, hosts[:19], + func(i int) Option { + return WithDefaultValidator(NewBasicSeqnoValidator(newMockPeerMetadataStore())) + }, + func(i int) Option { + return WithSeenMessagesTTL(time.Nanosecond) + }, + ) + _ = newReplayActor(t, ctx, hosts[19]) + + var msgs []*Subscription + for _, ps := range psubs { + subch, err := ps.Subscribe("foobar") + if err != nil { + t.Fatal(err) + } + + msgs = append(msgs, subch) + } + + sparseConnect(t, hosts) + + time.Sleep(time.Millisecond * 100) + + for i := 0; i < 10; i++ { + msg := []byte(fmt.Sprintf("%d the flooooooood %d", i, i)) + + owner := rng.Intn(len(psubs)) + + psubs[owner].Publish("foobar", msg) + + for _, sub := range msgs { + got, err := sub.Next(ctx) + if err != nil { + t.Fatal(sub.err) + } + if !bytes.Equal(msg, got.Data) { + t.Fatal("got wrong message!") + } + } + } + + for _, sub := range msgs { + assertNeverReceives(t, sub, time.Second) + } +} + +type mockPeerMetadataStore struct { + meta map[peer.ID][]byte +} + +func newMockPeerMetadataStore() *mockPeerMetadataStore { + return &mockPeerMetadataStore{ + meta: make(map[peer.ID][]byte), + } +} + +func (m *mockPeerMetadataStore) Get(ctx context.Context, p peer.ID) ([]byte, error) { + v, ok := m.meta[p] + if !ok { + return nil, nil + } + return v, nil +} + +func (m *mockPeerMetadataStore) Put(ctx context.Context, p peer.ID, v []byte) error { + m.meta[p] = v + return nil +} + +type replayActor struct { + t *testing.T + + ctx context.Context + h host.Host + + mx sync.Mutex + out map[peer.ID]network.Stream +} + +func newReplayActor(t *testing.T, ctx context.Context, h host.Host) *replayActor { + replay := &replayActor{t: t, ctx: ctx, h: h, out: make(map[peer.ID]network.Stream)} + h.SetStreamHandler(FloodSubID, replay.handleStream) + h.Network().Notify(&network.NotifyBundle{ConnectedF: replay.connected}) + return replay +} + +func (r *replayActor) handleStream(s network.Stream) { + defer s.Close() + + p := s.Conn().RemotePeer() + + rd := msgio.NewVarintReaderSize(s, 65536) + for { + msgbytes, err := rd.ReadMsg() + if err != nil { + s.Reset() + rd.ReleaseMsg(msgbytes) + return + } + + rpc := new(pb.RPC) + err = rpc.Unmarshal(msgbytes) + rd.ReleaseMsg(msgbytes) + if err != nil { + s.Reset() + return + } + + // subscribe to the same topics as our peer + subs := rpc.GetSubscriptions() + if len(subs) != 0 { + go r.send(p, &pb.RPC{Subscriptions: subs}) + } + + // replay all received messages + for _, pmsg := range rpc.GetPublish() { + go r.replay(pmsg) + } + } +} + +func (r *replayActor) send(p peer.ID, rpc *pb.RPC) { + r.mx.Lock() + defer r.mx.Unlock() + + s, ok := r.out[p] + if !ok { + r.t.Logf("cannot send message to %s: no stream", p) + return + } + + size := uint64(rpc.Size()) + + buf := pool.Get(varint.UvarintSize(size) + int(size)) + defer pool.Put(buf) + + n := binary.PutUvarint(buf, size) + + _, err := rpc.MarshalTo(buf[n:]) + if err != nil { + r.t.Logf("replay: error marshalling message: %s", err) + return + } + + _, err = s.Write(buf) + if err != nil { + r.t.Logf("replay: error sending message: %s", err) + } +} + +func (r *replayActor) replay(msg *pb.Message) { + // replay the message 10 times to a random subset of peers + for i := 0; i < 10; i++ { + delay := time.Duration(1+rng.Intn(20)) * time.Millisecond + time.Sleep(delay) + + var peers []peer.ID + r.mx.Lock() + for p, _ := range r.out { + if rng.Intn(2) > 0 { + peers = append(peers, p) + } + } + r.mx.Unlock() + + rpc := &pb.RPC{Publish: []*pb.Message{msg}} + r.t.Logf("replaying msg to %d peers", len(peers)) + for _, p := range peers { + r.send(p, rpc) + } + } +} + +func (r *replayActor) handleConnected(p peer.ID) { + s, err := r.h.NewStream(r.ctx, p, FloodSubID) + if err != nil { + r.t.Logf("replay: error opening stream: %s", err) + return + } + + r.mx.Lock() + defer r.mx.Unlock() + r.out[p] = s +} + +func (r *replayActor) connected(_ network.Network, conn network.Conn) { + go r.handleConnected(conn.RemotePeer()) +}