]> git.immae.eu Git - github/fretlink/terraform-provider-statuscake.git/blob - vendor/github.com/hashicorp/go-getter/get_s3.go
Upgrade to 0.12
[github/fretlink/terraform-provider-statuscake.git] / vendor / github.com / hashicorp / go-getter / get_s3.go
1 package getter
2
3 import (
4 "context"
5 "fmt"
6 "net/url"
7 "os"
8 "path/filepath"
9 "strings"
10
11 "github.com/aws/aws-sdk-go/aws"
12 "github.com/aws/aws-sdk-go/aws/credentials"
13 "github.com/aws/aws-sdk-go/aws/credentials/ec2rolecreds"
14 "github.com/aws/aws-sdk-go/aws/ec2metadata"
15 "github.com/aws/aws-sdk-go/aws/session"
16 "github.com/aws/aws-sdk-go/service/s3"
17 )
18
19 // S3Getter is a Getter implementation that will download a module from
20 // a S3 bucket.
21 type S3Getter struct {
22 getter
23 }
24
25 func (g *S3Getter) ClientMode(u *url.URL) (ClientMode, error) {
26 // Parse URL
27 region, bucket, path, _, creds, err := g.parseUrl(u)
28 if err != nil {
29 return 0, err
30 }
31
32 // Create client config
33 config := g.getAWSConfig(region, u, creds)
34 sess := session.New(config)
35 client := s3.New(sess)
36
37 // List the object(s) at the given prefix
38 req := &s3.ListObjectsInput{
39 Bucket: aws.String(bucket),
40 Prefix: aws.String(path),
41 }
42 resp, err := client.ListObjects(req)
43 if err != nil {
44 return 0, err
45 }
46
47 for _, o := range resp.Contents {
48 // Use file mode on exact match.
49 if *o.Key == path {
50 return ClientModeFile, nil
51 }
52
53 // Use dir mode if child keys are found.
54 if strings.HasPrefix(*o.Key, path+"/") {
55 return ClientModeDir, nil
56 }
57 }
58
59 // There was no match, so just return file mode. The download is going
60 // to fail but we will let S3 return the proper error later.
61 return ClientModeFile, nil
62 }
63
64 func (g *S3Getter) Get(dst string, u *url.URL) error {
65 ctx := g.Context()
66
67 // Parse URL
68 region, bucket, path, _, creds, err := g.parseUrl(u)
69 if err != nil {
70 return err
71 }
72
73 // Remove destination if it already exists
74 _, err = os.Stat(dst)
75 if err != nil && !os.IsNotExist(err) {
76 return err
77 }
78
79 if err == nil {
80 // Remove the destination
81 if err := os.RemoveAll(dst); err != nil {
82 return err
83 }
84 }
85
86 // Create all the parent directories
87 if err := os.MkdirAll(filepath.Dir(dst), 0755); err != nil {
88 return err
89 }
90
91 config := g.getAWSConfig(region, u, creds)
92 sess := session.New(config)
93 client := s3.New(sess)
94
95 // List files in path, keep listing until no more objects are found
96 lastMarker := ""
97 hasMore := true
98 for hasMore {
99 req := &s3.ListObjectsInput{
100 Bucket: aws.String(bucket),
101 Prefix: aws.String(path),
102 }
103 if lastMarker != "" {
104 req.Marker = aws.String(lastMarker)
105 }
106
107 resp, err := client.ListObjects(req)
108 if err != nil {
109 return err
110 }
111
112 hasMore = aws.BoolValue(resp.IsTruncated)
113
114 // Get each object storing each file relative to the destination path
115 for _, object := range resp.Contents {
116 lastMarker = aws.StringValue(object.Key)
117 objPath := aws.StringValue(object.Key)
118
119 // If the key ends with a backslash assume it is a directory and ignore
120 if strings.HasSuffix(objPath, "/") {
121 continue
122 }
123
124 // Get the object destination path
125 objDst, err := filepath.Rel(path, objPath)
126 if err != nil {
127 return err
128 }
129 objDst = filepath.Join(dst, objDst)
130
131 if err := g.getObject(ctx, client, objDst, bucket, objPath, ""); err != nil {
132 return err
133 }
134 }
135 }
136
137 return nil
138 }
139
140 func (g *S3Getter) GetFile(dst string, u *url.URL) error {
141 ctx := g.Context()
142 region, bucket, path, version, creds, err := g.parseUrl(u)
143 if err != nil {
144 return err
145 }
146
147 config := g.getAWSConfig(region, u, creds)
148 sess := session.New(config)
149 client := s3.New(sess)
150 return g.getObject(ctx, client, dst, bucket, path, version)
151 }
152
153 func (g *S3Getter) getObject(ctx context.Context, client *s3.S3, dst, bucket, key, version string) error {
154 req := &s3.GetObjectInput{
155 Bucket: aws.String(bucket),
156 Key: aws.String(key),
157 }
158 if version != "" {
159 req.VersionId = aws.String(version)
160 }
161
162 resp, err := client.GetObject(req)
163 if err != nil {
164 return err
165 }
166
167 // Create all the parent directories
168 if err := os.MkdirAll(filepath.Dir(dst), 0755); err != nil {
169 return err
170 }
171
172 f, err := os.Create(dst)
173 if err != nil {
174 return err
175 }
176 defer f.Close()
177
178 _, err = Copy(ctx, f, resp.Body)
179 return err
180 }
181
182 func (g *S3Getter) getAWSConfig(region string, url *url.URL, creds *credentials.Credentials) *aws.Config {
183 conf := &aws.Config{}
184 if creds == nil {
185 // Grab the metadata URL
186 metadataURL := os.Getenv("AWS_METADATA_URL")
187 if metadataURL == "" {
188 metadataURL = "http://169.254.169.254:80/latest"
189 }
190
191 creds = credentials.NewChainCredentials(
192 []credentials.Provider{
193 &credentials.EnvProvider{},
194 &credentials.SharedCredentialsProvider{Filename: "", Profile: ""},
195 &ec2rolecreds.EC2RoleProvider{
196 Client: ec2metadata.New(session.New(&aws.Config{
197 Endpoint: aws.String(metadataURL),
198 })),
199 },
200 })
201 }
202
203 if creds != nil {
204 conf.Endpoint = &url.Host
205 conf.S3ForcePathStyle = aws.Bool(true)
206 if url.Scheme == "http" {
207 conf.DisableSSL = aws.Bool(true)
208 }
209 }
210
211 conf.Credentials = creds
212 if region != "" {
213 conf.Region = aws.String(region)
214 }
215
216 return conf
217 }
218
219 func (g *S3Getter) parseUrl(u *url.URL) (region, bucket, path, version string, creds *credentials.Credentials, err error) {
220 // This just check whether we are dealing with S3 or
221 // any other S3 compliant service. S3 has a predictable
222 // url as others do not
223 if strings.Contains(u.Host, "amazonaws.com") {
224 // Expected host style: s3.amazonaws.com. They always have 3 parts,
225 // although the first may differ if we're accessing a specific region.
226 hostParts := strings.Split(u.Host, ".")
227 if len(hostParts) != 3 {
228 err = fmt.Errorf("URL is not a valid S3 URL")
229 return
230 }
231
232 // Parse the region out of the first part of the host
233 region = strings.TrimPrefix(strings.TrimPrefix(hostParts[0], "s3-"), "s3")
234 if region == "" {
235 region = "us-east-1"
236 }
237
238 pathParts := strings.SplitN(u.Path, "/", 3)
239 if len(pathParts) != 3 {
240 err = fmt.Errorf("URL is not a valid S3 URL")
241 return
242 }
243
244 bucket = pathParts[1]
245 path = pathParts[2]
246 version = u.Query().Get("version")
247
248 } else {
249 pathParts := strings.SplitN(u.Path, "/", 3)
250 if len(pathParts) != 3 {
251 err = fmt.Errorf("URL is not a valid S3 complaint URL")
252 return
253 }
254 bucket = pathParts[1]
255 path = pathParts[2]
256 version = u.Query().Get("version")
257 region = u.Query().Get("region")
258 if region == "" {
259 region = "us-east-1"
260 }
261 }
262
263 _, hasAwsId := u.Query()["aws_access_key_id"]
264 _, hasAwsSecret := u.Query()["aws_access_key_secret"]
265 _, hasAwsToken := u.Query()["aws_access_token"]
266 if hasAwsId || hasAwsSecret || hasAwsToken {
267 creds = credentials.NewStaticCredentials(
268 u.Query().Get("aws_access_key_id"),
269 u.Query().Get("aws_access_key_secret"),
270 u.Query().Get("aws_access_token"),
271 )
272 }
273
274 return
275 }