package getter
import (
- "bytes"
- "crypto/md5"
- "crypto/sha1"
- "crypto/sha256"
- "crypto/sha512"
- "encoding/hex"
+ "context"
"fmt"
- "hash"
- "io"
"io/ioutil"
"os"
"path/filepath"
"strings"
urlhelper "github.com/hashicorp/go-getter/helper/url"
- "github.com/hashicorp/go-safetemp"
+ safetemp "github.com/hashicorp/go-safetemp"
)
// Client is a client for downloading things.
// Using a client directly allows more fine-grained control over how downloading
// is done, as well as customizing the protocols supported.
type Client struct {
+ // Ctx for cancellation
+ Ctx context.Context
+
// Src is the source URL to get.
//
// Dst is the path to save the downloaded thing as. If Dir is set to
//
// WARNING: deprecated. If Mode is set, that will take precedence.
Dir bool
+
+ // ProgressListener allows to track file downloads.
+ // By default a no op progress listener is used.
+ ProgressListener ProgressTracker
+
+ Options []ClientOption
}
// Get downloads the configured source to the destination.
func (c *Client) Get() error {
+ if err := c.Configure(c.Options...); err != nil {
+ return err
+ }
+
// Store this locally since there are cases we swap this
mode := c.Mode
if mode == ClientModeInvalid {
}
}
- // Default decompressor value
- decompressors := c.Decompressors
- if decompressors == nil {
- decompressors = Decompressors
- }
-
- // Detect the URL. This is safe if it is already detected.
- detectors := c.Detectors
- if detectors == nil {
- detectors = Detectors
- }
- src, err := Detect(c.Src, c.Pwd, detectors)
+ src, err := Detect(c.Src, c.Pwd, c.Detectors)
if err != nil {
return err
}
force = u.Scheme
}
- getters := c.Getters
- if getters == nil {
- getters = Getters
- }
-
- g, ok := getters[force]
+ g, ok := c.Getters[force]
if !ok {
return fmt.Errorf(
"download not supported for scheme '%s'", force)
if archiveV == "" {
// We don't appear to... but is it part of the filename?
matchingLen := 0
- for k, _ := range decompressors {
+ for k := range c.Decompressors {
if strings.HasSuffix(u.Path, "."+k) && len(k) > matchingLen {
archiveV = k
matchingLen = len(k)
// real path.
var decompressDst string
var decompressDir bool
- decompressor := decompressors[archiveV]
+ decompressor := c.Decompressors[archiveV]
if decompressor != nil {
// Create a temporary directory to store our archive. We delete
// this at the end of everything.
mode = ClientModeFile
}
- // Determine if we have a checksum
- var checksumHash hash.Hash
- var checksumValue []byte
- if v := q.Get("checksum"); v != "" {
- // Delete the query parameter if we have it.
- q.Del("checksum")
- u.RawQuery = q.Encode()
-
- // Determine the checksum hash type
- checksumType := ""
- idx := strings.Index(v, ":")
- if idx > -1 {
- checksumType = v[:idx]
- }
- switch checksumType {
- case "md5":
- checksumHash = md5.New()
- case "sha1":
- checksumHash = sha1.New()
- case "sha256":
- checksumHash = sha256.New()
- case "sha512":
- checksumHash = sha512.New()
- default:
- return fmt.Errorf(
- "unsupported checksum type: %s", checksumType)
- }
-
- // Get the remainder of the value and parse it into bytes
- b, err := hex.DecodeString(v[idx+1:])
- if err != nil {
- return fmt.Errorf("invalid checksum: %s", err)
- }
-
- // Set our value
- checksumValue = b
+ // Determine checksum if we have one
+ checksum, err := c.extractChecksum(u)
+ if err != nil {
+ return fmt.Errorf("invalid checksum: %s", err)
}
+ // Delete the query parameter if we have it.
+ q.Del("checksum")
+ u.RawQuery = q.Encode()
+
if mode == ClientModeAny {
// Ask the getter which client mode to use
mode, err = g.ClientMode(u)
// If we're not downloading a directory, then just download the file
// and return.
if mode == ClientModeFile {
- err := g.GetFile(dst, u)
- if err != nil {
- return err
+ getFile := true
+ if checksum != nil {
+ if err := checksum.checksum(dst); err == nil {
+ // don't get the file if the checksum of dst is correct
+ getFile = false
+ }
}
-
- if checksumHash != nil {
- if err := checksum(dst, checksumHash, checksumValue); err != nil {
+ if getFile {
+ err := g.GetFile(dst, u)
+ if err != nil {
return err
}
+
+ if checksum != nil {
+ if err := checksum.checksum(dst); err != nil {
+ return err
+ }
+ }
}
if decompressor != nil {
if decompressor == nil {
// If we're getting a directory, then this is an error. You cannot
// checksum a directory. TODO: test
- if checksumHash != nil {
+ if checksum != nil {
return fmt.Errorf(
"checksum cannot be specified for directory download")
}
return err
}
- return copyDir(realDst, subDir, false)
- }
-
- return nil
-}
-
-// checksum is a simple method to compute the checksum of a source file
-// and compare it to the given expected value.
-func checksum(source string, h hash.Hash, v []byte) error {
- f, err := os.Open(source)
- if err != nil {
- return fmt.Errorf("Failed to open file for checksum: %s", err)
- }
- defer f.Close()
-
- if _, err := io.Copy(h, f); err != nil {
- return fmt.Errorf("Failed to hash: %s", err)
- }
-
- if actual := h.Sum(nil); !bytes.Equal(actual, v) {
- return fmt.Errorf(
- "Checksums did not match.\nExpected: %s\nGot: %s",
- hex.EncodeToString(v),
- hex.EncodeToString(actual))
+ return copyDir(c.Ctx, realDst, subDir, false)
}
return nil