aboutsummaryrefslogtreecommitdiffhomepage
path: root/vendor/github.com/aws/aws-sdk-go/aws/credentials/stscreds/web_identity_provider.go
blob: 20510d9aec8e8a76e6ecdd0ac8bb532b6b17018d (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
package stscreds

import (
	"fmt"
	"io/ioutil"
	"strconv"
	"time"

	"github.com/aws/aws-sdk-go/aws"
	"github.com/aws/aws-sdk-go/aws/awserr"
	"github.com/aws/aws-sdk-go/aws/client"
	"github.com/aws/aws-sdk-go/aws/credentials"
	"github.com/aws/aws-sdk-go/service/sts"
	"github.com/aws/aws-sdk-go/service/sts/stsiface"
)

const (
	// ErrCodeWebIdentity will be used as an error code when constructing
	// a new error to be returned during session creation or retrieval.
	ErrCodeWebIdentity = "WebIdentityErr"

	// WebIdentityProviderName is the web identity provider name
	WebIdentityProviderName = "WebIdentityCredentials"
)

// now is used to return a time.Time object representing
// the current time. This can be used to easily test and
// compare test values.
var now = time.Now

// WebIdentityRoleProvider is used to retrieve credentials using
// an OIDC token.
type WebIdentityRoleProvider struct {
	credentials.Expiry

	client       stsiface.STSAPI
	ExpiryWindow time.Duration

	tokenFilePath   string
	roleARN         string
	roleSessionName string
}

// NewWebIdentityCredentials will return a new set of credentials with a given
// configuration, role arn, and token file path.
func NewWebIdentityCredentials(c client.ConfigProvider, roleARN, roleSessionName, path string) *credentials.Credentials {
	svc := sts.New(c)
	p := NewWebIdentityRoleProvider(svc, roleARN, roleSessionName, path)
	return credentials.NewCredentials(p)
}

// NewWebIdentityRoleProvider will return a new WebIdentityRoleProvider with the
// provided stsiface.STSAPI
func NewWebIdentityRoleProvider(svc stsiface.STSAPI, roleARN, roleSessionName, path string) *WebIdentityRoleProvider {
	return &WebIdentityRoleProvider{
		client:          svc,
		tokenFilePath:   path,
		roleARN:         roleARN,
		roleSessionName: roleSessionName,
	}
}

// Retrieve will attempt to assume a role from a token which is located at
// 'WebIdentityTokenFilePath' specified destination and if that is empty an
// error will be returned.
func (p *WebIdentityRoleProvider) Retrieve() (credentials.Value, error) {
	b, err := ioutil.ReadFile(p.tokenFilePath)
	if err != nil {
		errMsg := fmt.Sprintf("unable to read file at %s", p.tokenFilePath)
		return credentials.Value{}, awserr.New(ErrCodeWebIdentity, errMsg, err)
	}

	sessionName := p.roleSessionName
	if len(sessionName) == 0 {
		// session name is used to uniquely identify a session. This simply
		// uses unix time in nanoseconds to uniquely identify sessions.
		sessionName = strconv.FormatInt(now().UnixNano(), 10)
	}
	resp, err := p.client.AssumeRoleWithWebIdentity(&sts.AssumeRoleWithWebIdentityInput{
		RoleArn:          &p.roleARN,
		RoleSessionName:  &sessionName,
		WebIdentityToken: aws.String(string(b)),
	})
	if err != nil {
		return credentials.Value{}, awserr.New(ErrCodeWebIdentity, "failed to retrieve credentials", err)
	}

	p.SetExpiration(aws.TimeValue(resp.Credentials.Expiration), p.ExpiryWindow)

	value := credentials.Value{
		AccessKeyID:     aws.StringValue(resp.Credentials.AccessKeyId),
		SecretAccessKey: aws.StringValue(resp.Credentials.SecretAccessKey),
		SessionToken:    aws.StringValue(resp.Credentials.SessionToken),
		ProviderName:    WebIdentityProviderName,
	}
	return value, nil
}