diff --git a/cache.go b/cache.go index 37f6c15..927a0cc 100644 --- a/cache.go +++ b/cache.go @@ -23,7 +23,7 @@ func NewCachingResolver(parent *net.Resolver, options ...CacheOption) *net.Resol // NewCachingDialer adds caching to a net.Resolver.Dial function. func NewCachingDialer(parent DialFunc, options ...CacheOption) DialFunc { - var cache = cache{dial: parent} + var cache = cache{dial: parent, negative: true} for _, o := range options { o.apply(&cache) } @@ -47,10 +47,12 @@ type CacheOption interface { type maxEntriesOption int type maxTTLOption time.Duration type minTTLOption time.Duration +type negativeCacheOption bool -func (o maxEntriesOption) apply(c *cache) { c.maxEntries = int(o) } -func (o maxTTLOption) apply(c *cache) { c.maxTTL = time.Duration(o) } -func (o minTTLOption) apply(c *cache) { c.minTTL = time.Duration(o) } +func (o maxEntriesOption) apply(c *cache) { c.maxEntries = int(o) } +func (o maxTTLOption) apply(c *cache) { c.maxTTL = time.Duration(o) } +func (o minTTLOption) apply(c *cache) { c.minTTL = time.Duration(o) } +func (o negativeCacheOption) apply(c *cache) { c.negative = bool(o) } // MaxCacheEntries sets the maximum number of entries to cache. // If zero, DefaultMaxCacheEntries is used; negative means no limit. @@ -62,6 +64,9 @@ func MaxCacheTTL(d time.Duration) CacheOption { return maxTTLOption(d) } // MinCacheTTL sets the minimum time-to-live for entries in the cache. func MinCacheTTL(d time.Duration) CacheOption { return minTTLOption(d) } +// NegativeCache sets whether to cache negative responses. +func NegativeCache(b bool) CacheOption { return negativeCacheOption(b) } + type cache struct { sync.RWMutex @@ -71,6 +76,7 @@ type cache struct { maxEntries int maxTTL time.Duration minTTL time.Duration + negative bool } type cacheEntry struct { @@ -79,20 +85,13 @@ type cacheEntry struct { } func (c *cache) put(req string, res string) { - // ignore invalid/unmatched messages - if len(req) < 12 || len(res) < 12 { // header size - return - } - if req[0] != res[0] || req[1] != res[1] { // IDs match - return - } - if req[2] >= 0x7f || res[2] < 0x7f { // query, response - return - } - if req[2]&0x7a != 0 || res[2]&0x7a != 0 { // standard query, not truncated + // ignore uncacheable/unparseable answers + if invalid(req, res) { return } - if res[3]&0xf != 0 && res[3]&0xf != 3 { // no error, or name error + + // ignore errors (if requested) + if nameError(res) && !c.negative { return } @@ -169,6 +168,29 @@ func (c *cache) get(req string) (res string) { return "" } +func invalid(req string, res string) bool { + if len(req) < 12 || len(res) < 12 { // header size + return true + } + if req[0] != res[0] || req[1] != res[1] { // IDs match + return true + } + if req[2] >= 0x7f || res[2] < 0x7f { // query, response + return true + } + if req[2]&0x7a != 0 || res[2]&0x7a != 0 { // standard query, not truncated + return true + } + if res[3]&0xf != 0 && res[3]&0xf != 3 { // no error, or name error + return true + } + return false +} + +func nameError(res string) bool { + return res[3]&0xf == 3 +} + func getTTL(msg string) time.Duration { ttl := math.MaxInt32 diff --git a/fuzz_test.go b/fuzz_test.go new file mode 100644 index 0000000..d7073e3 --- /dev/null +++ b/fuzz_test.go @@ -0,0 +1,44 @@ +package dns + +import ( + "testing" + + "golang.org/x/net/dns/dnsmessage" +) + +func Fuzz_parsing(f *testing.F) { + f.Add("", "") + f.Add("00000000", "00000000") + + f.Fuzz(func(t *testing.T, req string, res string) { + var parser dnsmessage.Parser + + invalid := invalid(req, res) + hreq, ereq := parser.Start([]byte(req)) + hres, eres := parser.Start([]byte(res)) + + if !invalid { + if ereq != nil || eres != nil { // header size + t.Fail() + } + if hreq.ID != hres.ID { // IDs match + t.Fail() + } + if hreq.Response || !hres.Response { // query, response + t.Fail() + } + if hreq.OpCode != 0 || hres.OpCode != 0 { // standard query + t.Fail() + } + if hreq.Truncated || hres.Truncated { // not truncated + t.Fail() + } + if hres.RCode != 0 && hres.RCode != dnsmessage.RCodeNameError { // no error, or name error + t.Fail() + } + if nameError(res) != (hres.RCode == dnsmessage.RCodeNameError) { // name error + t.Fail() + } + } + }) +} diff --git a/go.mod b/go.mod index 8ad7371..d180b7d 100644 --- a/go.mod +++ b/go.mod @@ -1,3 +1,5 @@ module github.com/ncruces/go-dns go 1.18 + +require golang.org/x/net v0.15.0 diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..9746d02 --- /dev/null +++ b/go.sum @@ -0,0 +1,2 @@ +golang.org/x/net v0.15.0 h1:ugBLEUaxABaB5AJqW9enI0ACdci2RUd4eP51NTBvuJ8= +golang.org/x/net v0.15.0/go.mod h1:idbUs1IY1+zTqbi8yxTbhexhEEk5ur9LInksu6HrEpk=