]>
Commit | Line | Data |
---|---|---|
1 | package msgpack | |
2 | ||
3 | import ( | |
4 | "errors" | |
5 | "fmt" | |
6 | "reflect" | |
7 | ||
8 | "github.com/vmihailenco/msgpack/codes" | |
9 | ) | |
10 | ||
11 | const mapElemsAllocLimit = 1e4 | |
12 | ||
13 | var mapStringStringPtrType = reflect.TypeOf((*map[string]string)(nil)) | |
14 | var mapStringStringType = mapStringStringPtrType.Elem() | |
15 | ||
16 | var mapStringInterfacePtrType = reflect.TypeOf((*map[string]interface{})(nil)) | |
17 | var mapStringInterfaceType = mapStringInterfacePtrType.Elem() | |
18 | ||
19 | var errInvalidCode = errors.New("invalid code") | |
20 | ||
21 | func decodeMapValue(d *Decoder, v reflect.Value) error { | |
22 | size, err := d.DecodeMapLen() | |
23 | if err != nil { | |
24 | return err | |
25 | } | |
26 | ||
27 | typ := v.Type() | |
28 | if size == -1 { | |
29 | v.Set(reflect.Zero(typ)) | |
30 | return nil | |
31 | } | |
32 | ||
33 | if v.IsNil() { | |
34 | v.Set(reflect.MakeMap(typ)) | |
35 | } | |
36 | if size == 0 { | |
37 | return nil | |
38 | } | |
39 | ||
40 | return decodeMapValueSize(d, v, size) | |
41 | } | |
42 | ||
43 | func decodeMapValueSize(d *Decoder, v reflect.Value, size int) error { | |
44 | typ := v.Type() | |
45 | keyType := typ.Key() | |
46 | valueType := typ.Elem() | |
47 | ||
48 | for i := 0; i < size; i++ { | |
49 | mk := reflect.New(keyType).Elem() | |
50 | if err := d.DecodeValue(mk); err != nil { | |
51 | return err | |
52 | } | |
53 | ||
54 | mv := reflect.New(valueType).Elem() | |
55 | if err := d.DecodeValue(mv); err != nil { | |
56 | return err | |
57 | } | |
58 | ||
59 | v.SetMapIndex(mk, mv) | |
60 | } | |
61 | ||
62 | return nil | |
63 | } | |
64 | ||
65 | // DecodeMapLen decodes map length. Length is -1 when map is nil. | |
66 | func (d *Decoder) DecodeMapLen() (int, error) { | |
67 | c, err := d.readCode() | |
68 | if err != nil { | |
69 | return 0, err | |
70 | } | |
71 | ||
72 | if codes.IsExt(c) { | |
73 | if err = d.skipExtHeader(c); err != nil { | |
74 | return 0, err | |
75 | } | |
76 | ||
77 | c, err = d.readCode() | |
78 | if err != nil { | |
79 | return 0, err | |
80 | } | |
81 | } | |
82 | return d.mapLen(c) | |
83 | } | |
84 | ||
85 | func (d *Decoder) mapLen(c codes.Code) (int, error) { | |
86 | size, err := d._mapLen(c) | |
87 | err = expandInvalidCodeMapLenError(c, err) | |
88 | return size, err | |
89 | } | |
90 | ||
91 | func (d *Decoder) _mapLen(c codes.Code) (int, error) { | |
92 | if c == codes.Nil { | |
93 | return -1, nil | |
94 | } | |
95 | if c >= codes.FixedMapLow && c <= codes.FixedMapHigh { | |
96 | return int(c & codes.FixedMapMask), nil | |
97 | } | |
98 | if c == codes.Map16 { | |
99 | size, err := d.uint16() | |
100 | return int(size), err | |
101 | } | |
102 | if c == codes.Map32 { | |
103 | size, err := d.uint32() | |
104 | return int(size), err | |
105 | } | |
106 | return 0, errInvalidCode | |
107 | } | |
108 | ||
109 | func expandInvalidCodeMapLenError(c codes.Code, err error) error { | |
110 | if err == errInvalidCode { | |
111 | return fmt.Errorf("msgpack: invalid code=%x decoding map length", c) | |
112 | } | |
113 | return err | |
114 | } | |
115 | ||
116 | func decodeMapStringStringValue(d *Decoder, v reflect.Value) error { | |
117 | mptr := v.Addr().Convert(mapStringStringPtrType).Interface().(*map[string]string) | |
118 | return d.decodeMapStringStringPtr(mptr) | |
119 | } | |
120 | ||
121 | func (d *Decoder) decodeMapStringStringPtr(ptr *map[string]string) error { | |
122 | size, err := d.DecodeMapLen() | |
123 | if err != nil { | |
124 | return err | |
125 | } | |
126 | if size == -1 { | |
127 | *ptr = nil | |
128 | return nil | |
129 | } | |
130 | ||
131 | m := *ptr | |
132 | if m == nil { | |
133 | *ptr = make(map[string]string, min(size, mapElemsAllocLimit)) | |
134 | m = *ptr | |
135 | } | |
136 | ||
137 | for i := 0; i < size; i++ { | |
138 | mk, err := d.DecodeString() | |
139 | if err != nil { | |
140 | return err | |
141 | } | |
142 | mv, err := d.DecodeString() | |
143 | if err != nil { | |
144 | return err | |
145 | } | |
146 | m[mk] = mv | |
147 | } | |
148 | ||
149 | return nil | |
150 | } | |
151 | ||
152 | func decodeMapStringInterfaceValue(d *Decoder, v reflect.Value) error { | |
153 | ptr := v.Addr().Convert(mapStringInterfacePtrType).Interface().(*map[string]interface{}) | |
154 | return d.decodeMapStringInterfacePtr(ptr) | |
155 | } | |
156 | ||
157 | func (d *Decoder) decodeMapStringInterfacePtr(ptr *map[string]interface{}) error { | |
158 | n, err := d.DecodeMapLen() | |
159 | if err != nil { | |
160 | return err | |
161 | } | |
162 | if n == -1 { | |
163 | *ptr = nil | |
164 | return nil | |
165 | } | |
166 | ||
167 | m := *ptr | |
168 | if m == nil { | |
169 | *ptr = make(map[string]interface{}, min(n, mapElemsAllocLimit)) | |
170 | m = *ptr | |
171 | } | |
172 | ||
173 | for i := 0; i < n; i++ { | |
174 | mk, err := d.DecodeString() | |
175 | if err != nil { | |
176 | return err | |
177 | } | |
178 | mv, err := d.decodeInterfaceCond() | |
179 | if err != nil { | |
180 | return err | |
181 | } | |
182 | m[mk] = mv | |
183 | } | |
184 | ||
185 | return nil | |
186 | } | |
187 | ||
188 | func (d *Decoder) DecodeMap() (interface{}, error) { | |
189 | if d.decodeMapFunc != nil { | |
190 | return d.decodeMapFunc(d) | |
191 | } | |
192 | ||
193 | size, err := d.DecodeMapLen() | |
194 | if err != nil { | |
195 | return nil, err | |
196 | } | |
197 | if size == -1 { | |
198 | return nil, nil | |
199 | } | |
200 | if size == 0 { | |
201 | return make(map[string]interface{}), nil | |
202 | } | |
203 | ||
204 | code, err := d.PeekCode() | |
205 | if err != nil { | |
206 | return nil, err | |
207 | } | |
208 | ||
209 | if codes.IsString(code) { | |
210 | return d.decodeMapStringInterfaceSize(size) | |
211 | } | |
212 | ||
213 | key, err := d.decodeInterfaceCond() | |
214 | if err != nil { | |
215 | return nil, err | |
216 | } | |
217 | ||
218 | value, err := d.decodeInterfaceCond() | |
219 | if err != nil { | |
220 | return nil, err | |
221 | } | |
222 | ||
223 | keyType := reflect.TypeOf(key) | |
224 | valueType := reflect.TypeOf(value) | |
225 | mapType := reflect.MapOf(keyType, valueType) | |
226 | mapValue := reflect.MakeMap(mapType) | |
227 | ||
228 | mapValue.SetMapIndex(reflect.ValueOf(key), reflect.ValueOf(value)) | |
229 | size-- | |
230 | ||
231 | err = decodeMapValueSize(d, mapValue, size) | |
232 | if err != nil { | |
233 | return nil, err | |
234 | } | |
235 | ||
236 | return mapValue.Interface(), nil | |
237 | } | |
238 | ||
239 | func (d *Decoder) decodeMapStringInterfaceSize(size int) (map[string]interface{}, error) { | |
240 | m := make(map[string]interface{}, min(size, mapElemsAllocLimit)) | |
241 | for i := 0; i < size; i++ { | |
242 | mk, err := d.DecodeString() | |
243 | if err != nil { | |
244 | return nil, err | |
245 | } | |
246 | mv, err := d.decodeInterfaceCond() | |
247 | if err != nil { | |
248 | return nil, err | |
249 | } | |
250 | m[mk] = mv | |
251 | } | |
252 | return m, nil | |
253 | } | |
254 | ||
255 | func (d *Decoder) skipMap(c codes.Code) error { | |
256 | n, err := d.mapLen(c) | |
257 | if err != nil { | |
258 | return err | |
259 | } | |
260 | for i := 0; i < n; i++ { | |
261 | if err := d.Skip(); err != nil { | |
262 | return err | |
263 | } | |
264 | if err := d.Skip(); err != nil { | |
265 | return err | |
266 | } | |
267 | } | |
268 | return nil | |
269 | } | |
270 | ||
271 | func decodeStructValue(d *Decoder, v reflect.Value) error { | |
272 | c, err := d.readCode() | |
273 | if err != nil { | |
274 | return err | |
275 | } | |
276 | ||
277 | var isArray bool | |
278 | ||
279 | n, err := d._mapLen(c) | |
280 | if err != nil { | |
281 | var err2 error | |
282 | n, err2 = d.arrayLen(c) | |
283 | if err2 != nil { | |
284 | return expandInvalidCodeMapLenError(c, err) | |
285 | } | |
286 | isArray = true | |
287 | } | |
288 | if n == -1 { | |
289 | if err = mustSet(v); err != nil { | |
290 | return err | |
291 | } | |
292 | v.Set(reflect.Zero(v.Type())) | |
293 | return nil | |
294 | } | |
295 | ||
296 | var fields *fields | |
297 | if d.useJSONTag { | |
298 | fields = jsonStructs.Fields(v.Type()) | |
299 | } else { | |
300 | fields = structs.Fields(v.Type()) | |
301 | } | |
302 | ||
303 | if isArray { | |
304 | for i, f := range fields.List { | |
305 | if i >= n { | |
306 | break | |
307 | } | |
308 | if err := f.DecodeValue(d, v); err != nil { | |
309 | return err | |
310 | } | |
311 | } | |
312 | // Skip extra values. | |
313 | for i := len(fields.List); i < n; i++ { | |
314 | if err := d.Skip(); err != nil { | |
315 | return err | |
316 | } | |
317 | } | |
318 | return nil | |
319 | } | |
320 | ||
321 | for i := 0; i < n; i++ { | |
322 | name, err := d.DecodeString() | |
323 | if err != nil { | |
324 | return err | |
325 | } | |
326 | if f := fields.Table[name]; f != nil { | |
327 | if err := f.DecodeValue(d, v); err != nil { | |
328 | return err | |
329 | } | |
330 | } else { | |
331 | if err := d.Skip(); err != nil { | |
332 | return err | |
333 | } | |
334 | } | |
335 | } | |
336 | ||
337 | return nil | |
338 | } |