Skip to content

Commit

Permalink
Fix #10.
Browse files Browse the repository at this point in the history
  • Loading branch information
ncruces committed Sep 11, 2023
1 parent 82a063f commit 7f548ce
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 16 deletions.
54 changes: 38 additions & 16 deletions cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ func NewCachingResolver(parent *net.Resolver, options ...CacheOption) *net.Resol

// NewCachingDialer adds caching to a net.Resolver.Dial function.
func NewCachingDialer(parent DialFunc, options ...CacheOption) DialFunc {
var cache = cache{dial: parent}
var cache = cache{dial: parent, negative: true}
for _, o := range options {
o.apply(&cache)
}
Expand All @@ -47,10 +47,12 @@ type CacheOption interface {
type maxEntriesOption int
type maxTTLOption time.Duration
type minTTLOption time.Duration
type negativeCacheOption bool

func (o maxEntriesOption) apply(c *cache) { c.maxEntries = int(o) }
func (o maxTTLOption) apply(c *cache) { c.maxTTL = time.Duration(o) }
func (o minTTLOption) apply(c *cache) { c.minTTL = time.Duration(o) }
func (o maxEntriesOption) apply(c *cache) { c.maxEntries = int(o) }
func (o maxTTLOption) apply(c *cache) { c.maxTTL = time.Duration(o) }
func (o minTTLOption) apply(c *cache) { c.minTTL = time.Duration(o) }
func (o negativeCacheOption) apply(c *cache) { c.negative = bool(o) }

// MaxCacheEntries sets the maximum number of entries to cache.
// If zero, DefaultMaxCacheEntries is used; negative means no limit.
Expand All @@ -62,6 +64,9 @@ func MaxCacheTTL(d time.Duration) CacheOption { return maxTTLOption(d) }
// MinCacheTTL sets the minimum time-to-live for entries in the cache.
func MinCacheTTL(d time.Duration) CacheOption { return minTTLOption(d) }

// NegativeCache sets whether to cache negative responses.
func NegativeCache(b bool) CacheOption { return negativeCacheOption(b) }

type cache struct {
sync.RWMutex

Expand All @@ -71,6 +76,7 @@ type cache struct {
maxEntries int
maxTTL time.Duration
minTTL time.Duration
negative bool
}

type cacheEntry struct {
Expand All @@ -79,20 +85,13 @@ type cacheEntry struct {
}

func (c *cache) put(req string, res string) {
// ignore invalid/unmatched messages
if len(req) < 12 || len(res) < 12 { // header size
return
}
if req[0] != res[0] || req[1] != res[1] { // IDs match
return
}
if req[2] >= 0x7f || res[2] < 0x7f { // query, response
return
}
if req[2]&0x7a != 0 || res[2]&0x7a != 0 { // standard query, not truncated
// ignore uncacheable/unparseable answers
if invalid(req, res) {
return
}
if res[3]&0xf != 0 && res[3]&0xf != 3 { // no error, or name error

// ignore errors (if requested)
if nameError(res) && !c.negative {
return
}

Expand Down Expand Up @@ -169,6 +168,29 @@ func (c *cache) get(req string) (res string) {
return ""
}

func invalid(req string, res string) bool {
if len(req) < 12 || len(res) < 12 { // header size
return true
}
if req[0] != res[0] || req[1] != res[1] { // IDs match
return true
}
if req[2] >= 0x7f || res[2] < 0x7f { // query, response
return true
}
if req[2]&0x7a != 0 || res[2]&0x7a != 0 { // standard query, not truncated
return true
}
if res[3]&0xf != 0 && res[3]&0xf != 3 { // no error, or name error
return true
}
return false
}

func nameError(res string) bool {
return res[3]&0xf == 3
}

func getTTL(msg string) time.Duration {
ttl := math.MaxInt32

Expand Down
44 changes: 44 additions & 0 deletions fuzz_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
package dns

import (
"testing"

"golang.org/x/net/dns/dnsmessage"
)

func Fuzz_parsing(f *testing.F) {
f.Add("", "")
f.Add("00000000", "00000000")

f.Fuzz(func(t *testing.T, req string, res string) {
var parser dnsmessage.Parser

invalid := invalid(req, res)
hreq, ereq := parser.Start([]byte(req))
hres, eres := parser.Start([]byte(res))

if !invalid {
if ereq != nil || eres != nil { // header size
t.Fail()
}
if hreq.ID != hres.ID { // IDs match
t.Fail()
}
if hreq.Response || !hres.Response { // query, response
t.Fail()
}
if hreq.OpCode != 0 || hres.OpCode != 0 { // standard query
t.Fail()
}
if hreq.Truncated || hres.Truncated { // not truncated
t.Fail()
}
if hres.RCode != 0 && hres.RCode != dnsmessage.RCodeNameError { // no error, or name error
t.Fail()
}
if nameError(res) != (hres.RCode == dnsmessage.RCodeNameError) { // name error
t.Fail()
}
}
})
}
2 changes: 2 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
module github.com/ncruces/go-dns

go 1.18

require golang.org/x/net v0.15.0
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
golang.org/x/net v0.15.0 h1:ugBLEUaxABaB5AJqW9enI0ACdci2RUd4eP51NTBvuJ8=
golang.org/x/net v0.15.0/go.mod h1:idbUs1IY1+zTqbi8yxTbhexhEEk5ur9LInksu6HrEpk=

0 comments on commit 7f548ce

Please sign in to comment.