Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Build validator copy cache on write #31

Merged
merged 7 commits into from
Aug 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 14 additions & 22 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -159,11 +159,11 @@ validator, err := protovalidate.New(
)
```

Lazy mode requires usage of a mutex to keep the validator thread-safe, which
results in about 50% of CPU time spent obtaining a read lock. While [performance](#performance)
is sub-microsecond, the mutex overhead can be further reduced by disabling lazy
mode with the `WithDisableLazy` option. Note that all expected messages must be
provided during initialization of the validator:
Lazy mode uses a copy on write cache stategy to reduce the required locking.
While [performance](#performance) is sub-microsecond, the overhead can be
further reduced by disabling lazy mode with the `WithDisableLazy` option.
Note that all expected messages must be provided during initialization of the
validator:

```go
validator, err := protovalidate.New(
Expand Down Expand Up @@ -200,25 +200,17 @@ initial cold start, validation on a message is sub-microsecond
and only allocates in the event of a validation error.

```
[circa 15 May 2023]
[circa 24 August 2023]
goos: darwin
goarch: arm64
pkg: github.com/bufbuild/protovalidate-go
BenchmarkValidator
BenchmarkValidator/ColdStart
BenchmarkValidator/ColdStart-10 4372 276457 ns/op 470780 B/op 9255 allocs/op
BenchmarkValidator/Lazy/Valid
BenchmarkValidator/Lazy/Valid-10 9022392 134.1 ns/op 0 B/op 0 allocs/op
BenchmarkValidator/Lazy/Invalid
BenchmarkValidator/Lazy/Invalid-10 3416996 355.9 ns/op 632 B/op 14 allocs/op
BenchmarkValidator/Lazy/FailFast
BenchmarkValidator/Lazy/FailFast-10 6751131 172.6 ns/op 168 B/op 3 allocs/op
BenchmarkValidator/PreWarmed/Valid
BenchmarkValidator/PreWarmed/Valid-10 17557560 69.10 ns/op 0 B/op 0 allocs/op
BenchmarkValidator/PreWarmed/Invalid
BenchmarkValidator/PreWarmed/Invalid-10 3621961 332.9 ns/op 632 B/op 14 allocs/op
BenchmarkValidator/PreWarmed/FailFast
BenchmarkValidator/PreWarmed/FailFast-10 13960359 92.22 ns/op 168 B/op 3 allocs/op
BenchmarkValidator/ColdStart-8 5294 219906 ns/op 431759 B/op 5803 allocs/op
BenchmarkValidator/Lazy/Valid-8 9725028 114.7 ns/op 0 B/op 0 allocs/op
BenchmarkValidator/Lazy/Invalid-8 3060620 383.5 ns/op 649 B/op 15 allocs/op
BenchmarkValidator/Lazy/FailFast-8 11999664 98.17 ns/op 168 B/op 3 allocs/op
BenchmarkValidator/PreWarmed/Valid-8 11031498 112.0 ns/op 0 B/op 0 allocs/op
BenchmarkValidator/PreWarmed/Invalid-8 3132213 391.1 ns/op 649 B/op 15 allocs/op
BenchmarkValidator/PreWarmed/FailFast-8 12277747 99.36 ns/op 168 B/op 3 allocs/op
PASS
```

Expand All @@ -240,4 +232,4 @@ Offered under the [Apache 2 license][license].
[cel-spec]: https://github.com/google/cel-spec
[pv-cc]: https://github.com/bufbuild/protovalidate-cc
[pv-java]: https://github.com/bufbuild/protovalidate-java
[pv-python]: https://github.com/bufbuild/protovalidate-python
[pv-python]: https://github.com/bufbuild/protovalidate-python
96 changes: 67 additions & 29 deletions internal/evaluator/builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ package evaluator

import (
"sync"
"sync/atomic"

"buf.build/gen/go/bufbuild/protovalidate/protocolbuffers/go/buf/validate"
"github.com/bufbuild/protovalidate-go/internal/constraints"
Expand All @@ -29,11 +30,8 @@ import (
// Builder is a build-through cache of message evaluators keyed off the provided
// descriptor.
type Builder struct {
// TODO: (TCN-1708) based on benchmarks, about 50% of CPU time is spent obtaining a read
// lock on this mutex. Ideally, this can be reworked to be thread-safe while
// minimizing the need to obtain a lock.
mtx sync.RWMutex
cache map[protoreflect.MessageDescriptor]*message
mtx sync.Mutex // serializes cache writes.
cache atomic.Pointer[MessageCache] // copy-on-write cache.
env *cel.Env
constraints constraints.Cache
resolver StandardConstraintResolver
Expand All @@ -48,7 +46,6 @@ func NewBuilder(
seedDesc ...protoreflect.MessageDescriptor,
) *Builder {
bldr := &Builder{
cache: map[protoreflect.MessageDescriptor]*message{},
env: env,
constraints: constraints.NewCache(),
resolver: res,
Expand All @@ -60,18 +57,19 @@ func NewBuilder(
bldr.Load = bldr.loadOrBuild
}

cache := make(MessageCache, len(seedDesc))
for _, desc := range seedDesc {
bldr.build(desc)
bldr.build(desc, cache)
}

bldr.cache.Store(&cache)
return bldr
}

// load returns a pre-cached MessageEvaluator for the given descriptor or, if
// the descriptor is unknown, returns an evaluator that always resolves to a
// errors.CompilationError.
func (bldr *Builder) load(desc protoreflect.MessageDescriptor) MessageEvaluator {
if eval, ok := bldr.cache[desc]; ok {
if eval, ok := (*bldr.cache.Load())[desc]; ok {
return eval
}
return unknownMessage{desc: desc}
Expand All @@ -81,29 +79,38 @@ func (bldr *Builder) load(desc protoreflect.MessageDescriptor) MessageEvaluator
// descriptor, or lazily constructs a new one. This method is thread-safe via
// locking.
func (bldr *Builder) loadOrBuild(desc protoreflect.MessageDescriptor) MessageEvaluator {
bldr.mtx.RLock()
if eval, ok := bldr.cache[desc]; ok {
bldr.mtx.RUnlock()
if eval, ok := (*bldr.cache.Load())[desc]; ok {
return eval
}
bldr.mtx.RUnlock()

bldr.mtx.Lock()
defer bldr.mtx.Unlock()
return bldr.build(desc)
cache := *bldr.cache.Load()
if eval, ok := cache[desc]; ok {
return eval
}
newCache := cache.Clone()
msgEval := bldr.build(desc, newCache)
bldr.cache.Store(&newCache)
return msgEval
}

func (bldr *Builder) build(desc protoreflect.MessageDescriptor) *message {
if eval, ok := bldr.cache[desc]; ok {
func (bldr *Builder) build(
desc protoreflect.MessageDescriptor,
cache MessageCache,
) *message {
if eval, ok := cache[desc]; ok {
return eval
}
msgEval := &message{}
bldr.cache[desc] = msgEval
bldr.buildMessage(desc, msgEval)
cache[desc] = msgEval
bldr.buildMessage(desc, msgEval, cache)
return msgEval
}

func (bldr *Builder) buildMessage(desc protoreflect.MessageDescriptor, msgEval *message) {
func (bldr *Builder) buildMessage(
desc protoreflect.MessageDescriptor, msgEval *message,
cache MessageCache,
) {
msgConstraints := bldr.resolver.ResolveMessageConstraints(desc)
if msgConstraints.GetDisabled() {
return
Expand All @@ -113,14 +120,15 @@ func (bldr *Builder) buildMessage(desc protoreflect.MessageDescriptor, msgEval *
desc protoreflect.MessageDescriptor,
msgConstraints *validate.MessageConstraints,
msg *message,
cache MessageCache,
){
bldr.processMessageExpressions,
bldr.processOneofConstraints,
bldr.processFields,
}

for _, step := range steps {
if step(desc, msgConstraints, msgEval); msgEval.Err != nil {
if step(desc, msgConstraints, msgEval, cache); msgEval.Err != nil {
break
}
}
Expand All @@ -130,6 +138,7 @@ func (bldr *Builder) processMessageExpressions(
desc protoreflect.MessageDescriptor,
msgConstraints *validate.MessageConstraints,
msgEval *message,
_ MessageCache,
) {
compiledExprs, err := expression.Compile(
msgConstraints.GetCel(),
Expand All @@ -149,6 +158,7 @@ func (bldr *Builder) processOneofConstraints(
desc protoreflect.MessageDescriptor,
_ *validate.MessageConstraints,
msgEval *message,
_ MessageCache,
) {
oneofs := desc.Oneofs()
for i := 0; i < oneofs.Len(); i++ {
Expand All @@ -166,12 +176,13 @@ func (bldr *Builder) processFields(
desc protoreflect.MessageDescriptor,
_ *validate.MessageConstraints,
msgEval *message,
cache MessageCache,
) {
fields := desc.Fields()
for i := 0; i < fields.Len(); i++ {
fdesc := fields.Get(i)
fieldConstraints := bldr.resolver.ResolveFieldConstraints(fdesc)
fldEval, err := bldr.buildField(fdesc, fieldConstraints)
fldEval, err := bldr.buildField(fdesc, fieldConstraints, cache)
if err != nil {
msgEval.Err = err
return
Expand All @@ -183,13 +194,14 @@ func (bldr *Builder) processFields(
func (bldr *Builder) buildField(
fieldDescriptor protoreflect.FieldDescriptor,
fieldConstraints *validate.FieldConstraints,
cache MessageCache,
) (field, error) {
fld := field{
Descriptor: fieldDescriptor,
Required: fieldConstraints.GetRequired(),
Optional: fieldDescriptor.HasPresence(),
}
err := bldr.buildValue(fieldDescriptor, fieldConstraints, false, &fld.Value)
err := bldr.buildValue(fieldDescriptor, fieldConstraints, false, &fld.Value, cache)
return fld, err
}

Expand All @@ -198,13 +210,15 @@ func (bldr *Builder) buildValue(
constraints *validate.FieldConstraints,
forItems bool,
valEval *value,
cache MessageCache,
) (err error) {
valEval.IgnoreEmpty = constraints.GetIgnoreEmpty()
steps := []func(
fdesc protoreflect.FieldDescriptor,
fieldConstraints *validate.FieldConstraints,
forItems bool,
valEval *value,
cache MessageCache,
) error{
bldr.processZeroValue,
bldr.processFieldExpressions,
Expand All @@ -218,7 +232,7 @@ func (bldr *Builder) buildValue(
}

for _, step := range steps {
if err = step(fdesc, constraints, forItems, valEval); err != nil {
if err = step(fdesc, constraints, forItems, valEval, cache); err != nil {
return err
}
}
Expand All @@ -230,6 +244,7 @@ func (bldr *Builder) processZeroValue(
_ *validate.FieldConstraints,
forItems bool,
val *value,
_ MessageCache,
) error {
val.Zero = fdesc.Default()
if forItems && fdesc.IsList() {
Expand All @@ -244,6 +259,7 @@ func (bldr *Builder) processFieldExpressions(
fieldConstraints *validate.FieldConstraints,
_ bool,
eval *value,
_ MessageCache,
) error {
exprs := fieldConstraints.GetCel()
if len(exprs) == 0 {
Expand Down Expand Up @@ -275,14 +291,15 @@ func (bldr *Builder) processEmbeddedMessage(
rules *validate.FieldConstraints,
forItems bool,
valEval *value,
cache MessageCache,
) error {
if fdesc.Kind() != protoreflect.MessageKind ||
rules.GetSkipped() ||
fdesc.IsMap() || (fdesc.IsList() && !forItems) {
return nil
}

embedEval := bldr.build(fdesc.Message())
embedEval := bldr.build(fdesc.Message(), cache)
if err := embedEval.Err; err != nil {
return errors.NewCompilationErrorf(
"failed to compile embedded type %s for %s: %w",
Expand All @@ -298,6 +315,7 @@ func (bldr *Builder) processWrapperConstraints(
rules *validate.FieldConstraints,
forItems bool,
valEval *value,
cache MessageCache,
) error {
if fdesc.Kind() != protoreflect.MessageKind ||
rules.GetSkipped() ||
Expand All @@ -310,7 +328,7 @@ func (bldr *Builder) processWrapperConstraints(
return nil
}
var unwrapped value
err := bldr.buildValue(fdesc.Message().Fields().ByName("value"), rules, true, &unwrapped)
err := bldr.buildValue(fdesc.Message().Fields().ByName("value"), rules, true, &unwrapped, cache)
if err != nil {
return err
}
Expand All @@ -323,6 +341,7 @@ func (bldr *Builder) processStandardConstraints(
constraints *validate.FieldConstraints,
forItems bool,
valEval *value,
_ MessageCache,
) error {
stdConstraints, err := bldr.constraints.Build(
bldr.env,
Expand All @@ -342,6 +361,7 @@ func (bldr *Builder) processAnyConstraints(
fieldConstraints *validate.FieldConstraints,
forItems bool,
valEval *value,
_ MessageCache,
) error {
if (fdesc.IsList() && !forItems) ||
fdesc.Kind() != protoreflect.MessageKind ||
Expand All @@ -364,6 +384,7 @@ func (bldr *Builder) processEnumConstraints(
fieldConstraints *validate.FieldConstraints,
_ bool,
valEval *value,
_ MessageCache,
) error {
if fdesc.Kind() != protoreflect.EnumKind {
return nil
Expand All @@ -379,6 +400,7 @@ func (bldr *Builder) processMapConstraints(
constraints *validate.FieldConstraints,
_ bool,
valEval *value,
cache MessageCache,
) error {
if !fieldDesc.IsMap() {
return nil
Expand All @@ -390,7 +412,8 @@ func (bldr *Builder) processMapConstraints(
fieldDesc.MapKey(),
constraints.GetMap().GetKeys(),
true,
&mapEval.KeyConstraints)
&mapEval.KeyConstraints,
cache)
if err != nil {
return errors.NewCompilationErrorf(
"failed to compile key constraints for map %s: %w",
Expand All @@ -401,7 +424,8 @@ func (bldr *Builder) processMapConstraints(
fieldDesc.MapValue(),
constraints.GetMap().GetValues(),
true,
&mapEval.ValueConstraints)
&mapEval.ValueConstraints,
cache)
if err != nil {
return errors.NewCompilationErrorf(
"failed to compile value constraints for map %s: %w",
Expand All @@ -417,13 +441,14 @@ func (bldr *Builder) processRepeatedConstraints(
fieldConstraints *validate.FieldConstraints,
forItems bool,
valEval *value,
cache MessageCache,
) error {
if !fdesc.IsList() || forItems {
return nil
}

var listEval listItems
err := bldr.buildValue(fdesc, fieldConstraints.GetRepeated().GetItems(), true, &listEval.ItemConstraints)
err := bldr.buildValue(fdesc, fieldConstraints.GetRepeated().GetItems(), true, &listEval.ItemConstraints, cache)
if err != nil {
return errors.NewCompilationErrorf(
"failed to compile items constraints for repeated %v: %w", fdesc.FullName(), err)
Expand All @@ -432,3 +457,16 @@ func (bldr *Builder) processRepeatedConstraints(
valEval.Append(listEval)
return nil
}

type MessageCache map[protoreflect.MessageDescriptor]*message

func (c MessageCache) Clone() MessageCache {
newCache := make(MessageCache, len(c)+1)
c.SyncTo(newCache)
return newCache
}
func (c MessageCache) SyncTo(other MessageCache) {
for k, v := range c {
other[k] = v
}
}
Loading