]> git.immae.eu Git - github/fretlink/terraform-provider-statuscake.git/blobdiff - vendor/github.com/hashicorp/go-getter/get_s3.go
Upgrade to 0.12
[github/fretlink/terraform-provider-statuscake.git] / vendor / github.com / hashicorp / go-getter / get_s3.go
index d3bffeb1737527efe813abbe08ed6811c631496e..93eeb0b817f6497d57b4bab097df88a500316b52 100644 (file)
@@ -1,8 +1,8 @@
 package getter
 
 import (
+       "context"
        "fmt"
-       "io"
        "net/url"
        "os"
        "path/filepath"
@@ -18,7 +18,9 @@ import (
 
 // S3Getter is a Getter implementation that will download a module from
 // a S3 bucket.
-type S3Getter struct{}
+type S3Getter struct {
+       getter
+}
 
 func (g *S3Getter) ClientMode(u *url.URL) (ClientMode, error) {
        // Parse URL
@@ -28,7 +30,7 @@ func (g *S3Getter) ClientMode(u *url.URL) (ClientMode, error) {
        }
 
        // Create client config
-       config := g.getAWSConfig(region, creds)
+       config := g.getAWSConfig(region, u, creds)
        sess := session.New(config)
        client := s3.New(sess)
 
@@ -60,6 +62,8 @@ func (g *S3Getter) ClientMode(u *url.URL) (ClientMode, error) {
 }
 
 func (g *S3Getter) Get(dst string, u *url.URL) error {
+       ctx := g.Context()
+
        // Parse URL
        region, bucket, path, _, creds, err := g.parseUrl(u)
        if err != nil {
@@ -84,7 +88,7 @@ func (g *S3Getter) Get(dst string, u *url.URL) error {
                return err
        }
 
-       config := g.getAWSConfig(region, creds)
+       config := g.getAWSConfig(region, u, creds)
        sess := session.New(config)
        client := s3.New(sess)
 
@@ -124,7 +128,7 @@ func (g *S3Getter) Get(dst string, u *url.URL) error {
                        }
                        objDst = filepath.Join(dst, objDst)
 
-                       if err := g.getObject(client, objDst, bucket, objPath, ""); err != nil {
+                       if err := g.getObject(ctx, client, objDst, bucket, objPath, ""); err != nil {
                                return err
                        }
                }
@@ -134,18 +138,19 @@ func (g *S3Getter) Get(dst string, u *url.URL) error {
 }
 
 func (g *S3Getter) GetFile(dst string, u *url.URL) error {
+       ctx := g.Context()
        region, bucket, path, version, creds, err := g.parseUrl(u)
        if err != nil {
                return err
        }
 
-       config := g.getAWSConfig(region, creds)
+       config := g.getAWSConfig(region, u, creds)
        sess := session.New(config)
        client := s3.New(sess)
-       return g.getObject(client, dst, bucket, path, version)
+       return g.getObject(ctx, client, dst, bucket, path, version)
 }
 
-func (g *S3Getter) getObject(client *s3.S3, dst, bucket, key, version string) error {
+func (g *S3Getter) getObject(ctx context.Context, client *s3.S3, dst, bucket, key, version string) error {
        req := &s3.GetObjectInput{
                Bucket: aws.String(bucket),
                Key:    aws.String(key),
@@ -170,11 +175,11 @@ func (g *S3Getter) getObject(client *s3.S3, dst, bucket, key, version string) er
        }
        defer f.Close()
 
-       _, err = io.Copy(f, resp.Body)
+       _, err = Copy(ctx, f, resp.Body)
        return err
 }
 
-func (g *S3Getter) getAWSConfig(region string, creds *credentials.Credentials) *aws.Config {
+func (g *S3Getter) getAWSConfig(region string, url *url.URL, creds *credentials.Credentials) *aws.Config {
        conf := &aws.Config{}
        if creds == nil {
                // Grab the metadata URL
@@ -195,6 +200,14 @@ func (g *S3Getter) getAWSConfig(region string, creds *credentials.Credentials) *
                        })
        }
 
+       if creds != nil {
+               conf.Endpoint = &url.Host
+               conf.S3ForcePathStyle = aws.Bool(true)
+               if url.Scheme == "http" {
+                       conf.DisableSSL = aws.Bool(true)
+               }
+       }
+
        conf.Credentials = creds
        if region != "" {
                conf.Region = aws.String(region)
@@ -204,29 +217,48 @@ func (g *S3Getter) getAWSConfig(region string, creds *credentials.Credentials) *
 }
 
 func (g *S3Getter) parseUrl(u *url.URL) (region, bucket, path, version string, creds *credentials.Credentials, err error) {
-       // Expected host style: s3.amazonaws.com. They always have 3 parts,
-       // although the first may differ if we're accessing a specific region.
-       hostParts := strings.Split(u.Host, ".")
-       if len(hostParts) != 3 {
-               err = fmt.Errorf("URL is not a valid S3 URL")
-               return
-       }
+       // This just check whether we are dealing with S3 or
+       // any other S3 compliant service. S3 has a predictable
+       // url as others do not
+       if strings.Contains(u.Host, "amazonaws.com") {
+               // Expected host style: s3.amazonaws.com. They always have 3 parts,
+               // although the first may differ if we're accessing a specific region.
+               hostParts := strings.Split(u.Host, ".")
+               if len(hostParts) != 3 {
+                       err = fmt.Errorf("URL is not a valid S3 URL")
+                       return
+               }
 
-       // Parse the region out of the first part of the host
-       region = strings.TrimPrefix(strings.TrimPrefix(hostParts[0], "s3-"), "s3")
-       if region == "" {
-               region = "us-east-1"
-       }
+               // Parse the region out of the first part of the host
+               region = strings.TrimPrefix(strings.TrimPrefix(hostParts[0], "s3-"), "s3")
+               if region == "" {
+                       region = "us-east-1"
+               }
 
-       pathParts := strings.SplitN(u.Path, "/", 3)
-       if len(pathParts) != 3 {
-               err = fmt.Errorf("URL is not a valid S3 URL")
-               return
-       }
+               pathParts := strings.SplitN(u.Path, "/", 3)
+               if len(pathParts) != 3 {
+                       err = fmt.Errorf("URL is not a valid S3 URL")
+                       return
+               }
+
+               bucket = pathParts[1]
+               path = pathParts[2]
+               version = u.Query().Get("version")
 
-       bucket = pathParts[1]
-       path = pathParts[2]
-       version = u.Query().Get("version")
+       } else {
+               pathParts := strings.SplitN(u.Path, "/", 3)
+               if len(pathParts) != 3 {
+                       err = fmt.Errorf("URL is not a valid S3 complaint URL")
+                       return
+               }
+               bucket = pathParts[1]
+               path = pathParts[2]
+               version = u.Query().Get("version")
+               region = u.Query().Get("region")
+               if region == "" {
+                       region = "us-east-1"
+               }
+       }
 
        _, hasAwsId := u.Query()["aws_access_key_id"]
        _, hasAwsSecret := u.Query()["aws_access_key_secret"]