]>
Commit | Line | Data |
---|---|---|
bae9f6d2 JC |
1 | package copystructure |
2 | ||
3 | import ( | |
4 | "errors" | |
5 | "reflect" | |
6 | "sync" | |
7 | ||
8 | "github.com/mitchellh/reflectwalk" | |
9 | ) | |
10 | ||
11 | // Copy returns a deep copy of v. | |
12 | func Copy(v interface{}) (interface{}, error) { | |
13 | return Config{}.Copy(v) | |
14 | } | |
15 | ||
16 | // CopierFunc is a function that knows how to deep copy a specific type. | |
17 | // Register these globally with the Copiers variable. | |
18 | type CopierFunc func(interface{}) (interface{}, error) | |
19 | ||
20 | // Copiers is a map of types that behave specially when they are copied. | |
21 | // If a type is found in this map while deep copying, this function | |
22 | // will be called to copy it instead of attempting to copy all fields. | |
23 | // | |
24 | // The key should be the type, obtained using: reflect.TypeOf(value with type). | |
25 | // | |
26 | // It is unsafe to write to this map after Copies have started. If you | |
27 | // are writing to this map while also copying, wrap all modifications to | |
28 | // this map as well as to Copy in a mutex. | |
29 | var Copiers map[reflect.Type]CopierFunc = make(map[reflect.Type]CopierFunc) | |
30 | ||
31 | // Must is a helper that wraps a call to a function returning | |
32 | // (interface{}, error) and panics if the error is non-nil. It is intended | |
33 | // for use in variable initializations and should only be used when a copy | |
34 | // error should be a crashing case. | |
35 | func Must(v interface{}, err error) interface{} { | |
36 | if err != nil { | |
37 | panic("copy error: " + err.Error()) | |
38 | } | |
39 | ||
40 | return v | |
41 | } | |
42 | ||
43 | var errPointerRequired = errors.New("Copy argument must be a pointer when Lock is true") | |
44 | ||
45 | type Config struct { | |
46 | // Lock any types that are a sync.Locker and are not a mutex while copying. | |
47 | // If there is an RLocker method, use that to get the sync.Locker. | |
48 | Lock bool | |
49 | ||
50 | // Copiers is a map of types associated with a CopierFunc. Use the global | |
51 | // Copiers map if this is nil. | |
52 | Copiers map[reflect.Type]CopierFunc | |
53 | } | |
54 | ||
55 | func (c Config) Copy(v interface{}) (interface{}, error) { | |
56 | if c.Lock && reflect.ValueOf(v).Kind() != reflect.Ptr { | |
57 | return nil, errPointerRequired | |
58 | } | |
59 | ||
60 | w := new(walker) | |
61 | if c.Lock { | |
62 | w.useLocks = true | |
63 | } | |
64 | ||
65 | if c.Copiers == nil { | |
66 | c.Copiers = Copiers | |
67 | } | |
68 | ||
69 | err := reflectwalk.Walk(v, w) | |
70 | if err != nil { | |
71 | return nil, err | |
72 | } | |
73 | ||
74 | // Get the result. If the result is nil, then we want to turn it | |
75 | // into a typed nil if we can. | |
76 | result := w.Result | |
77 | if result == nil { | |
78 | val := reflect.ValueOf(v) | |
79 | result = reflect.Indirect(reflect.New(val.Type())).Interface() | |
80 | } | |
81 | ||
82 | return result, nil | |
83 | } | |
84 | ||
85 | // Return the key used to index interfaces types we've seen. Store the number | |
86 | // of pointers in the upper 32bits, and the depth in the lower 32bits. This is | |
87 | // easy to calculate, easy to match a key with our current depth, and we don't | |
88 | // need to deal with initializing and cleaning up nested maps or slices. | |
89 | func ifaceKey(pointers, depth int) uint64 { | |
90 | return uint64(pointers)<<32 | uint64(depth) | |
91 | } | |
92 | ||
93 | type walker struct { | |
94 | Result interface{} | |
95 | ||
96 | depth int | |
97 | ignoreDepth int | |
98 | vals []reflect.Value | |
99 | cs []reflect.Value | |
100 | ||
101 | // This stores the number of pointers we've walked over, indexed by depth. | |
102 | ps []int | |
103 | ||
104 | // If an interface is indirected by a pointer, we need to know the type of | |
105 | // interface to create when creating the new value. Store the interface | |
106 | // types here, indexed by both the walk depth and the number of pointers | |
107 | // already seen at that depth. Use ifaceKey to calculate the proper uint64 | |
108 | // value. | |
109 | ifaceTypes map[uint64]reflect.Type | |
110 | ||
111 | // any locks we've taken, indexed by depth | |
112 | locks []sync.Locker | |
113 | // take locks while walking the structure | |
114 | useLocks bool | |
115 | } | |
116 | ||
117 | func (w *walker) Enter(l reflectwalk.Location) error { | |
118 | w.depth++ | |
119 | ||
120 | // ensure we have enough elements to index via w.depth | |
121 | for w.depth >= len(w.locks) { | |
122 | w.locks = append(w.locks, nil) | |
123 | } | |
124 | ||
125 | for len(w.ps) < w.depth+1 { | |
126 | w.ps = append(w.ps, 0) | |
127 | } | |
128 | ||
129 | return nil | |
130 | } | |
131 | ||
132 | func (w *walker) Exit(l reflectwalk.Location) error { | |
133 | locker := w.locks[w.depth] | |
134 | w.locks[w.depth] = nil | |
135 | if locker != nil { | |
136 | defer locker.Unlock() | |
137 | } | |
138 | ||
139 | // clear out pointers and interfaces as we exit the stack | |
140 | w.ps[w.depth] = 0 | |
141 | ||
142 | for k := range w.ifaceTypes { | |
143 | mask := uint64(^uint32(0)) | |
144 | if k&mask == uint64(w.depth) { | |
145 | delete(w.ifaceTypes, k) | |
146 | } | |
147 | } | |
148 | ||
149 | w.depth-- | |
150 | if w.ignoreDepth > w.depth { | |
151 | w.ignoreDepth = 0 | |
152 | } | |
153 | ||
154 | if w.ignoring() { | |
155 | return nil | |
156 | } | |
157 | ||
158 | switch l { | |
159 | case reflectwalk.Map: | |
160 | fallthrough | |
161 | case reflectwalk.Slice: | |
162 | // Pop map off our container | |
163 | w.cs = w.cs[:len(w.cs)-1] | |
164 | case reflectwalk.MapValue: | |
165 | // Pop off the key and value | |
166 | mv := w.valPop() | |
167 | mk := w.valPop() | |
168 | m := w.cs[len(w.cs)-1] | |
169 | ||
170 | // If mv is the zero value, SetMapIndex deletes the key form the map, | |
171 | // or in this case never adds it. We need to create a properly typed | |
172 | // zero value so that this key can be set. | |
173 | if !mv.IsValid() { | |
174 | mv = reflect.Zero(m.Type().Elem()) | |
175 | } | |
176 | m.SetMapIndex(mk, mv) | |
177 | case reflectwalk.SliceElem: | |
178 | // Pop off the value and the index and set it on the slice | |
179 | v := w.valPop() | |
180 | i := w.valPop().Interface().(int) | |
181 | if v.IsValid() { | |
182 | s := w.cs[len(w.cs)-1] | |
183 | se := s.Index(i) | |
184 | if se.CanSet() { | |
185 | se.Set(v) | |
186 | } | |
187 | } | |
188 | case reflectwalk.Struct: | |
189 | w.replacePointerMaybe() | |
190 | ||
191 | // Remove the struct from the container stack | |
192 | w.cs = w.cs[:len(w.cs)-1] | |
193 | case reflectwalk.StructField: | |
194 | // Pop off the value and the field | |
195 | v := w.valPop() | |
196 | f := w.valPop().Interface().(reflect.StructField) | |
197 | if v.IsValid() { | |
198 | s := w.cs[len(w.cs)-1] | |
199 | sf := reflect.Indirect(s).FieldByName(f.Name) | |
200 | ||
201 | if sf.CanSet() { | |
202 | sf.Set(v) | |
203 | } | |
204 | } | |
205 | case reflectwalk.WalkLoc: | |
206 | // Clear out the slices for GC | |
207 | w.cs = nil | |
208 | w.vals = nil | |
209 | } | |
210 | ||
211 | return nil | |
212 | } | |
213 | ||
214 | func (w *walker) Map(m reflect.Value) error { | |
215 | if w.ignoring() { | |
216 | return nil | |
217 | } | |
218 | w.lock(m) | |
219 | ||
220 | // Create the map. If the map itself is nil, then just make a nil map | |
221 | var newMap reflect.Value | |
222 | if m.IsNil() { | |
223 | newMap = reflect.Indirect(reflect.New(m.Type())) | |
224 | } else { | |
225 | newMap = reflect.MakeMap(m.Type()) | |
226 | } | |
227 | ||
228 | w.cs = append(w.cs, newMap) | |
229 | w.valPush(newMap) | |
230 | return nil | |
231 | } | |
232 | ||
233 | func (w *walker) MapElem(m, k, v reflect.Value) error { | |
234 | return nil | |
235 | } | |
236 | ||
237 | func (w *walker) PointerEnter(v bool) error { | |
238 | if v { | |
239 | w.ps[w.depth]++ | |
240 | } | |
241 | return nil | |
242 | } | |
243 | ||
244 | func (w *walker) PointerExit(v bool) error { | |
245 | if v { | |
246 | w.ps[w.depth]-- | |
247 | } | |
248 | return nil | |
249 | } | |
250 | ||
251 | func (w *walker) Interface(v reflect.Value) error { | |
252 | if !v.IsValid() { | |
253 | return nil | |
254 | } | |
255 | if w.ifaceTypes == nil { | |
256 | w.ifaceTypes = make(map[uint64]reflect.Type) | |
257 | } | |
258 | ||
259 | w.ifaceTypes[ifaceKey(w.ps[w.depth], w.depth)] = v.Type() | |
260 | return nil | |
261 | } | |
262 | ||
263 | func (w *walker) Primitive(v reflect.Value) error { | |
264 | if w.ignoring() { | |
265 | return nil | |
266 | } | |
267 | w.lock(v) | |
268 | ||
269 | // IsValid verifies the v is non-zero and CanInterface verifies | |
270 | // that we're allowed to read this value (unexported fields). | |
271 | var newV reflect.Value | |
272 | if v.IsValid() && v.CanInterface() { | |
273 | newV = reflect.New(v.Type()) | |
274 | newV.Elem().Set(v) | |
275 | } | |
276 | ||
277 | w.valPush(newV) | |
278 | w.replacePointerMaybe() | |
279 | return nil | |
280 | } | |
281 | ||
282 | func (w *walker) Slice(s reflect.Value) error { | |
283 | if w.ignoring() { | |
284 | return nil | |
285 | } | |
286 | w.lock(s) | |
287 | ||
288 | var newS reflect.Value | |
289 | if s.IsNil() { | |
290 | newS = reflect.Indirect(reflect.New(s.Type())) | |
291 | } else { | |
292 | newS = reflect.MakeSlice(s.Type(), s.Len(), s.Cap()) | |
293 | } | |
294 | ||
295 | w.cs = append(w.cs, newS) | |
296 | w.valPush(newS) | |
297 | return nil | |
298 | } | |
299 | ||
300 | func (w *walker) SliceElem(i int, elem reflect.Value) error { | |
301 | if w.ignoring() { | |
302 | return nil | |
303 | } | |
304 | ||
305 | // We don't write the slice here because elem might still be | |
306 | // arbitrarily complex. Just record the index and continue on. | |
307 | w.valPush(reflect.ValueOf(i)) | |
308 | ||
309 | return nil | |
310 | } | |
311 | ||
312 | func (w *walker) Struct(s reflect.Value) error { | |
313 | if w.ignoring() { | |
314 | return nil | |
315 | } | |
316 | w.lock(s) | |
317 | ||
318 | var v reflect.Value | |
319 | if c, ok := Copiers[s.Type()]; ok { | |
320 | // We have a Copier for this struct, so we use that copier to | |
321 | // get the copy, and we ignore anything deeper than this. | |
322 | w.ignoreDepth = w.depth | |
323 | ||
324 | dup, err := c(s.Interface()) | |
325 | if err != nil { | |
326 | return err | |
327 | } | |
328 | ||
329 | v = reflect.ValueOf(dup) | |
330 | } else { | |
331 | // No copier, we copy ourselves and allow reflectwalk to guide | |
332 | // us deeper into the structure for copying. | |
333 | v = reflect.New(s.Type()) | |
334 | } | |
335 | ||
336 | // Push the value onto the value stack for setting the struct field, | |
337 | // and add the struct itself to the containers stack in case we walk | |
338 | // deeper so that its own fields can be modified. | |
339 | w.valPush(v) | |
340 | w.cs = append(w.cs, v) | |
341 | ||
342 | return nil | |
343 | } | |
344 | ||
345 | func (w *walker) StructField(f reflect.StructField, v reflect.Value) error { | |
346 | if w.ignoring() { | |
347 | return nil | |
348 | } | |
349 | ||
350 | // If PkgPath is non-empty, this is a private (unexported) field. | |
351 | // We do not set this unexported since the Go runtime doesn't allow us. | |
352 | if f.PkgPath != "" { | |
353 | return reflectwalk.SkipEntry | |
354 | } | |
355 | ||
356 | // Push the field onto the stack, we'll handle it when we exit | |
357 | // the struct field in Exit... | |
358 | w.valPush(reflect.ValueOf(f)) | |
359 | return nil | |
360 | } | |
361 | ||
362 | // ignore causes the walker to ignore any more values until we exit this on | |
363 | func (w *walker) ignore() { | |
364 | w.ignoreDepth = w.depth | |
365 | } | |
366 | ||
367 | func (w *walker) ignoring() bool { | |
368 | return w.ignoreDepth > 0 && w.depth >= w.ignoreDepth | |
369 | } | |
370 | ||
371 | func (w *walker) pointerPeek() bool { | |
372 | return w.ps[w.depth] > 0 | |
373 | } | |
374 | ||
375 | func (w *walker) valPop() reflect.Value { | |
376 | result := w.vals[len(w.vals)-1] | |
377 | w.vals = w.vals[:len(w.vals)-1] | |
378 | ||
379 | // If we're out of values, that means we popped everything off. In | |
380 | // this case, we reset the result so the next pushed value becomes | |
381 | // the result. | |
382 | if len(w.vals) == 0 { | |
383 | w.Result = nil | |
384 | } | |
385 | ||
386 | return result | |
387 | } | |
388 | ||
389 | func (w *walker) valPush(v reflect.Value) { | |
390 | w.vals = append(w.vals, v) | |
391 | ||
392 | // If we haven't set the result yet, then this is the result since | |
393 | // it is the first (outermost) value we're seeing. | |
394 | if w.Result == nil && v.IsValid() { | |
395 | w.Result = v.Interface() | |
396 | } | |
397 | } | |
398 | ||
399 | func (w *walker) replacePointerMaybe() { | |
400 | // Determine the last pointer value. If it is NOT a pointer, then | |
401 | // we need to push that onto the stack. | |
402 | if !w.pointerPeek() { | |
403 | w.valPush(reflect.Indirect(w.valPop())) | |
404 | return | |
405 | } | |
406 | ||
407 | v := w.valPop() | |
408 | for i := 1; i < w.ps[w.depth]; i++ { | |
409 | if iType, ok := w.ifaceTypes[ifaceKey(w.ps[w.depth]-i, w.depth)]; ok { | |
410 | iface := reflect.New(iType).Elem() | |
411 | iface.Set(v) | |
412 | v = iface | |
413 | } | |
414 | ||
415 | p := reflect.New(v.Type()) | |
416 | p.Elem().Set(v) | |
417 | v = p | |
418 | } | |
419 | ||
420 | w.valPush(v) | |
421 | } | |
422 | ||
423 | // if this value is a Locker, lock it and add it to the locks slice | |
424 | func (w *walker) lock(v reflect.Value) { | |
425 | if !w.useLocks { | |
426 | return | |
427 | } | |
428 | ||
429 | if !v.IsValid() || !v.CanInterface() { | |
430 | return | |
431 | } | |
432 | ||
433 | type rlocker interface { | |
434 | RLocker() sync.Locker | |
435 | } | |
436 | ||
437 | var locker sync.Locker | |
438 | ||
439 | // We can't call Interface() on a value directly, since that requires | |
440 | // a copy. This is OK, since the pointer to a value which is a sync.Locker | |
441 | // is also a sync.Locker. | |
442 | if v.Kind() == reflect.Ptr { | |
443 | switch l := v.Interface().(type) { | |
444 | case rlocker: | |
445 | // don't lock a mutex directly | |
446 | if _, ok := l.(*sync.RWMutex); !ok { | |
447 | locker = l.RLocker() | |
448 | } | |
449 | case sync.Locker: | |
450 | locker = l | |
451 | } | |
452 | } else if v.CanAddr() { | |
453 | switch l := v.Addr().Interface().(type) { | |
454 | case rlocker: | |
455 | // don't lock a mutex directly | |
456 | if _, ok := l.(*sync.RWMutex); !ok { | |
457 | locker = l.RLocker() | |
458 | } | |
459 | case sync.Locker: | |
460 | locker = l | |
461 | } | |
462 | } | |
463 | ||
464 | // still no callable locker | |
465 | if locker == nil { | |
466 | return | |
467 | } | |
468 | ||
469 | // don't lock a mutex directly | |
470 | switch locker.(type) { | |
471 | case *sync.Mutex, *sync.RWMutex: | |
472 | return | |
473 | } | |
474 | ||
475 | locker.Lock() | |
476 | w.locks[w.depth] = locker | |
477 | } |