]> git.immae.eu Git - github/fretlink/terraform-provider-statuscake.git/blobdiff - vendor/github.com/hashicorp/go-plugin/client.go
Upgrade to 0.12
[github/fretlink/terraform-provider-statuscake.git] / vendor / github.com / hashicorp / go-plugin / client.go
index b3e3b78eab7a6ef80eb954c6e094bab92477761c..679e10ad7591468985f3866ef6f9c93188848edd 100644 (file)
@@ -5,12 +5,13 @@ import (
        "context"
        "crypto/subtle"
        "crypto/tls"
+       "crypto/x509"
+       "encoding/base64"
        "errors"
        "fmt"
        "hash"
        "io"
        "io/ioutil"
-       "log"
        "net"
        "os"
        "os/exec"
@@ -20,7 +21,6 @@ import (
        "sync"
        "sync/atomic"
        "time"
-       "unicode"
 
        hclog "github.com/hashicorp/go-hclog"
 )
@@ -71,16 +71,31 @@ var (
 //
 // See NewClient and ClientConfig for using a Client.
 type Client struct {
-       config      *ClientConfig
-       exited      bool
-       doneLogging chan struct{}
-       l           sync.Mutex
-       address     net.Addr
-       process     *os.Process
-       client      ClientProtocol
-       protocol    Protocol
-       logger      hclog.Logger
-       doneCtx     context.Context
+       config            *ClientConfig
+       exited            bool
+       l                 sync.Mutex
+       address           net.Addr
+       process           *os.Process
+       client            ClientProtocol
+       protocol          Protocol
+       logger            hclog.Logger
+       doneCtx           context.Context
+       ctxCancel         context.CancelFunc
+       negotiatedVersion int
+
+       // clientWaitGroup is used to manage the lifecycle of the plugin management
+       // goroutines.
+       clientWaitGroup sync.WaitGroup
+
+       // processKilled is used for testing only, to flag when the process was
+       // forcefully killed.
+       processKilled bool
+}
+
+// NegotiatedVersion returns the protocol version negotiated with the server.
+// This is only valid after Start() is called.
+func (c *Client) NegotiatedVersion() int {
+       return c.negotiatedVersion
 }
 
 // ClientConfig is the configuration used to initialize a new
@@ -91,7 +106,13 @@ type ClientConfig struct {
        HandshakeConfig
 
        // Plugins are the plugins that can be consumed.
-       Plugins map[string]Plugin
+       // The implied version of this PluginSet is the Handshake.ProtocolVersion.
+       Plugins PluginSet
+
+       // VersionedPlugins is a map of PluginSets for specific protocol versions.
+       // These can be used to negotiate a compatible version between client and
+       // server. If this is set, Handshake.ProtocolVersion is not required.
+       VersionedPlugins map[int]PluginSet
 
        // One of the following must be set, but not both.
        //
@@ -158,6 +179,29 @@ type ClientConfig struct {
        // Logger is the logger that the client will used. If none is provided,
        // it will default to hclog's default logger.
        Logger hclog.Logger
+
+       // AutoMTLS has the client and server automatically negotiate mTLS for
+       // transport authentication. This ensures that only the original client will
+       // be allowed to connect to the server, and all other connections will be
+       // rejected. The client will also refuse to connect to any server that isn't
+       // the original instance started by the client.
+       //
+       // In this mode of operation, the client generates a one-time use tls
+       // certificate, sends the public x.509 certificate to the new server, and
+       // the server generates a one-time use tls certificate, and sends the public
+       // x.509 certificate back to the client. These are used to authenticate all
+       // rpc connections between the client and server.
+       //
+       // Setting AutoMTLS to true implies that the server must support the
+       // protocol, and correctly negotiate the tls certificates, or a connection
+       // failure will result.
+       //
+       // The client should not set TLSConfig, nor should the server set a
+       // TLSProvider, because AutoMTLS implies that a new certificate and tls
+       // configuration will be generated at startup.
+       //
+       // You cannot Reattach to a server with this option enabled.
+       AutoMTLS bool
 }
 
 // ReattachConfig is used to configure a client to reattach to an
@@ -234,7 +278,6 @@ func CleanupClients() {
        }
        managedClientsLock.Unlock()
 
-       log.Println("[DEBUG] plugin: waiting for all plugin processes to complete...")
        wg.Wait()
 }
 
@@ -333,6 +376,14 @@ func (c *Client) Exited() bool {
        return c.exited
 }
 
+// killed is used in tests to check if a process failed to exit gracefully, and
+// needed to be killed.
+func (c *Client) killed() bool {
+       c.l.Lock()
+       defer c.l.Unlock()
+       return c.processKilled
+}
+
 // End the executing subprocess (if it is running) and perform any cleanup
 // tasks necessary such as capturing any remaining logs and so on.
 //
@@ -344,14 +395,24 @@ func (c *Client) Kill() {
        c.l.Lock()
        process := c.process
        addr := c.address
-       doneCh := c.doneLogging
        c.l.Unlock()
 
-       // If there is no process, we never started anything. Nothing to kill.
+       // If there is no process, there is nothing to kill.
        if process == nil {
                return
        }
 
+       defer func() {
+               // Wait for the all client goroutines to finish.
+               c.clientWaitGroup.Wait()
+
+               // Make sure there is no reference to the old process after it has been
+               // killed.
+               c.l.Lock()
+               c.process = nil
+               c.l.Unlock()
+       }()
+
        // We need to check for address here. It is possible that the plugin
        // started (process != nil) but has no address (addr == nil) if the
        // plugin failed at startup. If we do have an address, we need to close
@@ -372,6 +433,8 @@ func (c *Client) Kill() {
                                // kill in a moment anyways.
                                c.logger.Warn("error closing client during Kill", "err", err)
                        }
+               } else {
+                       c.logger.Error("client", "error", err)
                }
        }
 
@@ -380,17 +443,20 @@ func (c *Client) Kill() {
        // doneCh which would be closed if the process exits.
        if graceful {
                select {
-               case <-doneCh:
+               case <-c.doneCtx.Done():
+                       c.logger.Debug("plugin exited")
                        return
-               case <-time.After(250 * time.Millisecond):
+               case <-time.After(2 * time.Second):
                }
        }
 
        // If graceful exiting failed, just kill it
+       c.logger.Warn("plugin failed to exit gracefully")
        process.Kill()
 
-       // Wait for the client to finish logging so we have a complete log
-       <-doneCh
+       c.l.Lock()
+       c.processKilled = true
+       c.l.Unlock()
 }
 
 // Starts the underlying subprocess, communicating with it to negotiate
@@ -409,7 +475,7 @@ func (c *Client) Start() (addr net.Addr, err error) {
 
        // If one of cmd or reattach isn't set, then it is an error. We wrap
        // this in a {} for scoping reasons, and hopeful that the escape
-       // analysis will pop the stock here.
+       // analysis will pop the stack here.
        {
                cmdSet := c.config.Cmd != nil
                attachSet := c.config.Reattach != nil
@@ -423,77 +489,49 @@ func (c *Client) Start() (addr net.Addr, err error) {
                }
        }
 
-       // Create the logging channel for when we kill
-       c.doneLogging = make(chan struct{})
-       // Create a context for when we kill
-       var ctxCancel context.CancelFunc
-       c.doneCtx, ctxCancel = context.WithCancel(context.Background())
-
        if c.config.Reattach != nil {
-               // Verify the process still exists. If not, then it is an error
-               p, err := os.FindProcess(c.config.Reattach.Pid)
-               if err != nil {
-                       return nil, err
-               }
+               return c.reattach()
+       }
 
-               // Attempt to connect to the addr since on Unix systems FindProcess
-               // doesn't actually return an error if it can't find the process.
-               conn, err := net.Dial(
-                       c.config.Reattach.Addr.Network(),
-                       c.config.Reattach.Addr.String())
-               if err != nil {
-                       p.Kill()
-                       return nil, ErrProcessNotFound
-               }
-               conn.Close()
-
-               // Goroutine to mark exit status
-               go func(pid int) {
-                       // Wait for the process to die
-                       pidWait(pid)
-
-                       // Log so we can see it
-                       c.logger.Debug("reattached plugin process exited")
-
-                       // Mark it
-                       c.l.Lock()
-                       defer c.l.Unlock()
-                       c.exited = true
-
-                       // Close the logging channel since that doesn't work on reattach
-                       close(c.doneLogging)
-
-                       // Cancel the context
-                       ctxCancel()
-               }(p.Pid)
-
-               // Set the address and process
-               c.address = c.config.Reattach.Addr
-               c.process = p
-               c.protocol = c.config.Reattach.Protocol
-               if c.protocol == "" {
-                       // Default the protocol to net/rpc for backwards compatibility
-                       c.protocol = ProtocolNetRPC
-               }
+       if c.config.VersionedPlugins == nil {
+               c.config.VersionedPlugins = make(map[int]PluginSet)
+       }
 
-               return c.address, nil
+       // handle all plugins as versioned, using the handshake config as the default.
+       version := int(c.config.ProtocolVersion)
+
+       // Make sure we're not overwriting a real version 0. If ProtocolVersion was
+       // non-zero, then we have to just assume the user made sure that
+       // VersionedPlugins doesn't conflict.
+       if _, ok := c.config.VersionedPlugins[version]; !ok && c.config.Plugins != nil {
+               c.config.VersionedPlugins[version] = c.config.Plugins
+       }
+
+       var versionStrings []string
+       for v := range c.config.VersionedPlugins {
+               versionStrings = append(versionStrings, strconv.Itoa(v))
        }
 
        env := []string{
                fmt.Sprintf("%s=%s", c.config.MagicCookieKey, c.config.MagicCookieValue),
                fmt.Sprintf("PLUGIN_MIN_PORT=%d", c.config.MinPort),
                fmt.Sprintf("PLUGIN_MAX_PORT=%d", c.config.MaxPort),
+               fmt.Sprintf("PLUGIN_PROTOCOL_VERSIONS=%s", strings.Join(versionStrings, ",")),
        }
 
-       stdout_r, stdout_w := io.Pipe()
-       stderr_r, stderr_w := io.Pipe()
-
        cmd := c.config.Cmd
        cmd.Env = append(cmd.Env, os.Environ()...)
        cmd.Env = append(cmd.Env, env...)
        cmd.Stdin = os.Stdin
-       cmd.Stderr = stderr_w
-       cmd.Stdout = stdout_w
+
+       cmdStdout, err := cmd.StdoutPipe()
+       if err != nil {
+               return nil, err
+       }
+       cmdStderr, err := cmd.StderrPipe()
+       if err != nil {
+               return nil, err
+       }
 
        if c.config.SecureConfig != nil {
                if ok, err := c.config.SecureConfig.Check(cmd.Path); err != nil {
@@ -503,6 +541,29 @@ func (c *Client) Start() (addr net.Addr, err error) {
                }
        }
 
+       // Setup a temporary certificate for client/server mtls, and send the public
+       // certificate to the plugin.
+       if c.config.AutoMTLS {
+               c.logger.Info("configuring client automatic mTLS")
+               certPEM, keyPEM, err := generateCert()
+               if err != nil {
+                       c.logger.Error("failed to generate client certificate", "error", err)
+                       return nil, err
+               }
+               cert, err := tls.X509KeyPair(certPEM, keyPEM)
+               if err != nil {
+                       c.logger.Error("failed to parse client certificate", "error", err)
+                       return nil, err
+               }
+
+               cmd.Env = append(cmd.Env, fmt.Sprintf("PLUGIN_CLIENT_CERT=%s", certPEM))
+
+               c.config.TLSConfig = &tls.Config{
+                       Certificates: []tls.Certificate{cert},
+                       ServerName:   "localhost",
+               }
+       }
+
        c.logger.Debug("starting plugin", "path", cmd.Path, "args", cmd.Args)
        err = cmd.Start()
        if err != nil {
@@ -511,6 +572,7 @@ func (c *Client) Start() (addr net.Addr, err error) {
 
        // Set the process
        c.process = cmd.Process
+       c.logger.Debug("plugin started", "path", cmd.Path, "pid", c.process.Pid)
 
        // Make sure the command is properly cleaned up if there is an error
        defer func() {
@@ -525,27 +587,37 @@ func (c *Client) Start() (addr net.Addr, err error) {
                }
        }()
 
-       // Start goroutine to wait for process to exit
-       exitCh := make(chan struct{})
+       // Create a context for when we kill
+       c.doneCtx, c.ctxCancel = context.WithCancel(context.Background())
+
+       c.clientWaitGroup.Add(1)
        go func() {
-               // Make sure we close the write end of our stderr/stdout so
-               // that the readers send EOF properly.
-               defer stderr_w.Close()
-               defer stdout_w.Close()
+               // ensure the context is cancelled when we're done
+               defer c.ctxCancel()
+
+               defer c.clientWaitGroup.Done()
+
+               // get the cmd info early, since the process information will be removed
+               // in Kill.
+               pid := c.process.Pid
+               path := cmd.Path
 
                // Wait for the command to end.
-               cmd.Wait()
+               err := cmd.Wait()
+
+               debugMsgArgs := []interface{}{
+                       "path", path,
+                       "pid", pid,
+               }
+               if err != nil {
+                       debugMsgArgs = append(debugMsgArgs,
+                               []interface{}{"error", err.Error()}...)
+               }
 
                // Log and make sure to flush the logs write away
-               c.logger.Debug("plugin process exited", "path", cmd.Path)
+               c.logger.Debug("plugin process exited", debugMsgArgs...)
                os.Stderr.Sync()
 
-               // Mark that we exited
-               close(exitCh)
-
-               // Cancel the context, marking that we exited
-               ctxCancel()
-
                // Set that we exited, which takes a lock
                c.l.Lock()
                defer c.l.Unlock()
@@ -553,32 +625,33 @@ func (c *Client) Start() (addr net.Addr, err error) {
        }()
 
        // Start goroutine that logs the stderr
-       go c.logStderr(stderr_r)
+       c.clientWaitGroup.Add(1)
+       // logStderr calls Done()
+       go c.logStderr(cmdStderr)
 
        // Start a goroutine that is going to be reading the lines
        // out of stdout
-       linesCh := make(chan []byte)
+       linesCh := make(chan string)
+       c.clientWaitGroup.Add(1)
        go func() {
+               defer c.clientWaitGroup.Done()
                defer close(linesCh)
 
-               buf := bufio.NewReader(stdout_r)
-               for {
-                       line, err := buf.ReadBytes('\n')
-                       if line != nil {
-                               linesCh <- line
-                       }
-
-                       if err == io.EOF {
-                               return
-                       }
+               scanner := bufio.NewScanner(cmdStdout)
+               for scanner.Scan() {
+                       linesCh <- scanner.Text()
                }
        }()
 
        // Make sure after we exit we read the lines from stdout forever
-       // so they don't block since it is an io.Pipe
+       // so they don't block since it is a pipe.
+       // The scanner goroutine above will close this, but track it with a wait
+       // group for completeness.
+       c.clientWaitGroup.Add(1)
        defer func() {
                go func() {
-                       for _ = range linesCh {
+                       defer c.clientWaitGroup.Done()
+                       for range linesCh {
                        }
                }()
        }()
@@ -591,12 +664,12 @@ func (c *Client) Start() (addr net.Addr, err error) {
        select {
        case <-timeout:
                err = errors.New("timeout while waiting for plugin to start")
-       case <-exitCh:
+       case <-c.doneCtx.Done():
                err = errors.New("plugin exited before we could connect")
-       case lineBytes := <-linesCh:
+       case line := <-linesCh:
                // Trim the line and split by "|" in order to get the parts of
                // the output.
-               line := strings.TrimSpace(string(lineBytes))
+               line = strings.TrimSpace(line)
                parts := strings.SplitN(line, "|", 6)
                if len(parts) < 4 {
                        err = fmt.Errorf(
@@ -624,20 +697,18 @@ func (c *Client) Start() (addr net.Addr, err error) {
                        }
                }
 
-               // Parse the protocol version
-               var protocol int64
-               protocol, err = strconv.ParseInt(parts[1], 10, 0)
+               // Test the API version
+               version, pluginSet, err := c.checkProtoVersion(parts[1])
                if err != nil {
-                       err = fmt.Errorf("Error parsing protocol version: %s", err)
-                       return
+                       return addr, err
                }
 
-               // Test the API version
-               if uint(protocol) != c.config.ProtocolVersion {
-                       err = fmt.Errorf("Incompatible API version with plugin. "+
-                               "Plugin version: %s, Core version: %d", parts[1], c.config.ProtocolVersion)
-                       return
-               }
+               // set the Plugins value to the compatible set, so the version
+               // doesn't need to be passed through to the ClientProtocol
+               // implementation.
+               c.config.Plugins = pluginSet
+               c.negotiatedVersion = version
+               c.logger.Debug("using plugin", "version", version)
 
                switch parts[2] {
                case "tcp":
@@ -665,15 +736,125 @@ func (c *Client) Start() (addr net.Addr, err error) {
                if !found {
                        err = fmt.Errorf("Unsupported plugin protocol %q. Supported: %v",
                                c.protocol, c.config.AllowedProtocols)
-                       return
+                       return addr, err
                }
 
+               // See if we have a TLS certificate from the server.
+               // Checking if the length is > 50 rules out catching the unused "extra"
+               // data returned from some older implementations.
+               if len(parts) >= 6 && len(parts[5]) > 50 {
+                       err := c.loadServerCert(parts[5])
+                       if err != nil {
+                               return nil, fmt.Errorf("error parsing server cert: %s", err)
+                       }
+               }
        }
 
        c.address = addr
        return
 }
 
+// loadServerCert is used by AutoMTLS to read an x.509 cert returned by the
+// server, and load it as the RootCA for the client TLSConfig.
+func (c *Client) loadServerCert(cert string) error {
+       certPool := x509.NewCertPool()
+
+       asn1, err := base64.RawStdEncoding.DecodeString(cert)
+       if err != nil {
+               return err
+       }
+
+       x509Cert, err := x509.ParseCertificate([]byte(asn1))
+       if err != nil {
+               return err
+       }
+
+       certPool.AddCert(x509Cert)
+
+       c.config.TLSConfig.RootCAs = certPool
+       return nil
+}
+
+func (c *Client) reattach() (net.Addr, error) {
+       // Verify the process still exists. If not, then it is an error
+       p, err := os.FindProcess(c.config.Reattach.Pid)
+       if err != nil {
+               return nil, err
+       }
+
+       // Attempt to connect to the addr since on Unix systems FindProcess
+       // doesn't actually return an error if it can't find the process.
+       conn, err := net.Dial(
+               c.config.Reattach.Addr.Network(),
+               c.config.Reattach.Addr.String())
+       if err != nil {
+               p.Kill()
+               return nil, ErrProcessNotFound
+       }
+       conn.Close()
+
+       // Create a context for when we kill
+       c.doneCtx, c.ctxCancel = context.WithCancel(context.Background())
+
+       c.clientWaitGroup.Add(1)
+       // Goroutine to mark exit status
+       go func(pid int) {
+               defer c.clientWaitGroup.Done()
+
+               // ensure the context is cancelled when we're done
+               defer c.ctxCancel()
+
+               // Wait for the process to die
+               pidWait(pid)
+
+               // Log so we can see it
+               c.logger.Debug("reattached plugin process exited")
+
+               // Mark it
+               c.l.Lock()
+               defer c.l.Unlock()
+               c.exited = true
+       }(p.Pid)
+
+       // Set the address and process
+       c.address = c.config.Reattach.Addr
+       c.process = p
+       c.protocol = c.config.Reattach.Protocol
+       if c.protocol == "" {
+               // Default the protocol to net/rpc for backwards compatibility
+               c.protocol = ProtocolNetRPC
+       }
+
+       return c.address, nil
+}
+
+// checkProtoVersion returns the negotiated version and PluginSet.
+// This returns an error if the server returned an incompatible protocol
+// version, or an invalid handshake response.
+func (c *Client) checkProtoVersion(protoVersion string) (int, PluginSet, error) {
+       serverVersion, err := strconv.Atoi(protoVersion)
+       if err != nil {
+               return 0, nil, fmt.Errorf("Error parsing protocol version %q: %s", protoVersion, err)
+       }
+
+       // record these for the error message
+       var clientVersions []int
+
+       // all versions, including the legacy ProtocolVersion have been added to
+       // the versions set
+       for version, plugins := range c.config.VersionedPlugins {
+               clientVersions = append(clientVersions, version)
+
+               if serverVersion != version {
+                       continue
+               }
+               return version, plugins, nil
+       }
+
+       return 0, nil, fmt.Errorf("Incompatible API version with plugin. "+
+               "Plugin version: %d, Client versions: %d", serverVersion, clientVersions)
+}
+
 // ReattachConfig returns the information that must be provided to NewClient
 // to reattach to the plugin process that this client started. This is
 // useful for plugins that detach from their parent process.
@@ -751,44 +932,84 @@ func (c *Client) dialer(_ string, timeout time.Duration) (net.Conn, error) {
        return conn, nil
 }
 
+var stdErrBufferSize = 64 * 1024
+
 func (c *Client) logStderr(r io.Reader) {
-       bufR := bufio.NewReader(r)
+       defer c.clientWaitGroup.Done()
+       l := c.logger.Named(filepath.Base(c.config.Cmd.Path))
+
+       reader := bufio.NewReaderSize(r, stdErrBufferSize)
+       // continuation indicates the previous line was a prefix
+       continuation := false
+
        for {
-               line, err := bufR.ReadString('\n')
-               if line != "" {
-                       c.config.Stderr.Write([]byte(line))
-                       line = strings.TrimRightFunc(line, unicode.IsSpace)
+               line, isPrefix, err := reader.ReadLine()
+               switch {
+               case err == io.EOF:
+                       return
+               case err != nil:
+                       l.Error("reading plugin stderr", "error", err)
+                       return
+               }
 
-                       l := c.logger.Named(filepath.Base(c.config.Cmd.Path))
+               c.config.Stderr.Write(line)
 
-                       entry, err := parseJSON(line)
-                       // If output is not JSON format, print directly to Debug
-                       if err != nil {
-                               l.Debug(line)
-                       } else {
-                               out := flattenKVPairs(entry.KVPairs)
-
-                               l = l.With("timestamp", entry.Timestamp.Format(hclog.TimeFormat))
-                               switch hclog.LevelFromString(entry.Level) {
-                               case hclog.Trace:
-                                       l.Trace(entry.Message, out...)
-                               case hclog.Debug:
-                                       l.Debug(entry.Message, out...)
-                               case hclog.Info:
-                                       l.Info(entry.Message, out...)
-                               case hclog.Warn:
-                                       l.Warn(entry.Message, out...)
-                               case hclog.Error:
-                                       l.Error(entry.Message, out...)
-                               }
+               // The line was longer than our max token size, so it's likely
+               // incomplete and won't unmarshal.
+               if isPrefix || continuation {
+                       l.Debug(string(line))
+
+                       // if we're finishing a continued line, add the newline back in
+                       if !isPrefix {
+                               c.config.Stderr.Write([]byte{'\n'})
                        }
+
+                       continuation = isPrefix
+                       continue
                }
 
-               if err == io.EOF {
-                       break
+               c.config.Stderr.Write([]byte{'\n'})
+
+               entry, err := parseJSON(line)
+               // If output is not JSON format, print directly to Debug
+               if err != nil {
+                       // Attempt to infer the desired log level from the commonly used
+                       // string prefixes
+                       switch line := string(line); {
+                       case strings.HasPrefix(line, "[TRACE]"):
+                               l.Trace(line)
+                       case strings.HasPrefix(line, "[DEBUG]"):
+                               l.Debug(line)
+                       case strings.HasPrefix(line, "[INFO]"):
+                               l.Info(line)
+                       case strings.HasPrefix(line, "[WARN]"):
+                               l.Warn(line)
+                       case strings.HasPrefix(line, "[ERROR]"):
+                               l.Error(line)
+                       default:
+                               l.Debug(line)
+                       }
+               } else {
+                       out := flattenKVPairs(entry.KVPairs)
+
+                       out = append(out, "timestamp", entry.Timestamp.Format(hclog.TimeFormat))
+                       switch hclog.LevelFromString(entry.Level) {
+                       case hclog.Trace:
+                               l.Trace(entry.Message, out...)
+                       case hclog.Debug:
+                               l.Debug(entry.Message, out...)
+                       case hclog.Info:
+                               l.Info(entry.Message, out...)
+                       case hclog.Warn:
+                               l.Warn(entry.Message, out...)
+                       case hclog.Error:
+                               l.Error(entry.Message, out...)
+                       default:
+                               // if there was no log level, it's likely this is unexpected
+                               // json from something other than hclog, and we should output
+                               // it verbatim.
+                               l.Debug(string(line))
+                       }
                }
        }
-
-       // Flag that we've completed logging for others
-       close(c.doneLogging)
 }