aboutsummaryrefslogblamecommitdiffhomepage
path: root/vendor/github.com/vmihailenco/msgpack/decode_value.go
blob: 7b858b5fc5c277b531734f263c11c03b868fe9e7 (plain) (tree)









































































































































































































































                                                                                      
package msgpack

import (
	"errors"
	"fmt"
	"reflect"
)

var interfaceType = reflect.TypeOf((*interface{})(nil)).Elem()
var stringType = reflect.TypeOf((*string)(nil)).Elem()

var valueDecoders []decoderFunc

func init() {
	valueDecoders = []decoderFunc{
		reflect.Bool:          decodeBoolValue,
		reflect.Int:           decodeInt64Value,
		reflect.Int8:          decodeInt64Value,
		reflect.Int16:         decodeInt64Value,
		reflect.Int32:         decodeInt64Value,
		reflect.Int64:         decodeInt64Value,
		reflect.Uint:          decodeUint64Value,
		reflect.Uint8:         decodeUint64Value,
		reflect.Uint16:        decodeUint64Value,
		reflect.Uint32:        decodeUint64Value,
		reflect.Uint64:        decodeUint64Value,
		reflect.Float32:       decodeFloat32Value,
		reflect.Float64:       decodeFloat64Value,
		reflect.Complex64:     decodeUnsupportedValue,
		reflect.Complex128:    decodeUnsupportedValue,
		reflect.Array:         decodeArrayValue,
		reflect.Chan:          decodeUnsupportedValue,
		reflect.Func:          decodeUnsupportedValue,
		reflect.Interface:     decodeInterfaceValue,
		reflect.Map:           decodeMapValue,
		reflect.Ptr:           decodeUnsupportedValue,
		reflect.Slice:         decodeSliceValue,
		reflect.String:        decodeStringValue,
		reflect.Struct:        decodeStructValue,
		reflect.UnsafePointer: decodeUnsupportedValue,
	}
}

func mustSet(v reflect.Value) error {
	if !v.CanSet() {
		return fmt.Errorf("msgpack: Decode(nonsettable %s)", v.Type())
	}
	return nil
}

func getDecoder(typ reflect.Type) decoderFunc {
	kind := typ.Kind()

	decoder, ok := typDecMap[typ]
	if ok {
		return decoder
	}

	if typ.Implements(customDecoderType) {
		return decodeCustomValue
	}
	if typ.Implements(unmarshalerType) {
		return unmarshalValue
	}

	// Addressable struct field value.
	if kind != reflect.Ptr {
		ptr := reflect.PtrTo(typ)
		if ptr.Implements(customDecoderType) {
			return decodeCustomValueAddr
		}
		if ptr.Implements(unmarshalerType) {
			return unmarshalValueAddr
		}
	}

	switch kind {
	case reflect.Ptr:
		return ptrDecoderFunc(typ)
	case reflect.Slice:
		elem := typ.Elem()
		switch elem.Kind() {
		case reflect.Uint8:
			return decodeBytesValue
		}
		switch elem {
		case stringType:
			return decodeStringSliceValue
		}
	case reflect.Array:
		if typ.Elem().Kind() == reflect.Uint8 {
			return decodeByteArrayValue
		}
	case reflect.Map:
		if typ.Key() == stringType {
			switch typ.Elem() {
			case stringType:
				return decodeMapStringStringValue
			case interfaceType:
				return decodeMapStringInterfaceValue
			}
		}
	}
	return valueDecoders[kind]
}

func ptrDecoderFunc(typ reflect.Type) decoderFunc {
	decoder := getDecoder(typ.Elem())
	return func(d *Decoder, v reflect.Value) error {
		if d.hasNilCode() {
			if err := mustSet(v); err != nil {
				return err
			}
			if !v.IsNil() {
				v.Set(reflect.Zero(v.Type()))
			}
			return d.DecodeNil()
		}
		if v.IsNil() {
			if err := mustSet(v); err != nil {
				return err
			}
			v.Set(reflect.New(v.Type().Elem()))
		}
		return decoder(d, v.Elem())
	}
}

func decodeCustomValueAddr(d *Decoder, v reflect.Value) error {
	if !v.CanAddr() {
		return fmt.Errorf("msgpack: Decode(nonaddressable %T)", v.Interface())
	}
	return decodeCustomValue(d, v.Addr())
}

func decodeCustomValue(d *Decoder, v reflect.Value) error {
	if d.hasNilCode() {
		return d.decodeNilValue(v)
	}

	if v.IsNil() {
		v.Set(reflect.New(v.Type().Elem()))
	}

	decoder := v.Interface().(CustomDecoder)
	return decoder.DecodeMsgpack(d)
}

func unmarshalValueAddr(d *Decoder, v reflect.Value) error {
	if !v.CanAddr() {
		return fmt.Errorf("msgpack: Decode(nonaddressable %T)", v.Interface())
	}
	return unmarshalValue(d, v.Addr())
}

func unmarshalValue(d *Decoder, v reflect.Value) error {
	if d.hasNilCode() {
		return d.decodeNilValue(v)
	}

	if v.IsNil() {
		v.Set(reflect.New(v.Type().Elem()))
	}

	if d.extLen != 0 {
		b, err := d.readN(d.extLen)
		if err != nil {
			return err
		}
		d.rec = b
	} else {
		d.rec = makeBuffer()
		if err := d.Skip(); err != nil {
			return err
		}
	}

	unmarshaler := v.Interface().(Unmarshaler)
	err := unmarshaler.UnmarshalMsgpack(d.rec)
	d.rec = nil
	return err
}

func decodeBoolValue(d *Decoder, v reflect.Value) error {
	flag, err := d.DecodeBool()
	if err != nil {
		return err
	}
	if err = mustSet(v); err != nil {
		return err
	}
	v.SetBool(flag)
	return nil
}

func decodeInterfaceValue(d *Decoder, v reflect.Value) error {
	if v.IsNil() {
		return d.interfaceValue(v)
	}

	elem := v.Elem()
	if !elem.CanAddr() {
		if d.hasNilCode() {
			v.Set(reflect.Zero(v.Type()))
			return d.DecodeNil()
		}
	}

	return d.DecodeValue(elem)
}

func (d *Decoder) interfaceValue(v reflect.Value) error {
	vv, err := d.decodeInterfaceCond()
	if err != nil {
		return err
	}

	if vv != nil {
		if v.Type() == errorType {
			if vv, ok := vv.(string); ok {
				v.Set(reflect.ValueOf(errors.New(vv)))
				return nil
			}
		}

		v.Set(reflect.ValueOf(vv))
	}

	return nil
}

func decodeUnsupportedValue(d *Decoder, v reflect.Value) error {
	return fmt.Errorf("msgpack: Decode(unsupported %s)", v.Type())
}