diff --git a/conn_test.go b/conn_test.go index b73fddd..c9dc2e3 100644 --- a/conn_test.go +++ b/conn_test.go @@ -206,7 +206,7 @@ func TestCloseBeforeSignal(t *testing.T) { FieldPath: MakeVariant(ObjectPath("/baz")), }, } - err = msg.EncodeTo(pipewriter, binary.LittleEndian) + _, err = msg.EncodeTo(pipewriter, binary.LittleEndian) if err != nil { t.Fatal(err) } diff --git a/decoder.go b/decoder.go index 1e3966a..89bfed9 100644 --- a/decoder.go +++ b/decoder.go @@ -10,14 +10,16 @@ type decoder struct { in io.Reader order binary.ByteOrder pos int + fds []int } // newDecoder returns a new decoder that reads values from in. The input is // expected to be in the given byte order. -func newDecoder(in io.Reader, order binary.ByteOrder) *decoder { +func newDecoder(in io.Reader, order binary.ByteOrder, fds []int) *decoder { dec := new(decoder) dec.in = in dec.order = order + dec.fds = fds return dec } @@ -161,7 +163,11 @@ func (dec *decoder) decode(s string, depth int) interface{} { variant.value = dec.decode(sig.str, depth+1) return variant case 'h': - return UnixFDIndex(dec.decode("u", depth).(uint32)) + idx := dec.decode("u", depth).(uint32) + if int(idx) < len(dec.fds) { + return UnixFD(dec.fds[idx]) + } + return UnixFDIndex(idx) case 'a': if len(s) > 1 && s[1] == '{' { ksig := s[2:3] diff --git a/decoder_test.go b/decoder_test.go index 2170b91..665d398 100644 --- a/decoder_test.go +++ b/decoder_test.go @@ -61,11 +61,11 @@ func TestDecodeArrayEmptyStruct(t *testing.T) { }, serial: 0x00000003, } - err := msg.EncodeTo(buf, binary.LittleEndian) + _, err := msg.EncodeTo(buf, binary.LittleEndian) if err != nil { t.Fatal(err) } - msg, err = DecodeMessage(buf) + msg, err = DecodeMessage(buf, make([]int, 0)) if err != nil { t.Fatal(err) } diff --git a/encoder.go b/encoder.go index 296de00..015b26c 100644 --- a/encoder.go +++ b/encoder.go @@ -12,23 +12,26 @@ import ( // An encoder encodes values to the D-Bus wire format. type encoder struct { out io.Writer + fds []int order binary.ByteOrder pos int } // NewEncoder returns a new encoder that writes to out in the given byte order. -func newEncoder(out io.Writer, order binary.ByteOrder) *encoder { - return newEncoderAtOffset(out, 0, order) +func newEncoder(out io.Writer, order binary.ByteOrder, fds []int) *encoder { + enc := newEncoderAtOffset(out, 0, order, fds) + return enc } // newEncoderAtOffset returns a new encoder that writes to out in the given // byte order. Specify the offset to initialize pos for proper alignment // computation. -func newEncoderAtOffset(out io.Writer, offset int, order binary.ByteOrder) *encoder { +func newEncoderAtOffset(out io.Writer, offset int, order binary.ByteOrder, fds []int) *encoder { enc := new(encoder) enc.out = out enc.order = order enc.pos = offset + enc.fds = fds return enc } @@ -102,7 +105,14 @@ func (enc *encoder) encode(v reflect.Value, depth int) { enc.binwrite(uint16(v.Uint())) enc.pos += 2 case reflect.Int, reflect.Int32: - enc.binwrite(int32(v.Int())) + if v.Type() == unixFDType { + fd := v.Int() + idx := len(enc.fds) + enc.fds = append(enc.fds, int(fd)) + enc.binwrite(uint32(idx)) + } else { + enc.binwrite(int32(v.Int())) + } enc.pos += 4 case reflect.Uint, reflect.Uint32: enc.binwrite(uint32(v.Uint())) @@ -147,7 +157,7 @@ func (enc *encoder) encode(v reflect.Value, depth int) { offset := enc.pos + n + enc.padding(n, alignment(v.Type().Elem())) var buf bytes.Buffer - bufenc := newEncoderAtOffset(&buf, offset, enc.order) + bufenc := newEncoderAtOffset(&buf, offset, enc.order, enc.fds) for i := 0; i < v.Len(); i++ { bufenc.encode(v.Index(i), depth+1) @@ -157,6 +167,7 @@ func (enc *encoder) encode(v reflect.Value, depth int) { panic(FormatError("input exceeds array size limitation")) } + enc.fds = bufenc.fds enc.encode(reflect.ValueOf(uint32(buf.Len())), depth) length := buf.Len() enc.align(alignment(v.Type().Elem())) @@ -202,12 +213,13 @@ func (enc *encoder) encode(v reflect.Value, depth int) { offset := enc.pos + n + enc.padding(n, 8) var buf bytes.Buffer - bufenc := newEncoderAtOffset(&buf, offset, enc.order) + bufenc := newEncoderAtOffset(&buf, offset, enc.order, enc.fds) for _, k := range keys { bufenc.align(8) bufenc.encode(k, depth+2) bufenc.encode(v.MapIndex(k), depth+2) } + enc.fds = bufenc.fds enc.encode(reflect.ValueOf(uint32(buf.Len())), depth) length := buf.Len() enc.align(8) diff --git a/encoder_test.go b/encoder_test.go index 9338640..c38b7e0 100644 --- a/encoder_test.go +++ b/encoder_test.go @@ -40,10 +40,11 @@ func TestEncodeArrayOfMaps(t *testing.T) { for _, order := range []binary.ByteOrder{binary.LittleEndian, binary.BigEndian} { for _, tt := range tests { buf := new(bytes.Buffer) - enc := newEncoder(buf, order) + fds := make([]int, 0) + enc := newEncoder(buf, order, fds) enc.Encode(tt.vs...) - dec := newDecoder(buf, order) + dec := newDecoder(buf, order, enc.fds) v, err := dec.Decode(SignatureOf(tt.vs...)) if err != nil { t.Errorf("%q: decode (%v) failed: %v", tt.name, order, err) @@ -60,14 +61,15 @@ func TestEncodeArrayOfMaps(t *testing.T) { func TestEncodeMapStringInterface(t *testing.T) { val := map[string]interface{}{"foo": "bar"} buf := new(bytes.Buffer) + fds := make([]int, 0) order := binary.LittleEndian - enc := newEncoder(buf, binary.LittleEndian) + enc := newEncoder(buf, binary.LittleEndian, fds) err := enc.Encode(val) if err != nil { t.Fatal(err) } - dec := newDecoder(buf, order) + dec := newDecoder(buf, order, enc.fds) v, err := dec.Decode(SignatureOf(val)) if err != nil { t.Fatal(err) @@ -85,14 +87,15 @@ type empty interface{} func TestEncodeMapStringNamedInterface(t *testing.T) { val := map[string]empty{"foo": "bar"} buf := new(bytes.Buffer) + fds := make([]int, 0) order := binary.LittleEndian - enc := newEncoder(buf, binary.LittleEndian) + enc := newEncoder(buf, binary.LittleEndian, fds) err := enc.Encode(val) if err != nil { t.Fatal(err) } - dec := newDecoder(buf, order) + dec := newDecoder(buf, order, enc.fds) v, err := dec.Decode(SignatureOf(val)) if err != nil { t.Fatal(err) @@ -116,14 +119,15 @@ func (fooimpl) Foo() {} func TestEncodeMapStringNonEmptyInterface(t *testing.T) { val := map[string]fooer{"foo": fooimpl("bar")} buf := new(bytes.Buffer) + fds := make([]int, 0) order := binary.LittleEndian - enc := newEncoder(buf, binary.LittleEndian) + enc := newEncoder(buf, binary.LittleEndian, fds) err := enc.Encode(val) if err != nil { t.Fatal(err) } - dec := newDecoder(buf, order) + dec := newDecoder(buf, order, enc.fds) v, err := dec.Decode(SignatureOf(val)) if err != nil { t.Fatal(err) @@ -138,14 +142,15 @@ func TestEncodeMapStringNonEmptyInterface(t *testing.T) { func TestEncodeSliceInterface(t *testing.T) { val := []interface{}{"foo", "bar"} buf := new(bytes.Buffer) + fds := make([]int, 0) order := binary.LittleEndian - enc := newEncoder(buf, binary.LittleEndian) + enc := newEncoder(buf, binary.LittleEndian, fds) err := enc.Encode(val) if err != nil { t.Fatal(err) } - dec := newDecoder(buf, order) + dec := newDecoder(buf, order, enc.fds) v, err := dec.Decode(SignatureOf(val)) if err != nil { t.Fatal(err) @@ -161,14 +166,15 @@ func TestEncodeSliceInterface(t *testing.T) { func TestEncodeSliceNamedInterface(t *testing.T) { val := []empty{"foo", "bar"} buf := new(bytes.Buffer) + fds := make([]int, 0) order := binary.LittleEndian - enc := newEncoder(buf, binary.LittleEndian) + enc := newEncoder(buf, binary.LittleEndian, fds) err := enc.Encode(val) if err != nil { t.Fatal(err) } - dec := newDecoder(buf, order) + dec := newDecoder(buf, order, enc.fds) v, err := dec.Decode(SignatureOf(val)) if err != nil { t.Fatal(err) @@ -194,14 +200,15 @@ func TestEncodeNestedInterface(t *testing.T) { }, } buf := new(bytes.Buffer) + fds := make([]int, 0) order := binary.LittleEndian - enc := newEncoder(buf, binary.LittleEndian) + enc := newEncoder(buf, binary.LittleEndian, fds) err := enc.Encode(val) if err != nil { t.Fatal(err) } - dec := newDecoder(buf, order) + dec := newDecoder(buf, order, enc.fds) v, err := dec.Decode(SignatureOf(val)) if err != nil { t.Fatal(err) @@ -217,14 +224,15 @@ func TestEncodeNestedInterface(t *testing.T) { func TestEncodeInt(t *testing.T) { val := 10 buf := new(bytes.Buffer) + fds := make([]int, 0) order := binary.LittleEndian - enc := newEncoder(buf, binary.LittleEndian) + enc := newEncoder(buf, binary.LittleEndian, fds) err := enc.Encode(val) if err != nil { t.Fatal(err) } - dec := newDecoder(buf, order) + dec := newDecoder(buf, order, enc.fds) v, err := dec.Decode(SignatureOf(val)) if err != nil { t.Fatal(err) @@ -240,14 +248,15 @@ func TestEncodeInt(t *testing.T) { func TestEncodeIntToNonCovertible(t *testing.T) { val := 150 buf := new(bytes.Buffer) + fds := make([]int, 0) order := binary.LittleEndian - enc := newEncoder(buf, binary.LittleEndian) + enc := newEncoder(buf, binary.LittleEndian, fds) err := enc.Encode(val) if err != nil { t.Fatal(err) } - dec := newDecoder(buf, order) + dec := newDecoder(buf, order, enc.fds) v, err := dec.Decode(SignatureOf(val)) if err != nil { t.Fatal(err) @@ -263,14 +272,15 @@ func TestEncodeIntToNonCovertible(t *testing.T) { func TestEncodeUint(t *testing.T) { val := uint(10) buf := new(bytes.Buffer) + fds := make([]int, 0) order := binary.LittleEndian - enc := newEncoder(buf, binary.LittleEndian) + enc := newEncoder(buf, binary.LittleEndian, fds) err := enc.Encode(val) if err != nil { t.Fatal(err) } - dec := newDecoder(buf, order) + dec := newDecoder(buf, order, enc.fds) v, err := dec.Decode(SignatureOf(val)) if err != nil { t.Fatal(err) @@ -286,14 +296,15 @@ func TestEncodeUint(t *testing.T) { func TestEncodeUintToNonCovertible(t *testing.T) { val := uint(10) buf := new(bytes.Buffer) + fds := make([]int, 0) order := binary.LittleEndian - enc := newEncoder(buf, binary.LittleEndian) + enc := newEncoder(buf, binary.LittleEndian, fds) err := enc.Encode(val) if err != nil { t.Fatal(err) } - dec := newDecoder(buf, order) + dec := newDecoder(buf, order, enc.fds) v, err := dec.Decode(SignatureOf(val)) if err != nil { t.Fatal(err) @@ -310,14 +321,15 @@ type boolean bool func TestEncodeOfAssignableType(t *testing.T) { val := boolean(true) buf := new(bytes.Buffer) + fds := make([]int, 0) order := binary.LittleEndian - enc := newEncoder(buf, binary.LittleEndian) + enc := newEncoder(buf, binary.LittleEndian, fds) err := enc.Encode(val) if err != nil { t.Fatal(err) } - dec := newDecoder(buf, order) + dec := newDecoder(buf, order, enc.fds) v, err := dec.Decode(SignatureOf(val)) if err != nil { t.Fatal(err) @@ -344,14 +356,15 @@ func TestEncodeVariant(t *testing.T) { }, } buf := new(bytes.Buffer) + fds := make([]int, 0) order := binary.LittleEndian - enc := newEncoder(buf, binary.LittleEndian) + enc := newEncoder(buf, binary.LittleEndian, fds) err := enc.Encode(src) if err != nil { t.Fatal(err) } - dec := newDecoder(buf, order) + dec := newDecoder(buf, order, enc.fds) v, err := dec.Decode(SignatureOf(src)) if err != nil { t.Fatal(err) @@ -369,14 +382,15 @@ func TestEncodeVariantToList(t *testing.T) { "foo": []interface{}{"a", "b", "c"}, } buf := new(bytes.Buffer) + fds := make([]int, 0) order := binary.LittleEndian - enc := newEncoder(buf, binary.LittleEndian) + enc := newEncoder(buf, binary.LittleEndian, fds) err := enc.Encode(src) if err != nil { t.Fatal(err) } - dec := newDecoder(buf, order) + dec := newDecoder(buf, order, enc.fds) v, err := dec.Decode(SignatureOf(src)) if err != nil { t.Fatal(err) @@ -394,14 +408,15 @@ func TestEncodeVariantToUint64(t *testing.T) { "foo": uint64(10), } buf := new(bytes.Buffer) + fds := make([]int, 0) order := binary.LittleEndian - enc := newEncoder(buf, binary.LittleEndian) + enc := newEncoder(buf, binary.LittleEndian, fds) err := enc.Encode(src) if err != nil { t.Fatal(err) } - dec := newDecoder(buf, order) + dec := newDecoder(buf, order, enc.fds) v, err := dec.Decode(SignatureOf(src)) if err != nil { t.Fatal(err) diff --git a/message.go b/message.go index cfdecb4..338e5ba 100644 --- a/message.go +++ b/message.go @@ -122,7 +122,7 @@ type header struct { // from the given reader. The byte order is figured out from the first byte. // The possibly returned error can be an error of the underlying reader, an // InvalidMessageError or a FormatError. -func DecodeMessage(rd io.Reader) (msg *Message, err error) { +func DecodeMessage(rd io.Reader, fds []int) (msg *Message, err error) { var order binary.ByteOrder var hlength, length uint32 var typ, flags, proto byte @@ -142,7 +142,7 @@ func DecodeMessage(rd io.Reader) (msg *Message, err error) { return nil, InvalidMessageError("invalid byte order") } - dec := newDecoder(rd, order) + dec := newDecoder(rd, order, fds) dec.pos = 1 msg = new(Message) @@ -166,7 +166,7 @@ func DecodeMessage(rd io.Reader) (msg *Message, err error) { if hlength+length+16 > 1<<27 { return nil, InvalidMessageError("message is too long") } - dec = newDecoder(io.MultiReader(bytes.NewBuffer(b), rd), order) + dec = newDecoder(io.MultiReader(bytes.NewBuffer(b), rd), order, fds) dec.pos = 12 vs, err = dec.Decode(Signature{"a(yv)"}) if err != nil { @@ -196,7 +196,7 @@ func DecodeMessage(rd io.Reader) (msg *Message, err error) { sig, _ := msg.Headers[FieldSignature].value.(Signature) if sig.str != "" { buf := bytes.NewBuffer(body) - dec = newDecoder(buf, order) + dec = newDecoder(buf, order, fds) vs, err := dec.Decode(sig) if err != nil { return nil, err @@ -207,12 +207,27 @@ func DecodeMessage(rd io.Reader) (msg *Message, err error) { return } +type nullwriter struct{} + +func (nullwriter) Write(p []byte) (cnt int, err error) { + return len(p), nil +} + +func (msg *Message) CountFds() (int, error) { + if len(msg.Body) == 0 { + return 0, nil + } + enc := newEncoder(nullwriter{}, nativeEndian, make([]int, 0)) + err := enc.Encode(msg.Body...) + return len(enc.fds), err +} + // EncodeTo encodes and sends a message to the given writer. The byte order must // be either binary.LittleEndian or binary.BigEndian. If the message is not // valid or an error occurs when writing, an error is returned. -func (msg *Message) EncodeTo(out io.Writer, order binary.ByteOrder) (err error) { - if err = msg.IsValid(); err != nil { - return +func (msg *Message) EncodeTo(out io.Writer, order binary.ByteOrder) (fds []int, err error) { + if err := msg.IsValid(); err != nil { + return make([]int, 0), err } var vs [7]interface{} switch order { @@ -221,10 +236,11 @@ func (msg *Message) EncodeTo(out io.Writer, order binary.ByteOrder) (err error) case binary.BigEndian: vs[0] = byte('B') default: - return errors.New("dbus: invalid byte order") + return make([]int, 0), errors.New("dbus: invalid byte order") } body := new(bytes.Buffer) - enc := newEncoder(body, order) + fds = make([]int, 0) + enc := newEncoder(body, order, fds) if len(msg.Body) != 0 { err = enc.Encode(msg.Body...) if err != nil { @@ -242,7 +258,7 @@ func (msg *Message) EncodeTo(out io.Writer, order binary.ByteOrder) (err error) } vs[6] = headers var buf bytes.Buffer - enc = newEncoder(&buf, order) + enc = newEncoder(&buf, order, enc.fds) err = enc.Encode(vs[:]...) if err != nil { return @@ -250,12 +266,12 @@ func (msg *Message) EncodeTo(out io.Writer, order binary.ByteOrder) (err error) enc.align(8) body.WriteTo(&buf) if buf.Len() > 1<<27 { - return InvalidMessageError("message is too long") + return make([]int, 0), InvalidMessageError("message is too long") } if _, err := buf.WriteTo(out); err != nil { - return err + return make([]int, 0), err } - return nil + return enc.fds, nil } // IsValid checks whether msg is a valid message and returns an diff --git a/proto_test.go b/proto_test.go index 364723e..4638aa2 100644 --- a/proto_test.go +++ b/proto_test.go @@ -84,7 +84,8 @@ var protoTests = []struct { func TestProto(t *testing.T) { for i, v := range protoTests { buf := new(bytes.Buffer) - bigEnc := newEncoder(buf, binary.BigEndian) + fds := make([]int, 0) + bigEnc := newEncoder(buf, binary.BigEndian, fds) bigEnc.Encode(v.vs...) marshalled := buf.Bytes() if !bytes.Equal(marshalled, v.bigEndian) { @@ -92,7 +93,8 @@ func TestProto(t *testing.T) { v.bigEndian) } buf.Reset() - litEnc := newEncoder(buf, binary.LittleEndian) + fds = make([]int, 0) + litEnc := newEncoder(buf, binary.LittleEndian, fds) litEnc.Encode(v.vs...) marshalled = buf.Bytes() if !bytes.Equal(marshalled, v.littleEndian) { @@ -105,7 +107,7 @@ func TestProto(t *testing.T) { unmarshalled = reflect.Append(unmarshalled, reflect.New(reflect.TypeOf(v.vs[i]))) } - bigDec := newDecoder(bytes.NewReader(v.bigEndian), binary.BigEndian) + bigDec := newDecoder(bytes.NewReader(v.bigEndian), binary.BigEndian, make([]int, 0)) vs, err := bigDec.Decode(SignatureOf(v.vs...)) if err != nil { t.Errorf("test %d (unmarshal be): %s\n", i+1, err) @@ -114,7 +116,7 @@ func TestProto(t *testing.T) { if !reflect.DeepEqual(vs, v.vs) { t.Errorf("test %d (unmarshal be): got %#v, but expected %#v\n", i+1, vs, v.vs) } - litDec := newDecoder(bytes.NewReader(v.littleEndian), binary.LittleEndian) + litDec := newDecoder(bytes.NewReader(v.littleEndian), binary.LittleEndian, make([]int, 0)) vs, err = litDec.Decode(SignatureOf(v.vs...)) if err != nil { t.Errorf("test %d (unmarshal le): %s\n", i+1, err) @@ -134,9 +136,10 @@ func TestProtoMap(t *testing.T) { } var n map[string]uint8 buf := new(bytes.Buffer) - enc := newEncoder(buf, binary.LittleEndian) + fds := make([]int, 0) + enc := newEncoder(buf, binary.LittleEndian, fds) enc.Encode(m) - dec := newDecoder(buf, binary.LittleEndian) + dec := newDecoder(buf, binary.LittleEndian, enc.fds) vs, err := dec.Decode(Signature{"a{sy}"}) if err != nil { t.Fatal(err) @@ -156,9 +159,10 @@ func TestProtoVariantStruct(t *testing.T) { B int16 }{1, 2}) buf := new(bytes.Buffer) - enc := newEncoder(buf, binary.LittleEndian) + fds := make([]int, 0) + enc := newEncoder(buf, binary.LittleEndian, fds) enc.Encode(v) - dec := newDecoder(buf, binary.LittleEndian) + dec := newDecoder(buf, binary.LittleEndian, enc.fds) vs, err := dec.Decode(Signature{"v"}) if err != nil { t.Fatal(err) @@ -186,9 +190,10 @@ func TestProtoStructTag(t *testing.T) { bar1.A = 234 bar2.C = 345 buf := new(bytes.Buffer) - enc := newEncoder(buf, binary.LittleEndian) + fds := make([]int, 0) + enc := newEncoder(buf, binary.LittleEndian, fds) enc.Encode(bar1) - dec := newDecoder(buf, binary.LittleEndian) + dec := newDecoder(buf, binary.LittleEndian, enc.fds) vs, err := dec.Decode(Signature{"(ii)"}) if err != nil { t.Fatal(err) @@ -248,11 +253,11 @@ func TestMessage(t *testing.T) { FieldMember: MakeVariant("baz"), } message.Body = make([]interface{}, 0) - err := message.EncodeTo(buf, binary.LittleEndian) + _, err := message.EncodeTo(buf, binary.LittleEndian) if err != nil { t.Error(err) } - _, err = DecodeMessage(buf) + _, err = DecodeMessage(buf, make([]int, 0)) if err != nil { t.Error(err) } @@ -260,7 +265,7 @@ func TestMessage(t *testing.T) { func TestProtoStructInterfaces(t *testing.T) { b := []byte{42} - vs, err := newDecoder(bytes.NewReader(b), binary.LittleEndian).Decode(Signature{"(y)"}) + vs, err := newDecoder(bytes.NewReader(b), binary.LittleEndian, make([]int, 0)).Decode(Signature{"(y)"}) if err != nil { t.Fatal(err) } @@ -312,7 +317,7 @@ func BenchmarkDecodeMessageSmall(b *testing.B) { b.StopTimer() buf := new(bytes.Buffer) - err = smallMessage.EncodeTo(buf, binary.LittleEndian) + _, err = smallMessage.EncodeTo(buf, binary.LittleEndian) if err != nil { b.Fatal(err) } @@ -320,7 +325,7 @@ func BenchmarkDecodeMessageSmall(b *testing.B) { b.StartTimer() for i := 0; i < b.N; i++ { rd = bytes.NewReader(decoded) - _, err = DecodeMessage(rd) + _, err = DecodeMessage(rd, make([]int, 0)) if err != nil { b.Fatal(err) } @@ -333,7 +338,7 @@ func BenchmarkDecodeMessageBig(b *testing.B) { b.StopTimer() buf := new(bytes.Buffer) - err = bigMessage.EncodeTo(buf, binary.LittleEndian) + _, err = bigMessage.EncodeTo(buf, binary.LittleEndian) if err != nil { b.Fatal(err) } @@ -341,7 +346,7 @@ func BenchmarkDecodeMessageBig(b *testing.B) { b.StartTimer() for i := 0; i < b.N; i++ { rd = bytes.NewReader(decoded) - _, err = DecodeMessage(rd) + _, err = DecodeMessage(rd, make([]int, 0)) if err != nil { b.Fatal(err) } @@ -351,7 +356,7 @@ func BenchmarkDecodeMessageBig(b *testing.B) { func BenchmarkEncodeMessageSmall(b *testing.B) { var err error for i := 0; i < b.N; i++ { - err = smallMessage.EncodeTo(ioutil.Discard, binary.LittleEndian) + _, err = smallMessage.EncodeTo(ioutil.Discard, binary.LittleEndian) if err != nil { b.Fatal(err) } @@ -361,7 +366,7 @@ func BenchmarkEncodeMessageSmall(b *testing.B) { func BenchmarkEncodeMessageBig(b *testing.B) { var err error for i := 0; i < b.N; i++ { - err = bigMessage.EncodeTo(ioutil.Discard, binary.LittleEndian) + _, err = bigMessage.EncodeTo(ioutil.Discard, binary.LittleEndian) if err != nil { b.Fatal(err) } diff --git a/transport_generic.go b/transport_generic.go index 718a1ff..aeb2be8 100644 --- a/transport_generic.go +++ b/transport_generic.go @@ -37,14 +37,17 @@ func (t genericTransport) SupportsUnixFDs() bool { func (t genericTransport) EnableUnixFDs() {} func (t genericTransport) ReadMessage() (*Message, error) { - return DecodeMessage(t) + return DecodeMessage(t, make([]int, 0)) } func (t genericTransport) SendMessage(msg *Message) error { - for _, v := range msg.Body { - if _, ok := v.(UnixFD); ok { - return errors.New("dbus: unix fd passing not enabled") - } + fds, err := msg.CountFds() + if err != nil { + return err } - return msg.EncodeTo(t, nativeEndian) + if fds != 0 { + return errors.New("dbus: unix fd passing not enabled") + } + _, err = msg.EncodeTo(t, nativeEndian) + return err } diff --git a/transport_unix.go b/transport_unix.go index 8ecbee8..5f73cb5 100644 --- a/transport_unix.go +++ b/transport_unix.go @@ -113,7 +113,7 @@ func (t *unixTransport) ReadMessage() (*Message, error) { if _, err := io.ReadFull(t.rdr, headerdata[4:]); err != nil { return nil, err } - dec := newDecoder(bytes.NewBuffer(headerdata), order) + dec := newDecoder(bytes.NewBuffer(headerdata), order, make([]int, 0)) dec.pos = 12 vs, err := dec.Decode(Signature{"a(yv)"}) if err != nil { @@ -147,7 +147,7 @@ func (t *unixTransport) ReadMessage() (*Message, error) { if err != nil { return nil, err } - msg, err := DecodeMessage(bytes.NewBuffer(all)) + msg, err := DecodeMessage(bytes.NewBuffer(all), fds) if err != nil { return nil, err } @@ -175,27 +175,25 @@ func (t *unixTransport) ReadMessage() (*Message, error) { } return msg, nil } - return DecodeMessage(bytes.NewBuffer(all)) + return DecodeMessage(bytes.NewBuffer(all), make([]int, 0)) } func (t *unixTransport) SendMessage(msg *Message) error { - fds := make([]int, 0) - for i, v := range msg.Body { - if fd, ok := v.(UnixFD); ok { - msg.Body[i] = UnixFDIndex(len(fds)) - fds = append(fds, int(fd)) - } + fdcnt, err := msg.CountFds() + if err != nil { + return err } - if len(fds) != 0 { + if fdcnt != 0 { if !t.hasUnixFDs { return errors.New("dbus: unix fd passing not enabled") } - msg.Headers[FieldUnixFDs] = MakeVariant(uint32(len(fds))) - oob := syscall.UnixRights(fds...) + msg.Headers[FieldUnixFDs] = MakeVariant(uint32(fdcnt)) buf := new(bytes.Buffer) - if err := msg.EncodeTo(buf, nativeEndian); err != nil { + fds, err := msg.EncodeTo(buf, nativeEndian) + if err != nil { return err } + oob := syscall.UnixRights(fds...) n, oobn, err := t.UnixConn.WriteMsgUnix(buf.Bytes(), oob, nil) if err != nil { return err @@ -204,7 +202,7 @@ func (t *unixTransport) SendMessage(msg *Message) error { return io.ErrShortWrite } } else { - if err := msg.EncodeTo(t, nativeEndian); err != nil { + if _, err := msg.EncodeTo(t, nativeEndian); err != nil { return err } } diff --git a/transport_unix_test.go b/transport_unix_test.go index b1d7bdf..5b053ba 100644 --- a/transport_unix_test.go +++ b/transport_unix_test.go @@ -8,10 +8,68 @@ import ( const testString = `This is a test! This text should be read from the file that is created by this test.` -type unixFDTest struct{} +type unixFDTest struct { + t *testing.T +} + +func (t unixFDTest) Testfd(fd UnixFD) (string, *Error) { + var b [4096]byte + file := os.NewFile(uintptr(fd), "testfile") + defer file.Close() + n, err := file.Read(b[:]) + if err != nil { + return "", &Error{"com.github.guelfey.test.Error", nil} + } + return string(b[:n]), nil +} + +func (t unixFDTest) Testvariant(v Variant) (string, *Error) { + var b [4096]byte + fd := v.Value().(UnixFD) + file := os.NewFile(uintptr(fd), "testfile") + defer file.Close() + n, err := file.Read(b[:]) + if err != nil { + return "", &Error{"com.github.guelfey.test.Error", nil} + } + return string(b[:n]), nil +} + +type unixfdContainer struct { + Fd UnixFD +} -func (t unixFDTest) Test(fd UnixFD) (string, *Error) { +func (t unixFDTest) Teststruct(s unixfdContainer) (string, *Error) { var b [4096]byte + file := os.NewFile(uintptr(s.Fd), "testfile") + defer file.Close() + n, err := file.Read(b[:]) + if err != nil { + return "", &Error{"com.github.guelfey.test.Error", nil} + } + return string(b[:n]), nil +} + +func (t unixFDTest) Testvariantstruct(vs Variant) (string, *Error) { + var b [4096]byte + s := vs.Value().([]interface{}) + u := s[0].(UnixFD) + file := os.NewFile(uintptr(u), "testfile") + defer file.Close() + n, err := file.Read(b[:]) + if err != nil { + return "", &Error{"com.github.guelfey.test.Error", nil} + } + return string(b[:n]), nil +} + +type variantContainer struct { + V Variant +} + +func (t unixFDTest) Teststructvariant(sv variantContainer) (string, *Error) { + var b [4096]byte + fd := sv.V.Value().(UnixFD) file := os.NewFile(uintptr(fd), "testfile") defer file.Close() n, err := file.Read(b[:]) @@ -32,19 +90,65 @@ func TestUnixFDs(t *testing.T) { t.Fatal(err) } defer w.Close() - if _, err := w.Write([]byte(testString)); err != nil { - t.Fatal(err) - } name := conn.Names()[0] - test := unixFDTest{} + test := unixFDTest{t} conn.Export(test, "/com/github/guelfey/test", "com.github.guelfey.test") var s string obj := conn.Object(name, "/com/github/guelfey/test") - err = obj.Call("com.github.guelfey.test.Test", 0, UnixFD(r.Fd())).Store(&s) + + if _, err := w.Write([]byte(testString)); err != nil { + t.Fatal(err) + } + err = obj.Call("com.github.guelfey.test.Testfd", 0, UnixFD(r.Fd())).Store(&s) if err != nil { t.Fatal(err) } if s != testString { t.Fatal("got", s, "wanted", testString) } + + if _, err := w.Write([]byte(testString)); err != nil { + t.Fatal(err) + } + err = obj.Call("com.github.guelfey.test.Testvariant", 0, MakeVariant(UnixFD(r.Fd()))).Store(&s) + if err != nil { + t.Fatal(err) + } + if s != testString { + t.Fatal("got", s, "wanted", testString) + } + + if _, err := w.Write([]byte(testString)); err != nil { + t.Fatal(err) + } + err = obj.Call("com.github.guelfey.test.Teststruct", 0, unixfdContainer{UnixFD(r.Fd())}).Store(&s) + if err != nil { + t.Fatal(err) + } + if s != testString { + t.Fatal("got", s, "wanted", testString) + } + + if _, err := w.Write([]byte(testString)); err != nil { + t.Fatal(err) + } + err = obj.Call("com.github.guelfey.test.Testvariantstruct", 0, MakeVariant(unixfdContainer{UnixFD(r.Fd())})).Store(&s) + if err != nil { + t.Fatal(err) + } + if s != testString { + t.Fatal("got", s, "wanted", testString) + } + + if _, err := w.Write([]byte(testString)); err != nil { + t.Fatal(err) + } + err = obj.Call("com.github.guelfey.test.Teststructvariant", 0, variantContainer{MakeVariant(UnixFD(r.Fd()))}).Store(&s) + if err != nil { + t.Fatal(err) + } + if s != testString { + t.Fatal("got", s, "wanted", testString) + } + }