X-Git-Url: https://git.immae.eu/?a=blobdiff_plain;f=vendor%2Fgithub.com%2Fhashicorp%2Fgo-getter%2Fget_s3.go;h=93eeb0b817f6497d57b4bab097df88a500316b52;hb=107c1cdb09c575aa2f61d97f48d8587eb6bada4c;hp=d3bffeb1737527efe813abbe08ed6811c631496e;hpb=cec3de8a3bcaffd21dedd1bf42da4b490cae7e16;p=github%2Ffretlink%2Fterraform-provider-statuscake.git diff --git a/vendor/github.com/hashicorp/go-getter/get_s3.go b/vendor/github.com/hashicorp/go-getter/get_s3.go index d3bffeb..93eeb0b 100644 --- a/vendor/github.com/hashicorp/go-getter/get_s3.go +++ b/vendor/github.com/hashicorp/go-getter/get_s3.go @@ -1,8 +1,8 @@ package getter import ( + "context" "fmt" - "io" "net/url" "os" "path/filepath" @@ -18,7 +18,9 @@ import ( // S3Getter is a Getter implementation that will download a module from // a S3 bucket. -type S3Getter struct{} +type S3Getter struct { + getter +} func (g *S3Getter) ClientMode(u *url.URL) (ClientMode, error) { // Parse URL @@ -28,7 +30,7 @@ func (g *S3Getter) ClientMode(u *url.URL) (ClientMode, error) { } // Create client config - config := g.getAWSConfig(region, creds) + config := g.getAWSConfig(region, u, creds) sess := session.New(config) client := s3.New(sess) @@ -60,6 +62,8 @@ func (g *S3Getter) ClientMode(u *url.URL) (ClientMode, error) { } func (g *S3Getter) Get(dst string, u *url.URL) error { + ctx := g.Context() + // Parse URL region, bucket, path, _, creds, err := g.parseUrl(u) if err != nil { @@ -84,7 +88,7 @@ func (g *S3Getter) Get(dst string, u *url.URL) error { return err } - config := g.getAWSConfig(region, creds) + config := g.getAWSConfig(region, u, creds) sess := session.New(config) client := s3.New(sess) @@ -124,7 +128,7 @@ func (g *S3Getter) Get(dst string, u *url.URL) error { } objDst = filepath.Join(dst, objDst) - if err := g.getObject(client, objDst, bucket, objPath, ""); err != nil { + if err := g.getObject(ctx, client, objDst, bucket, objPath, ""); err != nil { return err } } @@ -134,18 +138,19 @@ func (g *S3Getter) Get(dst string, u *url.URL) error { } func (g *S3Getter) GetFile(dst string, u *url.URL) error { + ctx := g.Context() region, bucket, path, version, creds, err := g.parseUrl(u) if err != nil { return err } - config := g.getAWSConfig(region, creds) + config := g.getAWSConfig(region, u, creds) sess := session.New(config) client := s3.New(sess) - return g.getObject(client, dst, bucket, path, version) + return g.getObject(ctx, client, dst, bucket, path, version) } -func (g *S3Getter) getObject(client *s3.S3, dst, bucket, key, version string) error { +func (g *S3Getter) getObject(ctx context.Context, client *s3.S3, dst, bucket, key, version string) error { req := &s3.GetObjectInput{ Bucket: aws.String(bucket), Key: aws.String(key), @@ -170,11 +175,11 @@ func (g *S3Getter) getObject(client *s3.S3, dst, bucket, key, version string) er } defer f.Close() - _, err = io.Copy(f, resp.Body) + _, err = Copy(ctx, f, resp.Body) return err } -func (g *S3Getter) getAWSConfig(region string, creds *credentials.Credentials) *aws.Config { +func (g *S3Getter) getAWSConfig(region string, url *url.URL, creds *credentials.Credentials) *aws.Config { conf := &aws.Config{} if creds == nil { // Grab the metadata URL @@ -195,6 +200,14 @@ func (g *S3Getter) getAWSConfig(region string, creds *credentials.Credentials) * }) } + if creds != nil { + conf.Endpoint = &url.Host + conf.S3ForcePathStyle = aws.Bool(true) + if url.Scheme == "http" { + conf.DisableSSL = aws.Bool(true) + } + } + conf.Credentials = creds if region != "" { conf.Region = aws.String(region) @@ -204,29 +217,48 @@ func (g *S3Getter) getAWSConfig(region string, creds *credentials.Credentials) * } func (g *S3Getter) parseUrl(u *url.URL) (region, bucket, path, version string, creds *credentials.Credentials, err error) { - // Expected host style: s3.amazonaws.com. They always have 3 parts, - // although the first may differ if we're accessing a specific region. - hostParts := strings.Split(u.Host, ".") - if len(hostParts) != 3 { - err = fmt.Errorf("URL is not a valid S3 URL") - return - } + // This just check whether we are dealing with S3 or + // any other S3 compliant service. S3 has a predictable + // url as others do not + if strings.Contains(u.Host, "amazonaws.com") { + // Expected host style: s3.amazonaws.com. They always have 3 parts, + // although the first may differ if we're accessing a specific region. + hostParts := strings.Split(u.Host, ".") + if len(hostParts) != 3 { + err = fmt.Errorf("URL is not a valid S3 URL") + return + } - // Parse the region out of the first part of the host - region = strings.TrimPrefix(strings.TrimPrefix(hostParts[0], "s3-"), "s3") - if region == "" { - region = "us-east-1" - } + // Parse the region out of the first part of the host + region = strings.TrimPrefix(strings.TrimPrefix(hostParts[0], "s3-"), "s3") + if region == "" { + region = "us-east-1" + } - pathParts := strings.SplitN(u.Path, "/", 3) - if len(pathParts) != 3 { - err = fmt.Errorf("URL is not a valid S3 URL") - return - } + pathParts := strings.SplitN(u.Path, "/", 3) + if len(pathParts) != 3 { + err = fmt.Errorf("URL is not a valid S3 URL") + return + } + + bucket = pathParts[1] + path = pathParts[2] + version = u.Query().Get("version") - bucket = pathParts[1] - path = pathParts[2] - version = u.Query().Get("version") + } else { + pathParts := strings.SplitN(u.Path, "/", 3) + if len(pathParts) != 3 { + err = fmt.Errorf("URL is not a valid S3 complaint URL") + return + } + bucket = pathParts[1] + path = pathParts[2] + version = u.Query().Get("version") + region = u.Query().Get("region") + if region == "" { + region = "us-east-1" + } + } _, hasAwsId := u.Query()["aws_access_key_id"] _, hasAwsSecret := u.Query()["aws_access_key_secret"]