package getter
import (
+ "context"
"fmt"
- "io"
"net/url"
"os"
"path/filepath"
// S3Getter is a Getter implementation that will download a module from
// a S3 bucket.
-type S3Getter struct{}
+type S3Getter struct {
+ getter
+}
func (g *S3Getter) ClientMode(u *url.URL) (ClientMode, error) {
// Parse URL
}
// Create client config
- config := g.getAWSConfig(region, creds)
+ config := g.getAWSConfig(region, u, creds)
sess := session.New(config)
client := s3.New(sess)
}
func (g *S3Getter) Get(dst string, u *url.URL) error {
+ ctx := g.Context()
+
// Parse URL
region, bucket, path, _, creds, err := g.parseUrl(u)
if err != nil {
return err
}
- config := g.getAWSConfig(region, creds)
+ config := g.getAWSConfig(region, u, creds)
sess := session.New(config)
client := s3.New(sess)
}
objDst = filepath.Join(dst, objDst)
- if err := g.getObject(client, objDst, bucket, objPath, ""); err != nil {
+ if err := g.getObject(ctx, client, objDst, bucket, objPath, ""); err != nil {
return err
}
}
}
func (g *S3Getter) GetFile(dst string, u *url.URL) error {
+ ctx := g.Context()
region, bucket, path, version, creds, err := g.parseUrl(u)
if err != nil {
return err
}
- config := g.getAWSConfig(region, creds)
+ config := g.getAWSConfig(region, u, creds)
sess := session.New(config)
client := s3.New(sess)
- return g.getObject(client, dst, bucket, path, version)
+ return g.getObject(ctx, client, dst, bucket, path, version)
}
-func (g *S3Getter) getObject(client *s3.S3, dst, bucket, key, version string) error {
+func (g *S3Getter) getObject(ctx context.Context, client *s3.S3, dst, bucket, key, version string) error {
req := &s3.GetObjectInput{
Bucket: aws.String(bucket),
Key: aws.String(key),
}
defer f.Close()
- _, err = io.Copy(f, resp.Body)
+ _, err = Copy(ctx, f, resp.Body)
return err
}
-func (g *S3Getter) getAWSConfig(region string, creds *credentials.Credentials) *aws.Config {
+func (g *S3Getter) getAWSConfig(region string, url *url.URL, creds *credentials.Credentials) *aws.Config {
conf := &aws.Config{}
if creds == nil {
// Grab the metadata URL
})
}
+ if creds != nil {
+ conf.Endpoint = &url.Host
+ conf.S3ForcePathStyle = aws.Bool(true)
+ if url.Scheme == "http" {
+ conf.DisableSSL = aws.Bool(true)
+ }
+ }
+
conf.Credentials = creds
if region != "" {
conf.Region = aws.String(region)
}
func (g *S3Getter) parseUrl(u *url.URL) (region, bucket, path, version string, creds *credentials.Credentials, err error) {
- // Expected host style: s3.amazonaws.com. They always have 3 parts,
- // although the first may differ if we're accessing a specific region.
- hostParts := strings.Split(u.Host, ".")
- if len(hostParts) != 3 {
- err = fmt.Errorf("URL is not a valid S3 URL")
- return
- }
+ // This just check whether we are dealing with S3 or
+ // any other S3 compliant service. S3 has a predictable
+ // url as others do not
+ if strings.Contains(u.Host, "amazonaws.com") {
+ // Expected host style: s3.amazonaws.com. They always have 3 parts,
+ // although the first may differ if we're accessing a specific region.
+ hostParts := strings.Split(u.Host, ".")
+ if len(hostParts) != 3 {
+ err = fmt.Errorf("URL is not a valid S3 URL")
+ return
+ }
- // Parse the region out of the first part of the host
- region = strings.TrimPrefix(strings.TrimPrefix(hostParts[0], "s3-"), "s3")
- if region == "" {
- region = "us-east-1"
- }
+ // Parse the region out of the first part of the host
+ region = strings.TrimPrefix(strings.TrimPrefix(hostParts[0], "s3-"), "s3")
+ if region == "" {
+ region = "us-east-1"
+ }
- pathParts := strings.SplitN(u.Path, "/", 3)
- if len(pathParts) != 3 {
- err = fmt.Errorf("URL is not a valid S3 URL")
- return
- }
+ pathParts := strings.SplitN(u.Path, "/", 3)
+ if len(pathParts) != 3 {
+ err = fmt.Errorf("URL is not a valid S3 URL")
+ return
+ }
+
+ bucket = pathParts[1]
+ path = pathParts[2]
+ version = u.Query().Get("version")
- bucket = pathParts[1]
- path = pathParts[2]
- version = u.Query().Get("version")
+ } else {
+ pathParts := strings.SplitN(u.Path, "/", 3)
+ if len(pathParts) != 3 {
+ err = fmt.Errorf("URL is not a valid S3 complaint URL")
+ return
+ }
+ bucket = pathParts[1]
+ path = pathParts[2]
+ version = u.Query().Get("version")
+ region = u.Query().Get("region")
+ if region == "" {
+ region = "us-east-1"
+ }
+ }
_, hasAwsId := u.Query()["aws_access_key_id"]
_, hasAwsSecret := u.Query()["aws_access_key_secret"]