Skip to content

Commit

Permalink
fix tinyint marshal, unmarshal
Browse files Browse the repository at this point in the history
  • Loading branch information
illia-li committed Sep 27, 2024
1 parent 25c97f8 commit 40a54cb
Show file tree
Hide file tree
Showing 6 changed files with 735 additions and 96 deletions.
131 changes: 35 additions & 96 deletions marshal.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"encoding/binary"
"errors"
"fmt"
"github.com/gocql/gocql/marshal/tinyint"
"math"
"math/big"
"math/bits"
Expand All @@ -30,14 +31,32 @@ var (
ErrorUDTUnavailable = errors.New("UDT are not available on protocols less than 3, please update config")
)

// Marshaler is the interface implemented by objects that can marshal
// themselves into values understood by Cassandra.
// Marshaler is an interface for marshalling objects according to the CQL binary protocol.
// Initially, each value of the 'CQL binary protocol' consist of <value_len> and <value_data>.
// <value_len> can be 'unset'(-2), 'nil'(-1), 'zero'(0) or any value up to 2147483647.
// In 'unset', 'nil' and 'zero' cases <value_data> is not present.
// Basically, 'unset' is applicable only to columns, but there may be exceptions.
// The current version of 'gocql' writes <value_len> through logic after marshaling functions,
// so you need to tell this logic about these cases:
// 1. In 'unset' case - you need to put gocql.UnsetValue instead of value.
// 2. In 'nil' case - your Marshaller implementation should return []byte==nil.
// 3. In 'zero' case - your Marshaller implementation should return initiated []byte with len==0.
//
// All CQL DB`s have proprietary value coding features, which you need to consider.
// CQL binary protocol info:https://github.com/apache/cassandra/tree/trunk/doc
type Marshaler interface {
MarshalCQL(info TypeInfo) ([]byte, error)
}

