diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml index b45f167..79731af 100644 --- a/.github/workflows/go.yml +++ b/.github/workflows/go.yml @@ -8,7 +8,7 @@ jobs: name: Build strategy: matrix: - go-version: [1.23.0-rc.1] + go-version: [1.23, 1.24] platform: [ubuntu-latest, macos-latest, windows-latest] runs-on: ${{ matrix.platform }} steps: diff --git a/.gitignore b/.gitignore index fe3fa4a..b7cb412 100644 --- a/.gitignore +++ b/.gitignore @@ -2,3 +2,10 @@ *.out *.sw? *.test + +# Claude Code session files +.claude/ +CLAUDE.md + +# Test databases that shouldn't be committed +*.mmdb diff --git a/.golangci.yml b/.golangci.yml index 86dab88..a2c4c80 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -21,6 +21,7 @@ linters: - interfacebloat - mnd - nlreturn + - noinlineerr - nonamedreturns - paralleltest - testpackage @@ -28,6 +29,7 @@ linters: - varnamelen - wrapcheck - wsl + - wsl_v5 settings: errorlint: errorf: true @@ -60,6 +62,9 @@ linters: gosec: excludes: - G115 + # Potential file inclusion via variable - we only open files asked by + # the user of the API. + - G304 govet: disable: - shadow @@ -135,22 +140,13 @@ linters: unparam: check-exported: true exclusions: - generated: lax - presets: - - comments - - common-false-positives - - legacy - - std-error-handling + warn-unused: true rules: - linters: - govet - revive path: _test.go text: 'fieldalignment:' - paths: - - third_party$ - - builtin$ - - examples$ formatters: enable: - gci @@ -166,9 +162,3 @@ formatters: - prefix(github.com/oschwald/maxminddb-golang) gofumpt: extra-rules: true - exclusions: - generated: lax - paths: - - third_party$ - - builtin$ - - examples$ diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 0000000..ce23df1 --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,209 @@ +# Changes + +## 2.0.0-beta.4 + +- **BREAKING CHANGE**: Removed experimental `deserializer` interface and + supporting code. Applications using this interface should migrate to the + `Unmarshaler` interface by implementing `UnmarshalMaxMindDB(d *Decoder) error` + instead. +- `Open` and `FromBytes` now accept options. +- **BREAKING CHANGE**: `IncludeNetworksWithoutData` and `IncludeAliasedNetworks` + now return a `NetworksOption` rather than being one themselves. These must now + be called as functions: `Networks(IncludeAliasedNetworks())` instead of + `Networks(IncludeAliasedNetworks)`. This was done to improve the documentation + organization. +- Added `Unmarshaler` interface to allow custom decoding implementations for + performance-critical applications. Types implementing + `UnmarshalMaxMindDB(d *Decoder) error` will automatically use custom decoding + logic instead of reflection, following the same pattern as + `json.Unmarshaler`. +- Added public `Decoder` type and `Kind` constants in `mmdbdata` package for + manual decoding. `Decoder` provides methods like `ReadMap()`, `ReadSlice()`, + `ReadString()`, `ReadUInt32()`, `PeekKind()`, etc. `Kind` type includes + helper methods `String()`, `IsContainer()`, and `IsScalar()` for type + introspection. The main `maxminddb` package re-exports these types for + backward compatibility. `NewDecoder()` supports an options pattern for + future extensibility. +- Enhanced `UnmarshalMaxMindDB` to work with nested struct fields, slice + elements, and map values. The custom unmarshaler is now called recursively + for any type that implements the `Unmarshaler` interface, similar to + `encoding/json`. +- Improved error messages to include byte offset information and, for the + reflection-based API, path information for nested structures using JSON + Pointer format. For example, errors may now show "at offset 1234, path + /city/names/en" or "at offset 1234, path /list/0/name" instead of just the + underlying error message. +- **PERFORMANCE**: Added string interning optimization that reduces allocations + while maintaining thread safety. Reduces allocation count from 33 to 10 per + operation in downstream libraries. Uses a fixed 512-entry cache with per-entry + mutexes for bounded memory usage (~8KB) while minimizing lock contention. + +## 2.0.0-beta.3 - 2025-02-16 + +- `Open` will now fall back to loading the database in memory if the + file-system does not support `mmap`. Pull request by database64128. GitHub + #163. +- Made significant improvements to the Windows memory-map handling. GitHub + #162. +- Fix an integer overflow on large databases when using a 32-bit architecture. + See ipinfo/mmdbctl#33. + +## 2.0.0-beta.2 - 2024-11-14 + +- Allow negative indexes for arrays when using `DecodePath`. #152 +- Add `IncludeNetworksWithoutData` option for `Networks` and `NetworksWithin`. + #155 and #156 + +## 2.0.0-beta.1 - 2024-08-18 + +This is the first beta of the v2 releases. Go 1.23 is required. I don't expect +to do a final release until Go 1.24 is available. See #141 for the v2 roadmap. + +Notable changes: + +- `(*Reader).Lookup` now takes only the IP address and returns a `Result`. + `Lookup(ip, &rec)` would now become `Lookup(ip).Decode(&rec)`. +- `(*Reader).LookupNetwork` has been removed. To get the network for a result, + use `(Result).Prefix()`. +- `(*Reader).LookupOffset` now _takes_ an offset and returns a `Result`. + `Result` has an `Offset()` method that returns the offset value. + `(*Reader).Decode` has been removed. +- Use of `net.IP` and `*net.IPNet` have been replaced with `netip.Addr` and + `netip.Prefix`. +- You may now decode a particular path within a database record using + `(Result).DecodePath`. For instance, to decode just the country code in + GeoLite2 Country to a string called `code`, you might do something like + `Lookup(ip).DecodePath(&code, "country", "iso_code")`. Strings should be used + for map keys and ints for array indexes. +- `(*Reader).Networks` and `(*Reader).NetworksWithin` now return a Go 1.23 + iterator of `Result` values. Aliased networks are now skipped by default. If + you wish to include them, use the `IncludeAliasedNetworks` option. + +## 1.13.1 - 2024-06-28 + +- Return the `*net.IPNet` in canonical form when using `NetworksWithin` to look + up a network more specific than the one in the database. Previously, the `IP` + field on the `*net.IPNet` would be set to the IP from the lookup network + rather than the first IP of the network. +- `NetworksWithin` will now correctly handle an `*net.IPNet` parameter that is + not in canonical form. This issue would only occur if the `*net.IPNet` was + manually constructed, as `net.ParseCIDR` returns the value in canonical form + even if the input string is not. + +## 1.13.0 - 2024-06-03 + +- Go 1.21 or greater is now required. +- The error messages when decoding have been improved. #119 + +## 1.12.0 - 2023-08-01 + +- The `wasi` target is now built without memory-mapping support. Pull request + by Alex Kashintsev. GitHub #114. +- When decoding to a map of non-scalar, non-interface types such as a + `map[string]map[string]any`, the decoder failed to zero out the value for the + map elements, which could result in incorrect decoding. Reported by JT Olio. + GitHub #115. + +## 1.11.0 - 2023-06-18 + +- `wasm` and `wasip1` targets are now built without memory-mapping support. + Pull request by Randy Reddig. GitHub #110. + +**Full Changelog**: +https://github.com/oschwald/maxminddb-golang/compare/v1.10.0...v1.11.0 + +## 1.10.0 - 2022-08-07 + +- Set Go version in go.mod file to 1.18. + +## 1.9.0 - 2022-03-26 + +- Set the minimum Go version in the go.mod file to 1.17. +- Updated dependencies. +- Minor performance improvements to the custom deserializer feature added in + 1.8.0. + +## 1.8.0 - 2020-11-23 + +- Added `maxminddb.SkipAliasedNetworks` option to `Networks` and + `NetworksWithin` methods. When set, this option will cause the iterator to + skip networks that are aliases of the IPv4 tree. +- Added experimental custom deserializer support. This allows much more control + over the deserialization. The API is subject to change and you should use at + your own risk. + +## 1.7.0 - 2020-06-13 + +- Add `NetworksWithin` method. This returns an iterator that traverses all + networks in the database that are contained in the given network. Pull + request by Olaf Alders. GitHub #65. + +## 1.6.0 - 2019-12-25 + +- This module now uses Go modules. Requested by Matthew Rothenberg. GitHub #49. +- Plan 9 is now supported. Pull request by Jacob Moody. GitHub #61. +- Documentation fixes. Pull request by Olaf Alders. GitHub #62. +- Thread-safety is now mentioned in the documentation. Requested by Ken + Sedgwick. GitHub #39. +- Fix off-by-one error in file offset safety check. Reported by Will Storey. + GitHub #63. + +## 1.5.0 - 2019-09-11 + +- Drop support for Go 1.7 and 1.8. +- Minor performance improvements. + +## 1.4.0 - 2019-08-28 + +- Add the method `LookupNetwork`. This returns the network that the record + belongs to as well as a boolean indicating whether there was a record for the + IP address in the database. GitHub #59. +- Improve performance. + +## 1.3.1 - 2019-08-28 + +- Fix issue with the finalizer running too early on Go 1.12 when using the + Verify method. Reported by Robert-AndrĂ© Mauchin. GitHub #55. +- Remove unnecessary call to reflect.ValueOf. PR by SenseyeDeveloper. GitHub + #53. + +## 1.3.0 - 2018-02-25 + +- The methods on the `maxminddb.Reader` struct now return an error if called on + a closed database reader. Previously, this could cause a segmentation + violation when using a memory-mapped file. +- The `Close` method on the `maxminddb.Reader` struct now sets the underlying + buffer to nil, even when using `FromBytes` or `Open` on Google App Engine. +- No longer uses constants from `syscall` + +## 1.2.1 - 2018-01-03 + +- Fix incorrect index being used when decoding into anonymous struct fields. PR + #42 by Andy Bursavich. + +## 1.2.0 - 2017-05-05 + +- The database decoder now does bound checking when decoding data from the + database. This is to help ensure that the reader does not panic when given a + corrupt database to decode. Closes #37. +- The reader will now return an error on a data structure with a depth greater + than 512. This is done to prevent the possibility of a stack overflow on a + cyclic data structure in a corrupt database. This matches the maximum depth + allowed by `libmaxminddb`. All MaxMind databases currently have a depth of + less than five. + +## 1.1.0 - 2016-12-31 + +- Added appengine build tag for Windows. When enabled, memory-mapping will be + disabled in the Windows build as it is for the non-Windows build. Pull + request #35 by Ingo Oeser. +- SetFinalizer is now used to unmap files if the user fails to close the + reader. Using `r.Close()` is still recommended for most use cases. +- Previously, an unsafe conversion between `[]byte` and string was used to + avoid unnecessary allocations when decoding struct keys. The decoder now + relies on a compiler optimization on `string([]byte)` map lookups to achieve + this rather than using `unsafe`. + +## 1.0.0 - 2016-11-09 + +New release for those using tagged releases. diff --git a/README.md b/README.md index c99991b..83e3d26 100644 --- a/README.md +++ b/README.md @@ -1,36 +1,256 @@ -# MaxMind DB Reader for Go # +# MaxMind DB Reader for Go [![Go Reference](https://pkg.go.dev/badge/github.com/oschwald/maxminddb-golang/v2.svg)](https://pkg.go.dev/github.com/oschwald/maxminddb-golang/v2) This is a Go reader for the MaxMind DB format. Although this can be used to -read [GeoLite2](http://dev.maxmind.com/geoip/geoip2/geolite2/) and -[GeoIP2](https://www.maxmind.com/en/geoip2-databases) databases, -[geoip2](https://github.com/oschwald/geoip2-golang) provides a higher-level -API for doing so. +read [GeoLite2](https://dev.maxmind.com/geoip/geolite2-free-geolocation-data) +and [GeoIP2](https://www.maxmind.com/en/geoip2-databases) databases, +[geoip2](https://github.com/oschwald/geoip2-golang) provides a higher-level API +for doing so. This is not an official MaxMind API. -## Installation ## +## Installation -``` +```bash go get github.com/oschwald/maxminddb-golang/v2 ``` -## Usage ## +## Version 2.0 Features + +Version 2.0 includes significant improvements: + +- **Modern API**: Uses `netip.Addr` instead of `net.IP` for better performance +- **Custom Unmarshaling**: Implement `Unmarshaler` interface for + zero-allocation decoding +- **Network Iteration**: Iterate over all networks in a database with + `Networks()` and `NetworksWithin()` +- **Enhanced Performance**: Optimized data structures and decoding paths +- **Go 1.23+ Support**: Takes advantage of modern Go features including + iterators +- **Better Error Handling**: More detailed error types and improved debugging + +## Quick Start + +```go +package main + +import ( + "fmt" + "log" + "net/netip" + + "github.com/oschwald/maxminddb-golang/v2" +) + +func main() { + db, err := maxminddb.Open("GeoLite2-City.mmdb") + if err != nil { + log.Fatal(err) + } + defer db.Close() + + ip, err := netip.ParseAddr("81.2.69.142") + if err != nil { + log.Fatal(err) + } + + var record struct { + Country struct { + ISOCode string `maxminddb:"iso_code"` + Names map[string]string `maxminddb:"names"` + } `maxminddb:"country"` + City struct { + Names map[string]string `maxminddb:"names"` + } `maxminddb:"city"` + } + + err = db.Lookup(ip).Decode(&record) + if err != nil { + log.Fatal(err) + } + + fmt.Printf("Country: %s (%s)\n", record.Country.Names["en"], record.Country.ISOCode) + fmt.Printf("City: %s\n", record.City.Names["en"]) +} +``` + +## Usage Patterns + +### Basic Lookup + +```go +db, err := maxminddb.Open("GeoLite2-City.mmdb") +if err != nil { + log.Fatal(err) +} +defer db.Close() + +var record any +ip := netip.MustParseAddr("1.2.3.4") +err = db.Lookup(ip).Decode(&record) +``` + +### Custom Struct Decoding + +```go +type City struct { + Country struct { + ISOCode string `maxminddb:"iso_code"` + Names struct { + English string `maxminddb:"en"` + German string `maxminddb:"de"` + } `maxminddb:"names"` + } `maxminddb:"country"` +} + +var city City +err = db.Lookup(ip).Decode(&city) +``` + +### High-Performance Custom Unmarshaling + +```go +type FastCity struct { + CountryISO string + CityName string +} + +func (c *FastCity) UnmarshalMaxMindDB(d *maxminddb.Decoder) error { + mapIter, size, err := d.ReadMap() + if err != nil { + return err + } + // Pre-allocate with correct capacity for better performance + _ = size // Use for pre-allocation if storing map data + for key, err := range mapIter { + if err != nil { + return err + } + switch string(key) { + case "country": + countryIter, _, err := d.ReadMap() + if err != nil { + return err + } + for countryKey, countryErr := range countryIter { + if countryErr != nil { + return countryErr + } + if string(countryKey) == "iso_code" { + c.CountryISO, err = d.ReadString() + if err != nil { + return err + } + } else { + if err := d.SkipValue(); err != nil { + return err + } + } + } + default: + if err := d.SkipValue(); err != nil { + return err + } + } + } + return nil +} +``` + +### Network Iteration + +```go +// Iterate over all networks in the database +for result := range db.Networks() { + var record struct { + Country struct { + ISOCode string `maxminddb:"iso_code"` + } `maxminddb:"country"` + } + err := result.Decode(&record) + if err != nil { + log.Fatal(err) + } + fmt.Printf("%s: %s\n", result.Prefix(), record.Country.ISOCode) +} + +// Iterate over networks within a specific prefix +prefix := netip.MustParsePrefix("192.168.0.0/16") +for result := range db.NetworksWithin(prefix) { + // Process networks within 192.168.0.0/16 +} +``` + +### Path-Based Decoding + +```go +var countryCode string +err = db.Lookup(ip).DecodePath(&countryCode, "country", "iso_code") + +var cityName string +err = db.Lookup(ip).DecodePath(&cityName, "city", "names", "en") +``` + +## Supported Database Types + +This library supports **all MaxMind DB (.mmdb) format databases**, including: + +**MaxMind Official Databases:** + +- **GeoLite/GeoIP City**: Comprehensive location data including city, country, + subdivisions +- **GeoLite/GeoIP Country**: Country-level geolocation data +- **GeoLite ASN**: Autonomous System Number and organization data +- **GeoIP Anonymous IP**: Anonymous network and proxy detection +- **GeoIP Enterprise**: Enhanced City data with additional business fields +- **GeoIP ISP**: Internet service provider information +- **GeoIP Domain**: Second-level domain data +- **GeoIP Connection Type**: Connection type identification + +**Third-Party Databases:** + +- **DB-IP databases**: Compatible with DB-IP's .mmdb format databases +- **IPinfo databases**: Works with IPinfo's MaxMind DB format files +- **Custom databases**: Any database following the MaxMind DB file format + specification + +The library is format-agnostic and will work with any valid .mmdb file +regardless of the data provider. + +## Performance Tips + +1. **Reuse Reader instances**: The `Reader` is thread-safe and should be reused + across goroutines +2. **Use specific structs**: Only decode the fields you need rather than using + `any` +3. **Implement Unmarshaler**: For high-throughput applications, implement + custom unmarshaling +4. **Consider caching**: Use `Result.Offset()` as a cache key for database + records + +## Getting Database Files + +### Free GeoLite2 Databases + +Download from +[MaxMind's GeoLite page](https://dev.maxmind.com/geoip/geolite2-free-geolocation-data). + +## Documentation -[See GoDoc](http://godoc.org/github.com/oschwald/maxminddb-golang) for -documentation and examples. +- [Go Reference](https://pkg.go.dev/github.com/oschwald/maxminddb-golang/v2) +- [MaxMind DB File Format Specification](https://maxmind.github.io/MaxMind-DB/) -## Examples ## +## Requirements -See [GoDoc](http://godoc.org/github.com/oschwald/maxminddb-golang) or -`example_test.go` for examples. +- Go 1.23 or later +- MaxMind DB file in .mmdb format -## Contributing ## +## Contributing -Contributions welcome! Please fork the repository and open a pull request -with your changes. +Contributions welcome! Please fork the repository and open a pull request with +your changes. -## License ## +## License This is free software, licensed under the ISC License. diff --git a/decoder.go b/decoder.go deleted file mode 100644 index 273f170..0000000 --- a/decoder.go +++ /dev/null @@ -1,982 +0,0 @@ -package maxminddb - -import ( - "encoding/binary" - "fmt" - "math" - "math/big" - "reflect" - "sync" -) - -type decoder struct { - buffer []byte -} - -type dataType int - -const ( - _Extended dataType = iota - _Pointer - _String - _Float64 - _Bytes - _Uint16 - _Uint32 - _Map - _Int32 - _Uint64 - _Uint128 - _Slice - // We don't use the next two. They are placeholders. See the spec - // for more details. - _Container //nolint:deadcode,varcheck // above - _Marker //nolint:deadcode,varcheck // above - _Bool - _Float32 -) - -const ( - // This is the value used in libmaxminddb. - maximumDataStructureDepth = 512 -) - -func (d *decoder) decode(offset uint, result reflect.Value, depth int) (uint, error) { - if depth > maximumDataStructureDepth { - return 0, newInvalidDatabaseError( - "exceeded maximum data structure depth; database is likely corrupt", - ) - } - typeNum, size, newOffset, err := d.decodeCtrlData(offset) - if err != nil { - return 0, err - } - - if typeNum != _Pointer && result.Kind() == reflect.Uintptr { - result.Set(reflect.ValueOf(uintptr(offset))) - return d.nextValueOffset(offset, 1) - } - return d.decodeFromType(typeNum, size, newOffset, result, depth+1) -} - -func (d *decoder) decodeToDeserializer( - offset uint, - dser deserializer, - depth int, - getNext bool, -) (uint, error) { - if depth > maximumDataStructureDepth { - return 0, newInvalidDatabaseError( - "exceeded maximum data structure depth; database is likely corrupt", - ) - } - skip, err := dser.ShouldSkip(uintptr(offset)) - if err != nil { - return 0, err - } - if skip { - if getNext { - return d.nextValueOffset(offset, 1) - } - return 0, nil - } - - typeNum, size, newOffset, err := d.decodeCtrlData(offset) - if err != nil { - return 0, err - } - - return d.decodeFromTypeToDeserializer(typeNum, size, newOffset, dser, depth+1) -} - -func (d *decoder) decodePath( - offset uint, - path []any, - result reflect.Value, -) error { -PATH: - for i, v := range path { - var ( - typeNum dataType - size uint - err error - ) - typeNum, size, offset, err = d.decodeCtrlData(offset) - if err != nil { - return err - } - - if typeNum == _Pointer { - pointer, _, err := d.decodePointer(size, offset) - if err != nil { - return err - } - - typeNum, size, offset, err = d.decodeCtrlData(pointer) - if err != nil { - return err - } - } - - switch v := v.(type) { - case string: - // We are expecting a map - if typeNum != _Map { - // XXX - use type names in errors. - return fmt.Errorf("expected a map for %s but found %d", v, typeNum) - } - for range size { - var key []byte - key, offset, err = d.decodeKey(offset) - if err != nil { - return err - } - if string(key) == v { - continue PATH - } - offset, err = d.nextValueOffset(offset, 1) - if err != nil { - return err - } - } - // Not found. Maybe return a boolean? - return nil - case int: - // We are expecting an array - if typeNum != _Slice { - // XXX - use type names in errors. - return fmt.Errorf("expected a slice for %d but found %d", v, typeNum) - } - var i uint - if v < 0 { - if size < uint(-v) { - // Slice is smaller than negative index, not found - return nil - } - i = size - uint(-v) - } else { - if size <= uint(v) { - // Slice is smaller than index, not found - return nil - } - i = uint(v) - } - offset, err = d.nextValueOffset(offset, i) - if err != nil { - return err - } - default: - return fmt.Errorf("unexpected type for %d value in path, %v: %T", i, v, v) - } - } - _, err := d.decode(offset, result, len(path)) - return err -} - -func (d *decoder) decodeCtrlData(offset uint) (dataType, uint, uint, error) { - newOffset := offset + 1 - if offset >= uint(len(d.buffer)) { - return 0, 0, 0, newOffsetError() - } - ctrlByte := d.buffer[offset] - - typeNum := dataType(ctrlByte >> 5) - if typeNum == _Extended { - if newOffset >= uint(len(d.buffer)) { - return 0, 0, 0, newOffsetError() - } - typeNum = dataType(d.buffer[newOffset] + 7) - newOffset++ - } - - var size uint - size, newOffset, err := d.sizeFromCtrlByte(ctrlByte, newOffset, typeNum) - return typeNum, size, newOffset, err -} - -func (d *decoder) sizeFromCtrlByte( - ctrlByte byte, - offset uint, - typeNum dataType, -) (uint, uint, error) { - size := uint(ctrlByte & 0x1f) - if typeNum == _Extended { - return size, offset, nil - } - - var bytesToRead uint - if size < 29 { - return size, offset, nil - } - - bytesToRead = size - 28 - newOffset := offset + bytesToRead - if newOffset > uint(len(d.buffer)) { - return 0, 0, newOffsetError() - } - if size == 29 { - return 29 + uint(d.buffer[offset]), offset + 1, nil - } - - sizeBytes := d.buffer[offset:newOffset] - - switch { - case size == 30: - size = 285 + uintFromBytes(0, sizeBytes) - case size > 30: - size = uintFromBytes(0, sizeBytes) + 65821 - } - return size, newOffset, nil -} - -func (d *decoder) decodeFromType( - dtype dataType, - size uint, - offset uint, - result reflect.Value, - depth int, -) (uint, error) { - result = indirect(result) - - // For these types, size has a special meaning - switch dtype { - case _Bool: - return unmarshalBool(size, offset, result) - case _Map: - return d.unmarshalMap(size, offset, result, depth) - case _Pointer: - return d.unmarshalPointer(size, offset, result, depth) - case _Slice: - return d.unmarshalSlice(size, offset, result, depth) - } - - // For the remaining types, size is the byte size - if offset+size > uint(len(d.buffer)) { - return 0, newOffsetError() - } - switch dtype { - case _Bytes: - return d.unmarshalBytes(size, offset, result) - case _Float32: - return d.unmarshalFloat32(size, offset, result) - case _Float64: - return d.unmarshalFloat64(size, offset, result) - case _Int32: - return d.unmarshalInt32(size, offset, result) - case _String: - return d.unmarshalString(size, offset, result) - case _Uint16: - return d.unmarshalUint(size, offset, result, 16) - case _Uint32: - return d.unmarshalUint(size, offset, result, 32) - case _Uint64: - return d.unmarshalUint(size, offset, result, 64) - case _Uint128: - return d.unmarshalUint128(size, offset, result) - default: - return 0, newInvalidDatabaseError("unknown type: %d", dtype) - } -} - -func (d *decoder) decodeFromTypeToDeserializer( - dtype dataType, - size uint, - offset uint, - dser deserializer, - depth int, -) (uint, error) { - // For these types, size has a special meaning - switch dtype { - case _Bool: - v, offset := decodeBool(size, offset) - return offset, dser.Bool(v) - case _Map: - return d.decodeMapToDeserializer(size, offset, dser, depth) - case _Pointer: - pointer, newOffset, err := d.decodePointer(size, offset) - if err != nil { - return 0, err - } - _, err = d.decodeToDeserializer(pointer, dser, depth, false) - return newOffset, err - case _Slice: - return d.decodeSliceToDeserializer(size, offset, dser, depth) - } - - // For the remaining types, size is the byte size - if offset+size > uint(len(d.buffer)) { - return 0, newOffsetError() - } - switch dtype { - case _Bytes: - v, offset := d.decodeBytes(size, offset) - return offset, dser.Bytes(v) - case _Float32: - v, offset := d.decodeFloat32(size, offset) - return offset, dser.Float32(v) - case _Float64: - v, offset := d.decodeFloat64(size, offset) - return offset, dser.Float64(v) - case _Int32: - v, offset := d.decodeInt(size, offset) - return offset, dser.Int32(int32(v)) - case _String: - v, offset := d.decodeString(size, offset) - return offset, dser.String(v) - case _Uint16: - v, offset := d.decodeUint(size, offset) - return offset, dser.Uint16(uint16(v)) - case _Uint32: - v, offset := d.decodeUint(size, offset) - return offset, dser.Uint32(uint32(v)) - case _Uint64: - v, offset := d.decodeUint(size, offset) - return offset, dser.Uint64(v) - case _Uint128: - v, offset := d.decodeUint128(size, offset) - return offset, dser.Uint128(v) - default: - return 0, newInvalidDatabaseError("unknown type: %d", dtype) - } -} - -func unmarshalBool(size, offset uint, result reflect.Value) (uint, error) { - if size > 1 { - return 0, newInvalidDatabaseError( - "the MaxMind DB file's data section contains bad data (bool size of %v)", - size, - ) - } - value, newOffset := decodeBool(size, offset) - - switch result.Kind() { - case reflect.Bool: - result.SetBool(value) - return newOffset, nil - case reflect.Interface: - if result.NumMethod() == 0 { - result.Set(reflect.ValueOf(value)) - return newOffset, nil - } - } - return newOffset, newUnmarshalTypeError(value, result.Type()) -} - -// indirect follows pointers and create values as necessary. This is -// heavily based on encoding/json as my original version had a subtle -// bug. This method should be considered to be licensed under -// https://golang.org/LICENSE -func indirect(result reflect.Value) reflect.Value { - for { - // Load value from interface, but only if the result will be - // usefully addressable. - if result.Kind() == reflect.Interface && !result.IsNil() { - e := result.Elem() - if e.Kind() == reflect.Ptr && !e.IsNil() { - result = e - continue - } - } - - if result.Kind() != reflect.Ptr { - break - } - - if result.IsNil() { - result.Set(reflect.New(result.Type().Elem())) - } - - result = result.Elem() - } - return result -} - -var sliceType = reflect.TypeOf([]byte{}) - -func (d *decoder) unmarshalBytes(size, offset uint, result reflect.Value) (uint, error) { - value, newOffset := d.decodeBytes(size, offset) - - switch result.Kind() { - case reflect.Slice: - if result.Type() == sliceType { - result.SetBytes(value) - return newOffset, nil - } - case reflect.Interface: - if result.NumMethod() == 0 { - result.Set(reflect.ValueOf(value)) - return newOffset, nil - } - } - return newOffset, newUnmarshalTypeError(value, result.Type()) -} - -func (d *decoder) unmarshalFloat32(size, offset uint, result reflect.Value) (uint, error) { - if size != 4 { - return 0, newInvalidDatabaseError( - "the MaxMind DB file's data section contains bad data (float32 size of %v)", - size, - ) - } - value, newOffset := d.decodeFloat32(size, offset) - - switch result.Kind() { - case reflect.Float32, reflect.Float64: - result.SetFloat(float64(value)) - return newOffset, nil - case reflect.Interface: - if result.NumMethod() == 0 { - result.Set(reflect.ValueOf(value)) - return newOffset, nil - } - } - return newOffset, newUnmarshalTypeError(value, result.Type()) -} - -func (d *decoder) unmarshalFloat64(size, offset uint, result reflect.Value) (uint, error) { - if size != 8 { - return 0, newInvalidDatabaseError( - "the MaxMind DB file's data section contains bad data (float 64 size of %v)", - size, - ) - } - value, newOffset := d.decodeFloat64(size, offset) - - switch result.Kind() { - case reflect.Float32, reflect.Float64: - if result.OverflowFloat(value) { - return 0, newUnmarshalTypeError(value, result.Type()) - } - result.SetFloat(value) - return newOffset, nil - case reflect.Interface: - if result.NumMethod() == 0 { - result.Set(reflect.ValueOf(value)) - return newOffset, nil - } - } - return newOffset, newUnmarshalTypeError(value, result.Type()) -} - -func (d *decoder) unmarshalInt32(size, offset uint, result reflect.Value) (uint, error) { - if size > 4 { - return 0, newInvalidDatabaseError( - "the MaxMind DB file's data section contains bad data (int32 size of %v)", - size, - ) - } - value, newOffset := d.decodeInt(size, offset) - - switch result.Kind() { - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - n := int64(value) - if !result.OverflowInt(n) { - result.SetInt(n) - return newOffset, nil - } - case reflect.Uint, - reflect.Uint8, - reflect.Uint16, - reflect.Uint32, - reflect.Uint64, - reflect.Uintptr: - n := uint64(value) - if !result.OverflowUint(n) { - result.SetUint(n) - return newOffset, nil - } - case reflect.Interface: - if result.NumMethod() == 0 { - result.Set(reflect.ValueOf(value)) - return newOffset, nil - } - } - return newOffset, newUnmarshalTypeError(value, result.Type()) -} - -func (d *decoder) unmarshalMap( - size uint, - offset uint, - result reflect.Value, - depth int, -) (uint, error) { - result = indirect(result) - switch result.Kind() { - default: - return 0, newUnmarshalTypeStrError("map", result.Type()) - case reflect.Struct: - return d.decodeStruct(size, offset, result, depth) - case reflect.Map: - return d.decodeMap(size, offset, result, depth) - case reflect.Interface: - if result.NumMethod() == 0 { - rv := reflect.ValueOf(make(map[string]any, size)) - newOffset, err := d.decodeMap(size, offset, rv, depth) - result.Set(rv) - return newOffset, err - } - return 0, newUnmarshalTypeStrError("map", result.Type()) - } -} - -func (d *decoder) unmarshalPointer( - size, offset uint, - result reflect.Value, - depth int, -) (uint, error) { - pointer, newOffset, err := d.decodePointer(size, offset) - if err != nil { - return 0, err - } - _, err = d.decode(pointer, result, depth) - return newOffset, err -} - -func (d *decoder) unmarshalSlice( - size uint, - offset uint, - result reflect.Value, - depth int, -) (uint, error) { - switch result.Kind() { - case reflect.Slice: - return d.decodeSlice(size, offset, result, depth) - case reflect.Interface: - if result.NumMethod() == 0 { - a := []any{} - rv := reflect.ValueOf(&a).Elem() - newOffset, err := d.decodeSlice(size, offset, rv, depth) - result.Set(rv) - return newOffset, err - } - } - return 0, newUnmarshalTypeStrError("array", result.Type()) -} - -func (d *decoder) unmarshalString(size, offset uint, result reflect.Value) (uint, error) { - value, newOffset := d.decodeString(size, offset) - - switch result.Kind() { - case reflect.String: - result.SetString(value) - return newOffset, nil - case reflect.Interface: - if result.NumMethod() == 0 { - result.Set(reflect.ValueOf(value)) - return newOffset, nil - } - } - return newOffset, newUnmarshalTypeError(value, result.Type()) -} - -func (d *decoder) unmarshalUint( - size, offset uint, - result reflect.Value, - uintType uint, -) (uint, error) { - if size > uintType/8 { - return 0, newInvalidDatabaseError( - "the MaxMind DB file's data section contains bad data (uint%v size of %v)", - uintType, - size, - ) - } - - value, newOffset := d.decodeUint(size, offset) - - switch result.Kind() { - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - n := int64(value) - if !result.OverflowInt(n) { - result.SetInt(n) - return newOffset, nil - } - case reflect.Uint, - reflect.Uint8, - reflect.Uint16, - reflect.Uint32, - reflect.Uint64, - reflect.Uintptr: - if !result.OverflowUint(value) { - result.SetUint(value) - return newOffset, nil - } - case reflect.Interface: - if result.NumMethod() == 0 { - result.Set(reflect.ValueOf(value)) - return newOffset, nil - } - } - return newOffset, newUnmarshalTypeError(value, result.Type()) -} - -var bigIntType = reflect.TypeOf(big.Int{}) - -func (d *decoder) unmarshalUint128(size, offset uint, result reflect.Value) (uint, error) { - if size > 16 { - return 0, newInvalidDatabaseError( - "the MaxMind DB file's data section contains bad data (uint128 size of %v)", - size, - ) - } - value, newOffset := d.decodeUint128(size, offset) - - switch result.Kind() { - case reflect.Struct: - if result.Type() == bigIntType { - result.Set(reflect.ValueOf(*value)) - return newOffset, nil - } - case reflect.Interface: - if result.NumMethod() == 0 { - result.Set(reflect.ValueOf(value)) - return newOffset, nil - } - } - return newOffset, newUnmarshalTypeError(value, result.Type()) -} - -func decodeBool(size, offset uint) (bool, uint) { - return size != 0, offset -} - -func (d *decoder) decodeBytes(size, offset uint) ([]byte, uint) { - newOffset := offset + size - bytes := make([]byte, size) - copy(bytes, d.buffer[offset:newOffset]) - return bytes, newOffset -} - -func (d *decoder) decodeFloat64(size, offset uint) (float64, uint) { - newOffset := offset + size - bits := binary.BigEndian.Uint64(d.buffer[offset:newOffset]) - return math.Float64frombits(bits), newOffset -} - -func (d *decoder) decodeFloat32(size, offset uint) (float32, uint) { - newOffset := offset + size - bits := binary.BigEndian.Uint32(d.buffer[offset:newOffset]) - return math.Float32frombits(bits), newOffset -} - -func (d *decoder) decodeInt(size, offset uint) (int, uint) { - newOffset := offset + size - var val int32 - for _, b := range d.buffer[offset:newOffset] { - val = (val << 8) | int32(b) - } - return int(val), newOffset -} - -func (d *decoder) decodeMap( - size uint, - offset uint, - result reflect.Value, - depth int, -) (uint, error) { - if result.IsNil() { - result.Set(reflect.MakeMapWithSize(result.Type(), int(size))) - } - - mapType := result.Type() - keyValue := reflect.New(mapType.Key()).Elem() - elemType := mapType.Elem() - var elemValue reflect.Value - for range size { - var key []byte - var err error - key, offset, err = d.decodeKey(offset) - if err != nil { - return 0, err - } - - if elemValue.IsValid() { - elemValue.SetZero() - } else { - elemValue = reflect.New(elemType).Elem() - } - - offset, err = d.decode(offset, elemValue, depth) - if err != nil { - return 0, fmt.Errorf("decoding value for %s: %w", key, err) - } - - keyValue.SetString(string(key)) - result.SetMapIndex(keyValue, elemValue) - } - return offset, nil -} - -func (d *decoder) decodeMapToDeserializer( - size uint, - offset uint, - dser deserializer, - depth int, -) (uint, error) { - err := dser.StartMap(size) - if err != nil { - return 0, err - } - for range size { - // TODO - implement key/value skipping? - offset, err = d.decodeToDeserializer(offset, dser, depth, true) - if err != nil { - return 0, err - } - - offset, err = d.decodeToDeserializer(offset, dser, depth, true) - if err != nil { - return 0, err - } - } - err = dser.End() - if err != nil { - return 0, err - } - return offset, nil -} - -func (d *decoder) decodePointer( - size uint, - offset uint, -) (uint, uint, error) { - pointerSize := ((size >> 3) & 0x3) + 1 - newOffset := offset + pointerSize - if newOffset > uint(len(d.buffer)) { - return 0, 0, newOffsetError() - } - pointerBytes := d.buffer[offset:newOffset] - var prefix uint - if pointerSize == 4 { - prefix = 0 - } else { - prefix = size & 0x7 - } - unpacked := uintFromBytes(prefix, pointerBytes) - - var pointerValueOffset uint - switch pointerSize { - case 1: - pointerValueOffset = 0 - case 2: - pointerValueOffset = 2048 - case 3: - pointerValueOffset = 526336 - case 4: - pointerValueOffset = 0 - } - - pointer := unpacked + pointerValueOffset - - return pointer, newOffset, nil -} - -func (d *decoder) decodeSlice( - size uint, - offset uint, - result reflect.Value, - depth int, -) (uint, error) { - result.Set(reflect.MakeSlice(result.Type(), int(size), int(size))) - for i := range size { - var err error - offset, err = d.decode(offset, result.Index(int(i)), depth) - if err != nil { - return 0, err - } - } - return offset, nil -} - -func (d *decoder) decodeSliceToDeserializer( - size uint, - offset uint, - dser deserializer, - depth int, -) (uint, error) { - err := dser.StartSlice(size) - if err != nil { - return 0, err - } - for range size { - offset, err = d.decodeToDeserializer(offset, dser, depth, true) - if err != nil { - return 0, err - } - } - err = dser.End() - if err != nil { - return 0, err - } - return offset, nil -} - -func (d *decoder) decodeString(size, offset uint) (string, uint) { - newOffset := offset + size - return string(d.buffer[offset:newOffset]), newOffset -} - -func (d *decoder) decodeStruct( - size uint, - offset uint, - result reflect.Value, - depth int, -) (uint, error) { - fields := cachedFields(result) - - // This fills in embedded structs - for _, i := range fields.anonymousFields { - _, err := d.unmarshalMap(size, offset, result.Field(i), depth) - if err != nil { - return 0, err - } - } - - // This handles named fields - for range size { - var ( - err error - key []byte - ) - key, offset, err = d.decodeKey(offset) - if err != nil { - return 0, err - } - // The string() does not create a copy due to this compiler - // optimization: https://github.com/golang/go/issues/3512 - j, ok := fields.namedFields[string(key)] - if !ok { - offset, err = d.nextValueOffset(offset, 1) - if err != nil { - return 0, err - } - continue - } - - offset, err = d.decode(offset, result.Field(j), depth) - if err != nil { - return 0, fmt.Errorf("decoding value for %s: %w", key, err) - } - } - return offset, nil -} - -type fieldsType struct { - namedFields map[string]int - anonymousFields []int -} - -var fieldsMap sync.Map - -func cachedFields(result reflect.Value) *fieldsType { - resultType := result.Type() - - if fields, ok := fieldsMap.Load(resultType); ok { - return fields.(*fieldsType) - } - numFields := resultType.NumField() - namedFields := make(map[string]int, numFields) - var anonymous []int - for i := range numFields { - field := resultType.Field(i) - - fieldName := field.Name - if tag := field.Tag.Get("maxminddb"); tag != "" { - if tag == "-" { - continue - } - fieldName = tag - } - if field.Anonymous { - anonymous = append(anonymous, i) - continue - } - namedFields[fieldName] = i - } - fields := &fieldsType{namedFields, anonymous} - fieldsMap.Store(resultType, fields) - - return fields -} - -func (d *decoder) decodeUint(size, offset uint) (uint64, uint) { - newOffset := offset + size - bytes := d.buffer[offset:newOffset] - - var val uint64 - for _, b := range bytes { - val = (val << 8) | uint64(b) - } - return val, newOffset -} - -func (d *decoder) decodeUint128(size, offset uint) (*big.Int, uint) { - newOffset := offset + size - val := new(big.Int) - val.SetBytes(d.buffer[offset:newOffset]) - - return val, newOffset -} - -func uintFromBytes(prefix uint, uintBytes []byte) uint { - val := prefix - for _, b := range uintBytes { - val = (val << 8) | uint(b) - } - return val -} - -// decodeKey decodes a map key into []byte slice. We use a []byte so that we -// can take advantage of https://github.com/golang/go/issues/3512 to avoid -// copying the bytes when decoding a struct. Previously, we achieved this by -// using unsafe. -func (d *decoder) decodeKey(offset uint) ([]byte, uint, error) { - typeNum, size, dataOffset, err := d.decodeCtrlData(offset) - if err != nil { - return nil, 0, err - } - if typeNum == _Pointer { - pointer, ptrOffset, err := d.decodePointer(size, dataOffset) - if err != nil { - return nil, 0, err - } - key, _, err := d.decodeKey(pointer) - return key, ptrOffset, err - } - if typeNum != _String { - return nil, 0, newInvalidDatabaseError("unexpected type when decoding string: %v", typeNum) - } - newOffset := dataOffset + size - if newOffset > uint(len(d.buffer)) { - return nil, 0, newOffsetError() - } - return d.buffer[dataOffset:newOffset], newOffset, nil -} - -// This function is used to skip ahead to the next value without decoding -// the one at the offset passed in. The size bits have different meanings for -// different data types. -func (d *decoder) nextValueOffset(offset, numberToSkip uint) (uint, error) { - if numberToSkip == 0 { - return offset, nil - } - typeNum, size, offset, err := d.decodeCtrlData(offset) - if err != nil { - return 0, err - } - switch typeNum { - case _Pointer: - _, offset, err = d.decodePointer(size, offset) - if err != nil { - return 0, err - } - case _Map: - numberToSkip += 2 * size - case _Slice: - numberToSkip += size - case _Bool: - default: - offset += size - } - return d.nextValueOffset(offset, numberToSkip-1) -} diff --git a/deserializer.go b/deserializer.go deleted file mode 100644 index c6dd68d..0000000 --- a/deserializer.go +++ /dev/null @@ -1,31 +0,0 @@ -package maxminddb - -import "math/big" - -// deserializer is an interface for a type that deserializes an MaxMind DB -// data record to some other type. This exists as an alternative to the -// standard reflection API. -// -// This is fundamentally different than the Unmarshaler interface that -// several packages provide. A Deserializer will generally create the -// final struct or value rather than unmarshaling to itself. -// -// This interface and the associated unmarshaling code is EXPERIMENTAL! -// It is not currently covered by any Semantic Versioning guarantees. -// Use at your own risk. -type deserializer interface { - ShouldSkip(offset uintptr) (bool, error) - StartSlice(size uint) error - StartMap(size uint) error - End() error - String(string) error - Float64(float64) error - Bytes([]byte) error - Uint16(uint16) error - Uint32(uint32) error - Int32(int32) error - Uint64(uint64) error - Uint128(*big.Int) error - Bool(bool) error - Float32(float32) error -} diff --git a/deserializer_test.go b/deserializer_test.go deleted file mode 100644 index a6e3b70..0000000 --- a/deserializer_test.go +++ /dev/null @@ -1,120 +0,0 @@ -package maxminddb - -import ( - "math/big" - "net/netip" - "testing" - - "github.com/stretchr/testify/require" -) - -func TestDecodingToDeserializer(t *testing.T) { - reader, err := Open(testFile("MaxMind-DB-test-decoder.mmdb")) - require.NoError(t, err, "unexpected error while opening database: %v", err) - - dser := testDeserializer{} - err = reader.Lookup(netip.MustParseAddr("::1.1.1.0")).Decode(&dser) - require.NoError(t, err, "unexpected error while doing lookup: %v", err) - - checkDecodingToInterface(t, dser.rv) -} - -type stackValue struct { - value any - curNum int -} - -type testDeserializer struct { - stack []*stackValue - rv any - key *string -} - -func (*testDeserializer) ShouldSkip(_ uintptr) (bool, error) { - return false, nil -} - -func (d *testDeserializer) StartSlice(size uint) error { - return d.add(make([]any, size)) -} - -func (d *testDeserializer) StartMap(_ uint) error { - return d.add(map[string]any{}) -} - -//nolint:unparam // This is to meet the requirements of the interface. -func (d *testDeserializer) End() error { - d.stack = d.stack[:len(d.stack)-1] - return nil -} - -func (d *testDeserializer) String(v string) error { - return d.add(v) -} - -func (d *testDeserializer) Float64(v float64) error { - return d.add(v) -} - -func (d *testDeserializer) Bytes(v []byte) error { - return d.add(v) -} - -func (d *testDeserializer) Uint16(v uint16) error { - return d.add(uint64(v)) -} - -func (d *testDeserializer) Uint32(v uint32) error { - return d.add(uint64(v)) -} - -func (d *testDeserializer) Int32(v int32) error { - return d.add(int(v)) -} - -func (d *testDeserializer) Uint64(v uint64) error { - return d.add(v) -} - -func (d *testDeserializer) Uint128(v *big.Int) error { - return d.add(v) -} - -func (d *testDeserializer) Bool(v bool) error { - return d.add(v) -} - -func (d *testDeserializer) Float32(v float32) error { - return d.add(v) -} - -func (d *testDeserializer) add(v any) error { - if len(d.stack) == 0 { - d.rv = v - } else { - top := d.stack[len(d.stack)-1] - switch parent := top.value.(type) { - case map[string]any: - if d.key == nil { - key := v.(string) - d.key = &key - } else { - parent[*d.key] = v - d.key = nil - } - - case []any: - parent[top.curNum] = v - top.curNum++ - default: - } - } - - switch v := v.(type) { - case map[string]any, []any: - d.stack = append(d.stack, &stackValue{value: v}) - default: - } - - return nil -} diff --git a/errors.go b/errors.go index f141f61..bffa4ca 100644 --- a/errors.go +++ b/errors.go @@ -1,46 +1,13 @@ package maxminddb -import ( - "fmt" - "reflect" -) - -// InvalidDatabaseError is returned when the database contains invalid data -// and cannot be parsed. -type InvalidDatabaseError struct { - message string -} - -func newOffsetError() InvalidDatabaseError { - return InvalidDatabaseError{"unexpected end of database"} -} - -func newInvalidDatabaseError(format string, args ...any) InvalidDatabaseError { - return InvalidDatabaseError{fmt.Sprintf(format, args...)} -} +import "github.com/oschwald/maxminddb-golang/v2/internal/mmdberrors" -func (e InvalidDatabaseError) Error() string { - return e.message -} +type ( + // InvalidDatabaseError is returned when the database contains invalid data + // and cannot be parsed. + InvalidDatabaseError = mmdberrors.InvalidDatabaseError -// UnmarshalTypeError is returned when the value in the database cannot be -// assigned to the specified data type. -type UnmarshalTypeError struct { - Type reflect.Type - Value string -} - -func newUnmarshalTypeStrError(value string, rType reflect.Type) UnmarshalTypeError { - return UnmarshalTypeError{ - Type: rType, - Value: value, - } -} - -func newUnmarshalTypeError(value any, rType reflect.Type) UnmarshalTypeError { - return newUnmarshalTypeStrError(fmt.Sprintf("%v (%T)", value, value), rType) -} - -func (e UnmarshalTypeError) Error() string { - return fmt.Sprintf("maxminddb: cannot unmarshal %s into type %s", e.Value, e.Type) -} + // UnmarshalTypeError is returned when the value in the database cannot be + // assigned to the specified data type. + UnmarshalTypeError = mmdberrors.UnmarshalTypeError +) diff --git a/example_test.go b/example_test.go index 0f15c1e..9e6413e 100644 --- a/example_test.go +++ b/example_test.go @@ -6,6 +6,7 @@ import ( "net/netip" "github.com/oschwald/maxminddb-golang/v2" + "github.com/oschwald/maxminddb-golang/v2/mmdbdata" ) // This example shows how to decode to a struct. @@ -14,7 +15,7 @@ func ExampleReader_Lookup_struct() { if err != nil { log.Fatal(err) } - defer db.Close() + defer db.Close() //nolint:errcheck // error doesn't matter addr := netip.MustParseAddr("81.2.69.142") @@ -39,7 +40,7 @@ func ExampleReader_Lookup_interface() { if err != nil { log.Fatal(err) } - defer db.Close() + defer db.Close() //nolint:errcheck // error doesn't matter addr := netip.MustParseAddr("81.2.69.142") @@ -61,18 +62,18 @@ func ExampleReader_Networks() { if err != nil { log.Fatal(err) } - defer db.Close() + defer db.Close() //nolint:errcheck // error doesn't matter for result := range db.Networks() { record := struct { - Domain string `maxminddb:"connection_type"` + ConnectionType string `maxminddb:"connection_type"` }{} err := result.Decode(&record) if err != nil { log.Panic(err) } - fmt.Printf("%s: %s\n", result.Prefix(), record.Domain) + fmt.Printf("%s: %s\n", result.Prefix(), record.ConnectionType) } // Output: // 1.0.0.0/24: Cable/DSL @@ -101,6 +102,39 @@ func ExampleReader_Networks() { // 2003::/24: Cable/DSL } +// This example demonstrates how to validate a MaxMind DB file and access metadata. +func ExampleReader_Verify() { + db, err := maxminddb.Open("test-data/test-data/GeoIP2-City-Test.mmdb") + if err != nil { + log.Fatal(err) + } + defer db.Close() //nolint:errcheck // error doesn't matter + + // Verify database integrity + if err := db.Verify(); err != nil { + log.Printf("Database validation failed: %v", err) + return + } + + // Access metadata information + metadata := db.Metadata + fmt.Printf("Database type: %s\n", metadata.DatabaseType) + fmt.Printf("Build time: %s\n", metadata.BuildTime().UTC().Format("2006-01-02 15:04:05")) + fmt.Printf("IP version: IPv%d\n", metadata.IPVersion) + fmt.Printf("Languages: %v\n", metadata.Languages) + + if desc, ok := metadata.Description["en"]; ok { + fmt.Printf("Description: %s\n", desc) + } + + // Output: + // Database type: GeoIP2-City + // Build time: 2022-07-26 14:53:10 + // IP version: IPv6 + // Languages: [en zh] + // Description: GeoIP2 City Test Database (fake GeoIP2 data, for example purposes only) +} + // This example demonstrates how to iterate over all networks in the // database which are contained within an arbitrary network. func ExampleReader_NetworksWithin() { @@ -108,7 +142,7 @@ func ExampleReader_NetworksWithin() { if err != nil { log.Fatal(err) } - defer db.Close() + defer db.Close() //nolint:errcheck // error doesn't matter prefix, err := netip.ParsePrefix("1.0.0.0/8") if err != nil { @@ -117,13 +151,13 @@ func ExampleReader_NetworksWithin() { for result := range db.NetworksWithin(prefix) { record := struct { - Domain string `maxminddb:"connection_type"` + ConnectionType string `maxminddb:"connection_type"` }{} err := result.Decode(&record) if err != nil { log.Panic(err) } - fmt.Printf("%s: %s\n", result.Prefix(), record.Domain) + fmt.Printf("%s: %s\n", result.Prefix(), record.ConnectionType) } // Output: @@ -137,3 +171,105 @@ func ExampleReader_NetworksWithin() { // 1.0.64.0/18: Cable/DSL // 1.0.128.0/17: Cable/DSL } + +// CustomCity represents a simplified city record with custom unmarshaling. +// This demonstrates the Unmarshaler interface for custom decoding. +type CustomCity struct { + Names map[string]string + GeoNameID uint +} + +// UnmarshalMaxMindDB implements the mmdbdata.Unmarshaler interface. +// This provides custom decoding logic, similar to how json.Unmarshaler works +// with encoding/json, allowing fine-grained control over data processing. +func (c *CustomCity) UnmarshalMaxMindDB(d *mmdbdata.Decoder) error { + mapIter, _, err := d.ReadMap() + if err != nil { + return err + } + for key, err := range mapIter { + if err != nil { + return err + } + + switch string(key) { + case "city": + // Decode nested city structure + cityMapIter, _, err := d.ReadMap() + if err != nil { + return err + } + for cityKey, cityErr := range cityMapIter { + if cityErr != nil { + return cityErr + } + switch string(cityKey) { + case "names": + // Decode nested map[string]string for localized names + names := make(map[string]string) + nameMapIter, _, err := d.ReadMap() + if err != nil { + return err + } + for nameKey, nameErr := range nameMapIter { + if nameErr != nil { + return nameErr + } + value, valueErr := d.ReadString() + if valueErr != nil { + return valueErr + } + names[string(nameKey)] = value + } + c.Names = names + case "geoname_id": + geoID, err := d.ReadUInt32() + if err != nil { + return err + } + c.GeoNameID = uint(geoID) + default: + if err := d.SkipValue(); err != nil { + return err + } + } + } + default: + // Skip unknown fields to ensure forward compatibility + if err := d.SkipValue(); err != nil { + return err + } + } + } + return nil +} + +// This example demonstrates how to use the Unmarshaler interface for custom decoding. +// Types implementing Unmarshaler automatically use custom decoding logic instead of +// reflection, similar to how json.Unmarshaler works with encoding/json. +func ExampleUnmarshaler() { + db, err := maxminddb.Open("test-data/test-data/GeoIP2-City-Test.mmdb") + if err != nil { + log.Fatal(err) + } + defer db.Close() //nolint:errcheck // error doesn't matter + + addr := netip.MustParseAddr("81.2.69.142") + + // CustomCity implements Unmarshaler, so it will automatically use + // the custom UnmarshalMaxMindDB method instead of reflection + var city CustomCity + err = db.Lookup(addr).Decode(&city) + if err != nil { + log.Panic(err) + } + + fmt.Printf("City ID: %d\n", city.GeoNameID) + fmt.Printf("English name: %s\n", city.Names["en"]) + fmt.Printf("German name: %s\n", city.Names["de"]) + + // Output: + // City ID: 2643743 + // English name: London + // German name: London +} diff --git a/internal/decoder/data_decoder.go b/internal/decoder/data_decoder.go new file mode 100644 index 0000000..bee68cb --- /dev/null +++ b/internal/decoder/data_decoder.go @@ -0,0 +1,477 @@ +// Package decoder decodes values in the data section. +package decoder + +import ( + "encoding/binary" + "fmt" + "math" + + "github.com/oschwald/maxminddb-golang/v2/internal/mmdberrors" +) + +// Kind constants for the different MMDB data kinds. +type Kind int + +// MMDB data kind constants. +const ( + // KindExtended indicates an extended kind. + KindExtended Kind = iota + // KindPointer is a pointer to another location in the data section. + KindPointer + // KindString is a UTF-8 string. + KindString + // KindFloat64 is a 64-bit floating point number. + KindFloat64 + // KindBytes is a byte slice. + KindBytes + // KindUint16 is a 16-bit unsigned integer. + KindUint16 + // KindUint32 is a 32-bit unsigned integer. + KindUint32 + // KindMap is a map from strings to other data types. + KindMap + // KindInt32 is a 32-bit signed integer. + KindInt32 + // KindUint64 is a 64-bit unsigned integer. + KindUint64 + // KindUint128 is a 128-bit unsigned integer. + KindUint128 + // KindSlice is an array of values. + KindSlice + // KindContainer is a data cache container. + KindContainer + // KindEndMarker marks the end of the data section. + KindEndMarker + // KindBool is a boolean value. + KindBool + // KindFloat32 is a 32-bit floating point number. + KindFloat32 +) + +// String returns a human-readable name for the Kind. +func (k Kind) String() string { + switch k { + case KindExtended: + return "Extended" + case KindPointer: + return "Pointer" + case KindString: + return "String" + case KindFloat64: + return "Float64" + case KindBytes: + return "Bytes" + case KindUint16: + return "Uint16" + case KindUint32: + return "Uint32" + case KindMap: + return "Map" + case KindInt32: + return "Int32" + case KindUint64: + return "Uint64" + case KindUint128: + return "Uint128" + case KindSlice: + return "Slice" + case KindContainer: + return "Container" + case KindEndMarker: + return "EndMarker" + case KindBool: + return "Bool" + case KindFloat32: + return "Float32" + default: + return fmt.Sprintf("Unknown(%d)", int(k)) + } +} + +// IsContainer returns true if the Kind represents a container type (Map or Slice). +func (k Kind) IsContainer() bool { + return k == KindMap || k == KindSlice +} + +// IsScalar returns true if the Kind represents a scalar value type. +func (k Kind) IsScalar() bool { + switch k { + case KindString, KindFloat64, KindBytes, KindUint16, KindUint32, + KindInt32, KindUint64, KindUint128, KindBool, KindFloat32: + return true + default: + return false + } +} + +// DataDecoder is a decoder for the MMDB data section. +// This is exported so mmdbdata package can use it, but still internal. +type DataDecoder struct { + stringCache *stringCache + buffer []byte +} + +const ( + // This is the value used in libmaxminddb. + maximumDataStructureDepth = 512 +) + +// NewDataDecoder creates a [DataDecoder]. +func NewDataDecoder(buffer []byte) DataDecoder { + return DataDecoder{ + buffer: buffer, + stringCache: newStringCache(), + } +} + +// getBuffer returns the underlying buffer for direct access. +func (d *DataDecoder) getBuffer() []byte { + return d.buffer +} + +// decodeCtrlData decodes the control byte and data info at the given offset. +func (d *DataDecoder) decodeCtrlData(offset uint) (Kind, uint, uint, error) { + newOffset := offset + 1 + if offset >= uint(len(d.buffer)) { + return 0, 0, 0, mmdberrors.NewOffsetError() + } + ctrlByte := d.buffer[offset] + + kindNum := Kind(ctrlByte >> 5) + if kindNum == KindExtended { + if newOffset >= uint(len(d.buffer)) { + return 0, 0, 0, mmdberrors.NewOffsetError() + } + kindNum = Kind(d.buffer[newOffset] + 7) + newOffset++ + } + + var size uint + size, newOffset, err := d.sizeFromCtrlByte(ctrlByte, newOffset, kindNum) + return kindNum, size, newOffset, err +} + +// decodeBytes decodes a byte slice from the given offset with the given size. +func (d *DataDecoder) decodeBytes(size, offset uint) ([]byte, uint, error) { + if offset+size > uint(len(d.buffer)) { + return nil, 0, mmdberrors.NewOffsetError() + } + + newOffset := offset + size + bytes := make([]byte, size) + copy(bytes, d.buffer[offset:newOffset]) + return bytes, newOffset, nil +} + +// DecodeFloat64 decodes a 64-bit float from the given offset. +func (d *DataDecoder) decodeFloat64(size, offset uint) (float64, uint, error) { + if size != 8 { + return 0, 0, mmdberrors.NewInvalidDatabaseError( + "the MaxMind DB file's data section contains bad data (float 64 size of %v)", + size, + ) + } + if offset+size > uint(len(d.buffer)) { + return 0, 0, mmdberrors.NewOffsetError() + } + + newOffset := offset + size + bits := binary.BigEndian.Uint64(d.buffer[offset:newOffset]) + return math.Float64frombits(bits), newOffset, nil +} + +// DecodeFloat32 decodes a 32-bit float from the given offset. +func (d *DataDecoder) decodeFloat32(size, offset uint) (float32, uint, error) { + if size != 4 { + return 0, 0, mmdberrors.NewInvalidDatabaseError( + "the MaxMind DB file's data section contains bad data (float32 size of %v)", + size, + ) + } + if offset+size > uint(len(d.buffer)) { + return 0, 0, mmdberrors.NewOffsetError() + } + + newOffset := offset + size + bits := binary.BigEndian.Uint32(d.buffer[offset:newOffset]) + return math.Float32frombits(bits), newOffset, nil +} + +// DecodeInt32 decodes a 32-bit signed integer from the given offset. +func (d *DataDecoder) decodeInt32(size, offset uint) (int32, uint, error) { + if size > 4 { + return 0, 0, mmdberrors.NewInvalidDatabaseError( + "the MaxMind DB file's data section contains bad data (int32 size of %v)", + size, + ) + } + if offset+size > uint(len(d.buffer)) { + return 0, 0, mmdberrors.NewOffsetError() + } + + newOffset := offset + size + var val int32 + for _, b := range d.buffer[offset:newOffset] { + val = (val << 8) | int32(b) + } + return val, newOffset, nil +} + +// DecodePointer decodes a pointer from the given offset. +func (d *DataDecoder) decodePointer( + size uint, + offset uint, +) (uint, uint, error) { + pointerSize := ((size >> 3) & 0x3) + 1 + newOffset := offset + pointerSize + if newOffset > uint(len(d.buffer)) { + return 0, 0, mmdberrors.NewOffsetError() + } + pointerBytes := d.buffer[offset:newOffset] + var prefix uint + if pointerSize == 4 { + prefix = 0 + } else { + prefix = size & 0x7 + } + unpacked := uintFromBytes(prefix, pointerBytes) + + var pointerValueOffset uint + switch pointerSize { + case 1: + pointerValueOffset = 0 + case 2: + pointerValueOffset = 2048 + case 3: + pointerValueOffset = 526336 + case 4: + pointerValueOffset = 0 + } + + pointer := unpacked + pointerValueOffset + + return pointer, newOffset, nil +} + +// DecodeBool decodes a boolean from the given offset. +func (*DataDecoder) decodeBool(size, offset uint) (bool, uint, error) { + if size > 1 { + return false, 0, mmdberrors.NewInvalidDatabaseError( + "the MaxMind DB file's data section contains bad data (bool size of %v)", + size, + ) + } + value, newOffset := decodeBool(size, offset) + return value, newOffset, nil +} + +// DecodeString decodes a string from the given offset. +func (d *DataDecoder) decodeString(size, offset uint) (string, uint, error) { + if offset+size > uint(len(d.buffer)) { + return "", 0, mmdberrors.NewOffsetError() + } + + newOffset := offset + size + value := d.stringCache.internAt(offset, size, d.buffer) + return value, newOffset, nil +} + +// DecodeUint16 decodes a 16-bit unsigned integer from the given offset. +func (d *DataDecoder) decodeUint16(size, offset uint) (uint16, uint, error) { + if size > 2 { + return 0, 0, mmdberrors.NewInvalidDatabaseError( + "the MaxMind DB file's data section contains bad data (uint16 size of %v)", + size, + ) + } + if offset+size > uint(len(d.buffer)) { + return 0, 0, mmdberrors.NewOffsetError() + } + + newOffset := offset + size + bytes := d.buffer[offset:newOffset] + + var val uint16 + for _, b := range bytes { + val = (val << 8) | uint16(b) + } + return val, newOffset, nil +} + +// DecodeUint32 decodes a 32-bit unsigned integer from the given offset. +func (d *DataDecoder) decodeUint32(size, offset uint) (uint32, uint, error) { + if size > 4 { + return 0, 0, mmdberrors.NewInvalidDatabaseError( + "the MaxMind DB file's data section contains bad data (uint32 size of %v)", + size, + ) + } + if offset+size > uint(len(d.buffer)) { + return 0, 0, mmdberrors.NewOffsetError() + } + + newOffset := offset + size + bytes := d.buffer[offset:newOffset] + + var val uint32 + for _, b := range bytes { + val = (val << 8) | uint32(b) + } + return val, newOffset, nil +} + +// DecodeUint64 decodes a 64-bit unsigned integer from the given offset. +func (d *DataDecoder) decodeUint64(size, offset uint) (uint64, uint, error) { + if size > 8 { + return 0, 0, mmdberrors.NewInvalidDatabaseError( + "the MaxMind DB file's data section contains bad data (uint64 size of %v)", + size, + ) + } + if offset+size > uint(len(d.buffer)) { + return 0, 0, mmdberrors.NewOffsetError() + } + + newOffset := offset + size + bytes := d.buffer[offset:newOffset] + + var val uint64 + for _, b := range bytes { + val = (val << 8) | uint64(b) + } + return val, newOffset, nil +} + +// DecodeUint128 decodes a 128-bit unsigned integer from the given offset. +// Returns the value as high and low 64-bit unsigned integers. +func (d *DataDecoder) decodeUint128(size, offset uint) (hi, lo uint64, newOffset uint, err error) { + if size > 16 { + return 0, 0, 0, mmdberrors.NewInvalidDatabaseError( + "the MaxMind DB file's data section contains bad data (uint128 size of %v)", + size, + ) + } + if offset+size > uint(len(d.buffer)) { + return 0, 0, 0, mmdberrors.NewOffsetError() + } + + newOffset = offset + size + + // Process bytes from most significant to least significant + for _, b := range d.buffer[offset:newOffset] { + var carry byte + lo, carry = append64(lo, b) + hi, _ = append64(hi, carry) + } + + return hi, lo, newOffset, nil +} + +func append64(val uint64, b byte) (uint64, byte) { + return (val << 8) | uint64(b), byte(val >> 56) +} + +// DecodeKey decodes a map key into []byte slice. We use a []byte so that we +// can take advantage of https://github.com/golang/go/issues/3512 to avoid +// copying the bytes when decoding a struct. Previously, we achieved this by +// using unsafe. +func (d *DataDecoder) decodeKey(offset uint) ([]byte, uint, error) { + kindNum, size, dataOffset, err := d.decodeCtrlData(offset) + if err != nil { + return nil, 0, err + } + if kindNum == KindPointer { + pointer, ptrOffset, err := d.decodePointer(size, dataOffset) + if err != nil { + return nil, 0, err + } + key, _, err := d.decodeKey(pointer) + return key, ptrOffset, err + } + if kindNum != KindString { + return nil, 0, mmdberrors.NewInvalidDatabaseError( + "unexpected type when decoding string: %v", + kindNum, + ) + } + newOffset := dataOffset + size + if newOffset > uint(len(d.buffer)) { + return nil, 0, mmdberrors.NewOffsetError() + } + return d.buffer[dataOffset:newOffset], newOffset, nil +} + +// NextValueOffset skips ahead to the next value without decoding +// the one at the offset passed in. The size bits have different meanings for +// different data types. +func (d *DataDecoder) nextValueOffset(offset, numberToSkip uint) (uint, error) { + if numberToSkip == 0 { + return offset, nil + } + kindNum, size, offset, err := d.decodeCtrlData(offset) + if err != nil { + return 0, err + } + switch kindNum { + case KindPointer: + _, offset, err = d.decodePointer(size, offset) + if err != nil { + return 0, err + } + case KindMap: + numberToSkip += 2 * size + case KindSlice: + numberToSkip += size + case KindBool: + default: + offset += size + } + return d.nextValueOffset(offset, numberToSkip-1) +} + +func (d *DataDecoder) sizeFromCtrlByte( + ctrlByte byte, + offset uint, + kindNum Kind, +) (uint, uint, error) { + size := uint(ctrlByte & 0x1f) + if kindNum == KindExtended { + return size, offset, nil + } + + var bytesToRead uint + if size < 29 { + return size, offset, nil + } + + bytesToRead = size - 28 + newOffset := offset + bytesToRead + if newOffset > uint(len(d.buffer)) { + return 0, 0, mmdberrors.NewOffsetError() + } + if size == 29 { + return 29 + uint(d.buffer[offset]), offset + 1, nil + } + + sizeBytes := d.buffer[offset:newOffset] + + switch { + case size == 30: + size = 285 + uintFromBytes(0, sizeBytes) + case size > 30: + size = uintFromBytes(0, sizeBytes) + 65821 + } + return size, newOffset, nil +} + +func decodeBool(size, offset uint) (bool, uint) { + return size != 0, offset +} + +func uintFromBytes(prefix uint, uintBytes []byte) uint { + val := prefix + for _, b := range uintBytes { + val = (val << 8) | uint(b) + } + return val +} diff --git a/internal/decoder/decoder.go b/internal/decoder/decoder.go new file mode 100644 index 0000000..123fb5d --- /dev/null +++ b/internal/decoder/decoder.go @@ -0,0 +1,422 @@ +package decoder + +import ( + "errors" + "fmt" + "iter" + + "github.com/oschwald/maxminddb-golang/v2/internal/mmdberrors" +) + +// Decoder allows decoding of a single value stored at a specific offset +// in the database. +type Decoder struct { + d DataDecoder + offset uint + nextOffset uint + opts decoderOptions + hasNextOffset bool +} + +type decoderOptions struct { + // Reserved for future options +} + +// DecoderOption configures a Decoder. +// +//nolint:revive // name follows existing library pattern (ReaderOption, NetworksOption) +type DecoderOption func(*decoderOptions) + +// NewDecoder creates a new Decoder with the given DataDecoder, offset, and options. +func NewDecoder(d DataDecoder, offset uint, options ...DecoderOption) *Decoder { + opts := decoderOptions{} + for _, option := range options { + option(&opts) + } + + decoder := &Decoder{ + d: d, + offset: offset, + opts: opts, + } + + return decoder +} + +// ReadBool reads the value pointed by the decoder as a bool. +// +// Returns an error if the database is malformed or if the pointed value is not a bool. +func (d *Decoder) ReadBool() (bool, error) { + size, offset, err := d.decodeCtrlDataAndFollow(KindBool) + if err != nil { + return false, d.wrapError(err) + } + + value, newOffset, err := d.d.decodeBool(size, offset) + if err != nil { + return false, d.wrapError(err) + } + d.setNextOffset(newOffset) + return value, nil +} + +// ReadString reads the value pointed by the decoder as a string. +// +// Returns an error if the database is malformed or if the pointed value is not a string. +func (d *Decoder) ReadString() (string, error) { + size, offset, err := d.decodeCtrlDataAndFollow(KindString) + if err != nil { + return "", d.wrapError(err) + } + + value, newOffset, err := d.d.decodeString(size, offset) + if err != nil { + return "", d.wrapError(err) + } + d.setNextOffset(newOffset) + return value, nil +} + +// ReadBytes reads the value pointed by the decoder as bytes. +// +// Returns an error if the database is malformed or if the pointed value is not bytes. +func (d *Decoder) ReadBytes() ([]byte, error) { + val, err := d.readBytes(KindBytes) + if err != nil { + return nil, d.wrapError(err) + } + return val, nil +} + +// ReadFloat32 reads the value pointed by the decoder as a float32. +// +// Returns an error if the database is malformed or if the pointed value is not a float. +func (d *Decoder) ReadFloat32() (float32, error) { + size, offset, err := d.decodeCtrlDataAndFollow(KindFloat32) + if err != nil { + return 0, d.wrapError(err) + } + + value, nextOffset, err := d.d.decodeFloat32(size, offset) + if err != nil { + return 0, d.wrapError(err) + } + + d.setNextOffset(nextOffset) + return value, nil +} + +// ReadFloat64 reads the value pointed by the decoder as a float64. +// +// Returns an error if the database is malformed or if the pointed value is not a double. +func (d *Decoder) ReadFloat64() (float64, error) { + size, offset, err := d.decodeCtrlDataAndFollow(KindFloat64) + if err != nil { + return 0, d.wrapError(err) + } + + value, nextOffset, err := d.d.decodeFloat64(size, offset) + if err != nil { + return 0, d.wrapError(err) + } + + d.setNextOffset(nextOffset) + return value, nil +} + +// ReadInt32 reads the value pointed by the decoder as a int32. +// +// Returns an error if the database is malformed or if the pointed value is not an int32. +func (d *Decoder) ReadInt32() (int32, error) { + size, offset, err := d.decodeCtrlDataAndFollow(KindInt32) + if err != nil { + return 0, d.wrapError(err) + } + + value, nextOffset, err := d.d.decodeInt32(size, offset) + if err != nil { + return 0, d.wrapError(err) + } + + d.setNextOffset(nextOffset) + return value, nil +} + +// ReadUInt16 reads the value pointed by the decoder as a uint16. +// +// Returns an error if the database is malformed or if the pointed value is not an uint16. +func (d *Decoder) ReadUInt16() (uint16, error) { + size, offset, err := d.decodeCtrlDataAndFollow(KindUint16) + if err != nil { + return 0, d.wrapError(err) + } + + value, nextOffset, err := d.d.decodeUint16(size, offset) + if err != nil { + return 0, d.wrapError(err) + } + + d.setNextOffset(nextOffset) + return value, nil +} + +// ReadUInt32 reads the value pointed by the decoder as a uint32. +// +// Returns an error if the database is malformed or if the pointed value is not an uint32. +func (d *Decoder) ReadUInt32() (uint32, error) { + size, offset, err := d.decodeCtrlDataAndFollow(KindUint32) + if err != nil { + return 0, d.wrapError(err) + } + + value, nextOffset, err := d.d.decodeUint32(size, offset) + if err != nil { + return 0, d.wrapError(err) + } + + d.setNextOffset(nextOffset) + return value, nil +} + +// ReadUInt64 reads the value pointed by the decoder as a uint64. +// +// Returns an error if the database is malformed or if the pointed value is not an uint64. +func (d *Decoder) ReadUInt64() (uint64, error) { + size, offset, err := d.decodeCtrlDataAndFollow(KindUint64) + if err != nil { + return 0, d.wrapError(err) + } + + value, nextOffset, err := d.d.decodeUint64(size, offset) + if err != nil { + return 0, d.wrapError(err) + } + + d.setNextOffset(nextOffset) + return value, nil +} + +// ReadUInt128 reads the value pointed by the decoder as a uint128. +// +// Returns an error if the database is malformed or if the pointed value is not an uint128. +func (d *Decoder) ReadUInt128() (hi, lo uint64, err error) { + size, offset, err := d.decodeCtrlDataAndFollow(KindUint128) + if err != nil { + return 0, 0, d.wrapError(err) + } + + hi, lo, nextOffset, err := d.d.decodeUint128(size, offset) + if err != nil { + return 0, 0, d.wrapError(err) + } + + d.setNextOffset(nextOffset) + return hi, lo, nil +} + +// ReadMap returns an iterator to read the map along with the map size. The +// size can be used to pre-allocate a map with the correct capacity for better +// performance. The first value from the iterator is the key. Please note that +// this byte slice is only valid during the iteration. This is done to avoid +// an unnecessary allocation. You must make a copy of it if you are storing it +// for later use. The second value is an error indicating that the database is +// malformed or that the pointed value is not a map. +func (d *Decoder) ReadMap() (iter.Seq2[[]byte, error], uint, error) { + size, offset, err := d.decodeCtrlDataAndFollow(KindMap) + if err != nil { + return nil, 0, d.wrapError(err) + } + + iterator := func(yield func([]byte, error) bool) { + currentOffset := offset + + for range size { + key, keyEndOffset, err := d.d.decodeKey(currentOffset) + if err != nil { + yield(nil, d.wrapErrorAtOffset(err, currentOffset)) + return + } + + // Position decoder to read value after yielding key + d.reset(keyEndOffset) + + ok := yield(key, nil) + if !ok { + return + } + + // Skip the value to get to next key-value pair + valueEndOffset, err := d.d.nextValueOffset(keyEndOffset, 1) + if err != nil { + yield(nil, d.wrapError(err)) + return + } + currentOffset = valueEndOffset + } + + // Set the final offset after map iteration + d.reset(currentOffset) + } + + return iterator, size, nil +} + +// ReadSlice returns an iterator over the values of the slice along with the +// slice size. The size can be used to pre-allocate a slice with the correct +// capacity for better performance. The iterator returns an error if the +// database is malformed or if the pointed value is not a slice. +func (d *Decoder) ReadSlice() (iter.Seq[error], uint, error) { + size, offset, err := d.decodeCtrlDataAndFollow(KindSlice) + if err != nil { + return nil, 0, d.wrapError(err) + } + + iterator := func(yield func(error) bool) { + currentOffset := offset + + for i := range size { + // Position decoder to read current element + d.reset(currentOffset) + + ok := yield(nil) + if !ok { + // Skip the unvisited elements + remaining := size - i - 1 + if remaining > 0 { + endOffset, err := d.d.nextValueOffset(currentOffset, remaining) + if err == nil { + d.reset(endOffset) + } + } + return + } + + // Advance to next element + nextOffset, err := d.d.nextValueOffset(currentOffset, 1) + if err != nil { + yield(d.wrapError(err)) + return + } + currentOffset = nextOffset + } + + // Set final offset after slice iteration + d.reset(currentOffset) + } + + return iterator, size, nil +} + +// SkipValue skips over the current value without decoding it. +// This is useful in custom decoders when encountering unknown fields. +// The decoder will be positioned after the skipped value. +func (d *Decoder) SkipValue() error { + // We can reuse the existing nextValueOffset logic by jumping to the next value + nextOffset, err := d.d.nextValueOffset(d.offset, 1) + if err != nil { + return d.wrapError(err) + } + d.reset(nextOffset) + return nil +} + +// PeekKind returns the kind of the current value without consuming it. +// This allows for look-ahead parsing similar to jsontext.Decoder.PeekKind(). +func (d *Decoder) PeekKind() (Kind, error) { + kindNum, _, _, err := d.d.decodeCtrlData(d.offset) + if err != nil { + return 0, d.wrapError(err) + } + + // Follow pointers to get the actual kind + if kindNum == KindPointer { + // We need to follow the pointer to get the real kind + dataOffset := d.offset + for { + var size uint + kindNum, size, dataOffset, err = d.d.decodeCtrlData(dataOffset) + if err != nil { + return 0, d.wrapError(err) + } + if kindNum != KindPointer { + break + } + dataOffset, _, err = d.d.decodePointer(size, dataOffset) + if err != nil { + return 0, d.wrapError(err) + } + } + } + + return kindNum, nil +} + +func (d *Decoder) reset(offset uint) { + d.offset = offset + d.hasNextOffset = false + d.nextOffset = 0 +} + +func (d *Decoder) setNextOffset(offset uint) { + if !d.hasNextOffset { + d.hasNextOffset = true + d.nextOffset = offset + } +} + +func (d *Decoder) getNextOffset() (uint, error) { + if !d.hasNextOffset { + return 0, errors.New("no next offset available") + } + return d.nextOffset, nil +} + +func unexpectedKindErr(expectedKind, actualKind Kind) error { + return fmt.Errorf("unexpected kind %d, expected %d", actualKind, expectedKind) +} + +func (d *Decoder) decodeCtrlDataAndFollow(expectedKind Kind) (uint, uint, error) { + dataOffset := d.offset + for { + var kindNum Kind + var size uint + var err error + kindNum, size, dataOffset, err = d.d.decodeCtrlData(dataOffset) + if err != nil { + return 0, 0, err // Don't wrap here, let caller wrap + } + + if kindNum == KindPointer { + var nextOffset uint + dataOffset, nextOffset, err = d.d.decodePointer(size, dataOffset) + if err != nil { + return 0, 0, err // Don't wrap here, let caller wrap + } + d.setNextOffset(nextOffset) + continue + } + + if kindNum != expectedKind { + return 0, 0, unexpectedKindErr(expectedKind, kindNum) + } + + return size, dataOffset, nil + } +} + +func (d *Decoder) readBytes(kind Kind) ([]byte, error) { + size, offset, err := d.decodeCtrlDataAndFollow(kind) + if err != nil { + return nil, err // Return unwrapped - caller will wrap + } + + if offset+size > uint(len(d.d.getBuffer())) { + return nil, mmdberrors.NewInvalidDatabaseError( + "the MaxMind DB file's data section contains bad data (offset+size %d exceeds buffer length %d)", + offset+size, + len(d.d.getBuffer()), + ) + } + d.setNextOffset(offset + size) + return d.d.getBuffer()[offset : offset+size], nil +} diff --git a/internal/decoder/decoder_test.go b/internal/decoder/decoder_test.go new file mode 100644 index 0000000..177ea8f --- /dev/null +++ b/internal/decoder/decoder_test.go @@ -0,0 +1,567 @@ +package decoder + +import ( + "encoding/hex" + "fmt" + "math/big" + "os" + "strings" + "testing" + + "github.com/stretchr/testify/require" +) + +// Helper function to create a Decoder for a given hex string. +func newDecoderFromHex(t *testing.T, hexStr string) *Decoder { + t.Helper() + inputBytes, err := hex.DecodeString(hexStr) + require.NoError(t, err, "Failed to decode hex string: %s", hexStr) + dd := NewDataDecoder(inputBytes) // [cite: 11] + return NewDecoder(dd, 0) // [cite: 26] +} + +// Helper function to create reasonable test names from potentially long hex strings. +func makeTestName(hexStr string) string { + if len(hexStr) <= 20 { + return hexStr + } + return hexStr[:16] + "..." + hexStr[len(hexStr)-4:] +} + +func TestDecodeBool(t *testing.T) { + tests := map[string]bool{ + "0007": false, // [cite: 29] + "0107": true, // [cite: 30] + } + + for hexStr, expected := range tests { + t.Run(hexStr, func(t *testing.T) { + decoder := newDecoderFromHex(t, hexStr) + result, err := decoder.ReadBool() // [cite: 30] + require.NoError(t, err) + require.Equal(t, expected, result) + // Check if offset was advanced correctly (simple check) + require.True(t, decoder.hasNextOffset || decoder.offset > 0, "Offset was not advanced") + }) + } +} + +func TestDecodeDouble(t *testing.T) { + tests := map[string]float64{ + "680000000000000000": 0.0, + "683FE0000000000000": 0.5, + "68400921FB54442EEA": 3.14159265359, + "68405EC00000000000": 123.0, + "6841D000000007F8F4": 1073741824.12457, + "68BFE0000000000000": -0.5, + "68C00921FB54442EEA": -3.14159265359, + "68C1D000000007F8F4": -1073741824.12457, + } + + for hexStr, expected := range tests { + t.Run(hexStr, func(t *testing.T) { + decoder := newDecoderFromHex(t, hexStr) + result, err := decoder.ReadFloat64() // [cite: 38] + require.NoError(t, err) + if expected == 0 { + require.InDelta(t, expected, result, 0) + } else { + require.InEpsilon(t, expected, result, 1e-15) + } + require.True(t, decoder.hasNextOffset || decoder.offset > 0, "Offset was not advanced") + }) + } +} + +func TestDecodeFloat(t *testing.T) { + tests := map[string]float32{ + "040800000000": float32(0.0), + "04083F800000": float32(1.0), + "04083F8CCCCD": float32(1.1), + "04084048F5C3": float32(3.14), + "0408461C3FF6": float32(9999.99), + "0408BF800000": float32(-1.0), + "0408BF8CCCCD": float32(-1.1), + "0408C048F5C3": -float32(3.14), + "0408C61C3FF6": float32(-9999.99), + } + + for hexStr, expected := range tests { + t.Run(hexStr, func(t *testing.T) { + decoder := newDecoderFromHex(t, hexStr) + result, err := decoder.ReadFloat32() // [cite: 36] + require.NoError(t, err) + if expected == 0 { + require.InDelta(t, expected, result, 0) + } else { + require.InEpsilon(t, expected, result, 1e-6) + } + require.True(t, decoder.hasNextOffset || decoder.offset > 0, "Offset was not advanced") + }) + } +} + +func TestDecodeInt32(t *testing.T) { + tests := map[string]int32{ + "0001": int32(0), // [cite: 39] + "0401ffffffff": int32(-1), + "0101ff": int32(255), + "0401ffffff01": int32(-255), + "020101f4": int32(500), + "0401fffffe0c": int32(-500), + "0201ffff": int32(65535), + "0401ffff0001": int32(-65535), + "0301ffffff": int32(16777215), + "0401ff000001": int32(-16777215), + "04017fffffff": int32(2147483647), // [cite: 86] + "040180000001": int32(-2147483647), + } + + for hexStr, expected := range tests { + t.Run(hexStr, func(t *testing.T) { + decoder := newDecoderFromHex(t, hexStr) + result, err := decoder.ReadInt32() // [cite: 40] + require.NoError(t, err) + require.Equal(t, expected, result) + require.True(t, decoder.hasNextOffset || decoder.offset > 0, "Offset was not advanced") + }) + } +} + +func TestDecodeMap(t *testing.T) { + tests := map[string]map[string]any{ + "e0": {}, // [cite: 50] + "e142656e43466f6f": {"en": "Foo"}, + "e242656e43466f6f427a6843e4baba": {"en": "Foo", "zh": "人"}, + // Nested map test needs separate handling or more complex validation logic + // "e1446e616d65e242656e43466f6f427a6843e4baba": map[string]any{ + // "name": map[string]any{"en": "Foo", "zh": "人"}, + // }, + // Map containing slice needs separate handling + // "e1496c616e677561676573020442656e427a68": map[string]any{ + // "languages": []any{"en", "zh"}, + // }, + } + + for hexStr, expected := range tests { + t.Run(hexStr, func(t *testing.T) { + decoder := newDecoderFromHex(t, hexStr) + mapIter, size, err := decoder.ReadMap() // [cite: 53] + require.NoError(t, err, "ReadMap failed") + resultMap := make(map[string]any, size) // Pre-allocate with correct capacity + + // Iterate through the map [cite: 54] + for keyBytes, err := range mapIter { // [cite: 50] + require.NoError(t, err, "Iterator returned error for key") + key := string(keyBytes) // [cite: 51] - Need to copy if stored + + // Now decode the value corresponding to the key + // For simplicity, we'll read as string here. Needs adjustment for mixed types. + value, err := decoder.ReadString() // [cite: 32] + require.NoError(t, err, "Failed to decode value for key %s", key) + resultMap[key] = value + } + + // Final check on the accumulated map + require.Equal(t, expected, resultMap) + require.True(t, decoder.hasNextOffset || decoder.offset > 0, "Offset was not advanced") + }) + } +} + +func TestDecodeSlice(t *testing.T) { + tests := map[string][]any{ + "0004": {}, // [cite: 55] + "010443466f6f": {"Foo"}, + "020443466f6f43e4baba": {"Foo", "人"}, + } + + for hexStr, expected := range tests { + t.Run(hexStr, func(t *testing.T) { + decoder := newDecoderFromHex(t, hexStr) + sliceIter, size, err := decoder.ReadSlice() // [cite: 56] + require.NoError(t, err, "ReadSlice failed") + results := make([]any, 0, size) // Pre-allocate with correct capacity + + // Iterate through the slice [cite: 57] + for err := range sliceIter { + require.NoError(t, err, "Iterator returned error") + + // Read the current element + // For simplicity, reading as string. Needs adjustment for mixed types. + elem, err := decoder.ReadString() // [cite: 32] + require.NoError(t, err, "Failed to decode slice element") + results = append(results, elem) + } + + require.Equal(t, expected, results) + require.True(t, decoder.hasNextOffset || decoder.offset > 0, "Offset was not advanced") + }) + } +} + +func TestDecodeString(t *testing.T) { + for hexStr, expected := range testStrings { + t.Run(makeTestName(hexStr), func(t *testing.T) { + decoder := newDecoderFromHex(t, hexStr) + result, err := decoder.ReadString() // [cite: 32] + require.NoError(t, err) + require.Equal(t, expected, result) + require.True(t, decoder.hasNextOffset || decoder.offset > 0, "Offset was not advanced") + }) + } +} + +func TestDecodeByte(t *testing.T) { + byteTests := make(map[string][]byte) + for key, val := range testStrings { + oldCtrl, err := hex.DecodeString(key[0:2]) + require.NoError(t, err) + // Adjust control byte for Bytes type (assuming String=0x2, Bytes=0x5) + // This mapping might need verification based on the actual type codes. + // Assuming TypeString=2 (010.....) -> TypeBytes=4 (100.....) + // Need to check the actual constants [cite: 4, 5] + newCtrlByte := (oldCtrl[0] & 0x1f) | (byte(KindBytes) << 5) + newCtrl := []byte{newCtrlByte} + + newKey := hex.EncodeToString(newCtrl) + key[2:] + byteTests[newKey] = []byte(val.(string)) + } + + for hexStr, expected := range byteTests { + t.Run(makeTestName(hexStr), func(t *testing.T) { + decoder := newDecoderFromHex(t, hexStr) + result, err := decoder.ReadBytes() // [cite: 34] + require.NoError(t, err) + require.Equal(t, expected, result) + require.True(t, decoder.hasNextOffset || decoder.offset > 0, "Offset was not advanced") + }) + } +} + +func TestDecodeUint16(t *testing.T) { + tests := map[string]uint16{ + "a0": uint16(0), // [cite: 41] + "a1ff": uint16(255), + "a201f4": uint16(500), + "a22a78": uint16(10872), + "a2ffff": uint16(65535), // [cite: 88] - Note: reflection test uses uint64 expected value + } + + for hexStr, expected := range tests { + t.Run(hexStr, func(t *testing.T) { + decoder := newDecoderFromHex(t, hexStr) + result, err := decoder.ReadUInt16() // [cite: 42] + require.NoError(t, err) + require.Equal(t, expected, result) + require.True(t, decoder.hasNextOffset || decoder.offset > 0, "Offset was not advanced") + }) + } +} + +func TestDecodeUint32(t *testing.T) { + tests := map[string]uint32{ + "c0": uint32(0), // [cite: 43] + "c1ff": uint32(255), + "c201f4": uint32(500), + "c22a78": uint32(10872), + "c2ffff": uint32(65535), + "c3ffffff": uint32(16777215), + "c4ffffffff": uint32(4294967295), + } + + for hexStr, expected := range tests { + t.Run(hexStr, func(t *testing.T) { + decoder := newDecoderFromHex(t, hexStr) + result, err := decoder.ReadUInt32() // [cite: 44] + require.NoError(t, err) + require.Equal(t, expected, result) + require.True(t, decoder.hasNextOffset || decoder.offset > 0, "Offset was not advanced") + }) + } +} + +func TestDecodeUint64(t *testing.T) { + ctrlByte := "02" // Extended type for Uint64 [cite: 10] + + tests := map[string]uint64{ + "00" + ctrlByte: uint64(0), // [cite: 45] + "02" + ctrlByte + "01f4": uint64(500), + "02" + ctrlByte + "2a78": uint64(10872), + // Add max value tests similar to reflection_test [cite: 89] + "08" + ctrlByte + "ffffffffffffffff": uint64(18446744073709551615), + } + + for hexStr, expected := range tests { + t.Run(hexStr, func(t *testing.T) { + decoder := newDecoderFromHex(t, hexStr) + result, err := decoder.ReadUInt64() // [cite: 46] + require.NoError(t, err) + require.Equal(t, expected, result) + require.True(t, decoder.hasNextOffset || decoder.offset > 0, "Offset was not advanced") + }) + } +} + +func TestDecodeUint128(t *testing.T) { + ctrlByte := "03" // Extended type for Uint128 [cite: 10] + bits := uint(128) + + tests := map[string]*big.Int{ + "00" + ctrlByte: big.NewInt(0), // [cite: 47] + "02" + ctrlByte + "01f4": big.NewInt(500), + "02" + ctrlByte + "2a78": big.NewInt(10872), + // Add max value tests similar to reflection_test [cite: 91] + "10" + ctrlByte + strings.Repeat("ff", 16): func() *big.Int { // 16 bytes = 128 bits + expected := powBigInt(big.NewInt(2), bits) + return expected.Sub(expected, big.NewInt(1)) + }(), + } + + for hexStr, expected := range tests { + t.Run(hexStr, func(t *testing.T) { + decoder := newDecoderFromHex(t, hexStr) + hi, lo, err := decoder.ReadUInt128() // [cite: 48] + require.NoError(t, err) + + // Reconstruct the big.Int from hi and lo parts for comparison + result := new(big.Int) + result.SetUint64(hi) + result.Lsh(result, 64) // Shift high part left by 64 bits + result.Or(result, new(big.Int).SetUint64(lo)) // OR with low part + + require.Equal(t, 0, expected.Cmp(result), + "Expected %v, got %v", expected.String(), result.String()) + require.True(t, decoder.hasNextOffset || decoder.offset > 0, "Offset was not advanced") + }) + } +} + +// TestPointers requires a specific data file and structure. +func TestPointersInDecoder(t *testing.T) { + // This test requires the 'maps-with-pointers.raw' file used in reflection_test [cite: 92] + // It demonstrates how to handle pointers using the basic Decoder. + bytes, err := os.ReadFile(testFile("maps-with-pointers.raw")) // [cite: 92] + require.NoError(t, err) + dd := NewDataDecoder(bytes) + + expected := map[uint]map[string]string{ + // Offsets and expected values from reflection_test.go [cite: 92] + 0: {"long_key": "long_value1"}, + 22: {"long_key": "long_value2"}, + 37: {"long_key2": "long_value1"}, + 50: {"long_key2": "long_value2"}, + 55: {"long_key": "long_value1"}, + 57: {"long_key2": "long_value2"}, + } + + for startOffset, expectedValue := range expected { + t.Run(fmt.Sprintf("Offset_%d", startOffset), func(t *testing.T) { + decoder := NewDecoder(dd, startOffset) // Start at the specific offset + actualValue := make(map[string]string) + + // Expecting a map at the target offset (may be behind a pointer) + mapIter, size, err := decoder.ReadMap() + require.NoError(t, err, "ReadMap failed") + _ = size // Use size if needed for pre-allocation + for keyBytes, errIter := range mapIter { + require.NoError(t, errIter) + key := string(keyBytes) + // Value is expected to be a string + value, errDecode := decoder.ReadString() + require.NoError(t, errDecode) + actualValue[key] = value + } + + require.Equal(t, expectedValue, actualValue) + // Offset check might be complex here due to pointer jumps + }) + } +} + +// TestBoundsChecking verifies that buffer access is properly bounds-checked +// to prevent panics on malformed databases. +func TestBoundsChecking(t *testing.T) { + // Create a very small buffer that would cause out-of-bounds access + // if bounds checking is not working + smallBuffer := []byte{0x44, 0x41} // Type string (0x4), size 4, but only 2 bytes total + dd := NewDataDecoder(smallBuffer) + decoder := NewDecoder(dd, 0) + + // This should fail gracefully with an error instead of panicking + _, err := decoder.ReadString() + require.Error(t, err) + require.Contains(t, err.Error(), "unexpected end of database") + + // Test DecodeBytes bounds checking with a separate buffer + bytesBuffer := []byte{ + 0x84, + 0x41, + } // Type bytes (4 << 5 = 0x80), size 4 (0x04), but only 2 bytes total + dd3 := NewDataDecoder(bytesBuffer) + decoder3 := NewDecoder(dd3, 0) + + _, err = decoder3.ReadBytes() + require.Error(t, err) + require.Contains(t, err.Error(), "exceeds buffer length") + + // Test DecodeUInt128 bounds checking + uint128Buffer := []byte{ + 0x0B, + 0x03, + } // Extended type (0x0), size 11, TypeUint128-7=3, but only 2 bytes total + dd2 := NewDataDecoder(uint128Buffer) + decoder2 := NewDecoder(dd2, 0) + + _, _, err = decoder2.ReadUInt128() + require.Error(t, err) + require.Contains(t, err.Error(), "unexpected end of database") +} + +func TestPeekKind(t *testing.T) { + tests := []struct { + name string + buffer []byte + expected Kind + }{ + { + name: "string type", + buffer: []byte{0x44, 't', 'e', 's', 't'}, // String "test" (TypeString=2, (2<<5)|4) + expected: KindString, + }, + { + name: "map type", + buffer: []byte{0xE0}, // Empty map (TypeMap=7, (7<<5)|0) + expected: KindMap, + }, + { + name: "slice type", + buffer: []byte{ + 0x00, + 0x04, + }, // Empty slice (TypeSlice=11, extended type: 0x00, TypeSlice-7=4) + expected: KindSlice, + }, + { + name: "bool type", + buffer: []byte{ + 0x01, + 0x07, + }, // Bool true (TypeBool=14, extended type: size 1, TypeBool-7=7) + expected: KindBool, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + decoder := NewDecoder(NewDataDecoder(tt.buffer), 0) + + actualType, err := decoder.PeekKind() + require.NoError(t, err, "PeekKind failed") + + require.Equal( + t, + tt.expected, + actualType, + "Expected type %d, got %d", + tt.expected, + actualType, + ) + + // Verify that PeekKind doesn't consume the value + actualType2, err := decoder.PeekKind() + require.NoError(t, err, "Second PeekKind failed") + + require.Equal( + t, + tt.expected, + actualType2, + "Second PeekKind gave different result: expected %d, got %d", + tt.expected, + actualType2, + ) + }) + } +} + +// TestPeekKindWithPointer tests that PeekKind correctly follows pointers +// to get the actual kind of the pointed-to value. +func TestPeekKindWithPointer(t *testing.T) { + // Create a buffer with a pointer that points to a string + // This is a simplified test - in real MMDB files pointers are more complex + buffer := []byte{ + // Pointer (TypePointer=1, size/pointer encoding) + 0x20, 0x05, // Simple pointer to offset 5 + // Target string at offset 5 (but we'll put it at offset 2 for this test) + 0x44, 't', 'e', 's', 't', // String "test" + } + + decoder := NewDecoder(NewDataDecoder(buffer), 0) + + // PeekKind should follow the pointer and return KindString + actualType, err := decoder.PeekKind() + require.NoError(t, err, "PeekKind with pointer failed") + + // Note: This test may need adjustment based on actual pointer encoding + // The important thing is that PeekKind follows pointers + if actualType != KindPointer { + // If the implementation follows pointers completely, it should return the target kind + // If it just returns KindPointer, that's also acceptable behavior + t.Logf("PeekKind returned %d (this may be expected behavior)", actualType) + } +} + +// ExampleDecoder_PeekKind demonstrates how to use PeekKind for +// look-ahead parsing without consuming values. +func ExampleDecoder_PeekKind() { + // Create test data with different types + testCases := [][]byte{ + {0x44, 't', 'e', 's', 't'}, // String + {0xE0}, // Empty map + {0x00, 0x04}, // Empty slice (extended type) + {0x01, 0x07}, // Bool true (extended type) + } + + typeNames := []string{"String", "Map", "Slice", "Bool"} + + for i, buffer := range testCases { + decoder := NewDecoder(NewDataDecoder(buffer), 0) + + // Peek at the kind without consuming it + typ, err := decoder.PeekKind() + if err != nil { + panic(err) + } + + fmt.Printf("Type %d: %s (value: %d)\n", i+1, typeNames[i], typ) + + // PeekKind doesn't consume, so we can peek again + typ2, err := decoder.PeekKind() + if err != nil { + panic(err) + } + + if typ != typ2 { + fmt.Println("ERROR: PeekKind consumed the value!") + } + } + + // Output: + // Type 1: String (value: 2) + // Type 2: Map (value: 7) + // Type 3: Slice (value: 11) + // Type 4: Bool (value: 14) +} + +func TestDecoderOptions(t *testing.T) { + buffer := []byte{0x44, 't', 'e', 's', 't'} // String "test" + dd := NewDataDecoder(buffer) + + // Test that options infrastructure works (even with no current options) + decoder1 := NewDecoder(dd, 0) + require.NotNil(t, decoder1) + + // Test that passing empty options slice works + decoder2 := NewDecoder(dd, 0) + require.NotNil(t, decoder2) +} diff --git a/internal/decoder/error_context.go b/internal/decoder/error_context.go new file mode 100644 index 0000000..31bf199 --- /dev/null +++ b/internal/decoder/error_context.go @@ -0,0 +1,49 @@ +package decoder + +import ( + "github.com/oschwald/maxminddb-golang/v2/internal/mmdberrors" +) + +// errorContext provides zero-allocation error context tracking for Decoder. +// This is only used when an error occurs, ensuring no performance impact +// on the happy path. +type errorContext struct { + path *mmdberrors.PathBuilder // Only allocated when needed +} + +// BuildPath implements mmdberrors.ErrorContextTracker. +// This is only called when an error occurs, so allocation is acceptable. +func (e *errorContext) BuildPath() string { + if e.path == nil { + return "" // No path tracking enabled + } + return e.path.Build() +} + +// wrapError wraps an error with context information when an error occurs. +// Zero allocation on happy path - only allocates when error != nil. +func (d *Decoder) wrapError(err error) error { + if err == nil { + return nil + } + // Only wrap with context when an error actually occurs + return mmdberrors.WrapWithContext(err, d.offset, nil) +} + +// wrapErrorAtOffset wraps an error with context at a specific offset. +// Used when the error occurs at a different offset than the decoder's current position. +func (*Decoder) wrapErrorAtOffset(err error, offset uint) error { + if err == nil { + return nil + } + return mmdberrors.WrapWithContext(err, offset, nil) +} + +// Example of how to integrate into existing decoder methods: +// Instead of: +// return mmdberrors.NewInvalidDatabaseError("message") +// Use: +// return d.wrapError(mmdberrors.NewInvalidDatabaseError("message")) +// +// This adds zero overhead when no error occurs, but provides rich context +// when errors do happen. diff --git a/internal/decoder/error_context_test.go b/internal/decoder/error_context_test.go new file mode 100644 index 0000000..ecce561 --- /dev/null +++ b/internal/decoder/error_context_test.go @@ -0,0 +1,288 @@ +package decoder + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/oschwald/maxminddb-golang/v2/internal/mmdberrors" +) + +func TestWrapError_ZeroAllocationHappyPath(t *testing.T) { + buffer := []byte{0x44, 't', 'e', 's', 't'} // String "test" + dd := NewDataDecoder(buffer) + decoder := NewDecoder(dd, 0) + + // Test that no error wrapping has zero allocation + err := decoder.wrapError(nil) + require.NoError(t, err) + + // DataDecoder should always have path tracking enabled + require.NotNil(t, decoder.d) +} + +func TestWrapError_ContextWhenError(t *testing.T) { + buffer := []byte{0x44, 't', 'e', 's', 't'} // String "test" + dd := NewDataDecoder(buffer) + decoder := NewDecoder(dd, 0) + + // Simulate an error with context + originalErr := mmdberrors.NewInvalidDatabaseError("test error") + wrappedErr := decoder.wrapError(originalErr) + + require.Error(t, wrappedErr) + + // Should be a ContextualError + var contextErr mmdberrors.ContextualError + require.ErrorAs(t, wrappedErr, &contextErr) + + // Should have offset information + require.Equal(t, uint(0), contextErr.Offset) + require.Equal(t, originalErr, contextErr.Err) +} + +func TestPathBuilder(t *testing.T) { + builder := mmdberrors.NewPathBuilder() + + // Test basic path building + require.Equal(t, "/", builder.Build()) + + builder.PushMap("city") + require.Equal(t, "/city", builder.Build()) + + builder.PushMap("names") + require.Equal(t, "/city/names", builder.Build()) + + builder.PushSlice(0) + require.Equal(t, "/city/names/0", builder.Build()) + + // Test pop + builder.Pop() + require.Equal(t, "/city/names", builder.Build()) + + // Test reset + builder.Reset() + require.Equal(t, "/", builder.Build()) +} + +// Benchmark to verify zero allocation on happy path. +func BenchmarkWrapError_HappyPath(b *testing.B) { + buffer := []byte{0x44, 't', 'e', 's', 't'} // String "test" + dd := NewDataDecoder(buffer) + decoder := NewDecoder(dd, 0) + + b.ResetTimer() + b.ReportAllocs() + + for range b.N { + err := decoder.wrapError(nil) + if err != nil { + b.Fatal("unexpected error") + } + } +} + +// Benchmark to show allocation only occurs on error path. +func BenchmarkWrapError_ErrorPath(b *testing.B) { + buffer := []byte{0x44, 't', 'e', 's', 't'} // String "test" + dd := NewDataDecoder(buffer) + decoder := NewDecoder(dd, 0) + + originalErr := mmdberrors.NewInvalidDatabaseError("test error") + + b.ResetTimer() + b.ReportAllocs() + + for range b.N { + err := decoder.wrapError(originalErr) + if err == nil { + b.Fatal("expected error") + } + } +} + +// Example showing the API in action. +func ExampleContextualError() { + // This would be internal to the decoder, shown for illustration + builder := mmdberrors.NewPathBuilder() + builder.PushMap("city") + builder.PushMap("names") + builder.PushMap("en") + + // Simulate an error with context + originalErr := mmdberrors.NewInvalidDatabaseError("string too long") + + contextTracker := &errorContext{path: builder} + wrappedErr := mmdberrors.WrapWithContext(originalErr, 1234, contextTracker) + + fmt.Println(wrappedErr.Error()) + // Output: at offset 1234, path /city/names/en: string too long +} + +func TestContextualErrorIntegration(t *testing.T) { + t.Run("InvalidStringLength", func(t *testing.T) { + // String claims size 4 but buffer only has 3 bytes total + buffer := []byte{0x44, 't', 'e', 's'} + + // Test ReflectionDecoder + rd := New(buffer) + var result string + err := rd.Decode(0, &result) + require.Error(t, err) + + var contextErr mmdberrors.ContextualError + require.ErrorAs(t, err, &contextErr) + require.Equal(t, uint(0), contextErr.Offset) + require.Contains(t, contextErr.Error(), "offset 0") + + // Test new Decoder API + dd := NewDataDecoder(buffer) + decoder := NewDecoder(dd, 0) + _, err = decoder.ReadString() + require.Error(t, err) + + require.ErrorAs(t, err, &contextErr) + require.Equal(t, uint(0), contextErr.Offset) + require.Contains(t, contextErr.Error(), "offset 0") + }) + + t.Run("NestedMapWithPath", func(t *testing.T) { + // Map with nested structure that has an error deep inside + // Map { "key": invalid_string } + buffer := []byte{ + 0xe1, // Map with 1 item + 0x43, 'k', 'e', 'y', // Key "key" (3 bytes) + 0x44, 't', 'e', // Invalid string (claims size 4, only has 2 bytes) + } + + // Test ReflectionDecoder with map decoding + rd := New(buffer) + var result map[string]string + err := rd.Decode(0, &result) + require.Error(t, err) + + // Should get a wrapped error with path information + var contextErr mmdberrors.ContextualError + require.ErrorAs(t, err, &contextErr) + require.Equal(t, "/key", contextErr.Path) + require.Contains(t, contextErr.Error(), "path /key") + + // Test new Decoder API - no automatic path tracking + dd := NewDataDecoder(buffer) + decoder := NewDecoder(dd, 0) + mapIter, _, err := decoder.ReadMap() + require.NoError(t, err, "ReadMap failed") + + var mapErr error + for _, iterErr := range mapIter { + if iterErr != nil { + mapErr = iterErr + break + } + + // Try to read the value (this should fail) + _, mapErr = decoder.ReadString() + if mapErr != nil { + break + } + } + + require.Error(t, mapErr) + require.ErrorAs(t, mapErr, &contextErr) + // New API should have offset but no path + require.Contains(t, contextErr.Error(), "offset") + require.Empty(t, contextErr.Path) + }) + + t.Run("SliceIndexInPath", func(t *testing.T) { + // Create nested map-slice-map structure: { "list": [{"name": invalid_string}] } + // This will test path like /list/0/name + buffer := []byte{ + 0xe1, // Map with 1 item + 0x44, 'l', 'i', 's', 't', // Key "list" (4 bytes) + 0x01, 0x04, // Array with 1 item (extended type: type=4 (slice), count=1) + 0xe1, // Map with 1 item (array element) + 0x44, 'n', 'a', 'm', 'e', // Key "name" (4 bytes) + 0x44, 't', 'e', // Invalid string (claims size 4, only has 2 bytes) + } + + // Test ReflectionDecoder with slice index in path + rd := New(buffer) + var result map[string][]map[string]string + err := rd.Decode(0, &result) + require.Error(t, err) + + // Debug: print the actual error and path + t.Logf("Error: %v", err) + + // Should get a wrapped error with slice index in path + var contextErr mmdberrors.ContextualError + require.ErrorAs(t, err, &contextErr) + t.Logf("Path: %s", contextErr.Path) + + // Verify we get the exact path with correct order + require.Equal(t, "/list/0/name", contextErr.Path) + require.Contains(t, contextErr.Error(), "path /list/0/name") + require.Contains(t, contextErr.Error(), "offset") + + // Test new Decoder API - manual iteration, no automatic path tracking + dd := NewDataDecoder(buffer) + decoder := NewDecoder(dd, 0) + + // Navigate through the nested structure manually + mapIter, _, err := decoder.ReadMap() + require.NoError(t, err, "ReadMap failed") + var mapErr error + + for key, iterErr := range mapIter { + if iterErr != nil { + mapErr = iterErr + break + } + require.Equal(t, "list", string(key)) + + // Read the array + sliceIter, _, err := decoder.ReadSlice() + require.NoError(t, err, "ReadSlice failed") + sliceIndex := 0 + for sliceIterErr := range sliceIter { + if sliceIterErr != nil { + mapErr = sliceIterErr + break + } + require.Equal(t, 0, sliceIndex) // Should be first element + + // Read the nested map (array element) + innerMapIter, _, err := decoder.ReadMap() + require.NoError(t, err, "ReadMap failed") + for innerKey, innerIterErr := range innerMapIter { + if innerIterErr != nil { + mapErr = innerIterErr + break + } + require.Equal(t, "name", string(innerKey)) + + // Try to read the invalid string (this should fail) + _, mapErr = decoder.ReadString() + if mapErr != nil { + break + } + } + if mapErr != nil { + break + } + sliceIndex++ + } + if mapErr != nil { + break + } + } + + require.Error(t, mapErr) + require.ErrorAs(t, mapErr, &contextErr) + // New API should have offset but no path (since it's manual iteration) + require.Contains(t, contextErr.Error(), "offset") + require.Empty(t, contextErr.Path) + }) +} diff --git a/internal/decoder/example_kind_test.go b/internal/decoder/example_kind_test.go new file mode 100644 index 0000000..5c2dd1d --- /dev/null +++ b/internal/decoder/example_kind_test.go @@ -0,0 +1,59 @@ +package decoder + +import ( + "fmt" +) + +// ExampleKind_String demonstrates human-readable Kind names. +func ExampleKind_String() { + kinds := []Kind{KindString, KindMap, KindSlice, KindUint32, KindBool} + + for _, k := range kinds { + fmt.Printf("%s\n", k.String()) + } + + // Output: + // String + // Map + // Slice + // Uint32 + // Bool +} + +// ExampleKind_IsContainer demonstrates container type detection. +func ExampleKind_IsContainer() { + kinds := []Kind{KindString, KindMap, KindSlice, KindUint32} + + for _, k := range kinds { + if k.IsContainer() { + fmt.Printf("%s is a container type\n", k.String()) + } else { + fmt.Printf("%s is not a container type\n", k.String()) + } + } + + // Output: + // String is not a container type + // Map is a container type + // Slice is a container type + // Uint32 is not a container type +} + +// ExampleKind_IsScalar demonstrates scalar type detection. +func ExampleKind_IsScalar() { + kinds := []Kind{KindString, KindMap, KindUint32, KindPointer} + + for _, k := range kinds { + if k.IsScalar() { + fmt.Printf("%s is a scalar value\n", k.String()) + } else { + fmt.Printf("%s is not a scalar value\n", k.String()) + } + } + + // Output: + // String is a scalar value + // Map is not a scalar value + // Uint32 is a scalar value + // Pointer is not a scalar value +} diff --git a/internal/decoder/field_precedence_test.go b/internal/decoder/field_precedence_test.go new file mode 100644 index 0000000..22519dd --- /dev/null +++ b/internal/decoder/field_precedence_test.go @@ -0,0 +1,134 @@ +package decoder + +import ( + "encoding/hex" + "reflect" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// TestFieldPrecedenceRules tests json/v2 style field precedence behavior. +func TestFieldPrecedenceRules(t *testing.T) { + // Test data: {"en": "Foo"} + testData := "e142656e43466f6f" + testBytes, err := hex.DecodeString(testData) + require.NoError(t, err) + + t.Run("DirectFieldWinsOverEmbedded", func(t *testing.T) { + type Embedded struct { + En string `maxminddb:"en"` + } + target := &struct { + Embedded + + En string `maxminddb:"en"` // Direct field should win + }{} + + decoder := New(testBytes) + err := decoder.Decode(0, target) + require.NoError(t, err) + + assert.Equal(t, "Foo", target.En, "Direct field should be set") + assert.Empty(t, target.Embedded.En, "Embedded field should not be set due to precedence") + }) + + t.Run("TaggedFieldWinsOverUntagged", func(t *testing.T) { + type Untagged struct { + En string // Untagged field + } + target := &struct { + Untagged + + En string `maxminddb:"en"` // Tagged field should win + }{} + + decoder := New(testBytes) + err := decoder.Decode(0, target) + require.NoError(t, err) + + assert.Equal(t, "Foo", target.En, "Tagged field should be set") + assert.Empty(t, target.Untagged.En, "Untagged field should not be set") + }) + + t.Run("ShallowFieldWinsOverDeep", func(t *testing.T) { + type DeepNested struct { + En string `maxminddb:"en"` // Deeper field + } + type MiddleNested struct { + DeepNested + } + target := &struct { + MiddleNested + + En string `maxminddb:"en"` // Shallow field should win + }{} + + decoder := New(testBytes) + err := decoder.Decode(0, target) + require.NoError(t, err) + + assert.Equal(t, "Foo", target.En, "Shallow field should be set") + assert.Empty(t, target.DeepNested.En, "Deep field should not be set due to precedence") + }) +} + +// TestEmbeddedPointerSupport tests support for embedded pointer types. +func TestEmbeddedPointerSupport(t *testing.T) { + // Test data: {"data": "test"} + testData := "e144646174614474657374" + testBytes, err := hex.DecodeString(testData) + require.NoError(t, err) + + type EmbeddedPointer struct { + Data string `maxminddb:"data"` + } + + target := &struct { + *EmbeddedPointer + + Other string `maxminddb:"other"` + }{} + + decoder := New(testBytes) + err = decoder.Decode(0, target) + require.NoError(t, err) + + // Test embedded pointer field access - this was causing nil pointer dereference before fix + require.NotNil(t, target.EmbeddedPointer, "Embedded pointer should be initialized") + assert.Equal(t, "test", target.Data) +} + +// TestFieldCaching tests the field caching mechanism works with new precedence rules. +func TestFieldCaching(t *testing.T) { + type Embedded struct { + Field1 string `maxminddb:"field1"` + } + + type TestStruct struct { + Embedded + + Field2 int `maxminddb:"field2"` + Field3 bool `maxminddb:"field3"` + } + + // Test that multiple instances use cached fields + s1 := TestStruct{} + s2 := TestStruct{} + + fields1 := cachedFields(reflect.ValueOf(s1)) + fields2 := cachedFields(reflect.ValueOf(s2)) + + // Should be the same cached instance + assert.Same(t, fields1, fields2, "Same struct type should use cached fields") + + // Verify field mapping includes embedded fields + expectedFieldNames := []string{"field1", "field2", "field3"} + + assert.Len(t, fields1.namedFields, 3, "Should have 3 named fields") + for _, name := range expectedFieldNames { + assert.Contains(t, fields1.namedFields, name, "Should contain field: "+name) + assert.NotNil(t, fields1.namedFields[name], "Field info should not be nil: "+name) + } +} diff --git a/internal/decoder/kind_test.go b/internal/decoder/kind_test.go new file mode 100644 index 0000000..b8d2735 --- /dev/null +++ b/internal/decoder/kind_test.go @@ -0,0 +1,130 @@ +package decoder + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestKind_String(t *testing.T) { + tests := []struct { + kind Kind + expected string + }{ + {KindExtended, "Extended"}, + {KindPointer, "Pointer"}, + {KindString, "String"}, + {KindFloat64, "Float64"}, + {KindBytes, "Bytes"}, + {KindUint16, "Uint16"}, + {KindUint32, "Uint32"}, + {KindMap, "Map"}, + {KindInt32, "Int32"}, + {KindUint64, "Uint64"}, + {KindUint128, "Uint128"}, + {KindSlice, "Slice"}, + {KindContainer, "Container"}, + {KindEndMarker, "EndMarker"}, + {KindBool, "Bool"}, + {KindFloat32, "Float32"}, + {Kind(999), "Unknown(999)"}, // Test unknown kind + } + + for _, tt := range tests { + t.Run(tt.expected, func(t *testing.T) { + result := tt.kind.String() + require.Equal(t, tt.expected, result) + }) + } +} + +func TestKind_IsContainer(t *testing.T) { + tests := []struct { + kind Kind + expected bool + name string + }{ + {KindMap, true, "Map is container"}, + {KindSlice, true, "Slice is container"}, + {KindString, false, "String is not container"}, + {KindUint32, false, "Uint32 is not container"}, + {KindBool, false, "Bool is not container"}, + {KindPointer, false, "Pointer is not container"}, + {KindExtended, false, "Extended is not container"}, + { + KindContainer, + false, + "Container is not container", + }, // Container is special, not a data container + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := tt.kind.IsContainer() + require.Equal(t, tt.expected, result) + }) + } +} + +func TestKind_IsScalar(t *testing.T) { + tests := []struct { + kind Kind + expected bool + name string + }{ + {KindString, true, "String is scalar"}, + {KindFloat64, true, "Float64 is scalar"}, + {KindBytes, true, "Bytes is scalar"}, + {KindUint16, true, "Uint16 is scalar"}, + {KindUint32, true, "Uint32 is scalar"}, + {KindInt32, true, "Int32 is scalar"}, + {KindUint64, true, "Uint64 is scalar"}, + {KindUint128, true, "Uint128 is scalar"}, + {KindBool, true, "Bool is scalar"}, + {KindFloat32, true, "Float32 is scalar"}, + {KindMap, false, "Map is not scalar"}, + {KindSlice, false, "Slice is not scalar"}, + {KindPointer, false, "Pointer is not scalar"}, + {KindExtended, false, "Extended is not scalar"}, + {KindContainer, false, "Container is not scalar"}, + {KindEndMarker, false, "EndMarker is not scalar"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := tt.kind.IsScalar() + require.Equal(t, tt.expected, result) + }) + } +} + +func TestKind_Classification(t *testing.T) { + // Test that IsContainer and IsScalar are mutually exclusive for data types + for k := KindExtended; k <= KindFloat32; k++ { + isContainer := k.IsContainer() + isScalar := k.IsScalar() + + // For actual data types (not meta types), they should be either container or scalar + switch k { + case KindMap, KindSlice: + require.True(t, isContainer, "Kind %s should be container", k.String()) + require.False(t, isScalar, "Kind %s should not be scalar", k.String()) + case KindString, + KindFloat64, + KindBytes, + KindUint16, + KindUint32, + KindInt32, + KindUint64, + KindUint128, + KindBool, + KindFloat32: + require.True(t, isScalar, "Kind %s should be scalar", k.String()) + require.False(t, isContainer, "Kind %s should not be container", k.String()) + default: + // Meta types like Extended, Pointer, Container, EndMarker are neither + require.False(t, isContainer, "Meta kind %s should not be container", k.String()) + require.False(t, isScalar, "Meta kind %s should not be scalar", k.String()) + } + } +} diff --git a/internal/decoder/nested_unmarshaler_test.go b/internal/decoder/nested_unmarshaler_test.go new file mode 100644 index 0000000..a4892b9 --- /dev/null +++ b/internal/decoder/nested_unmarshaler_test.go @@ -0,0 +1,222 @@ +package decoder + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +// Inner type with UnmarshalMaxMindDB. +type testInnerNested struct { + Value string + custom bool // track if custom unmarshaler was called +} + +func (i *testInnerNested) UnmarshalMaxMindDB(d *Decoder) error { + i.custom = true + str, err := d.ReadString() + if err != nil { + return err + } + i.Value = "custom:" + str + return nil +} + +// TestNestedUnmarshaler tests that UnmarshalMaxMindDB is called for nested struct fields. +func TestNestedUnmarshaler(t *testing.T) { + // Outer type without UnmarshalMaxMindDB + type Outer struct { + Field testInnerNested + Name string + } + + // Create test data: a map with "Field" -> "test" and "Name" -> "example" + data := []byte{ + // Map with 2 items + 0xe2, + // Key "Field" + 0x45, 'F', 'i', 'e', 'l', 'd', + // Value "test" (string) + 0x44, 't', 'e', 's', 't', + // Key "Name" + 0x44, 'N', 'a', 'm', 'e', + // Value "example" (string) + 0x47, 'e', 'x', 'a', 'm', 'p', 'l', 'e', + } + + t.Run("nested field with UnmarshalMaxMindDB", func(t *testing.T) { + d := New(data) + var result Outer + + err := d.Decode(0, &result) + require.NoError(t, err) + + // Check that custom unmarshaler WAS called for nested field + require.True( + t, + result.Field.custom, + "Custom unmarshaler should be called for nested fields", + ) + require.Equal(t, "custom:test", result.Field.Value) + require.Equal(t, "example", result.Name) + }) +} + +// testInnerPointer with UnmarshalMaxMindDB for pointer test. +type testInnerPointer struct { + Value string + custom bool +} + +func (i *testInnerPointer) UnmarshalMaxMindDB(d *Decoder) error { + i.custom = true + str, err := d.ReadString() + if err != nil { + return err + } + i.Value = "ptr:" + str + return nil +} + +// TestNestedUnmarshalerPointer tests UnmarshalMaxMindDB with pointer fields. +func TestNestedUnmarshalerPointer(t *testing.T) { + type Outer struct { + Field *testInnerPointer + Name string + } + + // Test data + data := []byte{ + // Map with 2 items + 0xe2, + // Key "Field" + 0x45, 'F', 'i', 'e', 'l', 'd', + // Value "test" + 0x44, 't', 'e', 's', 't', + // Key "Name" + 0x44, 'N', 'a', 'm', 'e', + // Value "example" + 0x47, 'e', 'x', 'a', 'm', 'p', 'l', 'e', + } + + t.Run("pointer field with UnmarshalMaxMindDB", func(t *testing.T) { + d := New(data) + var result Outer + err := d.Decode(0, &result) + require.NoError(t, err) + + // The pointer should be created and unmarshaled with custom unmarshaler + require.NotNil(t, result.Field) + require.True( + t, + result.Field.custom, + "Custom unmarshaler should be called for pointer fields", + ) + require.Equal(t, "ptr:test", result.Field.Value) + require.Equal(t, "example", result.Name) + }) +} + +// testItem with UnmarshalMaxMindDB for slice test. +type testItem struct { + ID int + custom bool +} + +func (item *testItem) UnmarshalMaxMindDB(d *Decoder) error { + item.custom = true + id, err := d.ReadUInt32() + if err != nil { + return err + } + item.ID = int(id) * 2 + return nil +} + +// TestNestedUnmarshalerInSlice tests UnmarshalMaxMindDB for slice elements. +func TestNestedUnmarshalerInSlice(t *testing.T) { + type Container struct { + Items []testItem + } + + // Test data: a map with "Items" -> [1, 2, 3] + data := []byte{ + // Map with 1 item (KindMap=7 << 5 | size=1) + 0xe1, + // Key "Items" (KindString=2 << 5 | size=5) + 0x45, 'I', 't', 'e', 'm', 's', + // Slice with 3 items - KindSlice=11, which is > 7, so we need extended type + // Extended type: ctrl_byte = (KindExtended << 5) | size = (0 << 5) | 3 = 0x03 + // Next byte: KindSlice - 7 = 11 - 7 = 4 + 0x03, 0x04, + // Value 1 (KindUint32=6 << 5 | size=1) + 0xc1, 0x01, + // Value 2 (KindUint32=6 << 5 | size=1) + 0xc1, 0x02, + // Value 3 (KindUint32=6 << 5 | size=1) + 0xc1, 0x03, + } + + t.Run("slice elements with UnmarshalMaxMindDB", func(t *testing.T) { + d := New(data) + var result Container + err := d.Decode(0, &result) + require.NoError(t, err) + + require.Len(t, result.Items, 3) + // With custom unmarshaler, values should be doubled + require.True( + t, + result.Items[0].custom, + "Custom unmarshaler should be called for slice elements", + ) + require.Equal(t, 2, result.Items[0].ID) // 1 * 2 + require.Equal(t, 4, result.Items[1].ID) // 2 * 2 + require.Equal(t, 6, result.Items[2].ID) // 3 * 2 + }) +} + +// testValue with UnmarshalMaxMindDB for map test. +type testValue struct { + Data string + custom bool +} + +func (v *testValue) UnmarshalMaxMindDB(d *Decoder) error { + v.custom = true + str, err := d.ReadString() + if err != nil { + return err + } + v.Data = "map:" + str + return nil +} + +// TestNestedUnmarshalerInMap tests UnmarshalMaxMindDB for map values. +func TestNestedUnmarshalerInMap(t *testing.T) { + // Test data: {"key1": "value1", "key2": "value2"} + data := []byte{ + // Map with 2 items + 0xe2, + // Key "key1" + 0x44, 'k', 'e', 'y', '1', + // Value "value1" + 0x46, 'v', 'a', 'l', 'u', 'e', '1', + // Key "key2" + 0x44, 'k', 'e', 'y', '2', + // Value "value2" + 0x46, 'v', 'a', 'l', 'u', 'e', '2', + } + + t.Run("map values with UnmarshalMaxMindDB", func(t *testing.T) { + d := New(data) + var result map[string]testValue + err := d.Decode(0, &result) + require.NoError(t, err) + + require.Len(t, result, 2) + require.True(t, result["key1"].custom, "Custom unmarshaler should be called for map values") + require.Equal(t, "map:value1", result["key1"].Data) + require.Equal(t, "map:value2", result["key2"].Data) + }) +} diff --git a/internal/decoder/performance_test.go b/internal/decoder/performance_test.go new file mode 100644 index 0000000..4285ec7 --- /dev/null +++ b/internal/decoder/performance_test.go @@ -0,0 +1,99 @@ +package decoder + +import ( + "encoding/hex" + "reflect" + "testing" +) + +const testDataHex = "e142656e43466f6f" // Map with: "en"->"Foo" + +// BenchmarkStructDecoding tests the performance of struct decoding +// with the new optimized field access patterns. +func BenchmarkStructDecoding(b *testing.B) { + // Create test data from field_precedence_test.go + mmdbHex := testDataHex + + testBytes, err := hex.DecodeString(mmdbHex) + if err != nil { + b.Fatalf("Failed to decode hex: %v", err) + } + decoder := New(testBytes) + + // Test struct that exercises field access patterns + type TestStruct struct { + En string `maxminddb:"en"` // Simple field + } + + b.ResetTimer() + + for range b.N { + var result TestStruct + err := decoder.Decode(0, &result) + if err != nil { + b.Fatalf("Decode failed: %v", err) + } + } +} + +// BenchmarkSimpleDecoding tests basic decoding performance. +func BenchmarkSimpleDecoding(b *testing.B) { + // Simple test data - same as struct decoding + mmdbHex := testDataHex + + testBytes, err := hex.DecodeString(mmdbHex) + if err != nil { + b.Fatalf("Failed to decode hex: %v", err) + } + decoder := New(testBytes) + + type TestStruct struct { + En string `maxminddb:"en"` + } + + b.ResetTimer() + + for range b.N { + var result TestStruct + err := decoder.Decode(0, &result) + if err != nil { + b.Fatalf("Decode failed: %v", err) + } + } +} + +// BenchmarkFieldLookup tests the performance of field lookup with +// the optimized field maps. +func BenchmarkFieldLookup(b *testing.B) { + // Create a struct with many fields to test map performance + type LargeStruct struct { + Field01 string `maxminddb:"f01"` + Field02 string `maxminddb:"f02"` + Field03 string `maxminddb:"f03"` + Field04 string `maxminddb:"f04"` + Field05 string `maxminddb:"f05"` + Field06 string `maxminddb:"f06"` + Field07 string `maxminddb:"f07"` + Field08 string `maxminddb:"f08"` + Field09 string `maxminddb:"f09"` + Field10 string `maxminddb:"f10"` + } + + // Build the field cache + var testStruct LargeStruct + fields := cachedFields(reflect.ValueOf(testStruct)) + + fieldNames := []string{"f01", "f02", "f03", "f04", "f05", "f06", "f07", "f08", "f09", "f10"} + + b.ResetTimer() + + for range b.N { + // Test field lookup performance + for _, name := range fieldNames { + _, exists := fields.namedFields[name] + if !exists { + b.Fatalf("Field %s not found", name) + } + } + } +} diff --git a/internal/decoder/reflection.go b/internal/decoder/reflection.go new file mode 100644 index 0000000..0be74f2 --- /dev/null +++ b/internal/decoder/reflection.go @@ -0,0 +1,1195 @@ +// Package decoder decodes values in the data section. +package decoder + +import ( + "errors" + "fmt" + "math/big" + "reflect" + "sync" + "unicode/utf8" + + "github.com/oschwald/maxminddb-golang/v2/internal/mmdberrors" +) + +// Unmarshaler is implemented by types that can unmarshal MaxMind DB data. +// This is used internally for reflection-based decoding. +type Unmarshaler interface { + UnmarshalMaxMindDB(d *Decoder) error +} + +// ReflectionDecoder is a decoder for the MMDB data section. +type ReflectionDecoder struct { + DataDecoder +} + +// New creates a [ReflectionDecoder]. +func New(buffer []byte) ReflectionDecoder { + return ReflectionDecoder{ + DataDecoder: NewDataDecoder(buffer), + } +} + +// Decode decodes the data value at offset and stores it in the value +// pointed at by v. +func (d *ReflectionDecoder) Decode(offset uint, v any) error { + // Check if the type implements Unmarshaler interface without reflection + if unmarshaler, ok := v.(Unmarshaler); ok { + decoder := NewDecoder(d.DataDecoder, offset) + return unmarshaler.UnmarshalMaxMindDB(decoder) + } + + rv := reflect.ValueOf(v) + if rv.Kind() != reflect.Ptr || rv.IsNil() { + return errors.New("result param must be a pointer") + } + + _, err := d.decode(offset, rv, 0) + if err == nil { + return nil + } + + // Check if error already has context (including path), if so just add offset if missing + var contextErr mmdberrors.ContextualError + if errors.As(err, &contextErr) { + // If the outermost error already has offset and path info, return as-is + if contextErr.Offset != 0 || contextErr.Path != "" { + return err + } + // Otherwise, just add offset to root + return mmdberrors.WrapWithContext(contextErr.Err, offset, nil) + } + + // Plain error, add offset + return mmdberrors.WrapWithContext(err, offset, nil) +} + +// DecodePath decodes the data value at offset and stores the value assocated +// with the path in the value pointed at by v. +func (d *ReflectionDecoder) DecodePath( + offset uint, + path []any, + v any, +) error { + result := reflect.ValueOf(v) + if result.Kind() != reflect.Ptr || result.IsNil() { + return errors.New("result param must be a pointer") + } + +PATH: + for i, v := range path { + var ( + typeNum Kind + size uint + err error + ) + typeNum, size, offset, err = d.decodeCtrlData(offset) + if err != nil { + return err + } + + if typeNum == KindPointer { + pointer, _, err := d.decodePointer(size, offset) + if err != nil { + return err + } + + typeNum, size, offset, err = d.decodeCtrlData(pointer) + if err != nil { + return err + } + } + + switch v := v.(type) { + case string: + // We are expecting a map + if typeNum != KindMap { + return fmt.Errorf("expected a map for %s but found %s", v, typeNum.String()) + } + for range size { + var key []byte + key, offset, err = d.decodeKey(offset) + if err != nil { + return err + } + if string(key) == v { + continue PATH + } + offset, err = d.nextValueOffset(offset, 1) + if err != nil { + return err + } + } + // Not found. Maybe return a boolean? + return nil + case int: + // We are expecting an array + if typeNum != KindSlice { + return fmt.Errorf("expected a slice for %d but found %s", v, typeNum.String()) + } + var i uint + if v < 0 { + if size < uint(-v) { + // Slice is smaller than negative index, not found + return nil + } + i = size - uint(-v) + } else { + if size <= uint(v) { + // Slice is smaller than index, not found + return nil + } + i = uint(v) + } + offset, err = d.nextValueOffset(offset, i) + if err != nil { + return err + } + default: + return fmt.Errorf("unexpected type for %d value in path, %v: %T", i, v, v) + } + } + _, err := d.decode(offset, result, len(path)) + return d.wrapError(err, offset) +} + +// wrapError wraps an error with context information when an error occurs. +// Zero allocation on happy path - only allocates when error != nil. +func (*ReflectionDecoder) wrapError(err error, offset uint) error { + if err == nil { + return nil + } + // Only wrap with context when an error actually occurs + return mmdberrors.WrapWithContext(err, offset, nil) +} + +// wrapErrorWithMapKey wraps an error with map key context, building path retroactively. +// Zero allocation on happy path - only allocates when error != nil. +func (*ReflectionDecoder) wrapErrorWithMapKey(err error, key string) error { + if err == nil { + return nil + } + + // Build path context retroactively by checking if the error already has context + var pathBuilder *mmdberrors.PathBuilder + var contextErr mmdberrors.ContextualError + if errors.As(err, &contextErr) { + // Error already has context, extract existing path and extend it + pathBuilder = mmdberrors.NewPathBuilder() + if contextErr.Path != "" && contextErr.Path != "/" { + // Parse existing path and rebuild + pathBuilder.ParseAndExtend(contextErr.Path) + } + pathBuilder.PrependMap(key) + // Return unwrapped error with extended path, preserving original offset + return mmdberrors.WrapWithContext(contextErr.Err, contextErr.Offset, pathBuilder) + } + + // New error, start building path - extract offset if it's already a contextual error + pathBuilder = mmdberrors.NewPathBuilder() + pathBuilder.PrependMap(key) + + // Try to get existing offset from any wrapped contextual error + var existingOffset uint + var existingErr mmdberrors.ContextualError + if errors.As(err, &existingErr) { + existingOffset = existingErr.Offset + } + + return mmdberrors.WrapWithContext(err, existingOffset, pathBuilder) +} + +// wrapErrorWithSliceIndex wraps an error with slice index context, building path retroactively. +// Zero allocation on happy path - only allocates when error != nil. +func (*ReflectionDecoder) wrapErrorWithSliceIndex(err error, index int) error { + if err == nil { + return nil + } + + // Build path context retroactively by checking if the error already has context + var pathBuilder *mmdberrors.PathBuilder + var contextErr mmdberrors.ContextualError + if errors.As(err, &contextErr) { + // Error already has context, extract existing path and extend it + pathBuilder = mmdberrors.NewPathBuilder() + if contextErr.Path != "" && contextErr.Path != "/" { + // Parse existing path and rebuild + pathBuilder.ParseAndExtend(contextErr.Path) + } + pathBuilder.PrependSlice(index) + // Return unwrapped error with extended path, preserving original offset + return mmdberrors.WrapWithContext(contextErr.Err, contextErr.Offset, pathBuilder) + } + + // New error, start building path - extract offset if it's already a contextual error + pathBuilder = mmdberrors.NewPathBuilder() + pathBuilder.PrependSlice(index) + + // Try to get existing offset from any wrapped contextual error + var existingOffset uint + var existingErr mmdberrors.ContextualError + if errors.As(err, &existingErr) { + existingOffset = existingErr.Offset + } + + return mmdberrors.WrapWithContext(err, existingOffset, pathBuilder) +} + +func (d *ReflectionDecoder) decode(offset uint, result reflect.Value, depth int) (uint, error) { + // Convert to addressableValue and delegate to internal method + // Use fast path for already addressable values to avoid allocation + if result.CanAddr() { + av := addressableValue{Value: result, forcedAddr: false} + return d.decodeValue(offset, av, depth) + } + av := makeAddressable(result) + return d.decodeValue(offset, av, depth) +} + +// decodeValue is the internal decode method that works with addressableValue +// for consistent optimization throughout the decoder. +func (d *ReflectionDecoder) decodeValue( + offset uint, + result addressableValue, + depth int, +) (uint, error) { + if depth > maximumDataStructureDepth { + return 0, mmdberrors.NewInvalidDatabaseError( + "exceeded maximum data structure depth; database is likely corrupt", + ) + } + + // Apply the original indirect logic to handle pointers and interfaces properly + for { + // Load value from interface, but only if the result will be + // usefully addressable. + if result.Kind() == reflect.Interface && !result.IsNil() { + e := result.Elem() + if e.Kind() == reflect.Ptr && !e.IsNil() { + result = addressableValue{e, result.forcedAddr} + continue + } + } + + if result.Kind() != reflect.Ptr { + break + } + + if result.IsNil() { + result.Set(reflect.New(result.Type().Elem())) + } + + result = addressableValue{ + result.Elem(), + false, + } // dereferenced pointer is always addressable + } + + // Check if the value implements Unmarshaler interface using type assertion + if result.CanAddr() { + if unmarshaler, ok := result.Addr().Interface().(Unmarshaler); ok { + decoder := NewDecoder(d.DataDecoder, offset) + if err := unmarshaler.UnmarshalMaxMindDB(decoder); err != nil { + return 0, err + } + return decoder.getNextOffset() + } + } + + typeNum, size, newOffset, err := d.decodeCtrlData(offset) + if err != nil { + return 0, err + } + + if typeNum != KindPointer && result.Kind() == reflect.Uintptr { + result.Set(reflect.ValueOf(uintptr(offset))) + return d.nextValueOffset(offset, 1) + } + return d.decodeFromType(typeNum, size, newOffset, result, depth+1) +} + +func (d *ReflectionDecoder) decodeFromType( + dtype Kind, + size uint, + offset uint, + result addressableValue, + depth int, +) (uint, error) { + // For these types, size has a special meaning + switch dtype { + case KindBool: + return d.unmarshalBool(size, offset, result) + case KindMap: + return d.unmarshalMap(size, offset, result, depth) + case KindPointer: + return d.unmarshalPointer(size, offset, result, depth) + case KindSlice: + return d.unmarshalSlice(size, offset, result, depth) + case KindBytes: + return d.unmarshalBytes(size, offset, result) + case KindFloat32: + return d.unmarshalFloat32(size, offset, result) + case KindFloat64: + return d.unmarshalFloat64(size, offset, result) + case KindInt32: + return d.unmarshalInt32(size, offset, result) + case KindUint16: + return d.unmarshalUint(size, offset, result, 16) + case KindUint32: + return d.unmarshalUint(size, offset, result, 32) + case KindUint64: + return d.unmarshalUint(size, offset, result, 64) + case KindString: + return d.unmarshalString(size, offset, result) + case KindUint128: + return d.unmarshalUint128(size, offset, result) + default: + return 0, mmdberrors.NewInvalidDatabaseError("unknown type: %d", dtype) + } +} + +func (d *ReflectionDecoder) unmarshalBool( + size, offset uint, + result addressableValue, +) (uint, error) { + value, newOffset, err := d.decodeBool(size, offset) + if err != nil { + return 0, err + } + + switch result.Kind() { + case reflect.Bool: + result.SetBool(value) + return newOffset, nil + case reflect.Interface: + if result.NumMethod() == 0 { + result.Set(reflect.ValueOf(value)) + return newOffset, nil + } + } + return newOffset, mmdberrors.NewUnmarshalTypeError(value, result.Type()) +} + +var sliceType = reflect.TypeOf([]byte{}) + +func (d *ReflectionDecoder) unmarshalBytes( + size, offset uint, + result addressableValue, +) (uint, error) { + value, newOffset, err := d.decodeBytes(size, offset) + if err != nil { + return 0, err + } + + switch result.Kind() { + case reflect.Slice: + if result.Type() == sliceType { + result.SetBytes(value) + return newOffset, nil + } + case reflect.Interface: + if result.NumMethod() == 0 { + result.Set(reflect.ValueOf(value)) + return newOffset, nil + } + } + return newOffset, mmdberrors.NewUnmarshalTypeError(value, result.Type()) +} + +func (d *ReflectionDecoder) unmarshalFloat32( + size, offset uint, result addressableValue, +) (uint, error) { + value, newOffset, err := d.decodeFloat32(size, offset) + if err != nil { + return 0, err + } + + switch result.Kind() { + case reflect.Float32, reflect.Float64: + result.SetFloat(float64(value)) + return newOffset, nil + case reflect.Interface: + if result.NumMethod() == 0 { + result.Set(reflect.ValueOf(value)) + return newOffset, nil + } + } + return newOffset, mmdberrors.NewUnmarshalTypeError(value, result.Type()) +} + +func (d *ReflectionDecoder) unmarshalFloat64( + size, offset uint, result addressableValue, +) (uint, error) { + value, newOffset, err := d.decodeFloat64(size, offset) + if err != nil { + return 0, err + } + + switch result.Kind() { + case reflect.Float32, reflect.Float64: + if result.OverflowFloat(value) { + return 0, mmdberrors.NewUnmarshalTypeError(value, result.Type()) + } + result.SetFloat(value) + return newOffset, nil + case reflect.Interface: + if result.NumMethod() == 0 { + result.Set(reflect.ValueOf(value)) + return newOffset, nil + } + } + return newOffset, mmdberrors.NewUnmarshalTypeError(value, result.Type()) +} + +func (d *ReflectionDecoder) unmarshalInt32( + size, offset uint, + result addressableValue, +) (uint, error) { + value, newOffset, err := d.decodeInt32(size, offset) + if err != nil { + return 0, err + } + + switch result.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + n := int64(value) + if !result.OverflowInt(n) { + result.SetInt(n) + return newOffset, nil + } + case reflect.Uint, + reflect.Uint8, + reflect.Uint16, + reflect.Uint32, + reflect.Uint64, + reflect.Uintptr: + n := uint64(value) + if !result.OverflowUint(n) { + result.SetUint(n) + return newOffset, nil + } + case reflect.Interface: + if result.NumMethod() == 0 { + result.Set(reflect.ValueOf(value)) + return newOffset, nil + } + } + return newOffset, mmdberrors.NewUnmarshalTypeError(value, result.Type()) +} + +func (d *ReflectionDecoder) unmarshalMap( + size uint, + offset uint, + result addressableValue, + depth int, +) (uint, error) { + switch result.Kind() { + default: + return 0, mmdberrors.NewUnmarshalTypeStrError("map", result.Type()) + case reflect.Struct: + return d.decodeStruct(size, offset, result, depth) + case reflect.Map: + return d.decodeMap(size, offset, result, depth) + case reflect.Interface: + if result.NumMethod() == 0 { + // Create map directly without makeAddressable wrapper + mapVal := reflect.ValueOf(make(map[string]any, size)) + rv := addressableValue{Value: mapVal, forcedAddr: false} + newOffset, err := d.decodeMap(size, offset, rv, depth) + result.Set(rv.Value) + return newOffset, err + } + return 0, mmdberrors.NewUnmarshalTypeStrError("map", result.Type()) + } +} + +func (d *ReflectionDecoder) unmarshalPointer( + size, offset uint, + result addressableValue, + depth int, +) (uint, error) { + pointer, newOffset, err := d.decodePointer(size, offset) + if err != nil { + return 0, err + } + _, err = d.decodeValue(pointer, result, depth) + return newOffset, err +} + +func (d *ReflectionDecoder) unmarshalSlice( + size uint, + offset uint, + result addressableValue, + depth int, +) (uint, error) { + switch result.Kind() { + case reflect.Slice: + return d.decodeSlice(size, offset, result, depth) + case reflect.Interface: + if result.NumMethod() == 0 { + a := []any{} + // Create slice directly without makeAddressable wrapper + sliceVal := reflect.ValueOf(&a).Elem() + rv := addressableValue{Value: sliceVal, forcedAddr: false} + newOffset, err := d.decodeSlice(size, offset, rv, depth) + result.Set(rv.Value) + return newOffset, err + } + } + return 0, mmdberrors.NewUnmarshalTypeStrError("array", result.Type()) +} + +func (d *ReflectionDecoder) unmarshalString( + size, offset uint, + result addressableValue, +) (uint, error) { + value, newOffset, err := d.decodeString(size, offset) + if err != nil { + return 0, err + } + + switch result.Kind() { + case reflect.String: + result.SetString(value) + return newOffset, nil + case reflect.Interface: + if result.NumMethod() == 0 { + result.Set(reflect.ValueOf(value)) + return newOffset, nil + } + } + return newOffset, mmdberrors.NewUnmarshalTypeError(value, result.Type()) +} + +func (d *ReflectionDecoder) unmarshalUint( + size, offset uint, + result addressableValue, + uintType uint, +) (uint, error) { + // Use the appropriate DataDecoder method based on uint type + var value uint64 + var newOffset uint + var err error + + switch uintType { + case 16: + v16, off, e := d.decodeUint16(size, offset) + value, newOffset, err = uint64(v16), off, e + case 32: + v32, off, e := d.decodeUint32(size, offset) + value, newOffset, err = uint64(v32), off, e + case 64: + value, newOffset, err = d.decodeUint64(size, offset) + default: + return 0, mmdberrors.NewInvalidDatabaseError( + "unsupported uint type: %d", uintType) + } + + if err != nil { + return 0, err + } + + // Fast path for exact type matches (inspired by json/v2 fast paths) + switch result.Kind() { + case reflect.Uint32: + if uintType == 32 && value <= 0xFFFFFFFF { + result.SetUint(value) + return newOffset, nil + } + case reflect.Uint64: + if uintType == 64 { + result.SetUint(value) + return newOffset, nil + } + case reflect.Uint16: + if uintType == 16 && value <= 0xFFFF { + result.SetUint(value) + return newOffset, nil + } + case reflect.Uint8: + if uintType == 16 && value <= 0xFF { // uint8 often stored as uint16 in MMDB + result.SetUint(value) + return newOffset, nil + } + } + + switch result.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + n := int64(value) + if !result.OverflowInt(n) { + result.SetInt(n) + return newOffset, nil + } + case reflect.Uint, + reflect.Uint8, + reflect.Uint16, + reflect.Uint32, + reflect.Uint64, + reflect.Uintptr: + if !result.OverflowUint(value) { + result.SetUint(value) + return newOffset, nil + } + case reflect.Interface: + if result.NumMethod() == 0 { + result.Set(reflect.ValueOf(value)) + return newOffset, nil + } + } + return newOffset, mmdberrors.NewUnmarshalTypeError(value, result.Type()) +} + +var bigIntType = reflect.TypeOf(big.Int{}) + +func (d *ReflectionDecoder) unmarshalUint128( + size, offset uint, result addressableValue, +) (uint, error) { + hi, lo, newOffset, err := d.decodeUint128(size, offset) + if err != nil { + return 0, err + } + + // Convert hi/lo representation to big.Int + value := new(big.Int) + if hi == 0 { + value.SetUint64(lo) + } else { + value.SetUint64(hi) + value.Lsh(value, 64) // Shift high part left by 64 bits + value.Or(value, new(big.Int).SetUint64(lo)) // OR with low part + } + + switch result.Kind() { + case reflect.Struct: + if result.Type() == bigIntType { + result.Set(reflect.ValueOf(*value)) + return newOffset, nil + } + case reflect.Interface: + if result.NumMethod() == 0 { + result.Set(reflect.ValueOf(value)) + return newOffset, nil + } + } + return newOffset, mmdberrors.NewUnmarshalTypeError(value, result.Type()) +} + +func (d *ReflectionDecoder) decodeMap( + size uint, + offset uint, + result addressableValue, + depth int, +) (uint, error) { + if result.IsNil() { + result.Set(reflect.MakeMapWithSize(result.Type(), int(size))) + } + + mapType := result.Type() + + // Pre-allocated values for efficient reuse + keyVal := reflect.New(mapType.Key()).Elem() + keyValue := addressableValue{Value: keyVal, forcedAddr: false} + elemType := mapType.Elem() + var elemValue addressableValue + // Pre-allocate element value to reduce allocations + elemVal := reflect.New(elemType).Elem() + elemValue = addressableValue{Value: elemVal, forcedAddr: false} + for range size { + var err error + + // Reuse keyValue by zeroing it + keyValue.SetZero() + offset, err = d.decodeValue(offset, keyValue, depth) + if err != nil { + return 0, err + } + + // Reuse elemValue by zeroing it + elemValue.SetZero() + + offset, err = d.decodeValue(offset, elemValue, depth) + if err != nil { + return 0, d.wrapErrorWithMapKey(err, keyValue.String()) + } + + result.SetMapIndex(keyValue.Value, elemValue.Value) + } + return offset, nil +} + +func (d *ReflectionDecoder) decodeSlice( + size uint, + offset uint, + result addressableValue, + depth int, +) (uint, error) { + result.Set(reflect.MakeSlice(result.Type(), int(size), int(size))) + for i := range size { + var err error + // Use slice element directly to avoid allocation + elemVal := result.Index(int(i)) + elemValue := addressableValue{Value: elemVal, forcedAddr: false} + offset, err = d.decodeValue(offset, elemValue, depth) + if err != nil { + return 0, d.wrapErrorWithSliceIndex(err, int(i)) + } + } + return offset, nil +} + +func (d *ReflectionDecoder) decodeStruct( + size uint, + offset uint, + result addressableValue, + depth int, +) (uint, error) { + fields := cachedFields(result.Value) + + // Single-phase processing: decode only the dominant fields + for range size { + var ( + err error + key []byte + ) + key, offset, err = d.decodeKey(offset) + if err != nil { + return 0, err + } + // The string() does not create a copy due to this compiler + // optimization: https://github.com/golang/go/issues/3512 + fieldInfo, ok := fields.namedFields[string(key)] + if !ok { + offset, err = d.nextValueOffset(offset, 1) + if err != nil { + return 0, err + } + continue + } + + // Use optimized field access with addressable value wrapper + fieldValue := result.fieldByIndex(fieldInfo.index0, fieldInfo.index, true) + if !fieldValue.IsValid() { + // Field access failed, skip this field + offset, err = d.nextValueOffset(offset, 1) + if err != nil { + return 0, err + } + continue + } + + // Fast path for common simple field types + if len(fieldInfo.index) == 0 && fieldInfo.isFastType { + // Try fast decode path for pre-identified simple types + if fastOffset, ok := d.tryFastDecodeTyped(offset, fieldValue, fieldInfo.fieldType); ok { + offset = fastOffset + continue + } + } + + offset, err = d.decodeValue(offset, fieldValue, depth) + if err != nil { + return 0, d.wrapErrorWithMapKey(err, string(key)) + } + } + return offset, nil +} + +type fieldInfo struct { + fieldType reflect.Type + name string + index []int + index0 int + depth int + hasTag bool + isFastType bool +} + +type fieldsType struct { + namedFields map[string]*fieldInfo // Map from field name to field info +} + +type queueEntry struct { + typ reflect.Type + index []int // Field index path + depth int // Embedding depth +} + +// getEmbeddedStructType returns the struct type for embedded fields. +// Returns nil if the field is not an embeddable struct type. +func getEmbeddedStructType(fieldType reflect.Type) reflect.Type { + if fieldType.Kind() == reflect.Struct { + return fieldType + } + if fieldType.Kind() == reflect.Ptr && fieldType.Elem().Kind() == reflect.Struct { + return fieldType.Elem() + } + return nil +} + +// handleEmbeddedField processes an embedded struct field and returns true if the field should be skipped. +func handleEmbeddedField( + field reflect.StructField, + hasTag bool, + queue *[]queueEntry, + seen *map[reflect.Type]bool, + fieldIndex []int, + depth int, +) bool { + embeddedType := getEmbeddedStructType(field.Type) + if embeddedType == nil { + return false + } + + // For embedded structs (and pointer to structs), add to queue for further traversal + if !(*seen)[embeddedType] { + *queue = append(*queue, queueEntry{embeddedType, fieldIndex, depth + 1}) + (*seen)[embeddedType] = true + } + + // If embedded struct has no explicit tag, don't add it as a named field + return !hasTag +} + +// validateTag performs basic validation of maxminddb struct tags. +func validateTag(field reflect.StructField, tag string) error { + if tag == "" || tag == "-" { + return nil + } + + // Check for invalid UTF-8 + if !utf8.ValidString(tag) { + return fmt.Errorf("field %s has tag with invalid UTF-8: %q", field.Name, tag) + } + + // Only flag very obvious mistakes - don't be too restrictive + return nil +} + +var fieldsMap sync.Map + +func cachedFields(result reflect.Value) *fieldsType { + resultType := result.Type() + + if fields, ok := fieldsMap.Load(resultType); ok { + return fields.(*fieldsType) + } + + fields := makeStructFields(resultType) + fieldsMap.Store(resultType, fields) + + return fields +} + +// makeStructFields implements json/v2 style field precedence rules. +func makeStructFields(rootType reflect.Type) *fieldsType { + // Breadth-first traversal to collect all fields with depth information + + queue := []queueEntry{{rootType, nil, 0}} + var allFields []fieldInfo + seen := make(map[reflect.Type]bool) + seen[rootType] = true + + // Collect all reachable fields using breadth-first search + for len(queue) > 0 { + entry := queue[0] + queue = queue[1:] + + for i := range entry.typ.NumField() { + field := entry.typ.Field(i) + + // Skip unexported fields (except embedded structs) + if !field.IsExported() && (!field.Anonymous || field.Type.Kind() != reflect.Struct) { + continue + } + + // Build field index path + fieldIndex := make([]int, len(entry.index)+1) + copy(fieldIndex, entry.index) + fieldIndex[len(entry.index)] = i + + // Parse maxminddb tag + fieldName := field.Name + hasTag := false + if tag := field.Tag.Get("maxminddb"); tag != "" { + // Validate tag syntax + if err := validateTag(field, tag); err != nil { + // Log warning but continue processing + // In a real implementation, you might want to use a proper logger + _ = err // For now, just ignore validation errors + } + + if tag == "-" { + continue // Skip ignored fields + } + fieldName = tag + hasTag = true + } + + // Handle embedded structs and embedded pointers to structs + if field.Anonymous && handleEmbeddedField( + field, hasTag, &queue, &seen, fieldIndex, entry.depth, + ) { + continue + } + + // Add field to collection with optimization hints + fieldType := field.Type + isFast := isFastDecodeType(fieldType) + allFields = append(allFields, fieldInfo{ + index: fieldIndex, // Will be reindexed later for optimization + name: fieldName, + hasTag: hasTag, + depth: entry.depth, + fieldType: fieldType, + isFastType: isFast, + }) + } + } + + // Apply precedence rules to resolve field conflicts + // Pre-size the map based on field count for better memory efficiency + namedFields := make(map[string]*fieldInfo, len(allFields)) + fieldsByName := make(map[string][]fieldInfo, len(allFields)) + + // Group fields by name + for _, field := range allFields { + fieldsByName[field.name] = append(fieldsByName[field.name], field) + } + + // Apply precedence rules for each field name + // Store results in a flattened slice to allow pointer references + flatFields := make([]fieldInfo, 0, len(fieldsByName)) + + for name, fields := range fieldsByName { + if len(fields) == 1 { + // No conflict, use the field + flatFields = append(flatFields, fields[0]) + namedFields[name] = &flatFields[len(flatFields)-1] + continue + } + + // Find the dominant field using json/v2 precedence rules: + // 1. Shallowest depth wins + // 2. Among same depth, explicitly tagged field wins + // 3. Among same depth with same tag status, first declared wins + + dominant := fields[0] + for i := 1; i < len(fields); i++ { + candidate := fields[i] + + // Shallowest depth wins + if candidate.depth < dominant.depth { + dominant = candidate + continue + } + if candidate.depth > dominant.depth { + continue + } + + // Same depth: explicitly tagged field wins + if candidate.hasTag && !dominant.hasTag { + dominant = candidate + continue + } + if !candidate.hasTag && dominant.hasTag { + continue + } + + // Same depth and tag status: first declared wins (keep current dominant) + } + + flatFields = append(flatFields, dominant) + namedFields[name] = &flatFields[len(flatFields)-1] + } + + fields := &fieldsType{ + namedFields: namedFields, + } + + // Reindex all fields for optimized access + fields.reindex() + + return fields +} + +// reindex optimizes field indices to avoid bounds checks during runtime. +// This follows the json/v2 pattern of splitting the first index from the remainder. +func (fs *fieldsType) reindex() { + for _, field := range fs.namedFields { + if len(field.index) > 0 { + field.index0 = field.index[0] + field.index = field.index[1:] + if len(field.index) == 0 { + field.index = nil // avoid pinning the backing slice + } + } + } +} + +// addressableValue wraps a reflect.Value to optimize field access and +// embedded pointer handling. Based on encoding/json/v2 patterns. +type addressableValue struct { + reflect.Value + + forcedAddr bool +} + +// newAddressableValue creates an addressable value wrapper. +// If the value is not addressable, it wraps it to make it addressable. +func newAddressableValue(v reflect.Value) addressableValue { + if v.CanAddr() { + return addressableValue{Value: v, forcedAddr: false} + } + // Make non-addressable values addressable by boxing them + addressable := reflect.New(v.Type()).Elem() + addressable.Set(v) + return addressableValue{Value: addressable, forcedAddr: true} +} + +// makeAddressable efficiently converts a reflect.Value to addressableValue +// with minimal allocations when possible. +func makeAddressable(v reflect.Value) addressableValue { + // Fast path for already addressable values + if v.CanAddr() { + return addressableValue{Value: v, forcedAddr: false} + } + return newAddressableValue(v) +} + +// isFastDecodeType determines if a field type can use optimized decode paths. +func isFastDecodeType(t reflect.Type) bool { + switch t.Kind() { + case reflect.String, + reflect.Bool, + reflect.Uint16, + reflect.Uint32, + reflect.Uint64, + reflect.Float64: + return true + case reflect.Ptr: + // Pointer to fast types are also fast + return isFastDecodeType(t.Elem()) + default: + return false + } +} + +// fieldByIndex efficiently accesses a field by its index path, +// initializing embedded pointers as needed. +func (av addressableValue) fieldByIndex( + index0 int, + remainingIndex []int, + mayAlloc bool, +) addressableValue { + // First field access (optimized with no bounds check) + av = addressableValue{av.Field(index0), av.forcedAddr} + + // Handle remaining indices if any + if len(remainingIndex) > 0 { + for _, i := range remainingIndex { + av = av.indirect(mayAlloc) + if !av.IsValid() { + return av + } + av = addressableValue{av.Field(i), av.forcedAddr} + } + } + + return av +} + +// indirect handles pointer dereferencing and initialization. +func (av addressableValue) indirect(mayAlloc bool) addressableValue { + if av.Kind() == reflect.Ptr { + if av.IsNil() { + if !mayAlloc || !av.CanSet() { + return addressableValue{} // Return invalid value + } + av.Set(reflect.New(av.Type().Elem())) + } + av = addressableValue{av.Elem(), false} + } + return av +} + +// tryFastDecodeTyped attempts to decode using pre-computed type information. +func (d *ReflectionDecoder) tryFastDecodeTyped( + offset uint, + result addressableValue, + expectedType reflect.Type, +) (uint, bool) { + typeNum, size, newOffset, err := d.decodeCtrlData(offset) + if err != nil { + return 0, false + } + + // Use pre-computed type information for faster matching + switch expectedType.Kind() { + case reflect.String: + if typeNum == KindString { + value, finalOffset, err := d.decodeString(size, newOffset) + if err != nil { + return 0, false + } + result.SetString(value) + return finalOffset, true + } + case reflect.Uint32: + if typeNum == KindUint32 { + value, finalOffset, err := d.decodeUint32(size, newOffset) + if err != nil { + return 0, false + } + result.SetUint(uint64(value)) + return finalOffset, true + } + case reflect.Uint16: + if typeNum == KindUint16 { + value, finalOffset, err := d.decodeUint16(size, newOffset) + if err != nil { + return 0, false + } + result.SetUint(uint64(value)) + return finalOffset, true + } + case reflect.Uint64: + if typeNum == KindUint64 { + value, finalOffset, err := d.decodeUint64(size, newOffset) + if err != nil { + return 0, false + } + result.SetUint(value) + return finalOffset, true + } + case reflect.Bool: + if typeNum == KindBool { + value, finalOffset, err := d.decodeBool(size, newOffset) + if err != nil { + return 0, false + } + result.SetBool(value) + return finalOffset, true + } + case reflect.Float64: + if typeNum == KindFloat64 { + value, finalOffset, err := d.decodeFloat64(size, newOffset) + if err != nil { + return 0, false + } + result.SetFloat(value) + return finalOffset, true + } + case reflect.Ptr: + // Handle pointer to fast types + if result.IsNil() { + result.Set(reflect.New(expectedType.Elem())) + } + return d.tryFastDecodeTyped( + offset, + addressableValue{result.Elem(), false}, + expectedType.Elem(), + ) + } + + return 0, false +} diff --git a/decoder_test.go b/internal/decoder/reflection_test.go similarity index 91% rename from decoder_test.go rename to internal/decoder/reflection_test.go index a7aba68..5244e14 100644 --- a/decoder_test.go +++ b/internal/decoder/reflection_test.go @@ -1,9 +1,10 @@ -package maxminddb +package decoder import ( "encoding/hex" "math/big" "os" + "path/filepath" "reflect" "strings" "testing" @@ -51,18 +52,18 @@ func TestFloat(t *testing.T) { func TestInt32(t *testing.T) { int32s := map[string]any{ - "0001": 0, - "0401ffffffff": -1, - "0101ff": 255, - "0401ffffff01": -255, - "020101f4": 500, - "0401fffffe0c": -500, - "0201ffff": 65535, - "0401ffff0001": -65535, - "0301ffffff": 16777215, - "0401ff000001": -16777215, - "04017fffffff": 2147483647, - "040180000001": -2147483647, + "0001": int32(0), + "0401ffffffff": int32(-1), + "0101ff": int32(255), + "0401ffffff01": int32(-255), + "020101f4": int32(500), + "0401fffffe0c": int32(-500), + "0201ffff": int32(65535), + "0401ffff0001": int32(-65535), + "0301ffffff": int32(16777215), + "0401ff000001": int32(-16777215), + "04017fffffff": int32(2147483647), + "040180000001": int32(-2147483647), } validateDecoding(t, int32s) } @@ -207,7 +208,7 @@ func validateDecoding(t *testing.T, tests map[string]any) { for inputStr, expected := range tests { inputBytes, err := hex.DecodeString(inputStr) require.NoError(t, err) - d := decoder{buffer: inputBytes} + d := New(inputBytes) var result any _, err = d.decode(0, reflect.ValueOf(&result), 0) @@ -223,7 +224,7 @@ func validateDecoding(t *testing.T, tests map[string]any) { func TestPointers(t *testing.T) { bytes, err := os.ReadFile(testFile("maps-with-pointers.raw")) require.NoError(t, err) - d := decoder{buffer: bytes} + d := New(bytes) expected := map[uint]map[string]string{ 0: {"long_key": "long_value1"}, @@ -243,3 +244,7 @@ func TestPointers(t *testing.T) { } } } + +func testFile(file string) string { + return filepath.Join("..", "..", "test-data", "test-data", file) +} diff --git a/internal/decoder/string_cache.go b/internal/decoder/string_cache.go new file mode 100644 index 0000000..be331f4 --- /dev/null +++ b/internal/decoder/string_cache.go @@ -0,0 +1,62 @@ +// Package decoder decodes values in the data section. +package decoder + +import ( + "sync" +) + +// cacheEntry represents a cached string with its offset and dedicated mutex. +type cacheEntry struct { + str string + offset uint + mu sync.RWMutex +} + +// stringCache provides bounded string interning with per-entry mutexes for minimal contention. +// This achieves thread safety while avoiding the global lock bottleneck. +type stringCache struct { + entries [512]cacheEntry +} + +// newStringCache creates a new per-entry mutex-based string cache. +func newStringCache() *stringCache { + return &stringCache{} +} + +// internAt returns a canonical string for the data at the given offset and size. +// Uses per-entry RWMutex for fine-grained thread safety with minimal contention. +func (sc *stringCache) internAt(offset, size uint, data []byte) string { + const ( + minCachedLen = 2 // single byte strings not worth caching + maxCachedLen = 100 // reasonable upper bound for geographic strings + ) + + // Skip caching for very short or very long strings + if size < minCachedLen || size > maxCachedLen { + return string(data[offset : offset+size]) + } + + // Use same cache index calculation as original: offset % cacheSize + i := offset % uint(len(sc.entries)) + entry := &sc.entries[i] + + // Fast path: read lock and check + entry.mu.RLock() + if entry.offset == offset && entry.str != "" { + str := entry.str + entry.mu.RUnlock() + return str + } + entry.mu.RUnlock() + + // Cache miss - create new string + str := string(data[offset : offset+size]) + + // Store with write lock on this specific entry + entry.mu.Lock() + entry.offset = offset + entry.str = str + entry.mu.Unlock() + + return str +} diff --git a/internal/decoder/string_cache_test.go b/internal/decoder/string_cache_test.go new file mode 100644 index 0000000..25be1be --- /dev/null +++ b/internal/decoder/string_cache_test.go @@ -0,0 +1,51 @@ +package decoder + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestStringCacheOffsetZero(t *testing.T) { + cache := newStringCache() + data := []byte("hello world, this is test data") + + // Test string at offset 0 + str1 := cache.internAt(0, 5, data) + require.Equal(t, "hello", str1) + + // Second call should hit cache and return same interned string + str2 := cache.internAt(0, 5, data) + require.Equal(t, "hello", str2) + + // Note: Both strings should be identical (cache hit) + // We can't easily test if they're the same object without unsafe, + // but correctness is verified by the equal values +} + +func TestStringCacheVariousOffsets(t *testing.T) { + cache := newStringCache() + data := []byte("abcdefghijklmnopqrstuvwxyz") + + testCases := []struct { + offset uint + size uint + expected string + }{ + {0, 3, "abc"}, + {5, 3, "fgh"}, + {10, 5, "klmno"}, + {23, 3, "xyz"}, + } + + for _, tc := range testCases { + // First call + str1 := cache.internAt(tc.offset, tc.size, data) + require.Equal(t, tc.expected, str1) + + // Second call should hit cache + str2 := cache.internAt(tc.offset, tc.size, data) + require.Equal(t, tc.expected, str2) + // Verify cache hit returns correct value (interning tested via behavior) + } +} diff --git a/internal/decoder/tag_validation_test.go b/internal/decoder/tag_validation_test.go new file mode 100644 index 0000000..c9feaad --- /dev/null +++ b/internal/decoder/tag_validation_test.go @@ -0,0 +1,97 @@ +package decoder + +import ( + "reflect" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestValidateTag(t *testing.T) { + tests := []struct { + name string + fieldName string + tag string + expectError bool + description string + }{ + { + name: "ValidTag", + fieldName: "TestField", + tag: "valid_field", + expectError: false, + description: "Valid tag should not error", + }, + { + name: "IgnoreTag", + fieldName: "TestField", + tag: "-", + expectError: false, + description: "Ignore tag should not error", + }, + { + name: "EmptyTag", + fieldName: "TestField", + tag: "", + expectError: false, + description: "Empty tag should not error", + }, + { + name: "ComplexValidTag", + fieldName: "TestField", + tag: "some_complex_field_name_123", + expectError: false, + description: "Complex valid tag should not error", + }, + { + name: "InvalidUTF8", + fieldName: "TestField", + tag: "field\xff\xfe", + expectError: true, + description: "Invalid UTF-8 should error", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Create a mock struct field + field := reflect.StructField{ + Name: tt.fieldName, + Type: reflect.TypeOf(""), + } + + err := validateTag(field, tt.tag) + + if tt.expectError { + require.Error(t, err, tt.description) + assert.Contains(t, err.Error(), tt.fieldName, "Error should mention field name") + } else { + assert.NoError(t, err, tt.description) + } + }) + } +} + +// TestTagValidationIntegration tests that tag validation works during field processing. +func TestTagValidationIntegration(t *testing.T) { + // Test that makeStructFields processes tags without panicking + // even when there are validation errors + + type TestStruct struct { + ValidField string `maxminddb:"valid"` + IgnoredField string `maxminddb:"-"` + NoTagField string + } + + // This should not panic even with invalid tags + structType := reflect.TypeOf(TestStruct{}) + fields := makeStructFields(structType) + + // Verify that valid fields are still processed + assert.Contains(t, fields.namedFields, "valid", "Valid field should be processed") + assert.Contains(t, fields.namedFields, "NoTagField", "Field without tag should use field name") + + // The important thing is that it doesn't crash + assert.NotNil(t, fields.namedFields, "Fields map should be created") +} diff --git a/internal/decoder/verifier.go b/internal/decoder/verifier.go new file mode 100644 index 0000000..2366de8 --- /dev/null +++ b/internal/decoder/verifier.go @@ -0,0 +1,65 @@ +package decoder + +import ( + "reflect" + + "github.com/oschwald/maxminddb-golang/v2/internal/mmdberrors" +) + +// VerifyDataSection verifies the data section against the provided +// offsets from the tree. +func (d *ReflectionDecoder) VerifyDataSection(offsets map[uint]bool) error { + pointerCount := len(offsets) + + var offset uint + bufferLen := uint(len(d.buffer)) + for offset < bufferLen { + var data any + rv := reflect.ValueOf(&data) + newOffset, err := d.decode(offset, rv, 0) + if err != nil { + return mmdberrors.NewInvalidDatabaseError( + "received decoding error (%v) at offset of %v", + err, + offset, + ) + } + if newOffset <= offset { + return mmdberrors.NewInvalidDatabaseError( + "data section offset unexpectedly went from %v to %v", + offset, + newOffset, + ) + } + + pointer := offset + + if _, ok := offsets[pointer]; !ok { + return mmdberrors.NewInvalidDatabaseError( + "found data (%v) at %v that the search tree does not point to", + data, + pointer, + ) + } + delete(offsets, pointer) + + offset = newOffset + } + + if offset != bufferLen { + return mmdberrors.NewInvalidDatabaseError( + "unexpected data at the end of the data section (last offset: %v, end: %v)", + offset, + bufferLen, + ) + } + + if len(offsets) != 0 { + return mmdberrors.NewInvalidDatabaseError( + "found %v pointers (of %v) in the search tree that we did not see in the data section", + len(offsets), + pointerCount, + ) + } + return nil +} diff --git a/internal/mmdberrors/context.go b/internal/mmdberrors/context.go new file mode 100644 index 0000000..d24e1d5 --- /dev/null +++ b/internal/mmdberrors/context.go @@ -0,0 +1,136 @@ +package mmdberrors + +import ( + "fmt" + "strconv" + "strings" +) + +// ContextualError provides detailed error context with offset and path information. +// This is only allocated when an error actually occurs, ensuring zero allocation +// on the happy path. +type ContextualError struct { + Err error + Path string + Offset uint +} + +func (e ContextualError) Error() string { + if e.Path != "" { + return fmt.Sprintf("at offset %d, path %s: %v", e.Offset, e.Path, e.Err) + } + return fmt.Sprintf("at offset %d: %v", e.Offset, e.Err) +} + +func (e ContextualError) Unwrap() error { + return e.Err +} + +// ErrorContextTracker is an optional interface that can be used to track +// path context for better error messages. Only used when explicitly enabled +// and only allocates when an error occurs. +type ErrorContextTracker interface { + // BuildPath constructs a path string for the current decoder state. + // This is only called when an error occurs, so allocation is acceptable. + BuildPath() string +} + +// WrapWithContext wraps an error with offset and optional path context. +// This function is designed to have zero allocation on the happy path - +// it only allocates when an error actually occurs. +func WrapWithContext(err error, offset uint, tracker ErrorContextTracker) error { + if err == nil { + return nil // Zero allocation - no error to wrap + } + + // Only allocate when we actually have an error + ctxErr := ContextualError{ + Offset: offset, + Err: err, + } + + // Only build path if tracker is provided (opt-in behavior) + if tracker != nil { + ctxErr.Path = tracker.BuildPath() + } + + return ctxErr +} + +// PathBuilder helps build JSON-pointer-like paths efficiently. +// Only used when an error occurs, so allocations are acceptable here. +type PathBuilder struct { + segments []string +} + +// NewPathBuilder creates a new path builder. +func NewPathBuilder() *PathBuilder { + return &PathBuilder{ + segments: make([]string, 0, 8), // Pre-allocate for common depth + } +} + +// BuildPath implements ErrorContextTracker interface. +func (p *PathBuilder) BuildPath() string { + return p.Build() +} + +// PushMap adds a map key to the path. +func (p *PathBuilder) PushMap(key string) { + p.segments = append(p.segments, key) +} + +// PushSlice adds a slice index to the path. +func (p *PathBuilder) PushSlice(index int) { + p.segments = append(p.segments, strconv.Itoa(index)) +} + +// PrependMap adds a map key to the beginning of the path (for retroactive building). +func (p *PathBuilder) PrependMap(key string) { + p.segments = append([]string{key}, p.segments...) +} + +// PrependSlice adds a slice index to the beginning of the path (for retroactive building). +func (p *PathBuilder) PrependSlice(index int) { + p.segments = append([]string{strconv.Itoa(index)}, p.segments...) +} + +// Pop removes the last segment from the path. +func (p *PathBuilder) Pop() { + if len(p.segments) > 0 { + p.segments = p.segments[:len(p.segments)-1] + } +} + +// Build constructs the full path string. +func (p *PathBuilder) Build() string { + if len(p.segments) == 0 { + return "/" + } + return "/" + strings.Join(p.segments, "/") +} + +// Reset clears all segments for reuse. +func (p *PathBuilder) Reset() { + p.segments = p.segments[:0] +} + +// ParseAndExtend parses an existing path and extends this builder with those segments. +// This is used for retroactive path building during error unwinding. +func (p *PathBuilder) ParseAndExtend(path string) { + if path == "" || path == "/" { + return + } + + // Remove leading slash and split + if path[0] == '/' { + path = path[1:] + } + + segments := strings.Split(path, "/") + for _, segment := range segments { + if segment != "" { + p.segments = append(p.segments, segment) + } + } +} diff --git a/internal/mmdberrors/errors.go b/internal/mmdberrors/errors.go new file mode 100644 index 0000000..d574681 --- /dev/null +++ b/internal/mmdberrors/errors.go @@ -0,0 +1,56 @@ +// Package mmdberrors is an internal package for the errors used in +// this module. +package mmdberrors + +import ( + "fmt" + "reflect" +) + +// InvalidDatabaseError is returned when the database contains invalid data +// and cannot be parsed. +type InvalidDatabaseError struct { + message string +} + +// NewOffsetError creates an [InvalidDatabaseError] indicating that an offset +// pointed beyond the end of the database. +func NewOffsetError() InvalidDatabaseError { + return InvalidDatabaseError{"unexpected end of database"} +} + +// NewInvalidDatabaseError creates an [InvalidDatabaseError] using the +// provided format and format arguments. +func NewInvalidDatabaseError(format string, args ...any) InvalidDatabaseError { + return InvalidDatabaseError{fmt.Sprintf(format, args...)} +} + +func (e InvalidDatabaseError) Error() string { + return e.message +} + +// UnmarshalTypeError is returned when the value in the database cannot be +// assigned to the specified data type. +type UnmarshalTypeError struct { + Type reflect.Type + Value string +} + +// NewUnmarshalTypeStrError creates an [UnmarshalTypeError] when the string +// value cannot be assigned to a value of rType. +func NewUnmarshalTypeStrError(value string, rType reflect.Type) UnmarshalTypeError { + return UnmarshalTypeError{ + Type: rType, + Value: value, + } +} + +// NewUnmarshalTypeError creates an [UnmarshalTypeError] when the value +// cannot be assigned to a value of rType. +func NewUnmarshalTypeError(value any, rType reflect.Type) UnmarshalTypeError { + return NewUnmarshalTypeStrError(fmt.Sprintf("%v (%T)", value, value), rType) +} + +func (e UnmarshalTypeError) Error() string { + return fmt.Sprintf("maxminddb: cannot unmarshal %s into type %s", e.Value, e.Type) +} diff --git a/mmdbdata/doc.go b/mmdbdata/doc.go new file mode 100644 index 0000000..16e9a92 --- /dev/null +++ b/mmdbdata/doc.go @@ -0,0 +1,57 @@ +// Package mmdbdata provides low-level types and interfaces for custom MaxMind DB decoding. +// +// This package allows custom decoding logic for applications that need fine-grained +// control over how MaxMind DB data is processed. For most use cases, the high-level +// maxminddb.Reader API is recommended instead. +// +// # Manual Decoding Example +// +// Custom types can implement the Unmarshaler interface for custom decoding: +// +// type City struct { +// Names map[string]string `maxminddb:"names"` +// GeoNameID uint `maxminddb:"geoname_id"` +// } +// +// func (c *City) UnmarshalMaxMindDB(d *mmdbdata.Decoder) error { +// mapIter, _, err := d.ReadMap() +// if err != nil { return err } +// for key, err := range mapIter { +// if err != nil { return err } +// switch string(key) { +// case "names": +// nameIter, size, err := d.ReadMap() +// if err != nil { return err } +// names := make(map[string]string, size) // Pre-allocate with size +// for nameKey, nameErr := range nameIter { +// if nameErr != nil { return nameErr } +// value, valueErr := d.ReadString() +// if valueErr != nil { return valueErr } +// names[string(nameKey)] = value +// } +// c.Names = names +// case "geoname_id": +// geoID, err := d.ReadUInt32() +// if err != nil { return err } +// c.GeoNameID = uint(geoID) +// default: +// if err := d.SkipValue(); err != nil { return err } +// } +// } +// return nil +// } +// +// Types implementing Unmarshaler will automatically use custom decoding logic +// instead of reflection when used with maxminddb.Reader.Lookup, similar to how +// json.Unmarshaler works with encoding/json. +// +// # Direct Decoder Usage +// +// For even more control, you can use the Decoder directly: +// +// decoder := mmdbdata.NewDecoder(buffer, offset) +// value, err := decoder.ReadString() +// if err != nil { +// return err +// } +package mmdbdata diff --git a/mmdbdata/interface.go b/mmdbdata/interface.go new file mode 100644 index 0000000..63ed3c4 --- /dev/null +++ b/mmdbdata/interface.go @@ -0,0 +1,7 @@ +package mmdbdata + +// Unmarshaler is implemented by types that can unmarshal MaxMind DB data. +// This follows the same pattern as json.Unmarshaler and other Go standard library interfaces. +type Unmarshaler interface { + UnmarshalMaxMindDB(d *Decoder) error +} diff --git a/mmdbdata/type.go b/mmdbdata/type.go new file mode 100644 index 0000000..97edad4 --- /dev/null +++ b/mmdbdata/type.go @@ -0,0 +1,41 @@ +// Package mmdbdata provides types and interfaces for working with MaxMind DB data. +package mmdbdata + +import "github.com/oschwald/maxminddb-golang/v2/internal/decoder" + +// Kind represents MMDB data kinds. +type Kind = decoder.Kind + +// Decoder provides methods for decoding MMDB data. +type Decoder = decoder.Decoder + +// DecoderOption configures a Decoder. +type DecoderOption = decoder.DecoderOption + +// NewDecoder creates a new Decoder with the given buffer, offset, and options. +// Error messages automatically include contextual information like offset and +// path (e.g., "/city/names/en") with zero impact on successful operations. +func NewDecoder(buffer []byte, offset uint, options ...DecoderOption) *Decoder { + d := decoder.NewDataDecoder(buffer) + return decoder.NewDecoder(d, offset, options...) +} + +// Kind constants for MMDB data. +const ( + KindExtended = decoder.KindExtended + KindPointer = decoder.KindPointer + KindString = decoder.KindString + KindFloat64 = decoder.KindFloat64 + KindBytes = decoder.KindBytes + KindUint16 = decoder.KindUint16 + KindUint32 = decoder.KindUint32 + KindMap = decoder.KindMap + KindInt32 = decoder.KindInt32 + KindUint64 = decoder.KindUint64 + KindUint128 = decoder.KindUint128 + KindSlice = decoder.KindSlice + KindContainer = decoder.KindContainer + KindEndMarker = decoder.KindEndMarker + KindBool = decoder.KindBool + KindFloat32 = decoder.KindFloat32 +) diff --git a/node.go b/node.go deleted file mode 100644 index 16e8b5f..0000000 --- a/node.go +++ /dev/null @@ -1,58 +0,0 @@ -package maxminddb - -type nodeReader interface { - readLeft(uint) uint - readRight(uint) uint -} - -type nodeReader24 struct { - buffer []byte -} - -func (n nodeReader24) readLeft(nodeNumber uint) uint { - return (uint(n.buffer[nodeNumber]) << 16) | - (uint(n.buffer[nodeNumber+1]) << 8) | - uint(n.buffer[nodeNumber+2]) -} - -func (n nodeReader24) readRight(nodeNumber uint) uint { - return (uint(n.buffer[nodeNumber+3]) << 16) | - (uint(n.buffer[nodeNumber+4]) << 8) | - uint(n.buffer[nodeNumber+5]) -} - -type nodeReader28 struct { - buffer []byte -} - -func (n nodeReader28) readLeft(nodeNumber uint) uint { - return ((uint(n.buffer[nodeNumber+3]) & 0xF0) << 20) | - (uint(n.buffer[nodeNumber]) << 16) | - (uint(n.buffer[nodeNumber+1]) << 8) | - uint(n.buffer[nodeNumber+2]) -} - -func (n nodeReader28) readRight(nodeNumber uint) uint { - return ((uint(n.buffer[nodeNumber+3]) & 0x0F) << 24) | - (uint(n.buffer[nodeNumber+4]) << 16) | - (uint(n.buffer[nodeNumber+5]) << 8) | - uint(n.buffer[nodeNumber+6]) -} - -type nodeReader32 struct { - buffer []byte -} - -func (n nodeReader32) readLeft(nodeNumber uint) uint { - return (uint(n.buffer[nodeNumber]) << 24) | - (uint(n.buffer[nodeNumber+1]) << 16) | - (uint(n.buffer[nodeNumber+2]) << 8) | - uint(n.buffer[nodeNumber+3]) -} - -func (n nodeReader32) readRight(nodeNumber uint) uint { - return (uint(n.buffer[nodeNumber+4]) << 24) | - (uint(n.buffer[nodeNumber+5]) << 16) | - (uint(n.buffer[nodeNumber+6]) << 8) | - uint(n.buffer[nodeNumber+7]) -} diff --git a/reader.go b/reader.go index ae18794..639e8bc 100644 --- a/reader.go +++ b/reader.go @@ -1,4 +1,107 @@ // Package maxminddb provides a reader for the MaxMind DB file format. +// +// This package provides an API for reading MaxMind GeoIP2 and GeoLite2 +// databases in the MaxMind DB file format (.mmdb files). The API is designed +// to be simple to use while providing high performance for IP geolocation +// lookups and related data. +// +// # Basic Usage +// +// The most common use case is looking up geolocation data for an IP address: +// +// db, err := maxminddb.Open("GeoLite2-City.mmdb") +// if err != nil { +// log.Fatal(err) +// } +// defer db.Close() +// +// ip, err := netip.ParseAddr("81.2.69.142") +// if err != nil { +// log.Fatal(err) +// } +// +// var record struct { +// Country struct { +// ISOCode string `maxminddb:"iso_code"` +// Names map[string]string `maxminddb:"names"` +// } `maxminddb:"country"` +// City struct { +// Names map[string]string `maxminddb:"names"` +// } `maxminddb:"city"` +// } +// +// err = db.Lookup(ip).Decode(&record) +// if err != nil { +// log.Fatal(err) +// } +// +// fmt.Printf("Country: %s\n", record.Country.Names["en"]) +// fmt.Printf("City: %s\n", record.City.Names["en"]) +// +// # Database Types +// +// This library supports all MaxMind database types: +// - GeoLite2/GeoIP2 City: Comprehensive location data including city, country, subdivisions +// - GeoLite2/GeoIP2 Country: Country-level geolocation data +// - GeoLite2 ASN: Autonomous System Number and organization data +// - GeoIP2 Anonymous IP: Anonymous network and proxy detection +// - GeoIP2 Enterprise: Enhanced City data with additional business fields +// - GeoIP2 ISP: Internet service provider information +// - GeoIP2 Domain: Second-level domain data +// - GeoIP2 Connection Type: Connection type identification +// +// # Performance +// +// For maximum performance in high-throughput applications, consider: +// +// 1. Using custom struct types that only include the fields you need +// 2. Implementing the Unmarshaler interface for custom decoding +// 3. Reusing the Reader instance across multiple goroutines (it's thread-safe) +// +// # Custom Unmarshaling +// +// For custom decoding logic, you can implement the mmdbdata.Unmarshaler interface, +// similar to how encoding/json's json.Unmarshaler works. Types implementing this +// interface will automatically use custom decoding logic when used with Reader.Lookup: +// +// type FastCity struct { +// CountryISO string +// CityName string +// } +// +// func (c *FastCity) UnmarshalMaxMindDB(d *mmdbdata.Decoder) error { +// // Custom decoding logic using d.ReadMap(), d.ReadString(), etc. +// // Allows fine-grained control over how MaxMind DB data is decoded +// // See mmdbdata package documentation and ExampleUnmarshaler for complete examples +// } +// +// # Network Iteration +// +// You can iterate over all networks in a database: +// +// for result := range db.Networks() { +// var record struct { +// Country struct { +// ISOCode string `maxminddb:"iso_code"` +// } `maxminddb:"country"` +// } +// err := result.Decode(&record) +// if err != nil { +// log.Fatal(err) +// } +// fmt.Printf("%s: %s\n", result.Prefix(), record.Country.ISOCode) +// } +// +// # Database Files +// +// MaxMind provides both free (GeoLite2) and commercial (GeoIP2) databases: +// - Free: https://dev.maxmind.com/geoip/geolite2-free-geolocation-data +// - Commercial: https://www.maxmind.com/en/geoip2-databases +// +// # Thread Safety +// +// All Reader methods are thread-safe. The Reader can be safely shared across +// multiple goroutines. package maxminddb import ( @@ -8,8 +111,11 @@ import ( "io" "net/netip" "os" - "reflect" "runtime" + "time" + + "github.com/oschwald/maxminddb-golang/v2/internal/decoder" + "github.com/oschwald/maxminddb-golang/v2/internal/mmdberrors" ) const dataSectionSeparatorSize = 16 @@ -22,9 +128,8 @@ var metadataStartMarker = []byte("\xAB\xCD\xEFMaxMind.com") // All of the methods on Reader are thread-safe. The struct may be safely // shared across goroutines. type Reader struct { - nodeReader nodeReader buffer []byte - decoder decoder + decoder decoder.ReflectionDecoder Metadata Metadata ipv4Start uint ipv4StartBitDepth int @@ -32,34 +137,84 @@ type Reader struct { hasMappedFile bool } -// Metadata holds the metadata decoded from the MaxMind DB file. In particular -// it has the format version, the build time as Unix epoch time, the database -// type and description, the IP version supported, and a slice of the natural -// languages included. +// Metadata holds the metadata decoded from the MaxMind DB file. +// +// Key fields include: +// - DatabaseType: indicates the structure of data records (e.g., "GeoIP2-City") +// - Description: localized descriptions in various languages +// - Languages: locale codes for which the database may contain localized data +// - BuildEpoch: database build timestamp as Unix epoch seconds +// - IPVersion: supported IP version (4 for IPv4-only, 6 for IPv4/IPv6) +// - NodeCount: number of nodes in the search tree +// - RecordSize: size in bits of each record in the search tree (24, 28, or 32) +// +// For detailed field descriptions, see the MaxMind DB specification: +// https://maxmind.github.io/MaxMind-DB/ type Metadata struct { - Description map[string]string `maxminddb:"description"` - DatabaseType string `maxminddb:"database_type"` - Languages []string `maxminddb:"languages"` - BinaryFormatMajorVersion uint `maxminddb:"binary_format_major_version"` - BinaryFormatMinorVersion uint `maxminddb:"binary_format_minor_version"` - BuildEpoch uint `maxminddb:"build_epoch"` - IPVersion uint `maxminddb:"ip_version"` - NodeCount uint `maxminddb:"node_count"` - RecordSize uint `maxminddb:"record_size"` + // Description contains localized database descriptions. + // Keys are language codes (e.g., "en", "zh-CN"), values are UTF-8 descriptions. + Description map[string]string `maxminddb:"description"` + + // DatabaseType indicates the structure of data records associated with IP addresses. + // Names starting with "GeoIP" are reserved for MaxMind databases. + DatabaseType string `maxminddb:"database_type"` + + // Languages lists locale codes for which this database may contain localized data. + // Records should not contain localized data for locales not in this array. + Languages []string `maxminddb:"languages"` + + // BinaryFormatMajorVersion is the major version of the MaxMind DB binary format. + // Current supported version is 2. + BinaryFormatMajorVersion uint `maxminddb:"binary_format_major_version"` + + // BinaryFormatMinorVersion is the minor version of the MaxMind DB binary format. + // Current supported version is 0. + BinaryFormatMinorVersion uint `maxminddb:"binary_format_minor_version"` + + // BuildEpoch contains the database build timestamp as Unix epoch seconds. + // Use BuildTime() method for a time.Time representation. + BuildEpoch uint `maxminddb:"build_epoch"` + + // IPVersion indicates the IP version support: + // 4: IPv4 addresses only + // 6: Both IPv4 and IPv6 addresses + IPVersion uint `maxminddb:"ip_version"` + + // NodeCount is the number of nodes in the search tree. + NodeCount uint `maxminddb:"node_count"` + + // RecordSize is the size in bits of each record in the search tree. + // Valid values are 24, 28, or 32. + RecordSize uint `maxminddb:"record_size"` +} + +// BuildTime returns the database build time as a time.Time. +// This is a convenience method that converts the BuildEpoch field +// from Unix epoch seconds to a time.Time value. +func (m Metadata) BuildTime() time.Time { + return time.Unix(int64(m.BuildEpoch), 0) } -// Open takes a string path to a MaxMind DB file and returns a Reader -// structure or an error. The database file is opened using a memory map -// on supported platforms. On platforms without memory map support, such +type readerOptions struct{} + +// ReaderOption are options for [Open] and [FromBytes]. +// +// This was added to allow for future options, e.g., for caching, without +// causing a breaking API change. +type ReaderOption func(*readerOptions) + +// Open takes a string path to a MaxMind DB file and any options. It returns a +// Reader structure or an error. The database file is opened using a memory +// map on supported platforms. On platforms without memory map support, such // as WebAssembly or Google App Engine, or if the memory map attempt fails // due to lack of support from the filesystem, the database is loaded into memory. // Use the Close method on the Reader object to return the resources to the system. -func Open(file string) (*Reader, error) { +func Open(file string, options ...ReaderOption) (*Reader, error) { mapFile, err := os.Open(file) if err != nil { return nil, err } - defer mapFile.Close() + defer mapFile.Close() //nolint:errcheck // error is generally not relevant stats, err := mapFile.Stat() if err != nil { @@ -86,12 +241,12 @@ func Open(file string) (*Reader, error) { if err != nil { return nil, err } - return FromBytes(data) + return FromBytes(data, options...) } return nil, err } - reader, err := FromBytes(data) + reader, err := FromBytes(data, options...) if err != nil { _ = munmap(data) return nil, err @@ -120,22 +275,28 @@ func (r *Reader) Close() error { return err } -// FromBytes takes a byte slice corresponding to a MaxMind DB file and returns -// a Reader structure or an error. -func FromBytes(buffer []byte) (*Reader, error) { +// FromBytes takes a byte slice corresponding to a MaxMind DB file and any +// options. It returns a Reader structure or an error. +func FromBytes(buffer []byte, options ...ReaderOption) (*Reader, error) { + opts := &readerOptions{} + for _, option := range options { + option(opts) + } + metadataStart := bytes.LastIndex(buffer, metadataStartMarker) if metadataStart == -1 { - return nil, newInvalidDatabaseError("error opening database: invalid MaxMind DB file") + return nil, mmdberrors.NewInvalidDatabaseError( + "error opening database: invalid MaxMind DB file", + ) } metadataStart += len(metadataStartMarker) - metadataDecoder := decoder{buffer: buffer[metadataStart:]} + metadataDecoder := decoder.New(buffer[metadataStart:]) var metadata Metadata - rvMetadata := reflect.ValueOf(&metadata) - _, err := metadataDecoder.decode(0, rvMetadata, 0) + err := metadataDecoder.Decode(0, &metadata) if err != nil { return nil, err } @@ -144,28 +305,14 @@ func FromBytes(buffer []byte) (*Reader, error) { dataSectionStart := searchTreeSize + dataSectionSeparatorSize dataSectionEnd := uint(metadataStart - len(metadataStartMarker)) if dataSectionStart > dataSectionEnd { - return nil, newInvalidDatabaseError("the MaxMind DB contains invalid metadata") - } - d := decoder{ - buffer: buffer[searchTreeSize+dataSectionSeparatorSize : metadataStart-len(metadataStartMarker)], - } - - nodeBuffer := buffer[:searchTreeSize] - var nodeReader nodeReader - switch metadata.RecordSize { - case 24: - nodeReader = nodeReader24{buffer: nodeBuffer} - case 28: - nodeReader = nodeReader28{buffer: nodeBuffer} - case 32: - nodeReader = nodeReader32{buffer: nodeBuffer} - default: - return nil, newInvalidDatabaseError("unknown record size: %d", metadata.RecordSize) + return nil, mmdberrors.NewInvalidDatabaseError("the MaxMind DB contains invalid metadata") } + d := decoder.New( + buffer[searchTreeSize+dataSectionSeparatorSize : metadataStart-len(metadataStartMarker)], + ) reader := &Reader{ buffer: buffer, - nodeReader: nodeReader, decoder: d, Metadata: metadata, ipv4Start: 0, @@ -177,25 +324,8 @@ func FromBytes(buffer []byte) (*Reader, error) { return reader, err } -func (r *Reader) setIPv4Start() { - if r.Metadata.IPVersion != 6 { - r.ipv4StartBitDepth = 96 - return - } - - nodeCount := r.Metadata.NodeCount - - node := uint(0) - i := 0 - for ; i < 96 && node < nodeCount; i++ { - node = r.nodeReader.readLeft(node * r.nodeOffsetMult) - } - r.ipv4Start = node - r.ipv4StartBitDepth = i -} - -// Lookup retrieves the database record for ip and returns Result, which can -// be used to decode the data.. +// Lookup retrieves the database record for ip and returns a Result, which can +// be used to decode the data. func (r *Reader) Lookup(ip netip.Addr) Result { if r.buffer == nil { return Result{err: errors.New("cannot call Lookup on a closed database")} @@ -229,12 +359,29 @@ func (r *Reader) Lookup(ip netip.Addr) Result { // netip.Prefix returned by Networks will be invalid when using LookupOffset. func (r *Reader) LookupOffset(offset uintptr) Result { if r.buffer == nil { - return Result{err: errors.New("cannot call Decode on a closed database")} + return Result{err: errors.New("cannot call LookupOffset on a closed database")} } return Result{decoder: r.decoder, offset: uint(offset)} } +func (r *Reader) setIPv4Start() { + if r.Metadata.IPVersion != 6 { + r.ipv4StartBitDepth = 96 + return + } + + nodeCount := r.Metadata.NodeCount + + node := uint(0) + i := 0 + for ; i < 96 && node < nodeCount; i++ { + node = readNodeBySize(r.buffer, node*r.nodeOffsetMult, 0, r.Metadata.RecordSize) + } + r.ipv4Start = node + r.ipv4StartBitDepth = i +} + var zeroIP = netip.MustParseAddr("::") func (r *Reader) lookupPointer(ip netip.Addr) (uint, int, error) { @@ -245,7 +392,10 @@ func (r *Reader) lookupPointer(ip netip.Addr) (uint, int, error) { ) } - node, prefixLength := r.traverseTree(ip, 0, 128) + node, prefixLength, err := r.traverseTree(ip, 0, 128) + if err != nil { + return 0, 0, err + } nodeCount := r.Metadata.NodeCount if node == nodeCount { @@ -255,28 +405,137 @@ func (r *Reader) lookupPointer(ip netip.Addr) (uint, int, error) { return node, prefixLength, nil } - return 0, prefixLength, newInvalidDatabaseError("invalid node in search tree") + return 0, prefixLength, mmdberrors.NewInvalidDatabaseError("invalid node in search tree") +} + +// readNodeBySize reads a node value from the buffer based on record size and bit. +func readNodeBySize(buffer []byte, offset, bit, recordSize uint) uint { + switch recordSize { + case 24: + offset += bit * 3 + return (uint(buffer[offset]) << 16) | + (uint(buffer[offset+1]) << 8) | + uint(buffer[offset+2]) + case 28: + if bit == 0 { + return ((uint(buffer[offset+3]) & 0xF0) << 20) | + (uint(buffer[offset]) << 16) | + (uint(buffer[offset+1]) << 8) | + uint(buffer[offset+2]) + } + return ((uint(buffer[offset+3]) & 0x0F) << 24) | + (uint(buffer[offset+4]) << 16) | + (uint(buffer[offset+5]) << 8) | + uint(buffer[offset+6]) + case 32: + offset += bit * 4 + return (uint(buffer[offset]) << 24) | + (uint(buffer[offset+1]) << 16) | + (uint(buffer[offset+2]) << 8) | + uint(buffer[offset+3]) + default: + return 0 + } +} + +func (r *Reader) traverseTree(ip netip.Addr, node uint, stopBit int) (uint, int, error) { + switch r.Metadata.RecordSize { + case 24: + n, i := r.traverseTree24(ip, node, stopBit) + return n, i, nil + case 28: + n, i := r.traverseTree28(ip, node, stopBit) + return n, i, nil + case 32: + n, i := r.traverseTree32(ip, node, stopBit) + return n, i, nil + default: + return 0, 0, mmdberrors.NewInvalidDatabaseError( + "unsupported record size: %d", + r.Metadata.RecordSize, + ) + } +} + +func (r *Reader) traverseTree24(ip netip.Addr, node uint, stopBit int) (uint, int) { + i := 0 + if ip.Is4() { + i = r.ipv4StartBitDepth + node = r.ipv4Start + } + nodeCount := r.Metadata.NodeCount + buffer := r.buffer + ip16 := ip.As16() + + for ; i < stopBit && node < nodeCount; i++ { + byteIdx := i >> 3 + bitPos := 7 - (i & 7) + bit := (uint(ip16[byteIdx]) >> bitPos) & 1 + + baseOffset := node * 6 + offset := baseOffset + bit*3 + + node = (uint(buffer[offset]) << 16) | + (uint(buffer[offset+1]) << 8) | + uint(buffer[offset+2]) + } + + return node, i } -func (r *Reader) traverseTree(ip netip.Addr, node uint, stopBit int) (uint, int) { +func (r *Reader) traverseTree28(ip netip.Addr, node uint, stopBit int) (uint, int) { i := 0 if ip.Is4() { i = r.ipv4StartBitDepth node = r.ipv4Start } nodeCount := r.Metadata.NodeCount + buffer := r.buffer + ip16 := ip.As16() + + for ; i < stopBit && node < nodeCount; i++ { + byteIdx := i >> 3 + bitPos := 7 - (i & 7) + bit := (uint(ip16[byteIdx]) >> bitPos) & 1 + + baseOffset := node * 7 + sharedByte := uint(buffer[baseOffset+3]) + mask := uint(0xF0 >> (bit * 4)) + shift := 20 + bit*4 + nibble := ((sharedByte & mask) << shift) + offset := baseOffset + bit*4 + + node = nibble | + (uint(buffer[offset]) << 16) | + (uint(buffer[offset+1]) << 8) | + uint(buffer[offset+2]) + } + return node, i +} + +func (r *Reader) traverseTree32(ip netip.Addr, node uint, stopBit int) (uint, int) { + i := 0 + if ip.Is4() { + i = r.ipv4StartBitDepth + node = r.ipv4Start + } + nodeCount := r.Metadata.NodeCount + buffer := r.buffer ip16 := ip.As16() for ; i < stopBit && node < nodeCount; i++ { - bit := uint(1) & (uint(ip16[i>>3]) >> (7 - (i % 8))) + byteIdx := i >> 3 + bitPos := 7 - (i & 7) + bit := (uint(ip16[byteIdx]) >> bitPos) & 1 - offset := node * r.nodeOffsetMult - if bit == 0 { - node = r.nodeReader.readLeft(offset) - } else { - node = r.nodeReader.readRight(offset) - } + baseOffset := node * 8 + offset := baseOffset + bit*4 + + node = (uint(buffer[offset]) << 24) | + (uint(buffer[offset+1]) << 16) | + (uint(buffer[offset+2]) << 8) | + uint(buffer[offset+3]) } return node, i @@ -286,7 +545,7 @@ func (r *Reader) resolveDataPointer(pointer uint) (uintptr, error) { resolved := uintptr(pointer - r.Metadata.NodeCount - dataSectionSeparatorSize) if resolved >= uintptr(len(r.buffer)) { - return 0, newInvalidDatabaseError("the MaxMind DB file's search tree is corrupt") + return 0, mmdberrors.NewInvalidDatabaseError("the MaxMind DB file's search tree is corrupt") } return resolved, nil } diff --git a/reader_test.go b/reader_test.go index d6c8cf8..8d81319 100644 --- a/reader_test.go +++ b/reader_test.go @@ -9,11 +9,15 @@ import ( "net/netip" "os" "path/filepath" + "sync" "testing" "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + + "github.com/oschwald/maxminddb-golang/v2/internal/mmdberrors" + "github.com/oschwald/maxminddb-golang/v2/mmdbdata" ) func TestReader(t *testing.T) { @@ -81,7 +85,7 @@ func TestLookupNetwork(t *testing.T) { }, "double": 42.123456, "float": float32(1.1), - "int32": -268435456, + "int32": int32(-268435456), "map": map[string]any{ "mapX": map[string]any{ "arrayX": []any{ @@ -232,7 +236,7 @@ func checkDecodingToInterface(t *testing.T, recordInterface any) { assert.Equal(t, []byte{0x00, 0x00, 0x00, 0x2a}, record["bytes"]) assert.InEpsilon(t, 42.123456, record["double"], 1e-10) assert.InEpsilon(t, float32(1.1), record["float"], 1e-5) - assert.Equal(t, -268435456, record["int32"]) + assert.Equal(t, int32(-268435456), record["int32"]) assert.Equal(t, map[string]any{ "mapX": map[string]any{ @@ -393,7 +397,7 @@ func TestNonEmptyNilInterface(t *testing.T) { err = reader.Lookup(netip.MustParseAddr("::1.1.1.0")).Decode(&result) assert.Equal( t, - "maxminddb: cannot unmarshal map into type maxminddb.TestInterface", + "at offset 115: maxminddb: cannot unmarshal map into type maxminddb.TestInterface", err.Error(), ) } @@ -452,9 +456,10 @@ type NestedPointerMapX struct { type PointerMap struct { MapX struct { - Ignored string NestedMapX *NestedPointerMapX + + Ignored string } `maxminddb:"mapX"` } @@ -647,7 +652,7 @@ func TestBrokenDoubleDatabase(t *testing.T) { var result any err = reader.Lookup(netip.MustParseAddr("2001:220::")).Decode(&result) - expected := newInvalidDatabaseError( + expected := mmdberrors.NewInvalidDatabaseError( "the MaxMind DB file's data section contains bad data (float 64 size of 2)", ) require.ErrorAs(t, err, &expected) @@ -657,7 +662,7 @@ func TestBrokenDoubleDatabase(t *testing.T) { func TestInvalidNodeCountDatabase(t *testing.T) { _, err := Open(testFile("GeoIP2-City-Test-Invalid-Node-Count.mmdb")) - expected := newInvalidDatabaseError("the MaxMind DB contains invalid metadata") + expected := mmdberrors.NewInvalidDatabaseError("the MaxMind DB contains invalid metadata") assert.Equal(t, expected, err) } @@ -708,7 +713,7 @@ func TestUsingClosedDatabase(t *testing.T) { assert.Equal(t, "cannot call Lookup on a closed database", err.Error()) err = reader.LookupOffset(0).Decode(recordInterface) - assert.Equal(t, "cannot call Decode on a closed database", err.Error()) + assert.Equal(t, "cannot call LookupOffset on a closed database", err.Error()) } func checkMetadata(t *testing.T, reader *Reader, ipVersion, recordSize uint) { @@ -1003,6 +1008,61 @@ func BenchmarkDecodePathCountryCode(b *testing.B) { require.NoError(b, db.Close(), "error on close") } +// BenchmarkCityLookupConcurrent tests concurrent city lookups to demonstrate +// string cache performance under concurrent load. +func BenchmarkCityLookupConcurrent(b *testing.B) { + db, err := Open("GeoLite2-City.mmdb") + require.NoError(b, err) + defer func() { + require.NoError(b, db.Close(), "error on close") + }() + + // Test with different numbers of concurrent goroutines + goroutineCounts := []int{1, 4, 16, 64} + + for _, numGoroutines := range goroutineCounts { + b.Run(fmt.Sprintf("goroutines_%d", numGoroutines), func(b *testing.B) { + // Each goroutine performs 100 lookups + const lookupsPerGoroutine = 100 + b.ResetTimer() + + for range b.N { + var wg sync.WaitGroup + wg.Add(numGoroutines) + + for range numGoroutines { + go func() { + defer wg.Done() + + //nolint:gosec // this is a test + r := rand.New(rand.NewSource(time.Now().UnixNano())) + s := make(net.IP, 4) + var result fullCity + + for range lookupsPerGoroutine { + ip := randomIPv4Address(r, s) + err := db.Lookup(ip).Decode(&result) + if err != nil { + b.Error(err) + return + } + // Access string fields to exercise the cache + _ = result.City.Names + _ = result.Country.Names + } + }() + } + + wg.Wait() + } + + // Report operations per second + totalOps := int64(b.N) * int64(numGoroutines) * int64(lookupsPerGoroutine) + b.ReportMetric(float64(totalOps)/b.Elapsed().Seconds(), "lookups/sec") + }) + } +} + func randomIPv4Address(r *rand.Rand, ip []byte) netip.Addr { num := r.Uint32() ip[0] = byte(num >> 24) @@ -1016,3 +1076,175 @@ func randomIPv4Address(r *rand.Rand, ip []byte) netip.Addr { func testFile(file string) string { return filepath.Join("test-data", "test-data", file) } + +// Test custom unmarshaling through Reader.Lookup. +func TestCustomUnmarshaler(t *testing.T) { + reader, err := Open(testFile("MaxMind-DB-test-decoder.mmdb")) + require.NoError(t, err) + defer func() { + if err := reader.Close(); err != nil { + t.Errorf("Error closing reader: %v", err) + } + }() + + // Test a type that implements Unmarshaler + var customDecoded TestCity + result := reader.Lookup(netip.MustParseAddr("1.1.1.1")) + err = result.Decode(&customDecoded) + require.NoError(t, err) + + // Test that the same data decoded with reflection gives the same result + var reflectionDecoded map[string]any + result2 := reader.Lookup(netip.MustParseAddr("1.1.1.1")) + err = result2.Decode(&reflectionDecoded) + require.NoError(t, err) + + // Verify the custom decoder worked correctly + // The exact assertions depend on the test data in MaxMind-DB-test-decoder.mmdb + t.Logf("Custom decoded: %+v", customDecoded) + t.Logf("Reflection decoded: %+v", reflectionDecoded) + + // Test that both methods produce consistent results for any matching data + if len(customDecoded.Names) > 0 || len(reflectionDecoded) > 0 { + t.Log("Custom unmarshaler integration test passed - both decoders worked") + } +} + +// TestCity represents a simplified city data structure for testing custom unmarshaling. +type TestCity struct { + Names map[string]string `maxminddb:"names"` + GeoNameID uint `maxminddb:"geoname_id"` +} + +// UnmarshalMaxMindDB implements the Unmarshaler interface for TestCity. +// This demonstrates custom decoding that avoids reflection for better performance. +func (c *TestCity) UnmarshalMaxMindDB(d *mmdbdata.Decoder) error { + mapIter, _, err := d.ReadMap() + if err != nil { + return err + } + for key, err := range mapIter { + if err != nil { + return err + } + + switch string(key) { + case "names": + // Decode nested map[string]string for localized names + nameMapIter, size, err := d.ReadMap() + if err != nil { + return err + } + names := make(map[string]string, size) // Pre-allocate with correct capacity + for nameKey, nameErr := range nameMapIter { + if nameErr != nil { + return nameErr + } + value, valueErr := d.ReadString() + if valueErr != nil { + return valueErr + } + names[string(nameKey)] = value + } + c.Names = names + case "geoname_id": + geoID, err := d.ReadUInt32() + if err != nil { + return err + } + c.GeoNameID = uint(geoID) + default: + // Skip unknown fields + if err := d.SkipValue(); err != nil { + return err + } + } + } + return nil +} + +// TestASN represents ASN data for testing custom unmarshaling. +type TestASN struct { + AutonomousSystemOrganization string `maxminddb:"autonomous_system_organization"` + AutonomousSystemNumber uint `maxminddb:"autonomous_system_number"` +} + +// UnmarshalMaxMindDB implements the Unmarshaler interface for TestASN. +func (a *TestASN) UnmarshalMaxMindDB(d *mmdbdata.Decoder) error { + mapIter, _, err := d.ReadMap() + if err != nil { + return err + } + for key, err := range mapIter { + if err != nil { + return err + } + + switch string(key) { + case "autonomous_system_organization": + org, err := d.ReadString() + if err != nil { + return err + } + a.AutonomousSystemOrganization = org + case "autonomous_system_number": + asn, err := d.ReadUInt32() + if err != nil { + return err + } + a.AutonomousSystemNumber = uint(asn) + default: + if err := d.SkipValue(); err != nil { + return err + } + } + } + return nil +} + +// TestFallbackToReflection verifies that types without UnmarshalMaxMindDB still work. +func TestFallbackToReflection(t *testing.T) { + reader, err := Open(testFile("MaxMind-DB-test-decoder.mmdb")) + require.NoError(t, err) + defer func() { + if err := reader.Close(); err != nil { + t.Errorf("Error closing reader: %v", err) + } + }() + + // Test with a regular struct that doesn't implement Unmarshaler + var regularStruct struct { + Names map[string]string `maxminddb:"names"` + } + + result := reader.Lookup(netip.MustParseAddr("1.1.1.1")) + err = result.Decode(®ularStruct) + require.NoError(t, err) + + // Log the result for verification + t.Logf("Reflection fallback result: %+v", regularStruct) +} + +func TestMetadataBuildTime(t *testing.T) { + reader, err := Open(testFile("GeoIP2-City-Test.mmdb")) + require.NoError(t, err) + defer func() { + if err := reader.Close(); err != nil { + t.Errorf("Error closing reader: %v", err) + } + }() + + metadata := reader.Metadata + + // Test that BuildTime() returns a valid time + buildTime := metadata.BuildTime() + assert.False(t, buildTime.IsZero(), "BuildTime should not be zero") + + // Test that BuildTime() matches BuildEpoch + expectedTime := time.Unix(int64(metadata.BuildEpoch), 0) + assert.Equal(t, expectedTime, buildTime, "BuildTime should match time.Unix(BuildEpoch, 0)") + + // Verify the build time is reasonable (after 2010, before 2030) + assert.True(t, buildTime.After(time.Date(2010, 1, 1, 0, 0, 0, 0, time.UTC))) + assert.True(t, buildTime.Before(time.Date(2030, 1, 1, 0, 0, 0, 0, time.UTC))) +} diff --git a/result.go b/result.go index 7562b2b..6bd7e42 100644 --- a/result.go +++ b/result.go @@ -1,18 +1,19 @@ package maxminddb import ( - "errors" "math" "net/netip" - "reflect" + + "github.com/oschwald/maxminddb-golang/v2/internal/decoder" ) const notFound uint = math.MaxUint +// Result holds the result of the database lookup. type Result struct { ip netip.Addr err error - decoder decoder + decoder decoder.ReflectionDecoder offset uint prefixLen uint8 } @@ -35,18 +36,8 @@ func (r Result) Decode(v any) error { if r.offset == notFound { return nil } - rv := reflect.ValueOf(v) - if rv.Kind() != reflect.Ptr || rv.IsNil() { - return errors.New("result param must be a pointer") - } - if dser, ok := v.(deserializer); ok { - _, err := r.decoder.decodeToDeserializer(r.offset, dser, 0, false) - return err - } - - _, err := r.decoder.decode(r.offset, rv, 0) - return err + return r.decoder.Decode(r.offset, v) } // DecodePath unmarshals a value from data section into v, following the @@ -89,11 +80,7 @@ func (r Result) DecodePath(v any, path ...any) error { if r.offset == notFound { return nil } - rv := reflect.ValueOf(v) - if rv.Kind() != reflect.Ptr || rv.IsNil() { - return errors.New("result param must be a pointer") - } - return r.decoder.decodePath(r.offset, path, rv) + return r.decoder.DecodePath(r.offset, path, v) } // Err provides a way to check whether there was an error during the lookup @@ -113,7 +100,7 @@ func (r Result) Found() bool { // passed to (*Reader).LookupOffset. It can also be used as a unique // identifier for the data record in the particular database to cache the data // record across lookups. Note that while the offset uniquely identifies the -// data record, other data in Result may differ between lookups. The offset +// data record, other data in Result may differ between lookups. The offset // is only valid for the current database version. If you update the database // file, you must invalidate any cache associated with the previous version. func (r Result) Offset() uintptr { diff --git a/traverse.go b/traverse.go index b9a6acd..34f6cd2 100644 --- a/traverse.go +++ b/traverse.go @@ -5,6 +5,8 @@ import ( // comment to prevent gofumpt from randomly moving iter. "iter" "net/netip" + + "github.com/oschwald/maxminddb-golang/v2/internal/mmdberrors" ) // Internal structure used to keep track of nodes we still need to visit. @@ -30,14 +32,18 @@ type NetworksOption func(*networkOptions) // IncludeAliasedNetworks is an option for Networks and NetworksWithin // that makes them iterate over aliases of the IPv4 subtree in an IPv6 // database, e.g., ::ffff:0:0/96, 2001::/32, and 2002::/16. -func IncludeAliasedNetworks(networks *networkOptions) { - networks.includeAliasedNetworks = true +func IncludeAliasedNetworks() NetworksOption { + return func(networks *networkOptions) { + networks.includeAliasedNetworks = true + } } // IncludeNetworksWithoutData is an option for Networks and NetworksWithin // that makes them include networks without any data in the iteration. -func IncludeNetworksWithoutData(networks *networkOptions) { - networks.includeEmptyNetworks = true +func IncludeNetworksWithoutData() NetworksOption { + return func(networks *networkOptions) { + networks.includeEmptyNetworks = true + } } // Networks returns an iterator that can be used to traverse the networks in @@ -95,7 +101,14 @@ func (r *Reader) NetworksWithin(prefix netip.Prefix, options ...NetworksOption) stopBit += 96 } - pointer, bit := r.traverseTree(ip, 0, stopBit) + pointer, bit, err := r.traverseTree(ip, 0, stopBit) + if err != nil { + yield(Result{ + ip: ip, + err: err, + }) + return + } prefix, err := netIP.Prefix(bit) if err != nil { @@ -166,7 +179,7 @@ func (r *Reader) NetworksWithin(prefix netip.Prefix, options ...NetworksOption) ip: displayAddr, prefixLen: uint8(node.bit), } - res.err = newInvalidDatabaseError( + res.err = mmdberrors.NewInvalidDatabaseError( "invalid search tree at %s", res.Prefix()) yield(res) @@ -176,7 +189,7 @@ func (r *Reader) NetworksWithin(prefix netip.Prefix, options ...NetworksOption) ipRight[node.bit>>3] |= 1 << (7 - (node.bit % 8)) offset := node.pointer * r.nodeOffsetMult - rightPointer := r.nodeReader.readRight(offset) + rightPointer := readNodeBySize(r.buffer, offset, 1, r.Metadata.RecordSize) node.bit++ nodes = append(nodes, netNode{ @@ -185,7 +198,7 @@ func (r *Reader) NetworksWithin(prefix netip.Prefix, options ...NetworksOption) bit: node.bit, }) - node.pointer = r.nodeReader.readLeft(offset) + node.pointer = readNodeBySize(r.buffer, offset, 0, r.Metadata.RecordSize) } } } diff --git a/traverse_test.go b/traverse_test.go index 5fbbb88..5340978 100644 --- a/traverse_test.go +++ b/traverse_test.go @@ -227,7 +227,7 @@ var tests = []networkTest{ "2002:101:110::/44", "2002:101:120::/48", }, - Options: []NetworksOption{IncludeAliasedNetworks}, + Options: []NetworksOption{IncludeAliasedNetworks()}, }, { Network: "::/0", @@ -281,7 +281,7 @@ var tests = []networkTest{ "1.64.0.0/10", "1.128.0.0/9", }, - Options: []NetworksOption{IncludeNetworksWithoutData}, + Options: []NetworksOption{IncludeNetworksWithoutData()}, }, { Network: "1.1.1.16/28", diff --git a/verifier.go b/verifier.go index 335cb1b..3c00287 100644 --- a/verifier.go +++ b/verifier.go @@ -1,17 +1,32 @@ package maxminddb import ( - "reflect" "runtime" + + "github.com/oschwald/maxminddb-golang/v2/internal/mmdberrors" ) type verifier struct { reader *Reader } -// Verify checks that the database is valid. It validates the search tree, -// the data section, and the metadata section. This verifier is stricter than -// the specification and may return errors on databases that are readable. +// Verify performs comprehensive validation of the MaxMind DB file. +// +// This method validates: +// - Metadata section: format versions, required fields, and value constraints +// - Search tree: traverses all networks to verify tree structure integrity +// - Data section separator: validates the 16-byte separator between tree and data +// - Data section: verifies all data records referenced by the search tree +// +// The verifier is stricter than the MaxMind DB specification and may return +// errors on some databases that are still readable by normal operations. +// This method is useful for: +// - Validating database files after download or generation +// - Debugging database corruption issues +// - Ensuring database integrity in critical applications +// +// Note: Verification traverses the entire database and may be slow on large files. +// The method is thread-safe and can be called on an active Reader. func (r *Reader) Verify() error { v := verifier{r} if err := v.verifyMetadata(); err != nil { @@ -96,7 +111,7 @@ func (v *verifier) verifyDatabase() error { return err } - return v.verifyDataSection(offsets) + return v.reader.decoder.VerifyDataSection(offsets) } func (v *verifier) verifySearchTree() (map[uint]bool, error) { @@ -118,66 +133,11 @@ func (v *verifier) verifyDataSectionSeparator() error { for _, b := range separator { if b != 0 { - return newInvalidDatabaseError("unexpected byte in data separator: %v", separator) - } - } - return nil -} - -func (v *verifier) verifyDataSection(offsets map[uint]bool) error { - pointerCount := len(offsets) - - decoder := v.reader.decoder - - var offset uint - bufferLen := uint(len(decoder.buffer)) - for offset < bufferLen { - var data any - rv := reflect.ValueOf(&data) - newOffset, err := decoder.decode(offset, rv, 0) - if err != nil { - return newInvalidDatabaseError( - "received decoding error (%v) at offset of %v", - err, - offset, - ) - } - if newOffset <= offset { - return newInvalidDatabaseError( - "data section offset unexpectedly went from %v to %v", - offset, - newOffset, + return mmdberrors.NewInvalidDatabaseError( + "unexpected byte in data separator: %v", + separator, ) } - - pointer := offset - - if _, ok := offsets[pointer]; !ok { - return newInvalidDatabaseError( - "found data (%v) at %v that the search tree does not point to", - data, - pointer, - ) - } - delete(offsets, pointer) - - offset = newOffset - } - - if offset != bufferLen { - return newInvalidDatabaseError( - "unexpected data at the end of the data section (last offset: %v, end: %v)", - offset, - bufferLen, - ) - } - - if len(offsets) != 0 { - return newInvalidDatabaseError( - "found %v pointers (of %v) in the search tree that we did not see in the data section", - len(offsets), - pointerCount, - ) } return nil } @@ -187,7 +147,7 @@ func testError( expected any, actual any, ) error { - return newInvalidDatabaseError( + return mmdberrors.NewInvalidDatabaseError( "%v - Expected: %v Actual: %v", field, expected,