"context"
"crypto/subtle"
"crypto/tls"
+ "crypto/x509"
+ "encoding/base64"
"errors"
"fmt"
"hash"
"io"
"io/ioutil"
- "log"
"net"
"os"
"os/exec"
"sync"
"sync/atomic"
"time"
- "unicode"
hclog "github.com/hashicorp/go-hclog"
)
//
// See NewClient and ClientConfig for using a Client.
type Client struct {
- config *ClientConfig
- exited bool
- doneLogging chan struct{}
- l sync.Mutex
- address net.Addr
- process *os.Process
- client ClientProtocol
- protocol Protocol
- logger hclog.Logger
- doneCtx context.Context
+ config *ClientConfig
+ exited bool
+ l sync.Mutex
+ address net.Addr
+ process *os.Process
+ client ClientProtocol
+ protocol Protocol
+ logger hclog.Logger
+ doneCtx context.Context
+ ctxCancel context.CancelFunc
+ negotiatedVersion int
+
+ // clientWaitGroup is used to manage the lifecycle of the plugin management
+ // goroutines.
+ clientWaitGroup sync.WaitGroup
+
+ // processKilled is used for testing only, to flag when the process was
+ // forcefully killed.
+ processKilled bool
+}
+
+// NegotiatedVersion returns the protocol version negotiated with the server.
+// This is only valid after Start() is called.
+func (c *Client) NegotiatedVersion() int {
+ return c.negotiatedVersion
}
// ClientConfig is the configuration used to initialize a new
HandshakeConfig
// Plugins are the plugins that can be consumed.
- Plugins map[string]Plugin
+ // The implied version of this PluginSet is the Handshake.ProtocolVersion.
+ Plugins PluginSet
+
+ // VersionedPlugins is a map of PluginSets for specific protocol versions.
+ // These can be used to negotiate a compatible version between client and
+ // server. If this is set, Handshake.ProtocolVersion is not required.
+ VersionedPlugins map[int]PluginSet
// One of the following must be set, but not both.
//
// Logger is the logger that the client will used. If none is provided,
// it will default to hclog's default logger.
Logger hclog.Logger
+
+ // AutoMTLS has the client and server automatically negotiate mTLS for
+ // transport authentication. This ensures that only the original client will
+ // be allowed to connect to the server, and all other connections will be
+ // rejected. The client will also refuse to connect to any server that isn't
+ // the original instance started by the client.
+ //
+ // In this mode of operation, the client generates a one-time use tls
+ // certificate, sends the public x.509 certificate to the new server, and
+ // the server generates a one-time use tls certificate, and sends the public
+ // x.509 certificate back to the client. These are used to authenticate all
+ // rpc connections between the client and server.
+ //
+ // Setting AutoMTLS to true implies that the server must support the
+ // protocol, and correctly negotiate the tls certificates, or a connection
+ // failure will result.
+ //
+ // The client should not set TLSConfig, nor should the server set a
+ // TLSProvider, because AutoMTLS implies that a new certificate and tls
+ // configuration will be generated at startup.
+ //
+ // You cannot Reattach to a server with this option enabled.
+ AutoMTLS bool
}
// ReattachConfig is used to configure a client to reattach to an
}
managedClientsLock.Unlock()
- log.Println("[DEBUG] plugin: waiting for all plugin processes to complete...")
wg.Wait()
}
return c.exited
}
+// killed is used in tests to check if a process failed to exit gracefully, and
+// needed to be killed.
+func (c *Client) killed() bool {
+ c.l.Lock()
+ defer c.l.Unlock()
+ return c.processKilled
+}
+
// End the executing subprocess (if it is running) and perform any cleanup
// tasks necessary such as capturing any remaining logs and so on.
//
c.l.Lock()
process := c.process
addr := c.address
- doneCh := c.doneLogging
c.l.Unlock()
- // If there is no process, we never started anything. Nothing to kill.
+ // If there is no process, there is nothing to kill.
if process == nil {
return
}
+ defer func() {
+ // Wait for the all client goroutines to finish.
+ c.clientWaitGroup.Wait()
+
+ // Make sure there is no reference to the old process after it has been
+ // killed.
+ c.l.Lock()
+ c.process = nil
+ c.l.Unlock()
+ }()
+
// We need to check for address here. It is possible that the plugin
// started (process != nil) but has no address (addr == nil) if the
// plugin failed at startup. If we do have an address, we need to close
// kill in a moment anyways.
c.logger.Warn("error closing client during Kill", "err", err)
}
+ } else {
+ c.logger.Error("client", "error", err)
}
}
// doneCh which would be closed if the process exits.
if graceful {
select {
- case <-doneCh:
+ case <-c.doneCtx.Done():
+ c.logger.Debug("plugin exited")
return
- case <-time.After(250 * time.Millisecond):
+ case <-time.After(2 * time.Second):
}
}
// If graceful exiting failed, just kill it
+ c.logger.Warn("plugin failed to exit gracefully")
process.Kill()
- // Wait for the client to finish logging so we have a complete log
- <-doneCh
+ c.l.Lock()
+ c.processKilled = true
+ c.l.Unlock()
}
// Starts the underlying subprocess, communicating with it to negotiate
// If one of cmd or reattach isn't set, then it is an error. We wrap
// this in a {} for scoping reasons, and hopeful that the escape
- // analysis will pop the stock here.
+ // analysis will pop the stack here.
{
cmdSet := c.config.Cmd != nil
attachSet := c.config.Reattach != nil
}
}
- // Create the logging channel for when we kill
- c.doneLogging = make(chan struct{})
- // Create a context for when we kill
- var ctxCancel context.CancelFunc
- c.doneCtx, ctxCancel = context.WithCancel(context.Background())
-
if c.config.Reattach != nil {
- // Verify the process still exists. If not, then it is an error
- p, err := os.FindProcess(c.config.Reattach.Pid)
- if err != nil {
- return nil, err
- }
+ return c.reattach()
+ }
- // Attempt to connect to the addr since on Unix systems FindProcess
- // doesn't actually return an error if it can't find the process.
- conn, err := net.Dial(
- c.config.Reattach.Addr.Network(),
- c.config.Reattach.Addr.String())
- if err != nil {
- p.Kill()
- return nil, ErrProcessNotFound
- }
- conn.Close()
-
- // Goroutine to mark exit status
- go func(pid int) {
- // Wait for the process to die
- pidWait(pid)
-
- // Log so we can see it
- c.logger.Debug("reattached plugin process exited")
-
- // Mark it
- c.l.Lock()
- defer c.l.Unlock()
- c.exited = true
-
- // Close the logging channel since that doesn't work on reattach
- close(c.doneLogging)
-
- // Cancel the context
- ctxCancel()
- }(p.Pid)
-
- // Set the address and process
- c.address = c.config.Reattach.Addr
- c.process = p
- c.protocol = c.config.Reattach.Protocol
- if c.protocol == "" {
- // Default the protocol to net/rpc for backwards compatibility
- c.protocol = ProtocolNetRPC
- }
+ if c.config.VersionedPlugins == nil {
+ c.config.VersionedPlugins = make(map[int]PluginSet)
+ }
- return c.address, nil
+ // handle all plugins as versioned, using the handshake config as the default.
+ version := int(c.config.ProtocolVersion)
+
+ // Make sure we're not overwriting a real version 0. If ProtocolVersion was
+ // non-zero, then we have to just assume the user made sure that
+ // VersionedPlugins doesn't conflict.
+ if _, ok := c.config.VersionedPlugins[version]; !ok && c.config.Plugins != nil {
+ c.config.VersionedPlugins[version] = c.config.Plugins
+ }
+
+ var versionStrings []string
+ for v := range c.config.VersionedPlugins {
+ versionStrings = append(versionStrings, strconv.Itoa(v))
}
env := []string{
fmt.Sprintf("%s=%s", c.config.MagicCookieKey, c.config.MagicCookieValue),
fmt.Sprintf("PLUGIN_MIN_PORT=%d", c.config.MinPort),
fmt.Sprintf("PLUGIN_MAX_PORT=%d", c.config.MaxPort),
+ fmt.Sprintf("PLUGIN_PROTOCOL_VERSIONS=%s", strings.Join(versionStrings, ",")),
}
- stdout_r, stdout_w := io.Pipe()
- stderr_r, stderr_w := io.Pipe()
-
cmd := c.config.Cmd
cmd.Env = append(cmd.Env, os.Environ()...)
cmd.Env = append(cmd.Env, env...)
cmd.Stdin = os.Stdin
- cmd.Stderr = stderr_w
- cmd.Stdout = stdout_w
+
+ cmdStdout, err := cmd.StdoutPipe()
+ if err != nil {
+ return nil, err
+ }
+ cmdStderr, err := cmd.StderrPipe()
+ if err != nil {
+ return nil, err
+ }
if c.config.SecureConfig != nil {
if ok, err := c.config.SecureConfig.Check(cmd.Path); err != nil {
}
}
+ // Setup a temporary certificate for client/server mtls, and send the public
+ // certificate to the plugin.
+ if c.config.AutoMTLS {
+ c.logger.Info("configuring client automatic mTLS")
+ certPEM, keyPEM, err := generateCert()
+ if err != nil {
+ c.logger.Error("failed to generate client certificate", "error", err)
+ return nil, err
+ }
+ cert, err := tls.X509KeyPair(certPEM, keyPEM)
+ if err != nil {
+ c.logger.Error("failed to parse client certificate", "error", err)
+ return nil, err
+ }
+
+ cmd.Env = append(cmd.Env, fmt.Sprintf("PLUGIN_CLIENT_CERT=%s", certPEM))
+
+ c.config.TLSConfig = &tls.Config{
+ Certificates: []tls.Certificate{cert},
+ ServerName: "localhost",
+ }
+ }
+
c.logger.Debug("starting plugin", "path", cmd.Path, "args", cmd.Args)
err = cmd.Start()
if err != nil {
// Set the process
c.process = cmd.Process
+ c.logger.Debug("plugin started", "path", cmd.Path, "pid", c.process.Pid)
// Make sure the command is properly cleaned up if there is an error
defer func() {
}
}()
- // Start goroutine to wait for process to exit
- exitCh := make(chan struct{})
+ // Create a context for when we kill
+ c.doneCtx, c.ctxCancel = context.WithCancel(context.Background())
+
+ c.clientWaitGroup.Add(1)
go func() {
- // Make sure we close the write end of our stderr/stdout so
- // that the readers send EOF properly.
- defer stderr_w.Close()
- defer stdout_w.Close()
+ // ensure the context is cancelled when we're done
+ defer c.ctxCancel()
+
+ defer c.clientWaitGroup.Done()
+
+ // get the cmd info early, since the process information will be removed
+ // in Kill.
+ pid := c.process.Pid
+ path := cmd.Path
// Wait for the command to end.
- cmd.Wait()
+ err := cmd.Wait()
+
+ debugMsgArgs := []interface{}{
+ "path", path,
+ "pid", pid,
+ }
+ if err != nil {
+ debugMsgArgs = append(debugMsgArgs,
+ []interface{}{"error", err.Error()}...)
+ }
// Log and make sure to flush the logs write away
- c.logger.Debug("plugin process exited", "path", cmd.Path)
+ c.logger.Debug("plugin process exited", debugMsgArgs...)
os.Stderr.Sync()
- // Mark that we exited
- close(exitCh)
-
- // Cancel the context, marking that we exited
- ctxCancel()
-
// Set that we exited, which takes a lock
c.l.Lock()
defer c.l.Unlock()
}()
// Start goroutine that logs the stderr
- go c.logStderr(stderr_r)
+ c.clientWaitGroup.Add(1)
+ // logStderr calls Done()
+ go c.logStderr(cmdStderr)
// Start a goroutine that is going to be reading the lines
// out of stdout
- linesCh := make(chan []byte)
+ linesCh := make(chan string)
+ c.clientWaitGroup.Add(1)
go func() {
+ defer c.clientWaitGroup.Done()
defer close(linesCh)
- buf := bufio.NewReader(stdout_r)
- for {
- line, err := buf.ReadBytes('\n')
- if line != nil {
- linesCh <- line
- }
-
- if err == io.EOF {
- return
- }
+ scanner := bufio.NewScanner(cmdStdout)
+ for scanner.Scan() {
+ linesCh <- scanner.Text()
}
}()
// Make sure after we exit we read the lines from stdout forever
- // so they don't block since it is an io.Pipe
+ // so they don't block since it is a pipe.
+ // The scanner goroutine above will close this, but track it with a wait
+ // group for completeness.
+ c.clientWaitGroup.Add(1)
defer func() {
go func() {
- for _ = range linesCh {
+ defer c.clientWaitGroup.Done()
+ for range linesCh {
}
}()
}()
select {
case <-timeout:
err = errors.New("timeout while waiting for plugin to start")
- case <-exitCh:
+ case <-c.doneCtx.Done():
err = errors.New("plugin exited before we could connect")
- case lineBytes := <-linesCh:
+ case line := <-linesCh:
// Trim the line and split by "|" in order to get the parts of
// the output.
- line := strings.TrimSpace(string(lineBytes))
+ line = strings.TrimSpace(line)
parts := strings.SplitN(line, "|", 6)
if len(parts) < 4 {
err = fmt.Errorf(
}
}
- // Parse the protocol version
- var protocol int64
- protocol, err = strconv.ParseInt(parts[1], 10, 0)
+ // Test the API version
+ version, pluginSet, err := c.checkProtoVersion(parts[1])
if err != nil {
- err = fmt.Errorf("Error parsing protocol version: %s", err)
- return
+ return addr, err
}
- // Test the API version
- if uint(protocol) != c.config.ProtocolVersion {
- err = fmt.Errorf("Incompatible API version with plugin. "+
- "Plugin version: %s, Core version: %d", parts[1], c.config.ProtocolVersion)
- return
- }
+ // set the Plugins value to the compatible set, so the version
+ // doesn't need to be passed through to the ClientProtocol
+ // implementation.
+ c.config.Plugins = pluginSet
+ c.negotiatedVersion = version
+ c.logger.Debug("using plugin", "version", version)
switch parts[2] {
case "tcp":
if !found {
err = fmt.Errorf("Unsupported plugin protocol %q. Supported: %v",
c.protocol, c.config.AllowedProtocols)
- return
+ return addr, err
}
+ // See if we have a TLS certificate from the server.
+ // Checking if the length is > 50 rules out catching the unused "extra"
+ // data returned from some older implementations.
+ if len(parts) >= 6 && len(parts[5]) > 50 {
+ err := c.loadServerCert(parts[5])
+ if err != nil {
+ return nil, fmt.Errorf("error parsing server cert: %s", err)
+ }
+ }
}
c.address = addr
return
}
+// loadServerCert is used by AutoMTLS to read an x.509 cert returned by the
+// server, and load it as the RootCA for the client TLSConfig.
+func (c *Client) loadServerCert(cert string) error {
+ certPool := x509.NewCertPool()
+
+ asn1, err := base64.RawStdEncoding.DecodeString(cert)
+ if err != nil {
+ return err
+ }
+
+ x509Cert, err := x509.ParseCertificate([]byte(asn1))
+ if err != nil {
+ return err
+ }
+
+ certPool.AddCert(x509Cert)
+
+ c.config.TLSConfig.RootCAs = certPool
+ return nil
+}
+
+func (c *Client) reattach() (net.Addr, error) {
+ // Verify the process still exists. If not, then it is an error
+ p, err := os.FindProcess(c.config.Reattach.Pid)
+ if err != nil {
+ return nil, err
+ }
+
+ // Attempt to connect to the addr since on Unix systems FindProcess
+ // doesn't actually return an error if it can't find the process.
+ conn, err := net.Dial(
+ c.config.Reattach.Addr.Network(),
+ c.config.Reattach.Addr.String())
+ if err != nil {
+ p.Kill()
+ return nil, ErrProcessNotFound
+ }
+ conn.Close()
+
+ // Create a context for when we kill
+ c.doneCtx, c.ctxCancel = context.WithCancel(context.Background())
+
+ c.clientWaitGroup.Add(1)
+ // Goroutine to mark exit status
+ go func(pid int) {
+ defer c.clientWaitGroup.Done()
+
+ // ensure the context is cancelled when we're done
+ defer c.ctxCancel()
+
+ // Wait for the process to die
+ pidWait(pid)
+
+ // Log so we can see it
+ c.logger.Debug("reattached plugin process exited")
+
+ // Mark it
+ c.l.Lock()
+ defer c.l.Unlock()
+ c.exited = true
+ }(p.Pid)
+
+ // Set the address and process
+ c.address = c.config.Reattach.Addr
+ c.process = p
+ c.protocol = c.config.Reattach.Protocol
+ if c.protocol == "" {
+ // Default the protocol to net/rpc for backwards compatibility
+ c.protocol = ProtocolNetRPC
+ }
+
+ return c.address, nil
+}
+
+// checkProtoVersion returns the negotiated version and PluginSet.
+// This returns an error if the server returned an incompatible protocol
+// version, or an invalid handshake response.
+func (c *Client) checkProtoVersion(protoVersion string) (int, PluginSet, error) {
+ serverVersion, err := strconv.Atoi(protoVersion)
+ if err != nil {
+ return 0, nil, fmt.Errorf("Error parsing protocol version %q: %s", protoVersion, err)
+ }
+
+ // record these for the error message
+ var clientVersions []int
+
+ // all versions, including the legacy ProtocolVersion have been added to
+ // the versions set
+ for version, plugins := range c.config.VersionedPlugins {
+ clientVersions = append(clientVersions, version)
+
+ if serverVersion != version {
+ continue
+ }
+ return version, plugins, nil
+ }
+
+ return 0, nil, fmt.Errorf("Incompatible API version with plugin. "+
+ "Plugin version: %d, Client versions: %d", serverVersion, clientVersions)
+}
+
// ReattachConfig returns the information that must be provided to NewClient
// to reattach to the plugin process that this client started. This is
// useful for plugins that detach from their parent process.
return conn, nil
}
+var stdErrBufferSize = 64 * 1024
+
func (c *Client) logStderr(r io.Reader) {
- bufR := bufio.NewReader(r)
+ defer c.clientWaitGroup.Done()
+ l := c.logger.Named(filepath.Base(c.config.Cmd.Path))
+
+ reader := bufio.NewReaderSize(r, stdErrBufferSize)
+ // continuation indicates the previous line was a prefix
+ continuation := false
+
for {
- line, err := bufR.ReadString('\n')
- if line != "" {
- c.config.Stderr.Write([]byte(line))
- line = strings.TrimRightFunc(line, unicode.IsSpace)
+ line, isPrefix, err := reader.ReadLine()
+ switch {
+ case err == io.EOF:
+ return
+ case err != nil:
+ l.Error("reading plugin stderr", "error", err)
+ return
+ }
- l := c.logger.Named(filepath.Base(c.config.Cmd.Path))
+ c.config.Stderr.Write(line)
- entry, err := parseJSON(line)
- // If output is not JSON format, print directly to Debug
- if err != nil {
- l.Debug(line)
- } else {
- out := flattenKVPairs(entry.KVPairs)
-
- l = l.With("timestamp", entry.Timestamp.Format(hclog.TimeFormat))
- switch hclog.LevelFromString(entry.Level) {
- case hclog.Trace:
- l.Trace(entry.Message, out...)
- case hclog.Debug:
- l.Debug(entry.Message, out...)
- case hclog.Info:
- l.Info(entry.Message, out...)
- case hclog.Warn:
- l.Warn(entry.Message, out...)
- case hclog.Error:
- l.Error(entry.Message, out...)
- }
+ // The line was longer than our max token size, so it's likely
+ // incomplete and won't unmarshal.
+ if isPrefix || continuation {
+ l.Debug(string(line))
+
+ // if we're finishing a continued line, add the newline back in
+ if !isPrefix {
+ c.config.Stderr.Write([]byte{'\n'})
}
+
+ continuation = isPrefix
+ continue
}
- if err == io.EOF {
- break
+ c.config.Stderr.Write([]byte{'\n'})
+
+ entry, err := parseJSON(line)
+ // If output is not JSON format, print directly to Debug
+ if err != nil {
+ // Attempt to infer the desired log level from the commonly used
+ // string prefixes
+ switch line := string(line); {
+ case strings.HasPrefix(line, "[TRACE]"):
+ l.Trace(line)
+ case strings.HasPrefix(line, "[DEBUG]"):
+ l.Debug(line)
+ case strings.HasPrefix(line, "[INFO]"):
+ l.Info(line)
+ case strings.HasPrefix(line, "[WARN]"):
+ l.Warn(line)
+ case strings.HasPrefix(line, "[ERROR]"):
+ l.Error(line)
+ default:
+ l.Debug(line)
+ }
+ } else {
+ out := flattenKVPairs(entry.KVPairs)
+
+ out = append(out, "timestamp", entry.Timestamp.Format(hclog.TimeFormat))
+ switch hclog.LevelFromString(entry.Level) {
+ case hclog.Trace:
+ l.Trace(entry.Message, out...)
+ case hclog.Debug:
+ l.Debug(entry.Message, out...)
+ case hclog.Info:
+ l.Info(entry.Message, out...)
+ case hclog.Warn:
+ l.Warn(entry.Message, out...)
+ case hclog.Error:
+ l.Error(entry.Message, out...)
+ default:
+ // if there was no log level, it's likely this is unexpected
+ // json from something other than hclog, and we should output
+ // it verbatim.
+ l.Debug(string(line))
+ }
}
}
-
- // Flag that we've completed logging for others
- close(c.doneLogging)
}