diff --git a/batcher/batcher.go b/batcher/batcher.go index 37f586f..199e168 100644 --- a/batcher/batcher.go +++ b/batcher/batcher.go @@ -1,7 +1,10 @@ package batcher import ( + "sync/atomic" "time" + + "github.com/DmitriyVTitov/size" ) // FlushCallback is the callback function that will be called when the batcher is full or the flush interval is reached @@ -9,7 +12,11 @@ type FlushCallback[T any] func([]T) // Batcher is a batcher for any type of data type Batcher[T any] struct { - maxCapacity int + maxCapacity int + maxSize int32 + + currentSize atomic.Int32 + flushInterval *time.Duration flushCallback FlushCallback[T] @@ -29,6 +36,13 @@ func WithMaxCapacity[T any](maxCapacity int) BatcherOption[T] { } } +// WithMaxSize sets the max size of the batcher +func WithMaxSize[T any](maxSize int32) BatcherOption[T] { + return func(b *Batcher[T]) { + b.maxSize = maxSize + } +} + // WithFlushInterval sets the optional flush interval of the batcher func WithFlushInterval[T any](flushInterval time.Duration) BatcherOption[T] { return func(b *Batcher[T]) { @@ -53,6 +67,9 @@ func New[T any](opts ...BatcherOption[T]) *Batcher[T] { for _, opt := range opts { opt(batcher) } + if batcher.maxSize > 0 { + batcher.currentSize = atomic.Int32{} + } batcher.incomingData = make(chan T, batcher.maxCapacity) if batcher.flushCallback == nil { panic("batcher: flush callback is required") @@ -66,11 +83,22 @@ func New[T any](opts ...BatcherOption[T]) *Batcher[T] { // Append appends data to the batcher func (b *Batcher[T]) Append(d ...T) { for _, item := range d { + sizeofItem := size.Of(item) + currentSize := b.currentSize.Load() + + if b.maxSize > 0 && currentSize+int32(sizeofItem) > int32(b.maxSize) { + b.full <- true + b.incomingData <- item + b.currentSize.Add(int32(sizeofItem)) + continue + } + if !b.put(item) { // will wait until space available b.full <- true b.incomingData <- item } + b.currentSize.Add(int32(sizeofItem)) } } @@ -148,6 +176,7 @@ func (b *Batcher[T]) doCallback() { for item := range b.incomingData { items[k] = item k++ + b.currentSize.Add(-int32(size.Of(item))) if k >= n { break } diff --git a/batcher/batcher_test.go b/batcher/batcher_test.go index 704880b..335d2cf 100644 --- a/batcher/batcher_test.go +++ b/batcher/batcher_test.go @@ -1,6 +1,7 @@ package batcher import ( + "crypto/rand" "testing" "time" @@ -74,3 +75,45 @@ func TestBatcherWithInterval(t *testing.T) { require.Equal(t, wanted, got) require.True(t, minWantedBatches <= gotBatches) } + +type exampleBatcherStruct struct { + Value []byte +} + +func TestBatcherWithSizeLimit(t *testing.T) { + var ( + batchSize = 100 + maxSize = 1000 + wanted = 10 + gotBatches int + ) + var failedIteration bool + + callback := func(ta []exampleBatcherStruct) { + gotBatches++ + + if len(ta) != 5 { + failedIteration = true + } + } + bat := New[exampleBatcherStruct]( + WithMaxCapacity[exampleBatcherStruct](batchSize), + WithMaxSize[exampleBatcherStruct](int32(maxSize)), + WithFlushCallback[exampleBatcherStruct](callback), + ) + + bat.Run() + + for i := 0; i < wanted; i++ { + randData := make([]byte, 200) + _, _ = rand.Read(randData) + bat.Append(exampleBatcherStruct{Value: randData}) + } + + bat.Stop() + + bat.WaitDone() + + require.Equal(t, 2, gotBatches) + require.False(t, failedIteration) +} diff --git a/go.mod b/go.mod index 3674fa5..f446899 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,7 @@ module github.com/projectdiscovery/utils go 1.21 require ( + github.com/DmitriyVTitov/size v1.5.0 github.com/Masterminds/semver/v3 v3.2.1 github.com/Mzack9999/gcache v0.0.0-20230410081825-519e28eab057 github.com/andybalholm/brotli v1.0.6 diff --git a/go.sum b/go.sum index f6582a1..d4d4b28 100644 --- a/go.sum +++ b/go.sum @@ -1,6 +1,8 @@ aead.dev/minisign v0.2.0 h1:kAWrq/hBRu4AARY6AlciO83xhNnW9UaC8YipS2uhLPk= aead.dev/minisign v0.2.0/go.mod h1:zdq6LdSd9TbuSxchxwhpA9zEb9YXcVGoE8JakuiGaIQ= cloud.google.com/go/compute/metadata v0.2.0/go.mod h1:zFmK7XCadkQkj6TtorcaGlCW1hT1fIilQDwofLpJ20k= +github.com/DmitriyVTitov/size v1.5.0 h1:/PzqxYrOyOUX1BXj6J9OuVRVGe+66VL4D9FlUaW515g= +github.com/DmitriyVTitov/size v1.5.0/go.mod h1:le6rNI4CoLQV1b9gzp1+3d7hMAD/uu2QcJ+aYbNgiU0= github.com/Masterminds/semver/v3 v3.2.1 h1:RN9w6+7QoMeJVGyfmbcgs28Br8cvmnucEXnY0rYXWg0= github.com/Masterminds/semver/v3 v3.2.1/go.mod h1:qvl/7zhW3nngYb5+80sSMF+FG2BjYrf8m9wsX0PNOMQ= github.com/Mzack9999/gcache v0.0.0-20230410081825-519e28eab057 h1:KFac3SiGbId8ub47e7kd2PLZeACxc1LkiiNoDOFRClE= @@ -71,6 +73,8 @@ github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiU github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 h1:tfuBGBXKqDEevZMzYi5KSi8KkcZtzBcTgAUUtapy0OI= github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572/go.mod h1:9Pwr4B2jHnOSGXyyzV8ROjYa2ojvAY6HCGYYfMoC3Ls= github.com/gofrs/uuid v3.3.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRxnplIgP/c0N/04lM= +github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da h1:oI5xCqsCo564l8iNU+DwB5epxmsaqB+rhGL0m5jtYqE= +github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= diff --git a/reflect/reflectutil.go b/reflect/reflectutil.go index 478e577..9c7b6b6 100644 --- a/reflect/reflectutil.go +++ b/reflect/reflectutil.go @@ -94,3 +94,138 @@ func setUnexportedField(field reflect.Value, value interface{}) { Elem(). Set(reflect.ValueOf(value)) } + +// SizeOf returns the size of 'v' in bytes. +// If there is an error during calculation, Of returns -1. +// +// Implementation is taken from https://github.com/DmitriyVTitov/size/blob/v1.5.0/size.go#L14 which +// in turn is inspired from binary.Size of stdlib +func SizeOf(v interface{}) int { + // Cache with every visited pointer so we don't count two pointers + // to the same memory twice. + cache := make(map[uintptr]bool) + return sizeOf(reflect.Indirect(reflect.ValueOf(v)), cache) +} + +// sizeOf returns the number of bytes the actual data represented by v occupies in memory. +// If there is an error, sizeOf returns -1. +func sizeOf(v reflect.Value, cache map[uintptr]bool) int { + switch v.Kind() { + + case reflect.Array: + sum := 0 + for i := 0; i < v.Len(); i++ { + s := sizeOf(v.Index(i), cache) + if s < 0 { + return -1 + } + sum += s + } + + return sum + (v.Cap()-v.Len())*int(v.Type().Elem().Size()) + + case reflect.Slice: + // return 0 if this node has been visited already + if cache[v.Pointer()] { + return 0 + } + cache[v.Pointer()] = true + + sum := 0 + for i := 0; i < v.Len(); i++ { + s := sizeOf(v.Index(i), cache) + if s < 0 { + return -1 + } + sum += s + } + + sum += (v.Cap() - v.Len()) * int(v.Type().Elem().Size()) + + return sum + int(v.Type().Size()) + + case reflect.Struct: + sum := 0 + for i, n := 0, v.NumField(); i < n; i++ { + s := sizeOf(v.Field(i), cache) + if s < 0 { + return -1 + } + sum += s + } + + // Look for struct padding. + padding := int(v.Type().Size()) + for i, n := 0, v.NumField(); i < n; i++ { + padding -= int(v.Field(i).Type().Size()) + } + + return sum + padding + + case reflect.String: + s := v.String() + hdr := (*reflect.StringHeader)(unsafe.Pointer(&s)) + if cache[hdr.Data] { + return int(v.Type().Size()) + } + cache[hdr.Data] = true + return len(s) + int(v.Type().Size()) + + case reflect.Ptr: + // return Ptr size if this node has been visited already (infinite recursion) + if cache[v.Pointer()] { + return int(v.Type().Size()) + } + cache[v.Pointer()] = true + if v.IsNil() { + return int(reflect.New(v.Type()).Type().Size()) + } + s := sizeOf(reflect.Indirect(v), cache) + if s < 0 { + return -1 + } + return s + int(v.Type().Size()) + + case reflect.Bool, + reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, + reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, + reflect.Int, reflect.Uint, + reflect.Chan, + reflect.Uintptr, + reflect.Float32, reflect.Float64, reflect.Complex64, reflect.Complex128, + reflect.Func: + return int(v.Type().Size()) + + case reflect.Map: + // return 0 if this node has been visited already (infinite recursion) + if cache[v.Pointer()] { + return 0 + } + cache[v.Pointer()] = true + sum := 0 + keys := v.MapKeys() + for i := range keys { + val := v.MapIndex(keys[i]) + // calculate size of key and value separately + sv := sizeOf(val, cache) + if sv < 0 { + return -1 + } + sum += sv + sk := sizeOf(keys[i], cache) + if sk < 0 { + return -1 + } + sum += sk + } + // Include overhead due to unused map buckets. 10.79 comes + // from https://golang.org/src/runtime/map.go. + return sum + int(v.Type().Size()) + int(float64(len(keys))*10.79) + + case reflect.Interface: + return sizeOf(v.Elem(), cache) + int(v.Type().Size()) + + } + + return -1 +} diff --git a/reflect/reflectutil_test.go b/reflect/reflectutil_test.go index 34bfaf0..3683cfc 100644 --- a/reflect/reflectutil_test.go +++ b/reflect/reflectutil_test.go @@ -49,3 +49,69 @@ func TestUnexportedField(t *testing.T) { value := GetUnexportedField(testStruct, "unexported") require.Equal(t, value, "test") } + +// Test taken from https://github.com/DmitriyVTitov/size/blob/v1.5.0/size_test.go +func TestSizeOf(t *testing.T) { + tests := []struct { + name string + v interface{} + want int + }{ + { + name: "Array", + v: [3]int32{1, 2, 3}, // 3 * 4 = 12 + want: 12, + }, + { + name: "Slice", + v: make([]int64, 2, 5), // 5 * 8 + 24 = 64 + want: 64, + }, + { + name: "String", + v: "ABCdef", // 6 + 16 = 22 + want: 22, + }, + { + name: "Map", + // (8 + 3 + 16) + (8 + 4 + 16) = 55 + // 55 + 8 + 10.79 * 2 = 84 + v: map[int64]string{0: "ABC", 1: "DEFG"}, + want: 84, + }, + { + name: "Struct", + v: struct { + slice []int64 + array [2]bool + structure struct { + i int8 + s string + } + }{ + slice: []int64{12345, 67890}, // 2 * 8 + 24 = 40 + array: [2]bool{true, false}, // 2 * 1 = 2 + structure: struct { + i int8 + s string + }{ + i: 5, // 1 + s: "abc", // 3 * 1 + 16 = 19 + }, // 20 + 7 (padding) = 27 + }, // 40 + 2 + 27 = 69 + 6 (padding) = 75 + want: 75, + }, + { + name: "Pointer", + v: new(int64), // 8 + want: 8, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := SizeOf(tt.v); got != tt.want { + t.Errorf("Of() = %v, want %v", got, tt.want) + } + }) + } +}