// Unmarshaler is the interface implemented by objects that can unmarshal
// a Cassandra specific description of themselves.
// Unmarshaler is an interface for unmarshalling objects according to the CQL binary protocol.
// Initially, each value of the 'CQL binary protocol' consist of <value_len> and <value_data>.
// <value_len> can be 'nil'(-1), 'zero'(0) or any value up to 2147483647.
// In 'nil' and 'zero' cases <value_data> is not present.
// The current version of 'gocql' reads <value_len> through logic before unmarshalling functions,
// so your Unmarshaller implementation will receive:
// in the 'nil' case - []byte==nil,
// in the 'zero' case - initiated []byte with len==0.
// CQL binary protocol info:https://github.com/apache/cassandra/tree/trunk/doc
type Unmarshaler interface {
UnmarshalCQL(info TypeInfo, data []byte) error
}
Expand Down Expand Up @@ -115,7 +134,7 @@ func Marshal(info TypeInfo, value interface{}) ([]byte, error) {
case TypeBoolean:
return marshalBool(info, value)
case TypeTinyInt:
return marshalTinyInt(info, value)
return marshalTinyInt(value)
case TypeSmallInt:
return marshalSmallInt(info, value)
case TypeInt:
Expand Down Expand Up @@ -225,7 +244,7 @@ func Unmarshal(info TypeInfo, data []byte, value interface{}) error {
case TypeSmallInt:
return unmarshalSmallInt(info, data, value)
case TypeTinyInt:
return unmarshalTinyInt(info, data, value)
return unmarshalTinyInt(data, value)
case TypeFloat:
return unmarshalFloat(info, data, value)
case TypeDouble:
Expand Down Expand Up @@ -438,88 +457,12 @@ func marshalSmallInt(info TypeInfo, value interface{}) ([]byte, error) {
return nil, marshalErrorf("can not marshal %T into %s", value, info)
}

func marshalTinyInt(info TypeInfo, value interface{}) ([]byte, error) {
switch v := value.(type) {
case Marshaler:
return v.MarshalCQL(info)
case unsetColumn:
return nil, nil
case int8:
return []byte{byte(v)}, nil
case uint8:
return []byte{byte(v)}, nil
case int16:
if v > math.MaxInt8 || v < math.MinInt8 {
return nil, marshalErrorf("marshal tinyint: value %d out of range", v)
}
return []byte{byte(v)}, nil
case uint16:
if v > math.MaxUint8 {
return nil, marshalErrorf("marshal tinyint: value %d out of range", v)
}
return []byte{byte(v)}, nil
case int:
if v > math.MaxInt8 || v < math.MinInt8 {
return nil, marshalErrorf("marshal tinyint: value %d out of range", v)
}
return []byte{byte(v)}, nil
case int32:
if v > math.MaxInt8 || v < math.MinInt8 {
return nil, marshalErrorf("marshal tinyint: value %d out of range", v)
}
return []byte{byte(v)}, nil
case int64:
if v > math.MaxInt8 || v < math.MinInt8 {
return nil, marshalErrorf("marshal tinyint: value %d out of range", v)
}
return []byte{byte(v)}, nil
case uint:
if v > math.MaxUint8 {
return nil, marshalErrorf("marshal tinyint: value %d out of range", v)
}
return []byte{byte(v)}, nil
case uint32:
if v > math.MaxUint8 {
return nil, marshalErrorf("marshal tinyint: value %d out of range", v)
}
return []byte{byte(v)}, nil
case uint64:
if v > math.MaxUint8 {
return nil, marshalErrorf("marshal tinyint: value %d out of range", v)
}
return []byte{byte(v)}, nil
case string:
n, err := strconv.ParseInt(v, 10, 8)
if err != nil {
return nil, marshalErrorf("can not marshal %T into %s: %v", value, info, err)
}
return []byte{byte(n)}, nil
}

if value == nil {
return nil, nil
}

switch rv := reflect.ValueOf(value); rv.Type().Kind() {
case reflect.Int, reflect.Int64, reflect.Int32, reflect.Int16, reflect.Int8:
v := rv.Int()
if v > math.MaxInt8 || v < math.MinInt8 {
return nil, marshalErrorf("marshal tinyint: value %d out of range", v)
}
return []byte{byte(v)}, nil
case reflect.Uint, reflect.Uint64, reflect.Uint32, reflect.Uint16, reflect.Uint8:
v := rv.Uint()
if v > math.MaxUint8 {
return nil, marshalErrorf("marshal tinyint: value %d out of range", v)
}
return []byte{byte(v)}, nil
case reflect.Ptr:
if rv.IsNil() {
return nil, nil
}
func marshalTinyInt(value interface{}) ([]byte, error) {
data, err := tinyint.Marshal(value)
if err != nil {
return nil, MarshalError(err.(Error).Error())
}

return nil, marshalErrorf("can not marshal %T into %s", value, info)
return data, nil
}

func marshalInt(info TypeInfo, value interface{}) ([]byte, error) {
Expand Down Expand Up @@ -619,13 +562,6 @@ func decShort(p []byte) int16 {
return int16(p[0])<<8 | int16(p[1])
}

func decTiny(p []byte) int8 {
if len(p) != 1 {
return 0
}
return int8(p[0])
}

func marshalBigInt(info TypeInfo, value interface{}) ([]byte, error) {
switch v := value.(type) {
case Marshaler:
Expand Down Expand Up @@ -715,8 +651,11 @@ func unmarshalSmallInt(info TypeInfo, data []byte, value interface{}) error {
return unmarshalIntlike(info, int64(decShort(data)), data, value)
}

func unmarshalTinyInt(info TypeInfo, data []byte, value interface{}) error {
return unmarshalIntlike(info, int64(decTiny(data)), data, value)
func unmarshalTinyInt(data []byte, value interface{}) error {
if err := tinyint.Unmarshal(data, value); err != nil {
return UnmarshalError(err.(Error).Error())
}
return nil
}

func unmarshalVarint(info TypeInfo, data []byte, value interface{}) error {
Expand Down
7 changes: 7 additions & 0 deletions marshal/error.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
package marshal

import "errors"

var ErrorGOCQL = errors.New("gocql error")
var ErrorMarshal = errors.Join(ErrorGOCQL, errors.New("marshal error"))
var ErrorUnmarshal = errors.Join(ErrorGOCQL, errors.New("unmarshal error"))
72 changes: 72 additions & 0 deletions marshal/tinyint/marshal.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
package tinyint

import (
"math/big"
"reflect"
)

func Marshal(value interface{}) ([]byte, error) {
switch v := value.(type) {
case nil:
return nil, nil
case int8:
return EncInt8(v)
case int32:
return EncInt32(v)
case int16:
return EncInt16(v)
case int64:
return EncInt64(v)
case int:
return EncInt(v)

case uint8:
return EncUint8(v)
case uint16:
return EncUint16(v)
case uint32:
return EncUint32(v)
case uint64:
return EncUint64(v)
case uint:
return EncUint(v)

case big.Int:
return EncBigInt(v)
case string:
return EncString(v)

case *int8:
return EncInt8R(v)
case *int16:
return EncInt16R(v)
case *int32:
return EncInt32R(v)
case *int64:
return EncInt64R(v)
case *int:
return EncIntR(v)

case *uint8:
return EncUint8R(v)
case *uint16:
return EncUint16R(v)
case *uint32:
return EncUint32R(v)
case *uint64:
return EncUint64R(v)
case *uint:
return EncUintR(v)

case *big.Int:
return EncBigIntR(v)
case *string:
return EncStringR(v)
default:
rv := reflect.TypeOf(value)
if rv.Kind() != reflect.Ptr {
return EncReflect(reflect.ValueOf(v))
}
return EncReflectR(reflect.ValueOf(v))
}
}
Loading

0 comments on commit 40a54cb

Please sign in to comment.