]>
Commit | Line | Data |
---|---|---|
1 | package xmlutil | |
2 | ||
3 | import ( | |
4 | "encoding/base64" | |
5 | "encoding/xml" | |
6 | "fmt" | |
7 | "io" | |
8 | "reflect" | |
9 | "strconv" | |
10 | "strings" | |
11 | "time" | |
12 | ||
13 | "github.com/aws/aws-sdk-go/private/protocol" | |
14 | ) | |
15 | ||
16 | // UnmarshalXML deserializes an xml.Decoder into the container v. V | |
17 | // needs to match the shape of the XML expected to be decoded. | |
18 | // If the shape doesn't match unmarshaling will fail. | |
19 | func UnmarshalXML(v interface{}, d *xml.Decoder, wrapper string) error { | |
20 | n, err := XMLToStruct(d, nil) | |
21 | if err != nil { | |
22 | return err | |
23 | } | |
24 | if n.Children != nil { | |
25 | for _, root := range n.Children { | |
26 | for _, c := range root { | |
27 | if wrappedChild, ok := c.Children[wrapper]; ok { | |
28 | c = wrappedChild[0] // pull out wrapped element | |
29 | } | |
30 | ||
31 | err = parse(reflect.ValueOf(v), c, "") | |
32 | if err != nil { | |
33 | if err == io.EOF { | |
34 | return nil | |
35 | } | |
36 | return err | |
37 | } | |
38 | } | |
39 | } | |
40 | return nil | |
41 | } | |
42 | return nil | |
43 | } | |
44 | ||
45 | // parse deserializes any value from the XMLNode. The type tag is used to infer the type, or reflect | |
46 | // will be used to determine the type from r. | |
47 | func parse(r reflect.Value, node *XMLNode, tag reflect.StructTag) error { | |
48 | rtype := r.Type() | |
49 | if rtype.Kind() == reflect.Ptr { | |
50 | rtype = rtype.Elem() // check kind of actual element type | |
51 | } | |
52 | ||
53 | t := tag.Get("type") | |
54 | if t == "" { | |
55 | switch rtype.Kind() { | |
56 | case reflect.Struct: | |
57 | // also it can't be a time object | |
58 | if _, ok := r.Interface().(*time.Time); !ok { | |
59 | t = "structure" | |
60 | } | |
61 | case reflect.Slice: | |
62 | // also it can't be a byte slice | |
63 | if _, ok := r.Interface().([]byte); !ok { | |
64 | t = "list" | |
65 | } | |
66 | case reflect.Map: | |
67 | t = "map" | |
68 | } | |
69 | } | |
70 | ||
71 | switch t { | |
72 | case "structure": | |
73 | if field, ok := rtype.FieldByName("_"); ok { | |
74 | tag = field.Tag | |
75 | } | |
76 | return parseStruct(r, node, tag) | |
77 | case "list": | |
78 | return parseList(r, node, tag) | |
79 | case "map": | |
80 | return parseMap(r, node, tag) | |
81 | default: | |
82 | return parseScalar(r, node, tag) | |
83 | } | |
84 | } | |
85 | ||
86 | // parseStruct deserializes a structure and its fields from an XMLNode. Any nested | |
87 | // types in the structure will also be deserialized. | |
88 | func parseStruct(r reflect.Value, node *XMLNode, tag reflect.StructTag) error { | |
89 | t := r.Type() | |
90 | if r.Kind() == reflect.Ptr { | |
91 | if r.IsNil() { // create the structure if it's nil | |
92 | s := reflect.New(r.Type().Elem()) | |
93 | r.Set(s) | |
94 | r = s | |
95 | } | |
96 | ||
97 | r = r.Elem() | |
98 | t = t.Elem() | |
99 | } | |
100 | ||
101 | // unwrap any payloads | |
102 | if payload := tag.Get("payload"); payload != "" { | |
103 | field, _ := t.FieldByName(payload) | |
104 | return parseStruct(r.FieldByName(payload), node, field.Tag) | |
105 | } | |
106 | ||
107 | for i := 0; i < t.NumField(); i++ { | |
108 | field := t.Field(i) | |
109 | if c := field.Name[0:1]; strings.ToLower(c) == c { | |
110 | continue // ignore unexported fields | |
111 | } | |
112 | ||
113 | // figure out what this field is called | |
114 | name := field.Name | |
115 | if field.Tag.Get("flattened") != "" && field.Tag.Get("locationNameList") != "" { | |
116 | name = field.Tag.Get("locationNameList") | |
117 | } else if locName := field.Tag.Get("locationName"); locName != "" { | |
118 | name = locName | |
119 | } | |
120 | ||
121 | // try to find the field by name in elements | |
122 | elems := node.Children[name] | |
123 | ||
124 | if elems == nil { // try to find the field in attributes | |
125 | if val, ok := node.findElem(name); ok { | |
126 | elems = []*XMLNode{{Text: val}} | |
127 | } | |
128 | } | |
129 | ||
130 | member := r.FieldByName(field.Name) | |
131 | for _, elem := range elems { | |
132 | err := parse(member, elem, field.Tag) | |
133 | if err != nil { | |
134 | return err | |
135 | } | |
136 | } | |
137 | } | |
138 | return nil | |
139 | } | |
140 | ||
141 | // parseList deserializes a list of values from an XML node. Each list entry | |
142 | // will also be deserialized. | |
143 | func parseList(r reflect.Value, node *XMLNode, tag reflect.StructTag) error { | |
144 | t := r.Type() | |
145 | ||
146 | if tag.Get("flattened") == "" { // look at all item entries | |
147 | mname := "member" | |
148 | if name := tag.Get("locationNameList"); name != "" { | |
149 | mname = name | |
150 | } | |
151 | ||
152 | if Children, ok := node.Children[mname]; ok { | |
153 | if r.IsNil() { | |
154 | r.Set(reflect.MakeSlice(t, len(Children), len(Children))) | |
155 | } | |
156 | ||
157 | for i, c := range Children { | |
158 | err := parse(r.Index(i), c, "") | |
159 | if err != nil { | |
160 | return err | |
161 | } | |
162 | } | |
163 | } | |
164 | } else { // flattened list means this is a single element | |
165 | if r.IsNil() { | |
166 | r.Set(reflect.MakeSlice(t, 0, 0)) | |
167 | } | |
168 | ||
169 | childR := reflect.Zero(t.Elem()) | |
170 | r.Set(reflect.Append(r, childR)) | |
171 | err := parse(r.Index(r.Len()-1), node, "") | |
172 | if err != nil { | |
173 | return err | |
174 | } | |
175 | } | |
176 | ||
177 | return nil | |
178 | } | |
179 | ||
180 | // parseMap deserializes a map from an XMLNode. The direct children of the XMLNode | |
181 | // will also be deserialized as map entries. | |
182 | func parseMap(r reflect.Value, node *XMLNode, tag reflect.StructTag) error { | |
183 | if r.IsNil() { | |
184 | r.Set(reflect.MakeMap(r.Type())) | |
185 | } | |
186 | ||
187 | if tag.Get("flattened") == "" { // look at all child entries | |
188 | for _, entry := range node.Children["entry"] { | |
189 | parseMapEntry(r, entry, tag) | |
190 | } | |
191 | } else { // this element is itself an entry | |
192 | parseMapEntry(r, node, tag) | |
193 | } | |
194 | ||
195 | return nil | |
196 | } | |
197 | ||
198 | // parseMapEntry deserializes a map entry from a XML node. | |
199 | func parseMapEntry(r reflect.Value, node *XMLNode, tag reflect.StructTag) error { | |
200 | kname, vname := "key", "value" | |
201 | if n := tag.Get("locationNameKey"); n != "" { | |
202 | kname = n | |
203 | } | |
204 | if n := tag.Get("locationNameValue"); n != "" { | |
205 | vname = n | |
206 | } | |
207 | ||
208 | keys, ok := node.Children[kname] | |
209 | values := node.Children[vname] | |
210 | if ok { | |
211 | for i, key := range keys { | |
212 | keyR := reflect.ValueOf(key.Text) | |
213 | value := values[i] | |
214 | valueR := reflect.New(r.Type().Elem()).Elem() | |
215 | ||
216 | parse(valueR, value, "") | |
217 | r.SetMapIndex(keyR, valueR) | |
218 | } | |
219 | } | |
220 | return nil | |
221 | } | |
222 | ||
223 | // parseScaller deserializes an XMLNode value into a concrete type based on the | |
224 | // interface type of r. | |
225 | // | |
226 | // Error is returned if the deserialization fails due to invalid type conversion, | |
227 | // or unsupported interface type. | |
228 | func parseScalar(r reflect.Value, node *XMLNode, tag reflect.StructTag) error { | |
229 | switch r.Interface().(type) { | |
230 | case *string: | |
231 | r.Set(reflect.ValueOf(&node.Text)) | |
232 | return nil | |
233 | case []byte: | |
234 | b, err := base64.StdEncoding.DecodeString(node.Text) | |
235 | if err != nil { | |
236 | return err | |
237 | } | |
238 | r.Set(reflect.ValueOf(b)) | |
239 | case *bool: | |
240 | v, err := strconv.ParseBool(node.Text) | |
241 | if err != nil { | |
242 | return err | |
243 | } | |
244 | r.Set(reflect.ValueOf(&v)) | |
245 | case *int64: | |
246 | v, err := strconv.ParseInt(node.Text, 10, 64) | |
247 | if err != nil { | |
248 | return err | |
249 | } | |
250 | r.Set(reflect.ValueOf(&v)) | |
251 | case *float64: | |
252 | v, err := strconv.ParseFloat(node.Text, 64) | |
253 | if err != nil { | |
254 | return err | |
255 | } | |
256 | r.Set(reflect.ValueOf(&v)) | |
257 | case *time.Time: | |
258 | format := tag.Get("timestampFormat") | |
259 | if len(format) == 0 { | |
260 | format = protocol.ISO8601TimeFormatName | |
261 | } | |
262 | ||
263 | t, err := protocol.ParseTime(format, node.Text) | |
264 | if err != nil { | |
265 | return err | |
266 | } | |
267 | r.Set(reflect.ValueOf(&t)) | |
268 | default: | |
269 | return fmt.Errorf("unsupported value: %v (%s)", r.Interface(), r.Type()) | |
270 | } | |
271 | return nil | |
272 | } |