]>
Commit | Line | Data |
---|---|---|
1 | package msgpack | |
2 | ||
3 | import ( | |
4 | "bytes" | |
5 | "fmt" | |
6 | "reflect" | |
7 | "sync" | |
8 | ||
9 | "github.com/vmihailenco/msgpack/codes" | |
10 | ) | |
11 | ||
12 | var extTypes = make(map[int8]reflect.Type) | |
13 | ||
14 | var bufferPool = &sync.Pool{ | |
15 | New: func() interface{} { | |
16 | return new(bytes.Buffer) | |
17 | }, | |
18 | } | |
19 | ||
20 | // RegisterExt records a type, identified by a value for that type, | |
21 | // under the provided id. That id will identify the concrete type of a value | |
22 | // sent or received as an interface variable. Only types that will be | |
23 | // transferred as implementations of interface values need to be registered. | |
24 | // Expecting to be used only during initialization, it panics if the mapping | |
25 | // between types and ids is not a bijection. | |
26 | func RegisterExt(id int8, value interface{}) { | |
27 | typ := reflect.TypeOf(value) | |
28 | if typ.Kind() == reflect.Ptr { | |
29 | typ = typ.Elem() | |
30 | } | |
31 | ptr := reflect.PtrTo(typ) | |
32 | ||
33 | if _, ok := extTypes[id]; ok { | |
34 | panic(fmt.Errorf("msgpack: ext with id=%d is already registered", id)) | |
35 | } | |
36 | ||
37 | registerExt(id, ptr, getEncoder(ptr), getDecoder(ptr)) | |
38 | registerExt(id, typ, getEncoder(typ), getDecoder(typ)) | |
39 | } | |
40 | ||
41 | func registerExt(id int8, typ reflect.Type, enc encoderFunc, dec decoderFunc) { | |
42 | if dec != nil { | |
43 | extTypes[id] = typ | |
44 | } | |
45 | if enc != nil { | |
46 | typEncMap[typ] = makeExtEncoder(id, enc) | |
47 | } | |
48 | if dec != nil { | |
49 | typDecMap[typ] = makeExtDecoder(id, dec) | |
50 | } | |
51 | } | |
52 | ||
53 | func (e *Encoder) EncodeExtHeader(typeId int8, length int) error { | |
54 | if err := e.encodeExtLen(length); err != nil { | |
55 | return err | |
56 | } | |
57 | if err := e.w.WriteByte(byte(typeId)); err != nil { | |
58 | return err | |
59 | } | |
60 | return nil | |
61 | } | |
62 | ||
63 | func makeExtEncoder(typeId int8, enc encoderFunc) encoderFunc { | |
64 | return func(e *Encoder, v reflect.Value) error { | |
65 | buf := bufferPool.Get().(*bytes.Buffer) | |
66 | defer bufferPool.Put(buf) | |
67 | buf.Reset() | |
68 | ||
69 | oldw := e.w | |
70 | e.w = buf | |
71 | err := enc(e, v) | |
72 | e.w = oldw | |
73 | ||
74 | if err != nil { | |
75 | return err | |
76 | } | |
77 | ||
78 | err = e.EncodeExtHeader(typeId, buf.Len()) | |
79 | if err != nil { | |
80 | return err | |
81 | } | |
82 | return e.write(buf.Bytes()) | |
83 | } | |
84 | } | |
85 | ||
86 | func makeExtDecoder(typeId int8, dec decoderFunc) decoderFunc { | |
87 | return func(d *Decoder, v reflect.Value) error { | |
88 | c, err := d.PeekCode() | |
89 | if err != nil { | |
90 | return err | |
91 | } | |
92 | ||
93 | if !codes.IsExt(c) { | |
94 | return dec(d, v) | |
95 | } | |
96 | ||
97 | id, extLen, err := d.DecodeExtHeader() | |
98 | if err != nil { | |
99 | return err | |
100 | } | |
101 | ||
102 | if id != typeId { | |
103 | return fmt.Errorf("msgpack: got ext type=%d, wanted %d", int8(c), typeId) | |
104 | } | |
105 | ||
106 | d.extLen = extLen | |
107 | return dec(d, v) | |
108 | } | |
109 | } | |
110 | ||
111 | func (e *Encoder) encodeExtLen(l int) error { | |
112 | switch l { | |
113 | case 1: | |
114 | return e.writeCode(codes.FixExt1) | |
115 | case 2: | |
116 | return e.writeCode(codes.FixExt2) | |
117 | case 4: | |
118 | return e.writeCode(codes.FixExt4) | |
119 | case 8: | |
120 | return e.writeCode(codes.FixExt8) | |
121 | case 16: | |
122 | return e.writeCode(codes.FixExt16) | |
123 | } | |
124 | if l < 256 { | |
125 | return e.write1(codes.Ext8, uint8(l)) | |
126 | } | |
127 | if l < 65536 { | |
128 | return e.write2(codes.Ext16, uint16(l)) | |
129 | } | |
130 | return e.write4(codes.Ext32, uint32(l)) | |
131 | } | |
132 | ||
133 | func (d *Decoder) parseExtLen(c codes.Code) (int, error) { | |
134 | switch c { | |
135 | case codes.FixExt1: | |
136 | return 1, nil | |
137 | case codes.FixExt2: | |
138 | return 2, nil | |
139 | case codes.FixExt4: | |
140 | return 4, nil | |
141 | case codes.FixExt8: | |
142 | return 8, nil | |
143 | case codes.FixExt16: | |
144 | return 16, nil | |
145 | case codes.Ext8: | |
146 | n, err := d.uint8() | |
147 | return int(n), err | |
148 | case codes.Ext16: | |
149 | n, err := d.uint16() | |
150 | return int(n), err | |
151 | case codes.Ext32: | |
152 | n, err := d.uint32() | |
153 | return int(n), err | |
154 | default: | |
155 | return 0, fmt.Errorf("msgpack: invalid code=%x decoding ext length", c) | |
156 | } | |
157 | } | |
158 | ||
159 | func (d *Decoder) decodeExtHeader(c codes.Code) (int8, int, error) { | |
160 | length, err := d.parseExtLen(c) | |
161 | if err != nil { | |
162 | return 0, 0, err | |
163 | } | |
164 | ||
165 | typeId, err := d.readCode() | |
166 | if err != nil { | |
167 | return 0, 0, err | |
168 | } | |
169 | ||
170 | return int8(typeId), length, nil | |
171 | } | |
172 | ||
173 | func (d *Decoder) DecodeExtHeader() (typeId int8, length int, err error) { | |
174 | c, err := d.readCode() | |
175 | if err != nil { | |
176 | return | |
177 | } | |
178 | return d.decodeExtHeader(c) | |
179 | } | |
180 | ||
181 | func (d *Decoder) extInterface(c codes.Code) (interface{}, error) { | |
182 | extId, extLen, err := d.decodeExtHeader(c) | |
183 | if err != nil { | |
184 | return nil, err | |
185 | } | |
186 | ||
187 | typ, ok := extTypes[extId] | |
188 | if !ok { | |
189 | return nil, fmt.Errorf("msgpack: unregistered ext id=%d", extId) | |
190 | } | |
191 | ||
192 | v := reflect.New(typ) | |
193 | ||
194 | d.extLen = extLen | |
195 | err = d.DecodeValue(v.Elem()) | |
196 | d.extLen = 0 | |
197 | if err != nil { | |
198 | return nil, err | |
199 | } | |
200 | ||
201 | return v.Interface(), nil | |
202 | } | |
203 | ||
204 | func (d *Decoder) skipExt(c codes.Code) error { | |
205 | n, err := d.parseExtLen(c) | |
206 | if err != nil { | |
207 | return err | |
208 | } | |
209 | return d.skipN(n + 1) | |
210 | } | |
211 | ||
212 | func (d *Decoder) skipExtHeader(c codes.Code) error { | |
213 | // Read ext type. | |
214 | _, err := d.readCode() | |
215 | if err != nil { | |
216 | return err | |
217 | } | |
218 | // Read ext body len. | |
219 | for i := 0; i < extHeaderLen(c); i++ { | |
220 | _, err := d.readCode() | |
221 | if err != nil { | |
222 | return err | |
223 | } | |
224 | } | |
225 | return nil | |
226 | } | |
227 | ||
228 | func extHeaderLen(c codes.Code) int { | |
229 | switch c { | |
230 | case codes.Ext8: | |
231 | return 1 | |
232 | case codes.Ext16: | |
233 | return 2 | |
234 | case codes.Ext32: | |
235 | return 4 | |
236 | } | |
237 | return 0 | |
238 | } |