package msgpack import ( "bytes" "fmt" "reflect" "sync" "github.com/vmihailenco/msgpack/codes" ) var extTypes = make(map[int8]reflect.Type) var bufferPool = &sync.Pool{ New: func() interface{} { return new(bytes.Buffer) }, } // RegisterExt records a type, identified by a value for that type, // under the provided id. That id will identify the concrete type of a value // sent or received as an interface variable. Only types that will be // transferred as implementations of interface values need to be registered. // Expecting to be used only during initialization, it panics if the mapping // between types and ids is not a bijection. func RegisterExt(id int8, value interface{}) { typ := reflect.TypeOf(value) if typ.Kind() == reflect.Ptr { typ = typ.Elem() } ptr := reflect.PtrTo(typ) if _, ok := extTypes[id]; ok { panic(fmt.Errorf("msgpack: ext with id=%d is already registered", id)) } registerExt(id, ptr, getEncoder(ptr), getDecoder(ptr)) registerExt(id, typ, getEncoder(typ), getDecoder(typ)) } func registerExt(id int8, typ reflect.Type, enc encoderFunc, dec decoderFunc) { if dec != nil { extTypes[id] = typ } if enc != nil { typEncMap[typ] = makeExtEncoder(id, enc) } if dec != nil { typDecMap[typ] = makeExtDecoder(id, dec) } } func (e *Encoder) EncodeExtHeader(typeId int8, length int) error { if err := e.encodeExtLen(length); err != nil { return err } if err := e.w.WriteByte(byte(typeId)); err != nil { return err } return nil } func makeExtEncoder(typeId int8, enc encoderFunc) encoderFunc { return func(e *Encoder, v reflect.Value) error { buf := bufferPool.Get().(*bytes.Buffer) defer bufferPool.Put(buf) buf.Reset() oldw := e.w e.w = buf err := enc(e, v) e.w = oldw if err != nil { return err } err = e.EncodeExtHeader(typeId, buf.Len()) if err != nil { return err } return e.write(buf.Bytes()) } } func makeExtDecoder(typeId int8, dec decoderFunc) decoderFunc { return func(d *Decoder, v reflect.Value) error { c, err := d.PeekCode() if err != nil { return err } if !codes.IsExt(c) { return dec(d, v) } id, extLen, err := d.DecodeExtHeader() if err != nil { return err } if id != typeId { return fmt.Errorf("msgpack: got ext type=%d, wanted %d", int8(c), typeId) } d.extLen = extLen return dec(d, v) } } func (e *Encoder) encodeExtLen(l int) error { switch l { case 1: return e.writeCode(codes.FixExt1) case 2: return e.writeCode(codes.FixExt2) case 4: return e.writeCode(codes.FixExt4) case 8: return e.writeCode(codes.FixExt8) case 16: return e.writeCode(codes.FixExt16) } if l < 256 { return e.write1(codes.Ext8, uint8(l)) } if l < 65536 { return e.write2(codes.Ext16, uint16(l)) } return e.write4(codes.Ext32, uint32(l)) } func (d *Decoder) parseExtLen(c codes.Code) (int, error) { switch c { case codes.FixExt1: return 1, nil case codes.FixExt2: return 2, nil case codes.FixExt4: return 4, nil case codes.FixExt8: return 8, nil case codes.FixExt16: return 16, nil case codes.Ext8: n, err := d.uint8() return int(n), err case codes.Ext16: n, err := d.uint16() return int(n), err case codes.Ext32: n, err := d.uint32() return int(n), err default: return 0, fmt.Errorf("msgpack: invalid code=%x decoding ext length", c) } } func (d *Decoder) decodeExtHeader(c codes.Code) (int8, int, error) { length, err := d.parseExtLen(c) if err != nil { return 0, 0, err } typeId, err := d.readCode() if err != nil { return 0, 0, err } return int8(typeId), length, nil } func (d *Decoder) DecodeExtHeader() (typeId int8, length int, err error) { c, err := d.readCode() if err != nil { return } return d.decodeExtHeader(c) } func (d *Decoder) extInterface(c codes.Code) (interface{}, error) { extId, extLen, err := d.decodeExtHeader(c) if err != nil { return nil, err } typ, ok := extTypes[extId] if !ok { return nil, fmt.Errorf("msgpack: unregistered ext id=%d", extId) } v := reflect.New(typ) d.extLen = extLen err = d.DecodeValue(v.Elem()) d.extLen = 0 if err != nil { return nil, err } return v.Interface(), nil } func (d *Decoder) skipExt(c codes.Code) error { n, err := d.parseExtLen(c) if err != nil { return err } return d.skipN(n + 1) } func (d *Decoder) skipExtHeader(c codes.Code) error { // Read ext type. _, err := d.readCode() if err != nil { return err } // Read ext body len. for i := 0; i < extHeaderLen(c); i++ { _, err := d.readCode() if err != nil { return err } } return nil } func extHeaderLen(c codes.Code) int { switch c { case codes.Ext8: return 1 case codes.Ext16: return 2 case codes.Ext32: return 4 } return 0 }