Skip to content

Commit

Permalink
encoder: Better support for Unix File Descriptors
Browse files Browse the repository at this point in the history
This fixes a bug where UnixFDs that are inside
structs and variants aren't handled properly. It fixes the decoder too.

Fixes #223
  • Loading branch information
wdouglass authored and jsouthworth committed Sep 8, 2021
1 parent cd3ee85 commit e5edbf7
Show file tree
Hide file tree
Showing 10 changed files with 259 additions and 100 deletions.
2 changes: 1 addition & 1 deletion conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ func TestCloseBeforeSignal(t *testing.T) {
FieldPath: MakeVariant(ObjectPath("/baz")),
},
}
err = msg.EncodeTo(pipewriter, binary.LittleEndian)
_, err = msg.EncodeTo(pipewriter, binary.LittleEndian)
if err != nil {
t.Fatal(err)
}
Expand Down
10 changes: 8 additions & 2 deletions decoder.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,16 @@ type decoder struct {
in io.Reader
order binary.ByteOrder
pos int
fds []int
}

// newDecoder returns a new decoder that reads values from in. The input is
// expected to be in the given byte order.
func newDecoder(in io.Reader, order binary.ByteOrder) *decoder {
func newDecoder(in io.Reader, order binary.ByteOrder, fds []int) *decoder {
dec := new(decoder)
dec.in = in
dec.order = order
dec.fds = fds
return dec
}

Expand Down Expand Up @@ -161,7 +163,11 @@ func (dec *decoder) decode(s string, depth int) interface{} {
variant.value = dec.decode(sig.str, depth+1)
return variant
case 'h':
return UnixFDIndex(dec.decode("u", depth).(uint32))
idx := dec.decode("u", depth).(uint32)
if int(idx) < len(dec.fds) {
return UnixFD(dec.fds[idx])
}
return UnixFDIndex(idx)
case 'a':
if len(s) > 1 && s[1] == '{' {
ksig := s[2:3]
Expand Down
4 changes: 2 additions & 2 deletions decoder_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,11 @@ func TestDecodeArrayEmptyStruct(t *testing.T) {
},
serial: 0x00000003,
}
err := msg.EncodeTo(buf, binary.LittleEndian)
_, err := msg.EncodeTo(buf, binary.LittleEndian)
if err != nil {
t.Fatal(err)
}
msg, err = DecodeMessage(buf)
msg, err = DecodeMessage(buf, make([]int, 0))
if err != nil {
t.Fatal(err)
}
Expand Down
24 changes: 18 additions & 6 deletions encoder.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,23 +12,26 @@ import (
// An encoder encodes values to the D-Bus wire format.
type encoder struct {
out io.Writer
fds []int
order binary.ByteOrder
pos int
}

// NewEncoder returns a new encoder that writes to out in the given byte order.
func newEncoder(out io.Writer, order binary.ByteOrder) *encoder {
return newEncoderAtOffset(out, 0, order)
func newEncoder(out io.Writer, order binary.ByteOrder, fds []int) *encoder {
enc := newEncoderAtOffset(out, 0, order, fds)
return enc
}

// newEncoderAtOffset returns a new encoder that writes to out in the given
// byte order. Specify the offset to initialize pos for proper alignment
// computation.
func newEncoderAtOffset(out io.Writer, offset int, order binary.ByteOrder) *encoder {
func newEncoderAtOffset(out io.Writer, offset int, order binary.ByteOrder, fds []int) *encoder {
enc := new(encoder)
enc.out = out
enc.order = order
enc.pos = offset
enc.fds = fds
return enc
}

Expand Down Expand Up @@ -102,7 +105,14 @@ func (enc *encoder) encode(v reflect.Value, depth int) {
enc.binwrite(uint16(v.Uint()))
enc.pos += 2
case reflect.Int, reflect.Int32:
enc.binwrite(int32(v.Int()))
if v.Type() == unixFDType {
fd := v.Int()
idx := len(enc.fds)
enc.fds = append(enc.fds, int(fd))
enc.binwrite(uint32(idx))
} else {
enc.binwrite(int32(v.Int()))
}
enc.pos += 4
case reflect.Uint, reflect.Uint32:
enc.binwrite(uint32(v.Uint()))
Expand Down Expand Up @@ -147,7 +157,7 @@ func (enc *encoder) encode(v reflect.Value, depth int) {
offset := enc.pos + n + enc.padding(n, alignment(v.Type().Elem()))

var buf bytes.Buffer
bufenc := newEncoderAtOffset(&buf, offset, enc.order)
bufenc := newEncoderAtOffset(&buf, offset, enc.order, enc.fds)

for i := 0; i < v.Len(); i++ {
bufenc.encode(v.Index(i), depth+1)
Expand All @@ -157,6 +167,7 @@ func (enc *encoder) encode(v reflect.Value, depth int) {
panic(FormatError("input exceeds array size limitation"))
}

enc.fds = bufenc.fds
enc.encode(reflect.ValueOf(uint32(buf.Len())), depth)
length := buf.Len()
enc.align(alignment(v.Type().Elem()))
Expand Down Expand Up @@ -202,12 +213,13 @@ func (enc *encoder) encode(v reflect.Value, depth int) {
offset := enc.pos + n + enc.padding(n, 8)

var buf bytes.Buffer
bufenc := newEncoderAtOffset(&buf, offset, enc.order)
bufenc := newEncoderAtOffset(&buf, offset, enc.order, enc.fds)
for _, k := range keys {
bufenc.align(8)
bufenc.encode(k, depth+2)
bufenc.encode(v.MapIndex(k), depth+2)
}
enc.fds = bufenc.fds
enc.encode(reflect.ValueOf(uint32(buf.Len())), depth)
length := buf.Len()
enc.align(8)
Expand Down
75 changes: 45 additions & 30 deletions encoder_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,11 @@ func TestEncodeArrayOfMaps(t *testing.T) {
for _, order := range []binary.ByteOrder{binary.LittleEndian, binary.BigEndian} {
for _, tt := range tests {
buf := new(bytes.Buffer)
enc := newEncoder(buf, order)
fds := make([]int, 0)
enc := newEncoder(buf, order, fds)
enc.Encode(tt.vs...)

dec := newDecoder(buf, order)
dec := newDecoder(buf, order, enc.fds)
v, err := dec.Decode(SignatureOf(tt.vs...))
if err != nil {
t.Errorf("%q: decode (%v) failed: %v", tt.name, order, err)
Expand All @@ -60,14 +61,15 @@ func TestEncodeArrayOfMaps(t *testing.T) {
func TestEncodeMapStringInterface(t *testing.T) {
val := map[string]interface{}{"foo": "bar"}
buf := new(bytes.Buffer)
fds := make([]int, 0)
order := binary.LittleEndian
enc := newEncoder(buf, binary.LittleEndian)
enc := newEncoder(buf, binary.LittleEndian, fds)
err := enc.Encode(val)
if err != nil {
t.Fatal(err)
}

dec := newDecoder(buf, order)
dec := newDecoder(buf, order, enc.fds)
v, err := dec.Decode(SignatureOf(val))
if err != nil {
t.Fatal(err)
Expand All @@ -85,14 +87,15 @@ type empty interface{}
func TestEncodeMapStringNamedInterface(t *testing.T) {
val := map[string]empty{"foo": "bar"}
buf := new(bytes.Buffer)
fds := make([]int, 0)
order := binary.LittleEndian
enc := newEncoder(buf, binary.LittleEndian)
enc := newEncoder(buf, binary.LittleEndian, fds)
err := enc.Encode(val)
if err != nil {
t.Fatal(err)
}

dec := newDecoder(buf, order)
dec := newDecoder(buf, order, enc.fds)
v, err := dec.Decode(SignatureOf(val))
if err != nil {
t.Fatal(err)
Expand All @@ -116,14 +119,15 @@ func (fooimpl) Foo() {}
func TestEncodeMapStringNonEmptyInterface(t *testing.T) {
val := map[string]fooer{"foo": fooimpl("bar")}
buf := new(bytes.Buffer)
fds := make([]int, 0)
order := binary.LittleEndian
enc := newEncoder(buf, binary.LittleEndian)
enc := newEncoder(buf, binary.LittleEndian, fds)
err := enc.Encode(val)
if err != nil {
t.Fatal(err)
}

dec := newDecoder(buf, order)
dec := newDecoder(buf, order, enc.fds)
v, err := dec.Decode(SignatureOf(val))
if err != nil {
t.Fatal(err)
Expand All @@ -138,14 +142,15 @@ func TestEncodeMapStringNonEmptyInterface(t *testing.T) {
func TestEncodeSliceInterface(t *testing.T) {
val := []interface{}{"foo", "bar"}
buf := new(bytes.Buffer)
fds := make([]int, 0)
order := binary.LittleEndian
enc := newEncoder(buf, binary.LittleEndian)
enc := newEncoder(buf, binary.LittleEndian, fds)
err := enc.Encode(val)
if err != nil {
t.Fatal(err)
}

dec := newDecoder(buf, order)
dec := newDecoder(buf, order, enc.fds)
v, err := dec.Decode(SignatureOf(val))
if err != nil {
t.Fatal(err)
Expand All @@ -161,14 +166,15 @@ func TestEncodeSliceInterface(t *testing.T) {
func TestEncodeSliceNamedInterface(t *testing.T) {
val := []empty{"foo", "bar"}
buf := new(bytes.Buffer)
fds := make([]int, 0)
order := binary.LittleEndian
enc := newEncoder(buf, binary.LittleEndian)
enc := newEncoder(buf, binary.LittleEndian, fds)
err := enc.Encode(val)
if err != nil {
t.Fatal(err)
}

dec := newDecoder(buf, order)
dec := newDecoder(buf, order, enc.fds)
v, err := dec.Decode(SignatureOf(val))
if err != nil {
t.Fatal(err)
Expand All @@ -194,14 +200,15 @@ func TestEncodeNestedInterface(t *testing.T) {
},
}
buf := new(bytes.Buffer)
fds := make([]int, 0)
order := binary.LittleEndian
enc := newEncoder(buf, binary.LittleEndian)
enc := newEncoder(buf, binary.LittleEndian, fds)
err := enc.Encode(val)
if err != nil {
t.Fatal(err)
}

dec := newDecoder(buf, order)
dec := newDecoder(buf, order, enc.fds)
v, err := dec.Decode(SignatureOf(val))
if err != nil {
t.Fatal(err)
Expand All @@ -217,14 +224,15 @@ func TestEncodeNestedInterface(t *testing.T) {
func TestEncodeInt(t *testing.T) {
val := 10
buf := new(bytes.Buffer)
fds := make([]int, 0)
order := binary.LittleEndian
enc := newEncoder(buf, binary.LittleEndian)
enc := newEncoder(buf, binary.LittleEndian, fds)
err := enc.Encode(val)
if err != nil {
t.Fatal(err)
}

dec := newDecoder(buf, order)
dec := newDecoder(buf, order, enc.fds)
v, err := dec.Decode(SignatureOf(val))
if err != nil {
t.Fatal(err)
Expand All @@ -240,14 +248,15 @@ func TestEncodeInt(t *testing.T) {
func TestEncodeIntToNonCovertible(t *testing.T) {
val := 150
buf := new(bytes.Buffer)
fds := make([]int, 0)
order := binary.LittleEndian
enc := newEncoder(buf, binary.LittleEndian)
enc := newEncoder(buf, binary.LittleEndian, fds)
err := enc.Encode(val)
if err != nil {
t.Fatal(err)
}

dec := newDecoder(buf, order)
dec := newDecoder(buf, order, enc.fds)
v, err := dec.Decode(SignatureOf(val))
if err != nil {
t.Fatal(err)
Expand All @@ -263,14 +272,15 @@ func TestEncodeIntToNonCovertible(t *testing.T) {
func TestEncodeUint(t *testing.T) {
val := uint(10)
buf := new(bytes.Buffer)
fds := make([]int, 0)
order := binary.LittleEndian
enc := newEncoder(buf, binary.LittleEndian)
enc := newEncoder(buf, binary.LittleEndian, fds)
err := enc.Encode(val)
if err != nil {
t.Fatal(err)
}

dec := newDecoder(buf, order)
dec := newDecoder(buf, order, enc.fds)
v, err := dec.Decode(SignatureOf(val))
if err != nil {
t.Fatal(err)
Expand All @@ -286,14 +296,15 @@ func TestEncodeUint(t *testing.T) {
func TestEncodeUintToNonCovertible(t *testing.T) {
val := uint(10)
buf := new(bytes.Buffer)
fds := make([]int, 0)
order := binary.LittleEndian
enc := newEncoder(buf, binary.LittleEndian)
enc := newEncoder(buf, binary.LittleEndian, fds)
err := enc.Encode(val)
if err != nil {
t.Fatal(err)
}

dec := newDecoder(buf, order)
dec := newDecoder(buf, order, enc.fds)
v, err := dec.Decode(SignatureOf(val))
if err != nil {
t.Fatal(err)
Expand All @@ -310,14 +321,15 @@ type boolean bool
func TestEncodeOfAssignableType(t *testing.T) {
val := boolean(true)
buf := new(bytes.Buffer)
fds := make([]int, 0)
order := binary.LittleEndian
enc := newEncoder(buf, binary.LittleEndian)
enc := newEncoder(buf, binary.LittleEndian, fds)
err := enc.Encode(val)
if err != nil {
t.Fatal(err)
}

dec := newDecoder(buf, order)
dec := newDecoder(buf, order, enc.fds)
v, err := dec.Decode(SignatureOf(val))
if err != nil {
t.Fatal(err)
Expand All @@ -344,14 +356,15 @@ func TestEncodeVariant(t *testing.T) {
},
}
buf := new(bytes.Buffer)
fds := make([]int, 0)
order := binary.LittleEndian
enc := newEncoder(buf, binary.LittleEndian)
enc := newEncoder(buf, binary.LittleEndian, fds)
err := enc.Encode(src)
if err != nil {
t.Fatal(err)
}

dec := newDecoder(buf, order)
dec := newDecoder(buf, order, enc.fds)
v, err := dec.Decode(SignatureOf(src))
if err != nil {
t.Fatal(err)
Expand All @@ -369,14 +382,15 @@ func TestEncodeVariantToList(t *testing.T) {
"foo": []interface{}{"a", "b", "c"},
}
buf := new(bytes.Buffer)
fds := make([]int, 0)
order := binary.LittleEndian
enc := newEncoder(buf, binary.LittleEndian)
enc := newEncoder(buf, binary.LittleEndian, fds)
err := enc.Encode(src)
if err != nil {
t.Fatal(err)
}

dec := newDecoder(buf, order)
dec := newDecoder(buf, order, enc.fds)
v, err := dec.Decode(SignatureOf(src))
if err != nil {
t.Fatal(err)
Expand All @@ -394,14 +408,15 @@ func TestEncodeVariantToUint64(t *testing.T) {
"foo": uint64(10),
}
buf := new(bytes.Buffer)
fds := make([]int, 0)
order := binary.LittleEndian
enc := newEncoder(buf, binary.LittleEndian)
enc := newEncoder(buf, binary.LittleEndian, fds)
err := enc.Encode(src)
if err != nil {
t.Fatal(err)
}

dec := newDecoder(buf, order)
dec := newDecoder(buf, order, enc.fds)
v, err := dec.Decode(SignatureOf(src))
if err != nil {
t.Fatal(err)
Expand Down
Loading

0 comments on commit e5edbf7

Please sign in to comment.