]>
Commit | Line | Data |
---|---|---|
bae9f6d2 JC |
1 | package xmlutil |
2 | ||
3 | import ( | |
863486a6 | 4 | "bytes" |
bae9f6d2 JC |
5 | "encoding/base64" |
6 | "encoding/xml" | |
7 | "fmt" | |
8 | "io" | |
9 | "reflect" | |
10 | "strconv" | |
11 | "strings" | |
12 | "time" | |
15c0b25d | 13 | |
863486a6 | 14 | "github.com/aws/aws-sdk-go/aws/awserr" |
15c0b25d | 15 | "github.com/aws/aws-sdk-go/private/protocol" |
bae9f6d2 JC |
16 | ) |
17 | ||
863486a6 AG |
18 | // UnmarshalXMLError unmarshals the XML error from the stream into the value |
19 | // type specified. The value must be a pointer. If the message fails to | |
20 | // unmarshal, the message content will be included in the returned error as a | |
21 | // awserr.UnmarshalError. | |
22 | func UnmarshalXMLError(v interface{}, stream io.Reader) error { | |
23 | var errBuf bytes.Buffer | |
24 | body := io.TeeReader(stream, &errBuf) | |
25 | ||
26 | err := xml.NewDecoder(body).Decode(v) | |
27 | if err != nil && err != io.EOF { | |
28 | return awserr.NewUnmarshalError(err, | |
29 | "failed to unmarshal error message", errBuf.Bytes()) | |
30 | } | |
31 | ||
32 | return nil | |
33 | } | |
34 | ||
bae9f6d2 JC |
35 | // UnmarshalXML deserializes an xml.Decoder into the container v. V |
36 | // needs to match the shape of the XML expected to be decoded. | |
37 | // If the shape doesn't match unmarshaling will fail. | |
38 | func UnmarshalXML(v interface{}, d *xml.Decoder, wrapper string) error { | |
39 | n, err := XMLToStruct(d, nil) | |
40 | if err != nil { | |
41 | return err | |
42 | } | |
43 | if n.Children != nil { | |
44 | for _, root := range n.Children { | |
45 | for _, c := range root { | |
46 | if wrappedChild, ok := c.Children[wrapper]; ok { | |
47 | c = wrappedChild[0] // pull out wrapped element | |
48 | } | |
49 | ||
50 | err = parse(reflect.ValueOf(v), c, "") | |
51 | if err != nil { | |
52 | if err == io.EOF { | |
53 | return nil | |
54 | } | |
55 | return err | |
56 | } | |
57 | } | |
58 | } | |
59 | return nil | |
60 | } | |
61 | return nil | |
62 | } | |
63 | ||
64 | // parse deserializes any value from the XMLNode. The type tag is used to infer the type, or reflect | |
65 | // will be used to determine the type from r. | |
66 | func parse(r reflect.Value, node *XMLNode, tag reflect.StructTag) error { | |
67 | rtype := r.Type() | |
68 | if rtype.Kind() == reflect.Ptr { | |
69 | rtype = rtype.Elem() // check kind of actual element type | |
70 | } | |
71 | ||
72 | t := tag.Get("type") | |
73 | if t == "" { | |
74 | switch rtype.Kind() { | |
75 | case reflect.Struct: | |
15c0b25d AP |
76 | // also it can't be a time object |
77 | if _, ok := r.Interface().(*time.Time); !ok { | |
78 | t = "structure" | |
79 | } | |
bae9f6d2 | 80 | case reflect.Slice: |
15c0b25d AP |
81 | // also it can't be a byte slice |
82 | if _, ok := r.Interface().([]byte); !ok { | |
83 | t = "list" | |
84 | } | |
bae9f6d2 JC |
85 | case reflect.Map: |
86 | t = "map" | |
87 | } | |
88 | } | |
89 | ||
90 | switch t { | |
91 | case "structure": | |
92 | if field, ok := rtype.FieldByName("_"); ok { | |
93 | tag = field.Tag | |
94 | } | |
95 | return parseStruct(r, node, tag) | |
96 | case "list": | |
97 | return parseList(r, node, tag) | |
98 | case "map": | |
99 | return parseMap(r, node, tag) | |
100 | default: | |
101 | return parseScalar(r, node, tag) | |
102 | } | |
103 | } | |
104 | ||
105 | // parseStruct deserializes a structure and its fields from an XMLNode. Any nested | |
106 | // types in the structure will also be deserialized. | |
107 | func parseStruct(r reflect.Value, node *XMLNode, tag reflect.StructTag) error { | |
108 | t := r.Type() | |
109 | if r.Kind() == reflect.Ptr { | |
110 | if r.IsNil() { // create the structure if it's nil | |
111 | s := reflect.New(r.Type().Elem()) | |
112 | r.Set(s) | |
113 | r = s | |
114 | } | |
115 | ||
116 | r = r.Elem() | |
117 | t = t.Elem() | |
118 | } | |
119 | ||
120 | // unwrap any payloads | |
121 | if payload := tag.Get("payload"); payload != "" { | |
122 | field, _ := t.FieldByName(payload) | |
123 | return parseStruct(r.FieldByName(payload), node, field.Tag) | |
124 | } | |
125 | ||
126 | for i := 0; i < t.NumField(); i++ { | |
127 | field := t.Field(i) | |
128 | if c := field.Name[0:1]; strings.ToLower(c) == c { | |
129 | continue // ignore unexported fields | |
130 | } | |
131 | ||
132 | // figure out what this field is called | |
133 | name := field.Name | |
134 | if field.Tag.Get("flattened") != "" && field.Tag.Get("locationNameList") != "" { | |
135 | name = field.Tag.Get("locationNameList") | |
136 | } else if locName := field.Tag.Get("locationName"); locName != "" { | |
137 | name = locName | |
138 | } | |
139 | ||
140 | // try to find the field by name in elements | |
141 | elems := node.Children[name] | |
142 | ||
143 | if elems == nil { // try to find the field in attributes | |
144 | if val, ok := node.findElem(name); ok { | |
145 | elems = []*XMLNode{{Text: val}} | |
146 | } | |
147 | } | |
148 | ||
149 | member := r.FieldByName(field.Name) | |
150 | for _, elem := range elems { | |
151 | err := parse(member, elem, field.Tag) | |
152 | if err != nil { | |
153 | return err | |
154 | } | |
155 | } | |
156 | } | |
157 | return nil | |
158 | } | |
159 | ||
160 | // parseList deserializes a list of values from an XML node. Each list entry | |
161 | // will also be deserialized. | |
162 | func parseList(r reflect.Value, node *XMLNode, tag reflect.StructTag) error { | |
163 | t := r.Type() | |
164 | ||
165 | if tag.Get("flattened") == "" { // look at all item entries | |
166 | mname := "member" | |
167 | if name := tag.Get("locationNameList"); name != "" { | |
168 | mname = name | |
169 | } | |
170 | ||
171 | if Children, ok := node.Children[mname]; ok { | |
172 | if r.IsNil() { | |
173 | r.Set(reflect.MakeSlice(t, len(Children), len(Children))) | |
174 | } | |
175 | ||
176 | for i, c := range Children { | |
177 | err := parse(r.Index(i), c, "") | |
178 | if err != nil { | |
179 | return err | |
180 | } | |
181 | } | |
182 | } | |
183 | } else { // flattened list means this is a single element | |
184 | if r.IsNil() { | |
185 | r.Set(reflect.MakeSlice(t, 0, 0)) | |
186 | } | |
187 | ||
188 | childR := reflect.Zero(t.Elem()) | |
189 | r.Set(reflect.Append(r, childR)) | |
190 | err := parse(r.Index(r.Len()-1), node, "") | |
191 | if err != nil { | |
192 | return err | |
193 | } | |
194 | } | |
195 | ||
196 | return nil | |
197 | } | |
198 | ||
199 | // parseMap deserializes a map from an XMLNode. The direct children of the XMLNode | |
200 | // will also be deserialized as map entries. | |
201 | func parseMap(r reflect.Value, node *XMLNode, tag reflect.StructTag) error { | |
202 | if r.IsNil() { | |
203 | r.Set(reflect.MakeMap(r.Type())) | |
204 | } | |
205 | ||
206 | if tag.Get("flattened") == "" { // look at all child entries | |
207 | for _, entry := range node.Children["entry"] { | |
208 | parseMapEntry(r, entry, tag) | |
209 | } | |
210 | } else { // this element is itself an entry | |
211 | parseMapEntry(r, node, tag) | |
212 | } | |
213 | ||
214 | return nil | |
215 | } | |
216 | ||
217 | // parseMapEntry deserializes a map entry from a XML node. | |
218 | func parseMapEntry(r reflect.Value, node *XMLNode, tag reflect.StructTag) error { | |
219 | kname, vname := "key", "value" | |
220 | if n := tag.Get("locationNameKey"); n != "" { | |
221 | kname = n | |
222 | } | |
223 | if n := tag.Get("locationNameValue"); n != "" { | |
224 | vname = n | |
225 | } | |
226 | ||
227 | keys, ok := node.Children[kname] | |
228 | values := node.Children[vname] | |
229 | if ok { | |
230 | for i, key := range keys { | |
231 | keyR := reflect.ValueOf(key.Text) | |
232 | value := values[i] | |
233 | valueR := reflect.New(r.Type().Elem()).Elem() | |
234 | ||
235 | parse(valueR, value, "") | |
236 | r.SetMapIndex(keyR, valueR) | |
237 | } | |
238 | } | |
239 | return nil | |
240 | } | |
241 | ||
242 | // parseScaller deserializes an XMLNode value into a concrete type based on the | |
243 | // interface type of r. | |
244 | // | |
245 | // Error is returned if the deserialization fails due to invalid type conversion, | |
246 | // or unsupported interface type. | |
247 | func parseScalar(r reflect.Value, node *XMLNode, tag reflect.StructTag) error { | |
248 | switch r.Interface().(type) { | |
249 | case *string: | |
250 | r.Set(reflect.ValueOf(&node.Text)) | |
251 | return nil | |
252 | case []byte: | |
253 | b, err := base64.StdEncoding.DecodeString(node.Text) | |
254 | if err != nil { | |
255 | return err | |
256 | } | |
257 | r.Set(reflect.ValueOf(b)) | |
258 | case *bool: | |
259 | v, err := strconv.ParseBool(node.Text) | |
260 | if err != nil { | |
261 | return err | |
262 | } | |
263 | r.Set(reflect.ValueOf(&v)) | |
264 | case *int64: | |
265 | v, err := strconv.ParseInt(node.Text, 10, 64) | |
266 | if err != nil { | |
267 | return err | |
268 | } | |
269 | r.Set(reflect.ValueOf(&v)) | |
270 | case *float64: | |
271 | v, err := strconv.ParseFloat(node.Text, 64) | |
272 | if err != nil { | |
273 | return err | |
274 | } | |
275 | r.Set(reflect.ValueOf(&v)) | |
276 | case *time.Time: | |
15c0b25d AP |
277 | format := tag.Get("timestampFormat") |
278 | if len(format) == 0 { | |
279 | format = protocol.ISO8601TimeFormatName | |
280 | } | |
281 | ||
282 | t, err := protocol.ParseTime(format, node.Text) | |
bae9f6d2 JC |
283 | if err != nil { |
284 | return err | |
285 | } | |
286 | r.Set(reflect.ValueOf(&t)) | |
287 | default: | |
288 | return fmt.Errorf("unsupported value: %v (%s)", r.Interface(), r.Type()) | |
289 | } | |
290 | return nil | |
291 | } |