]> git.immae.eu Git - github/fretlink/terraform-provider-statuscake.git/blobdiff - vendor/github.com/hashicorp/go-plugin/server.go
update vendor and go.mod
[github/fretlink/terraform-provider-statuscake.git] / vendor / github.com / hashicorp / go-plugin / server.go
index b5c5270a7d83759ef0cf335c9cc9845b48da2ba3..4c230e3ab4cf1374d51dd1da00db3f9b4285a035 100644 (file)
@@ -1,6 +1,9 @@
 package plugin
 
 import (
+       "crypto/tls"
+       "crypto/x509"
+       "encoding/base64"
        "errors"
        "fmt"
        "io/ioutil"
@@ -9,8 +12,14 @@ import (
        "os"
        "os/signal"
        "runtime"
+       "sort"
        "strconv"
+       "strings"
        "sync/atomic"
+
+       "github.com/hashicorp/go-hclog"
+
+       "google.golang.org/grpc"
 )
 
 // CoreProtocolVersion is the ProtocolVersion of the plugin system itself.
@@ -30,6 +39,8 @@ type HandshakeConfig struct {
        // ProtocolVersion is the version that clients must match on to
        // agree they can communicate. This should match the ProtocolVersion
        // set on ClientConfig when using a plugin.
+       // This field is not required if VersionedPlugins are being used in the
+       // Client or Server configurations.
        ProtocolVersion uint
 
        // MagicCookieKey and value are used as a very basic verification
@@ -40,19 +51,125 @@ type HandshakeConfig struct {
        MagicCookieValue string
 }
 
+// PluginSet is a set of plugins provided to be registered in the plugin
+// server.
+type PluginSet map[string]Plugin
+
 // ServeConfig configures what sorts of plugins are served.
 type ServeConfig struct {
        // HandshakeConfig is the configuration that must match clients.
        HandshakeConfig
 
+       // TLSProvider is a function that returns a configured tls.Config.
+       TLSProvider func() (*tls.Config, error)
+
        // Plugins are the plugins that are served.
-       Plugins map[string]Plugin
+       // The implied version of this PluginSet is the Handshake.ProtocolVersion.
+       Plugins PluginSet
+
+       // VersionedPlugins is a map of PluginSets for specific protocol versions.
+       // These can be used to negotiate a compatible version between client and
+       // server. If this is set, Handshake.ProtocolVersion is not required.
+       VersionedPlugins map[int]PluginSet
+
+       // GRPCServer should be non-nil to enable serving the plugins over
+       // gRPC. This is a function to create the server when needed with the
+       // given server options. The server options populated by go-plugin will
+       // be for TLS if set. You may modify the input slice.
+       //
+       // Note that the grpc.Server will automatically be registered with
+       // the gRPC health checking service. This is not optional since go-plugin
+       // relies on this to implement Ping().
+       GRPCServer func([]grpc.ServerOption) *grpc.Server
+
+       // Logger is used to pass a logger into the server. If none is provided the
+       // server will create a default logger.
+       Logger hclog.Logger
+}
+
+// protocolVersion determines the protocol version and plugin set to be used by
+// the server. In the event that there is no suitable version, the last version
+// in the config is returned leaving the client to report the incompatibility.
+func protocolVersion(opts *ServeConfig) (int, Protocol, PluginSet) {
+       protoVersion := int(opts.ProtocolVersion)
+       pluginSet := opts.Plugins
+       protoType := ProtocolNetRPC
+       // Check if the client sent a list of acceptable versions
+       var clientVersions []int
+       if vs := os.Getenv("PLUGIN_PROTOCOL_VERSIONS"); vs != "" {
+               for _, s := range strings.Split(vs, ",") {
+                       v, err := strconv.Atoi(s)
+                       if err != nil {
+                               fmt.Fprintf(os.Stderr, "server sent invalid plugin version %q", s)
+                               continue
+                       }
+                       clientVersions = append(clientVersions, v)
+               }
+       }
+
+       // We want to iterate in reverse order, to ensure we match the newest
+       // compatible plugin version.
+       sort.Sort(sort.Reverse(sort.IntSlice(clientVersions)))
+
+       // set the old un-versioned fields as if they were versioned plugins
+       if opts.VersionedPlugins == nil {
+               opts.VersionedPlugins = make(map[int]PluginSet)
+       }
+
+       if pluginSet != nil {
+               opts.VersionedPlugins[protoVersion] = pluginSet
+       }
+
+       // Sort the version to make sure we match the latest first
+       var versions []int
+       for v := range opts.VersionedPlugins {
+               versions = append(versions, v)
+       }
+
+       sort.Sort(sort.Reverse(sort.IntSlice(versions)))
+
+       // See if we have multiple versions of Plugins to choose from
+       for _, version := range versions {
+               // Record each version, since we guarantee that this returns valid
+               // values even if they are not a protocol match.
+               protoVersion = version
+               pluginSet = opts.VersionedPlugins[version]
+
+               // If we have a configured gRPC server we should select a protocol
+               if opts.GRPCServer != nil {
+                       // All plugins in a set must use the same transport, so check the first
+                       // for the protocol type
+                       for _, p := range pluginSet {
+                               switch p.(type) {
+                               case GRPCPlugin:
+                                       protoType = ProtocolGRPC
+                               default:
+                                       protoType = ProtocolNetRPC
+                               }
+                               break
+                       }
+               }
+
+               for _, clientVersion := range clientVersions {
+                       if clientVersion == protoVersion {
+                               return protoVersion, protoType, pluginSet
+                       }
+               }
+       }
+
+       // Return the lowest version as the fallback.
+       // Since we iterated over all the versions in reverse order above, these
+       // values are from the lowest version number plugins (which may be from
+       // a combination of the Handshake.ProtocolVersion and ServeConfig.Plugins
+       // fields). This allows serving the oldest version of our plugins to a
+       // legacy client that did not send a PLUGIN_PROTOCOL_VERSIONS list.
+       return protoVersion, protoType, pluginSet
 }
 
 // Serve serves the plugins given by ServeConfig.
 //
 // Serve doesn't return until the plugin is done being executed. Any
-// errors will be outputted to the log.
+// errors will be outputted to os.Stderr.
 //
 // This is the method that plugins should call in their main() functions.
 func Serve(opts *ServeConfig) {
@@ -74,9 +191,23 @@ func Serve(opts *ServeConfig) {
                os.Exit(1)
        }
 
+       // negotiate the version and plugins
+       // start with default version in the handshake config
+       protoVersion, protoType, pluginSet := protocolVersion(opts)
+
        // Logging goes to the original stderr
        log.SetOutput(os.Stderr)
 
+       logger := opts.Logger
+       if logger == nil {
+               // internal logger to os.Stderr
+               logger = hclog.New(&hclog.LoggerOptions{
+                       Level:      hclog.Trace,
+                       Output:     os.Stderr,
+                       JSONFormat: true,
+               })
+       }
+
        // Create our new stdout, stderr files. These will override our built-in
        // stdout/stderr so that it works across the stream boundary.
        stdout_r, stdout_w, err := os.Pipe()
@@ -93,30 +224,113 @@ func Serve(opts *ServeConfig) {
        // Register a listener so we can accept a connection
        listener, err := serverListener()
        if err != nil {
-               log.Printf("[ERR] plugin: plugin init: %s", err)
+               logger.Error("plugin init error", "error", err)
                return
        }
-       defer listener.Close()
+
+       // Close the listener on return. We wrap this in a func() on purpose
+       // because the "listener" reference may change to TLS.
+       defer func() {
+               listener.Close()
+       }()
+
+       var tlsConfig *tls.Config
+       if opts.TLSProvider != nil {
+               tlsConfig, err = opts.TLSProvider()
+               if err != nil {
+                       logger.Error("plugin tls init", "error", err)
+                       return
+               }
+       }
+
+       var serverCert string
+       clientCert := os.Getenv("PLUGIN_CLIENT_CERT")
+       // If the client is configured using AutoMTLS, the certificate will be here,
+       // and we need to generate our own in response.
+       if tlsConfig == nil && clientCert != "" {
+               logger.Info("configuring server automatic mTLS")
+               clientCertPool := x509.NewCertPool()
+               if !clientCertPool.AppendCertsFromPEM([]byte(clientCert)) {
+                       logger.Error("client cert provided but failed to parse", "cert", clientCert)
+               }
+
+               certPEM, keyPEM, err := generateCert()
+               if err != nil {
+                       logger.Error("failed to generate client certificate", "error", err)
+                       panic(err)
+               }
+
+               cert, err := tls.X509KeyPair(certPEM, keyPEM)
+               if err != nil {
+                       logger.Error("failed to parse client certificate", "error", err)
+                       panic(err)
+               }
+
+               tlsConfig = &tls.Config{
+                       Certificates: []tls.Certificate{cert},
+                       ClientAuth:   tls.RequireAndVerifyClientCert,
+                       ClientCAs:    clientCertPool,
+                       MinVersion:   tls.VersionTLS12,
+               }
+
+               // We send back the raw leaf cert data for the client rather than the
+               // PEM, since the protocol can't handle newlines.
+               serverCert = base64.RawStdEncoding.EncodeToString(cert.Certificate[0])
+       }
 
        // Create the channel to tell us when we're done
        doneCh := make(chan struct{})
 
-       // Create the RPC server to dispense
-       server := &RPCServer{
-               Plugins: opts.Plugins,
-               Stdout:  stdout_r,
-               Stderr:  stderr_r,
-               DoneCh:  doneCh,
+       // Build the server type
+       var server ServerProtocol
+       switch protoType {
+       case ProtocolNetRPC:
+               // If we have a TLS configuration then we wrap the listener
+               // ourselves and do it at that level.
+               if tlsConfig != nil {
+                       listener = tls.NewListener(listener, tlsConfig)
+               }
+
+               // Create the RPC server to dispense
+               server = &RPCServer{
+                       Plugins: pluginSet,
+                       Stdout:  stdout_r,
+                       Stderr:  stderr_r,
+                       DoneCh:  doneCh,
+               }
+
+       case ProtocolGRPC:
+               // Create the gRPC server
+               server = &GRPCServer{
+                       Plugins: pluginSet,
+                       Server:  opts.GRPCServer,
+                       TLS:     tlsConfig,
+                       Stdout:  stdout_r,
+                       Stderr:  stderr_r,
+                       DoneCh:  doneCh,
+                       logger:  logger,
+               }
+
+       default:
+               panic("unknown server protocol: " + protoType)
+       }
+
+       // Initialize the servers
+       if err := server.Init(); err != nil {
+               logger.Error("protocol init", "error", err)
+               return
        }
 
-       // Output the address and service name to stdout so that core can bring it up.
-       log.Printf("[DEBUG] plugin: plugin address: %s %s\n",
-               listener.Addr().Network(), listener.Addr().String())
-       fmt.Printf("%d|%d|%s|%s\n",
+       logger.Debug("plugin address", "network", listener.Addr().Network(), "address", listener.Addr().String())
+
+       // Output the address and service name to stdout so that the client can bring it up.
+       fmt.Printf("%d|%d|%s|%s|%s|%s\n",
                CoreProtocolVersion,
-               opts.ProtocolVersion,
+               protoVersion,
                listener.Addr().Network(),
-               listener.Addr().String())
+               listener.Addr().String(),
+               protoType,
+               serverCert)
        os.Stdout.Sync()
 
        // Eat the interrupts
@@ -127,9 +341,7 @@ func Serve(opts *ServeConfig) {
                for {
                        <-ch
                        newCount := atomic.AddInt32(&count, 1)
-                       log.Printf(
-                               "[DEBUG] plugin: received interrupt signal (count: %d). Ignoring.",
-                               newCount)
+                       logger.Debug("plugin received interrupt signal, ignoring", "count", newCount)
                }
        }()
 
@@ -137,10 +349,8 @@ func Serve(opts *ServeConfig) {
        os.Stdout = stdout_w
        os.Stderr = stderr_w
 
-       // Serve
-       go server.Accept(listener)
-
-       // Wait for the graceful exit
+       // Accept connections and wait for completion
+       go server.Serve(listener)
        <-doneCh
 }
 
@@ -153,14 +363,34 @@ func serverListener() (net.Listener, error) {
 }
 
 func serverListener_tcp() (net.Listener, error) {
-       minPort, err := strconv.ParseInt(os.Getenv("PLUGIN_MIN_PORT"), 10, 32)
-       if err != nil {
-               return nil, err
+       envMinPort := os.Getenv("PLUGIN_MIN_PORT")
+       envMaxPort := os.Getenv("PLUGIN_MAX_PORT")
+
+       var minPort, maxPort int64
+       var err error
+
+       switch {
+       case len(envMinPort) == 0:
+               minPort = 0
+       default:
+               minPort, err = strconv.ParseInt(envMinPort, 10, 32)
+               if err != nil {
+                       return nil, fmt.Errorf("Couldn't get value from PLUGIN_MIN_PORT: %v", err)
+               }
        }
 
-       maxPort, err := strconv.ParseInt(os.Getenv("PLUGIN_MAX_PORT"), 10, 32)
-       if err != nil {
-               return nil, err
+       switch {
+       case len(envMaxPort) == 0:
+               maxPort = 0
+       default:
+               maxPort, err = strconv.ParseInt(envMaxPort, 10, 32)
+               if err != nil {
+                       return nil, fmt.Errorf("Couldn't get value from PLUGIN_MAX_PORT: %v", err)
+               }
+       }
+
+       if minPort > maxPort {
+               return nil, fmt.Errorf("ENV_MIN_PORT value of %d is greater than PLUGIN_MAX_PORT value of %d", minPort, maxPort)
        }
 
        for port := minPort; port <= maxPort; port++ {