]>
Commit | Line | Data |
---|---|---|
bae9f6d2 JC |
1 | package copystructure |
2 | ||
3 | import ( | |
4 | "errors" | |
5 | "reflect" | |
6 | "sync" | |
7 | ||
8 | "github.com/mitchellh/reflectwalk" | |
9 | ) | |
10 | ||
11 | // Copy returns a deep copy of v. | |
12 | func Copy(v interface{}) (interface{}, error) { | |
13 | return Config{}.Copy(v) | |
14 | } | |
15 | ||
16 | // CopierFunc is a function that knows how to deep copy a specific type. | |
17 | // Register these globally with the Copiers variable. | |
18 | type CopierFunc func(interface{}) (interface{}, error) | |
19 | ||
20 | // Copiers is a map of types that behave specially when they are copied. | |
21 | // If a type is found in this map while deep copying, this function | |
22 | // will be called to copy it instead of attempting to copy all fields. | |
23 | // | |
24 | // The key should be the type, obtained using: reflect.TypeOf(value with type). | |
25 | // | |
26 | // It is unsafe to write to this map after Copies have started. If you | |
27 | // are writing to this map while also copying, wrap all modifications to | |
28 | // this map as well as to Copy in a mutex. | |
29 | var Copiers map[reflect.Type]CopierFunc = make(map[reflect.Type]CopierFunc) | |
30 | ||
31 | // Must is a helper that wraps a call to a function returning | |
32 | // (interface{}, error) and panics if the error is non-nil. It is intended | |
33 | // for use in variable initializations and should only be used when a copy | |
34 | // error should be a crashing case. | |
35 | func Must(v interface{}, err error) interface{} { | |
36 | if err != nil { | |
37 | panic("copy error: " + err.Error()) | |
38 | } | |
39 | ||
40 | return v | |
41 | } | |
42 | ||
43 | var errPointerRequired = errors.New("Copy argument must be a pointer when Lock is true") | |
44 | ||
45 | type Config struct { | |
46 | // Lock any types that are a sync.Locker and are not a mutex while copying. | |
47 | // If there is an RLocker method, use that to get the sync.Locker. | |
48 | Lock bool | |
49 | ||
50 | // Copiers is a map of types associated with a CopierFunc. Use the global | |
51 | // Copiers map if this is nil. | |
52 | Copiers map[reflect.Type]CopierFunc | |
53 | } | |
54 | ||
55 | func (c Config) Copy(v interface{}) (interface{}, error) { | |
56 | if c.Lock && reflect.ValueOf(v).Kind() != reflect.Ptr { | |
57 | return nil, errPointerRequired | |
58 | } | |
59 | ||
60 | w := new(walker) | |
61 | if c.Lock { | |
62 | w.useLocks = true | |
63 | } | |
64 | ||
65 | if c.Copiers == nil { | |
66 | c.Copiers = Copiers | |
67 | } | |
68 | ||
69 | err := reflectwalk.Walk(v, w) | |
70 | if err != nil { | |
71 | return nil, err | |
72 | } | |
73 | ||
74 | // Get the result. If the result is nil, then we want to turn it | |
75 | // into a typed nil if we can. | |
76 | result := w.Result | |
77 | if result == nil { | |
78 | val := reflect.ValueOf(v) | |
79 | result = reflect.Indirect(reflect.New(val.Type())).Interface() | |
80 | } | |
81 | ||
82 | return result, nil | |
83 | } | |
84 | ||
85 | // Return the key used to index interfaces types we've seen. Store the number | |
86 | // of pointers in the upper 32bits, and the depth in the lower 32bits. This is | |
87 | // easy to calculate, easy to match a key with our current depth, and we don't | |
88 | // need to deal with initializing and cleaning up nested maps or slices. | |
89 | func ifaceKey(pointers, depth int) uint64 { | |
90 | return uint64(pointers)<<32 | uint64(depth) | |
91 | } | |
92 | ||
93 | type walker struct { | |
94 | Result interface{} | |
95 | ||
96 | depth int | |
97 | ignoreDepth int | |
98 | vals []reflect.Value | |
99 | cs []reflect.Value | |
100 | ||
101 | // This stores the number of pointers we've walked over, indexed by depth. | |
102 | ps []int | |
103 | ||
104 | // If an interface is indirected by a pointer, we need to know the type of | |
105 | // interface to create when creating the new value. Store the interface | |
106 | // types here, indexed by both the walk depth and the number of pointers | |
107 | // already seen at that depth. Use ifaceKey to calculate the proper uint64 | |
108 | // value. | |
109 | ifaceTypes map[uint64]reflect.Type | |
110 | ||
111 | // any locks we've taken, indexed by depth | |
112 | locks []sync.Locker | |
113 | // take locks while walking the structure | |
114 | useLocks bool | |
115 | } | |
116 | ||
117 | func (w *walker) Enter(l reflectwalk.Location) error { | |
118 | w.depth++ | |
119 | ||
120 | // ensure we have enough elements to index via w.depth | |
121 | for w.depth >= len(w.locks) { | |
122 | w.locks = append(w.locks, nil) | |
123 | } | |
124 | ||
125 | for len(w.ps) < w.depth+1 { | |
126 | w.ps = append(w.ps, 0) | |
127 | } | |
128 | ||
129 | return nil | |
130 | } | |
131 | ||
132 | func (w *walker) Exit(l reflectwalk.Location) error { | |
133 | locker := w.locks[w.depth] | |
134 | w.locks[w.depth] = nil | |
135 | if locker != nil { | |
136 | defer locker.Unlock() | |
137 | } | |
138 | ||
139 | // clear out pointers and interfaces as we exit the stack | |
140 | w.ps[w.depth] = 0 | |
141 | ||
142 | for k := range w.ifaceTypes { | |
143 | mask := uint64(^uint32(0)) | |
144 | if k&mask == uint64(w.depth) { | |
145 | delete(w.ifaceTypes, k) | |
146 | } | |
147 | } | |
148 | ||
149 | w.depth-- | |
150 | if w.ignoreDepth > w.depth { | |
151 | w.ignoreDepth = 0 | |
152 | } | |
153 | ||
154 | if w.ignoring() { | |
155 | return nil | |
156 | } | |
157 | ||
158 | switch l { | |
15c0b25d AP |
159 | case reflectwalk.Array: |
160 | fallthrough | |
bae9f6d2 JC |
161 | case reflectwalk.Map: |
162 | fallthrough | |
163 | case reflectwalk.Slice: | |
15c0b25d AP |
164 | w.replacePointerMaybe() |
165 | ||
bae9f6d2 JC |
166 | // Pop map off our container |
167 | w.cs = w.cs[:len(w.cs)-1] | |
168 | case reflectwalk.MapValue: | |
169 | // Pop off the key and value | |
170 | mv := w.valPop() | |
171 | mk := w.valPop() | |
172 | m := w.cs[len(w.cs)-1] | |
173 | ||
174 | // If mv is the zero value, SetMapIndex deletes the key form the map, | |
175 | // or in this case never adds it. We need to create a properly typed | |
176 | // zero value so that this key can be set. | |
177 | if !mv.IsValid() { | |
15c0b25d AP |
178 | mv = reflect.Zero(m.Elem().Type().Elem()) |
179 | } | |
180 | m.Elem().SetMapIndex(mk, mv) | |
181 | case reflectwalk.ArrayElem: | |
182 | // Pop off the value and the index and set it on the array | |
183 | v := w.valPop() | |
184 | i := w.valPop().Interface().(int) | |
185 | if v.IsValid() { | |
186 | a := w.cs[len(w.cs)-1] | |
187 | ae := a.Elem().Index(i) // storing array as pointer on stack - so need Elem() call | |
188 | if ae.CanSet() { | |
189 | ae.Set(v) | |
190 | } | |
bae9f6d2 | 191 | } |
bae9f6d2 JC |
192 | case reflectwalk.SliceElem: |
193 | // Pop off the value and the index and set it on the slice | |
194 | v := w.valPop() | |
195 | i := w.valPop().Interface().(int) | |
196 | if v.IsValid() { | |
197 | s := w.cs[len(w.cs)-1] | |
15c0b25d | 198 | se := s.Elem().Index(i) |
bae9f6d2 JC |
199 | if se.CanSet() { |
200 | se.Set(v) | |
201 | } | |
202 | } | |
203 | case reflectwalk.Struct: | |
204 | w.replacePointerMaybe() | |
205 | ||
206 | // Remove the struct from the container stack | |
207 | w.cs = w.cs[:len(w.cs)-1] | |
208 | case reflectwalk.StructField: | |
209 | // Pop off the value and the field | |
210 | v := w.valPop() | |
211 | f := w.valPop().Interface().(reflect.StructField) | |
212 | if v.IsValid() { | |
213 | s := w.cs[len(w.cs)-1] | |
214 | sf := reflect.Indirect(s).FieldByName(f.Name) | |
215 | ||
216 | if sf.CanSet() { | |
217 | sf.Set(v) | |
218 | } | |
219 | } | |
220 | case reflectwalk.WalkLoc: | |
221 | // Clear out the slices for GC | |
222 | w.cs = nil | |
223 | w.vals = nil | |
224 | } | |
225 | ||
226 | return nil | |
227 | } | |
228 | ||
229 | func (w *walker) Map(m reflect.Value) error { | |
230 | if w.ignoring() { | |
231 | return nil | |
232 | } | |
233 | w.lock(m) | |
234 | ||
235 | // Create the map. If the map itself is nil, then just make a nil map | |
236 | var newMap reflect.Value | |
237 | if m.IsNil() { | |
15c0b25d | 238 | newMap = reflect.New(m.Type()) |
bae9f6d2 | 239 | } else { |
15c0b25d | 240 | newMap = wrapPtr(reflect.MakeMap(m.Type())) |
bae9f6d2 JC |
241 | } |
242 | ||
243 | w.cs = append(w.cs, newMap) | |
244 | w.valPush(newMap) | |
245 | return nil | |
246 | } | |
247 | ||
248 | func (w *walker) MapElem(m, k, v reflect.Value) error { | |
249 | return nil | |
250 | } | |
251 | ||
252 | func (w *walker) PointerEnter(v bool) error { | |
253 | if v { | |
254 | w.ps[w.depth]++ | |
255 | } | |
256 | return nil | |
257 | } | |
258 | ||
259 | func (w *walker) PointerExit(v bool) error { | |
260 | if v { | |
261 | w.ps[w.depth]-- | |
262 | } | |
263 | return nil | |
264 | } | |
265 | ||
266 | func (w *walker) Interface(v reflect.Value) error { | |
267 | if !v.IsValid() { | |
268 | return nil | |
269 | } | |
270 | if w.ifaceTypes == nil { | |
271 | w.ifaceTypes = make(map[uint64]reflect.Type) | |
272 | } | |
273 | ||
274 | w.ifaceTypes[ifaceKey(w.ps[w.depth], w.depth)] = v.Type() | |
275 | return nil | |
276 | } | |
277 | ||
278 | func (w *walker) Primitive(v reflect.Value) error { | |
279 | if w.ignoring() { | |
280 | return nil | |
281 | } | |
282 | w.lock(v) | |
283 | ||
284 | // IsValid verifies the v is non-zero and CanInterface verifies | |
285 | // that we're allowed to read this value (unexported fields). | |
286 | var newV reflect.Value | |
287 | if v.IsValid() && v.CanInterface() { | |
288 | newV = reflect.New(v.Type()) | |
289 | newV.Elem().Set(v) | |
290 | } | |
291 | ||
292 | w.valPush(newV) | |
293 | w.replacePointerMaybe() | |
294 | return nil | |
295 | } | |
296 | ||
297 | func (w *walker) Slice(s reflect.Value) error { | |
298 | if w.ignoring() { | |
299 | return nil | |
300 | } | |
301 | w.lock(s) | |
302 | ||
303 | var newS reflect.Value | |
304 | if s.IsNil() { | |
15c0b25d | 305 | newS = reflect.New(s.Type()) |
bae9f6d2 | 306 | } else { |
15c0b25d | 307 | newS = wrapPtr(reflect.MakeSlice(s.Type(), s.Len(), s.Cap())) |
bae9f6d2 JC |
308 | } |
309 | ||
310 | w.cs = append(w.cs, newS) | |
311 | w.valPush(newS) | |
312 | return nil | |
313 | } | |
314 | ||
315 | func (w *walker) SliceElem(i int, elem reflect.Value) error { | |
316 | if w.ignoring() { | |
317 | return nil | |
318 | } | |
319 | ||
320 | // We don't write the slice here because elem might still be | |
321 | // arbitrarily complex. Just record the index and continue on. | |
322 | w.valPush(reflect.ValueOf(i)) | |
323 | ||
324 | return nil | |
325 | } | |
326 | ||
15c0b25d AP |
327 | func (w *walker) Array(a reflect.Value) error { |
328 | if w.ignoring() { | |
329 | return nil | |
330 | } | |
331 | w.lock(a) | |
332 | ||
333 | newA := reflect.New(a.Type()) | |
334 | ||
335 | w.cs = append(w.cs, newA) | |
336 | w.valPush(newA) | |
337 | return nil | |
338 | } | |
339 | ||
340 | func (w *walker) ArrayElem(i int, elem reflect.Value) error { | |
341 | if w.ignoring() { | |
342 | return nil | |
343 | } | |
344 | ||
345 | // We don't write the array here because elem might still be | |
346 | // arbitrarily complex. Just record the index and continue on. | |
347 | w.valPush(reflect.ValueOf(i)) | |
348 | ||
349 | return nil | |
350 | } | |
351 | ||
bae9f6d2 JC |
352 | func (w *walker) Struct(s reflect.Value) error { |
353 | if w.ignoring() { | |
354 | return nil | |
355 | } | |
356 | w.lock(s) | |
357 | ||
358 | var v reflect.Value | |
359 | if c, ok := Copiers[s.Type()]; ok { | |
360 | // We have a Copier for this struct, so we use that copier to | |
361 | // get the copy, and we ignore anything deeper than this. | |
362 | w.ignoreDepth = w.depth | |
363 | ||
364 | dup, err := c(s.Interface()) | |
365 | if err != nil { | |
366 | return err | |
367 | } | |
368 | ||
15c0b25d AP |
369 | // We need to put a pointer to the value on the value stack, |
370 | // so allocate a new pointer and set it. | |
371 | v = reflect.New(s.Type()) | |
372 | reflect.Indirect(v).Set(reflect.ValueOf(dup)) | |
bae9f6d2 JC |
373 | } else { |
374 | // No copier, we copy ourselves and allow reflectwalk to guide | |
375 | // us deeper into the structure for copying. | |
376 | v = reflect.New(s.Type()) | |
377 | } | |
378 | ||
379 | // Push the value onto the value stack for setting the struct field, | |
380 | // and add the struct itself to the containers stack in case we walk | |
381 | // deeper so that its own fields can be modified. | |
382 | w.valPush(v) | |
383 | w.cs = append(w.cs, v) | |
384 | ||
385 | return nil | |
386 | } | |
387 | ||
388 | func (w *walker) StructField(f reflect.StructField, v reflect.Value) error { | |
389 | if w.ignoring() { | |
390 | return nil | |
391 | } | |
392 | ||
393 | // If PkgPath is non-empty, this is a private (unexported) field. | |
394 | // We do not set this unexported since the Go runtime doesn't allow us. | |
395 | if f.PkgPath != "" { | |
396 | return reflectwalk.SkipEntry | |
397 | } | |
398 | ||
399 | // Push the field onto the stack, we'll handle it when we exit | |
400 | // the struct field in Exit... | |
401 | w.valPush(reflect.ValueOf(f)) | |
402 | return nil | |
403 | } | |
404 | ||
405 | // ignore causes the walker to ignore any more values until we exit this on | |
406 | func (w *walker) ignore() { | |
407 | w.ignoreDepth = w.depth | |
408 | } | |
409 | ||
410 | func (w *walker) ignoring() bool { | |
411 | return w.ignoreDepth > 0 && w.depth >= w.ignoreDepth | |
412 | } | |
413 | ||
414 | func (w *walker) pointerPeek() bool { | |
415 | return w.ps[w.depth] > 0 | |
416 | } | |
417 | ||
418 | func (w *walker) valPop() reflect.Value { | |
419 | result := w.vals[len(w.vals)-1] | |
420 | w.vals = w.vals[:len(w.vals)-1] | |
421 | ||
422 | // If we're out of values, that means we popped everything off. In | |
423 | // this case, we reset the result so the next pushed value becomes | |
424 | // the result. | |
425 | if len(w.vals) == 0 { | |
426 | w.Result = nil | |
427 | } | |
428 | ||
429 | return result | |
430 | } | |
431 | ||
432 | func (w *walker) valPush(v reflect.Value) { | |
433 | w.vals = append(w.vals, v) | |
434 | ||
435 | // If we haven't set the result yet, then this is the result since | |
436 | // it is the first (outermost) value we're seeing. | |
437 | if w.Result == nil && v.IsValid() { | |
438 | w.Result = v.Interface() | |
439 | } | |
440 | } | |
441 | ||
442 | func (w *walker) replacePointerMaybe() { | |
443 | // Determine the last pointer value. If it is NOT a pointer, then | |
444 | // we need to push that onto the stack. | |
445 | if !w.pointerPeek() { | |
446 | w.valPush(reflect.Indirect(w.valPop())) | |
447 | return | |
448 | } | |
449 | ||
450 | v := w.valPop() | |
15c0b25d AP |
451 | |
452 | // If the expected type is a pointer to an interface of any depth, | |
453 | // such as *interface{}, **interface{}, etc., then we need to convert | |
454 | // the value "v" from *CONCRETE to *interface{} so types match for | |
455 | // Set. | |
456 | // | |
457 | // Example if v is type *Foo where Foo is a struct, v would become | |
458 | // *interface{} instead. This only happens if we have an interface expectation | |
459 | // at this depth. | |
460 | // | |
461 | // For more info, see GH-16 | |
462 | if iType, ok := w.ifaceTypes[ifaceKey(w.ps[w.depth], w.depth)]; ok && iType.Kind() == reflect.Interface { | |
463 | y := reflect.New(iType) // Create *interface{} | |
464 | y.Elem().Set(reflect.Indirect(v)) // Assign "Foo" to interface{} (dereferenced) | |
465 | v = y // v is now typed *interface{} (where *v = Foo) | |
466 | } | |
467 | ||
bae9f6d2 JC |
468 | for i := 1; i < w.ps[w.depth]; i++ { |
469 | if iType, ok := w.ifaceTypes[ifaceKey(w.ps[w.depth]-i, w.depth)]; ok { | |
470 | iface := reflect.New(iType).Elem() | |
471 | iface.Set(v) | |
472 | v = iface | |
473 | } | |
474 | ||
475 | p := reflect.New(v.Type()) | |
476 | p.Elem().Set(v) | |
477 | v = p | |
478 | } | |
479 | ||
480 | w.valPush(v) | |
481 | } | |
482 | ||
483 | // if this value is a Locker, lock it and add it to the locks slice | |
484 | func (w *walker) lock(v reflect.Value) { | |
485 | if !w.useLocks { | |
486 | return | |
487 | } | |
488 | ||
489 | if !v.IsValid() || !v.CanInterface() { | |
490 | return | |
491 | } | |
492 | ||
493 | type rlocker interface { | |
494 | RLocker() sync.Locker | |
495 | } | |
496 | ||
497 | var locker sync.Locker | |
498 | ||
499 | // We can't call Interface() on a value directly, since that requires | |
500 | // a copy. This is OK, since the pointer to a value which is a sync.Locker | |
501 | // is also a sync.Locker. | |
502 | if v.Kind() == reflect.Ptr { | |
503 | switch l := v.Interface().(type) { | |
504 | case rlocker: | |
505 | // don't lock a mutex directly | |
506 | if _, ok := l.(*sync.RWMutex); !ok { | |
507 | locker = l.RLocker() | |
508 | } | |
509 | case sync.Locker: | |
510 | locker = l | |
511 | } | |
512 | } else if v.CanAddr() { | |
513 | switch l := v.Addr().Interface().(type) { | |
514 | case rlocker: | |
515 | // don't lock a mutex directly | |
516 | if _, ok := l.(*sync.RWMutex); !ok { | |
517 | locker = l.RLocker() | |
518 | } | |
519 | case sync.Locker: | |
520 | locker = l | |
521 | } | |
522 | } | |
523 | ||
524 | // still no callable locker | |
525 | if locker == nil { | |
526 | return | |
527 | } | |
528 | ||
529 | // don't lock a mutex directly | |
530 | switch locker.(type) { | |
531 | case *sync.Mutex, *sync.RWMutex: | |
532 | return | |
533 | } | |
534 | ||
535 | locker.Lock() | |
536 | w.locks[w.depth] = locker | |
537 | } | |
15c0b25d AP |
538 | |
539 | // wrapPtr is a helper that takes v and always make it *v. copystructure | |
540 | // stores things internally as pointers until the last moment before unwrapping | |
541 | func wrapPtr(v reflect.Value) reflect.Value { | |
542 | if !v.IsValid() { | |
543 | return v | |
544 | } | |
545 | vPtr := reflect.New(v.Type()) | |
546 | vPtr.Elem().Set(v) | |
547 | return vPtr | |
548 | } |