]>
Commit | Line | Data |
---|---|---|
bae9f6d2 JC |
1 | package getter |
2 | ||
3 | import ( | |
107c1cdb | 4 | "context" |
bae9f6d2 | 5 | "fmt" |
bae9f6d2 JC |
6 | "net/url" |
7 | "os" | |
8 | "path/filepath" | |
9 | "strings" | |
10 | ||
11 | "github.com/aws/aws-sdk-go/aws" | |
12 | "github.com/aws/aws-sdk-go/aws/credentials" | |
13 | "github.com/aws/aws-sdk-go/aws/credentials/ec2rolecreds" | |
14 | "github.com/aws/aws-sdk-go/aws/ec2metadata" | |
15 | "github.com/aws/aws-sdk-go/aws/session" | |
16 | "github.com/aws/aws-sdk-go/service/s3" | |
17 | ) | |
18 | ||
19 | // S3Getter is a Getter implementation that will download a module from | |
20 | // a S3 bucket. | |
107c1cdb ND |
21 | type S3Getter struct { |
22 | getter | |
23 | } | |
bae9f6d2 JC |
24 | |
25 | func (g *S3Getter) ClientMode(u *url.URL) (ClientMode, error) { | |
26 | // Parse URL | |
27 | region, bucket, path, _, creds, err := g.parseUrl(u) | |
28 | if err != nil { | |
29 | return 0, err | |
30 | } | |
31 | ||
32 | // Create client config | |
15c0b25d | 33 | config := g.getAWSConfig(region, u, creds) |
bae9f6d2 JC |
34 | sess := session.New(config) |
35 | client := s3.New(sess) | |
36 | ||
37 | // List the object(s) at the given prefix | |
38 | req := &s3.ListObjectsInput{ | |
39 | Bucket: aws.String(bucket), | |
40 | Prefix: aws.String(path), | |
41 | } | |
42 | resp, err := client.ListObjects(req) | |
43 | if err != nil { | |
44 | return 0, err | |
45 | } | |
46 | ||
47 | for _, o := range resp.Contents { | |
48 | // Use file mode on exact match. | |
49 | if *o.Key == path { | |
50 | return ClientModeFile, nil | |
51 | } | |
52 | ||
53 | // Use dir mode if child keys are found. | |
54 | if strings.HasPrefix(*o.Key, path+"/") { | |
55 | return ClientModeDir, nil | |
56 | } | |
57 | } | |
58 | ||
59 | // There was no match, so just return file mode. The download is going | |
60 | // to fail but we will let S3 return the proper error later. | |
61 | return ClientModeFile, nil | |
62 | } | |
63 | ||
64 | func (g *S3Getter) Get(dst string, u *url.URL) error { | |
107c1cdb ND |
65 | ctx := g.Context() |
66 | ||
bae9f6d2 JC |
67 | // Parse URL |
68 | region, bucket, path, _, creds, err := g.parseUrl(u) | |
69 | if err != nil { | |
70 | return err | |
71 | } | |
72 | ||
73 | // Remove destination if it already exists | |
74 | _, err = os.Stat(dst) | |
75 | if err != nil && !os.IsNotExist(err) { | |
76 | return err | |
77 | } | |
78 | ||
79 | if err == nil { | |
80 | // Remove the destination | |
81 | if err := os.RemoveAll(dst); err != nil { | |
82 | return err | |
83 | } | |
84 | } | |
85 | ||
86 | // Create all the parent directories | |
87 | if err := os.MkdirAll(filepath.Dir(dst), 0755); err != nil { | |
88 | return err | |
89 | } | |
90 | ||
15c0b25d | 91 | config := g.getAWSConfig(region, u, creds) |
bae9f6d2 JC |
92 | sess := session.New(config) |
93 | client := s3.New(sess) | |
94 | ||
95 | // List files in path, keep listing until no more objects are found | |
96 | lastMarker := "" | |
97 | hasMore := true | |
98 | for hasMore { | |
99 | req := &s3.ListObjectsInput{ | |
100 | Bucket: aws.String(bucket), | |
101 | Prefix: aws.String(path), | |
102 | } | |
103 | if lastMarker != "" { | |
104 | req.Marker = aws.String(lastMarker) | |
105 | } | |
106 | ||
107 | resp, err := client.ListObjects(req) | |
108 | if err != nil { | |
109 | return err | |
110 | } | |
111 | ||
112 | hasMore = aws.BoolValue(resp.IsTruncated) | |
113 | ||
114 | // Get each object storing each file relative to the destination path | |
115 | for _, object := range resp.Contents { | |
116 | lastMarker = aws.StringValue(object.Key) | |
117 | objPath := aws.StringValue(object.Key) | |
118 | ||
119 | // If the key ends with a backslash assume it is a directory and ignore | |
120 | if strings.HasSuffix(objPath, "/") { | |
121 | continue | |
122 | } | |
123 | ||
124 | // Get the object destination path | |
125 | objDst, err := filepath.Rel(path, objPath) | |
126 | if err != nil { | |
127 | return err | |
128 | } | |
129 | objDst = filepath.Join(dst, objDst) | |
130 | ||
107c1cdb | 131 | if err := g.getObject(ctx, client, objDst, bucket, objPath, ""); err != nil { |
bae9f6d2 JC |
132 | return err |
133 | } | |
134 | } | |
135 | } | |
136 | ||
137 | return nil | |
138 | } | |
139 | ||
140 | func (g *S3Getter) GetFile(dst string, u *url.URL) error { | |
107c1cdb | 141 | ctx := g.Context() |
bae9f6d2 JC |
142 | region, bucket, path, version, creds, err := g.parseUrl(u) |
143 | if err != nil { | |
144 | return err | |
145 | } | |
146 | ||
15c0b25d | 147 | config := g.getAWSConfig(region, u, creds) |
bae9f6d2 JC |
148 | sess := session.New(config) |
149 | client := s3.New(sess) | |
107c1cdb | 150 | return g.getObject(ctx, client, dst, bucket, path, version) |
bae9f6d2 JC |
151 | } |
152 | ||
107c1cdb | 153 | func (g *S3Getter) getObject(ctx context.Context, client *s3.S3, dst, bucket, key, version string) error { |
bae9f6d2 JC |
154 | req := &s3.GetObjectInput{ |
155 | Bucket: aws.String(bucket), | |
156 | Key: aws.String(key), | |
157 | } | |
158 | if version != "" { | |
159 | req.VersionId = aws.String(version) | |
160 | } | |
161 | ||
162 | resp, err := client.GetObject(req) | |
163 | if err != nil { | |
164 | return err | |
165 | } | |
166 | ||
167 | // Create all the parent directories | |
168 | if err := os.MkdirAll(filepath.Dir(dst), 0755); err != nil { | |
169 | return err | |
170 | } | |
171 | ||
172 | f, err := os.Create(dst) | |
173 | if err != nil { | |
174 | return err | |
175 | } | |
176 | defer f.Close() | |
177 | ||
107c1cdb | 178 | _, err = Copy(ctx, f, resp.Body) |
bae9f6d2 JC |
179 | return err |
180 | } | |
181 | ||
15c0b25d | 182 | func (g *S3Getter) getAWSConfig(region string, url *url.URL, creds *credentials.Credentials) *aws.Config { |
bae9f6d2 JC |
183 | conf := &aws.Config{} |
184 | if creds == nil { | |
185 | // Grab the metadata URL | |
186 | metadataURL := os.Getenv("AWS_METADATA_URL") | |
187 | if metadataURL == "" { | |
188 | metadataURL = "http://169.254.169.254:80/latest" | |
189 | } | |
190 | ||
191 | creds = credentials.NewChainCredentials( | |
192 | []credentials.Provider{ | |
193 | &credentials.EnvProvider{}, | |
194 | &credentials.SharedCredentialsProvider{Filename: "", Profile: ""}, | |
195 | &ec2rolecreds.EC2RoleProvider{ | |
196 | Client: ec2metadata.New(session.New(&aws.Config{ | |
197 | Endpoint: aws.String(metadataURL), | |
198 | })), | |
199 | }, | |
200 | }) | |
201 | } | |
202 | ||
15c0b25d AP |
203 | if creds != nil { |
204 | conf.Endpoint = &url.Host | |
205 | conf.S3ForcePathStyle = aws.Bool(true) | |
206 | if url.Scheme == "http" { | |
207 | conf.DisableSSL = aws.Bool(true) | |
208 | } | |
209 | } | |
210 | ||
bae9f6d2 JC |
211 | conf.Credentials = creds |
212 | if region != "" { | |
213 | conf.Region = aws.String(region) | |
214 | } | |
215 | ||
216 | return conf | |
217 | } | |
218 | ||
219 | func (g *S3Getter) parseUrl(u *url.URL) (region, bucket, path, version string, creds *credentials.Credentials, err error) { | |
15c0b25d AP |
220 | // This just check whether we are dealing with S3 or |
221 | // any other S3 compliant service. S3 has a predictable | |
222 | // url as others do not | |
223 | if strings.Contains(u.Host, "amazonaws.com") { | |
224 | // Expected host style: s3.amazonaws.com. They always have 3 parts, | |
225 | // although the first may differ if we're accessing a specific region. | |
226 | hostParts := strings.Split(u.Host, ".") | |
227 | if len(hostParts) != 3 { | |
228 | err = fmt.Errorf("URL is not a valid S3 URL") | |
229 | return | |
230 | } | |
bae9f6d2 | 231 | |
15c0b25d AP |
232 | // Parse the region out of the first part of the host |
233 | region = strings.TrimPrefix(strings.TrimPrefix(hostParts[0], "s3-"), "s3") | |
234 | if region == "" { | |
235 | region = "us-east-1" | |
236 | } | |
bae9f6d2 | 237 | |
15c0b25d AP |
238 | pathParts := strings.SplitN(u.Path, "/", 3) |
239 | if len(pathParts) != 3 { | |
240 | err = fmt.Errorf("URL is not a valid S3 URL") | |
241 | return | |
242 | } | |
243 | ||
244 | bucket = pathParts[1] | |
245 | path = pathParts[2] | |
246 | version = u.Query().Get("version") | |
bae9f6d2 | 247 | |
15c0b25d AP |
248 | } else { |
249 | pathParts := strings.SplitN(u.Path, "/", 3) | |
250 | if len(pathParts) != 3 { | |
251 | err = fmt.Errorf("URL is not a valid S3 complaint URL") | |
252 | return | |
253 | } | |
254 | bucket = pathParts[1] | |
255 | path = pathParts[2] | |
256 | version = u.Query().Get("version") | |
257 | region = u.Query().Get("region") | |
258 | if region == "" { | |
259 | region = "us-east-1" | |
260 | } | |
261 | } | |
bae9f6d2 JC |
262 | |
263 | _, hasAwsId := u.Query()["aws_access_key_id"] | |
264 | _, hasAwsSecret := u.Query()["aws_access_key_secret"] | |
265 | _, hasAwsToken := u.Query()["aws_access_token"] | |
266 | if hasAwsId || hasAwsSecret || hasAwsToken { | |
267 | creds = credentials.NewStaticCredentials( | |
268 | u.Query().Get("aws_access_key_id"), | |
269 | u.Query().Get("aws_access_key_secret"), | |
270 | u.Query().Get("aws_access_token"), | |
271 | ) | |
272 | } | |
273 | ||
274 | return | |
275 | } |