]> git.immae.eu Git - github/fretlink/terraform-provider-statuscake.git/blob - vendor/github.com/hashicorp/go-getter/get_s3.go
Initial transfer of provider code
[github/fretlink/terraform-provider-statuscake.git] / vendor / github.com / hashicorp / go-getter / get_s3.go
1 package getter
2
3 import (
4 "fmt"
5 "io"
6 "net/url"
7 "os"
8 "path/filepath"
9 "strings"
10
11 "github.com/aws/aws-sdk-go/aws"
12 "github.com/aws/aws-sdk-go/aws/credentials"
13 "github.com/aws/aws-sdk-go/aws/credentials/ec2rolecreds"
14 "github.com/aws/aws-sdk-go/aws/ec2metadata"
15 "github.com/aws/aws-sdk-go/aws/session"
16 "github.com/aws/aws-sdk-go/service/s3"
17 )
18
19 // S3Getter is a Getter implementation that will download a module from
20 // a S3 bucket.
21 type S3Getter struct{}
22
23 func (g *S3Getter) ClientMode(u *url.URL) (ClientMode, error) {
24 // Parse URL
25 region, bucket, path, _, creds, err := g.parseUrl(u)
26 if err != nil {
27 return 0, err
28 }
29
30 // Create client config
31 config := g.getAWSConfig(region, creds)
32 sess := session.New(config)
33 client := s3.New(sess)
34
35 // List the object(s) at the given prefix
36 req := &s3.ListObjectsInput{
37 Bucket: aws.String(bucket),
38 Prefix: aws.String(path),
39 }
40 resp, err := client.ListObjects(req)
41 if err != nil {
42 return 0, err
43 }
44
45 for _, o := range resp.Contents {
46 // Use file mode on exact match.
47 if *o.Key == path {
48 return ClientModeFile, nil
49 }
50
51 // Use dir mode if child keys are found.
52 if strings.HasPrefix(*o.Key, path+"/") {
53 return ClientModeDir, nil
54 }
55 }
56
57 // There was no match, so just return file mode. The download is going
58 // to fail but we will let S3 return the proper error later.
59 return ClientModeFile, nil
60 }
61
62 func (g *S3Getter) Get(dst string, u *url.URL) error {
63 // Parse URL
64 region, bucket, path, _, creds, err := g.parseUrl(u)
65 if err != nil {
66 return err
67 }
68
69 // Remove destination if it already exists
70 _, err = os.Stat(dst)
71 if err != nil && !os.IsNotExist(err) {
72 return err
73 }
74
75 if err == nil {
76 // Remove the destination
77 if err := os.RemoveAll(dst); err != nil {
78 return err
79 }
80 }
81
82 // Create all the parent directories
83 if err := os.MkdirAll(filepath.Dir(dst), 0755); err != nil {
84 return err
85 }
86
87 config := g.getAWSConfig(region, creds)
88 sess := session.New(config)
89 client := s3.New(sess)
90
91 // List files in path, keep listing until no more objects are found
92 lastMarker := ""
93 hasMore := true
94 for hasMore {
95 req := &s3.ListObjectsInput{
96 Bucket: aws.String(bucket),
97 Prefix: aws.String(path),
98 }
99 if lastMarker != "" {
100 req.Marker = aws.String(lastMarker)
101 }
102
103 resp, err := client.ListObjects(req)
104 if err != nil {
105 return err
106 }
107
108 hasMore = aws.BoolValue(resp.IsTruncated)
109
110 // Get each object storing each file relative to the destination path
111 for _, object := range resp.Contents {
112 lastMarker = aws.StringValue(object.Key)
113 objPath := aws.StringValue(object.Key)
114
115 // If the key ends with a backslash assume it is a directory and ignore
116 if strings.HasSuffix(objPath, "/") {
117 continue
118 }
119
120 // Get the object destination path
121 objDst, err := filepath.Rel(path, objPath)
122 if err != nil {
123 return err
124 }
125 objDst = filepath.Join(dst, objDst)
126
127 if err := g.getObject(client, objDst, bucket, objPath, ""); err != nil {
128 return err
129 }
130 }
131 }
132
133 return nil
134 }
135
136 func (g *S3Getter) GetFile(dst string, u *url.URL) error {
137 region, bucket, path, version, creds, err := g.parseUrl(u)
138 if err != nil {
139 return err
140 }
141
142 config := g.getAWSConfig(region, creds)
143 sess := session.New(config)
144 client := s3.New(sess)
145 return g.getObject(client, dst, bucket, path, version)
146 }
147
148 func (g *S3Getter) getObject(client *s3.S3, dst, bucket, key, version string) error {
149 req := &s3.GetObjectInput{
150 Bucket: aws.String(bucket),
151 Key: aws.String(key),
152 }
153 if version != "" {
154 req.VersionId = aws.String(version)
155 }
156
157 resp, err := client.GetObject(req)
158 if err != nil {
159 return err
160 }
161
162 // Create all the parent directories
163 if err := os.MkdirAll(filepath.Dir(dst), 0755); err != nil {
164 return err
165 }
166
167 f, err := os.Create(dst)
168 if err != nil {
169 return err
170 }
171 defer f.Close()
172
173 _, err = io.Copy(f, resp.Body)
174 return err
175 }
176
177 func (g *S3Getter) getAWSConfig(region string, creds *credentials.Credentials) *aws.Config {
178 conf := &aws.Config{}
179 if creds == nil {
180 // Grab the metadata URL
181 metadataURL := os.Getenv("AWS_METADATA_URL")
182 if metadataURL == "" {
183 metadataURL = "http://169.254.169.254:80/latest"
184 }
185
186 creds = credentials.NewChainCredentials(
187 []credentials.Provider{
188 &credentials.EnvProvider{},
189 &credentials.SharedCredentialsProvider{Filename: "", Profile: ""},
190 &ec2rolecreds.EC2RoleProvider{
191 Client: ec2metadata.New(session.New(&aws.Config{
192 Endpoint: aws.String(metadataURL),
193 })),
194 },
195 })
196 }
197
198 conf.Credentials = creds
199 if region != "" {
200 conf.Region = aws.String(region)
201 }
202
203 return conf
204 }
205
206 func (g *S3Getter) parseUrl(u *url.URL) (region, bucket, path, version string, creds *credentials.Credentials, err error) {
207 // Expected host style: s3.amazonaws.com. They always have 3 parts,
208 // although the first may differ if we're accessing a specific region.
209 hostParts := strings.Split(u.Host, ".")
210 if len(hostParts) != 3 {
211 err = fmt.Errorf("URL is not a valid S3 URL")
212 return
213 }
214
215 // Parse the region out of the first part of the host
216 region = strings.TrimPrefix(strings.TrimPrefix(hostParts[0], "s3-"), "s3")
217 if region == "" {
218 region = "us-east-1"
219 }
220
221 pathParts := strings.SplitN(u.Path, "/", 3)
222 if len(pathParts) != 3 {
223 err = fmt.Errorf("URL is not a valid S3 URL")
224 return
225 }
226
227 bucket = pathParts[1]
228 path = pathParts[2]
229 version = u.Query().Get("version")
230
231 _, hasAwsId := u.Query()["aws_access_key_id"]
232 _, hasAwsSecret := u.Query()["aws_access_key_secret"]
233 _, hasAwsToken := u.Query()["aws_access_token"]
234 if hasAwsId || hasAwsSecret || hasAwsToken {
235 creds = credentials.NewStaticCredentials(
236 u.Query().Get("aws_access_key_id"),
237 u.Query().Get("aws_access_key_secret"),
238 u.Query().Get("aws_access_token"),
239 )
240 }
241
242 return
243 }