From 107c1cdb09c575aa2f61d97f48d8587eb6bada4c Mon Sep 17 00:00:00 2001 From: Nathan Dench Date: Fri, 24 May 2019 15:16:44 +1000 Subject: Upgrade to 0.12 --- vendor/github.com/hashicorp/go-plugin/server.go | 165 ++++++++++++++++++++---- 1 file changed, 140 insertions(+), 25 deletions(-) (limited to 'vendor/github.com/hashicorp/go-plugin/server.go') diff --git a/vendor/github.com/hashicorp/go-plugin/server.go b/vendor/github.com/hashicorp/go-plugin/server.go index 1e808b9..fc9f05a 100644 --- a/vendor/github.com/hashicorp/go-plugin/server.go +++ b/vendor/github.com/hashicorp/go-plugin/server.go @@ -2,6 +2,7 @@ package plugin import ( "crypto/tls" + "crypto/x509" "encoding/base64" "errors" "fmt" @@ -11,7 +12,9 @@ import ( "os" "os/signal" "runtime" + "sort" "strconv" + "strings" "sync/atomic" "github.com/hashicorp/go-hclog" @@ -36,6 +39,8 @@ type HandshakeConfig struct { // ProtocolVersion is the version that clients must match on to // agree they can communicate. This should match the ProtocolVersion // set on ClientConfig when using a plugin. + // This field is not required if VersionedPlugins are being used in the + // Client or Server configurations. ProtocolVersion uint // MagicCookieKey and value are used as a very basic verification @@ -46,6 +51,10 @@ type HandshakeConfig struct { MagicCookieValue string } +// PluginSet is a set of plugins provided to be registered in the plugin +// server. +type PluginSet map[string]Plugin + // ServeConfig configures what sorts of plugins are served. type ServeConfig struct { // HandshakeConfig is the configuration that must match clients. @@ -55,7 +64,13 @@ type ServeConfig struct { TLSProvider func() (*tls.Config, error) // Plugins are the plugins that are served. - Plugins map[string]Plugin + // The implied version of this PluginSet is the Handshake.ProtocolVersion. + Plugins PluginSet + + // VersionedPlugins is a map of PluginSets for specific protocol versions. + // These can be used to negotiate a compatible version between client and + // server. If this is set, Handshake.ProtocolVersion is not required. + VersionedPlugins map[int]PluginSet // GRPCServer should be non-nil to enable serving the plugins over // gRPC. This is a function to create the server when needed with the @@ -72,14 +87,83 @@ type ServeConfig struct { Logger hclog.Logger } -// Protocol returns the protocol that this server should speak. -func (c *ServeConfig) Protocol() Protocol { - result := ProtocolNetRPC - if c.GRPCServer != nil { - result = ProtocolGRPC +// protocolVersion determines the protocol version and plugin set to be used by +// the server. In the event that there is no suitable version, the last version +// in the config is returned leaving the client to report the incompatibility. +func protocolVersion(opts *ServeConfig) (int, Protocol, PluginSet) { + protoVersion := int(opts.ProtocolVersion) + pluginSet := opts.Plugins + protoType := ProtocolNetRPC + // Check if the client sent a list of acceptable versions + var clientVersions []int + if vs := os.Getenv("PLUGIN_PROTOCOL_VERSIONS"); vs != "" { + for _, s := range strings.Split(vs, ",") { + v, err := strconv.Atoi(s) + if err != nil { + fmt.Fprintf(os.Stderr, "server sent invalid plugin version %q", s) + continue + } + clientVersions = append(clientVersions, v) + } + } + + // We want to iterate in reverse order, to ensure we match the newest + // compatible plugin version. + sort.Sort(sort.Reverse(sort.IntSlice(clientVersions))) + + // set the old un-versioned fields as if they were versioned plugins + if opts.VersionedPlugins == nil { + opts.VersionedPlugins = make(map[int]PluginSet) + } + + if pluginSet != nil { + opts.VersionedPlugins[protoVersion] = pluginSet } - return result + // Sort the version to make sure we match the latest first + var versions []int + for v := range opts.VersionedPlugins { + versions = append(versions, v) + } + + sort.Sort(sort.Reverse(sort.IntSlice(versions))) + + // See if we have multiple versions of Plugins to choose from + for _, version := range versions { + // Record each version, since we guarantee that this returns valid + // values even if they are not a protocol match. + protoVersion = version + pluginSet = opts.VersionedPlugins[version] + + // If we have a configured gRPC server we should select a protocol + if opts.GRPCServer != nil { + // All plugins in a set must use the same transport, so check the first + // for the protocol type + for _, p := range pluginSet { + switch p.(type) { + case GRPCPlugin: + protoType = ProtocolGRPC + default: + protoType = ProtocolNetRPC + } + break + } + } + + for _, clientVersion := range clientVersions { + if clientVersion == protoVersion { + return protoVersion, protoType, pluginSet + } + } + } + + // Return the lowest version as the fallback. + // Since we iterated over all the versions in reverse order above, these + // values are from the lowest version number plugins (which may be from + // a combination of the Handshake.ProtocolVersion and ServeConfig.Plugins + // fields). This allows serving the oldest version of our plugins to a + // legacy client that did not send a PLUGIN_PROTOCOL_VERSIONS list. + return protoVersion, protoType, pluginSet } // Serve serves the plugins given by ServeConfig. @@ -107,6 +191,10 @@ func Serve(opts *ServeConfig) { os.Exit(1) } + // negotiate the version and plugins + // start with default version in the handshake config + protoVersion, protoType, pluginSet := protocolVersion(opts) + // Logging goes to the original stderr log.SetOutput(os.Stderr) @@ -155,12 +243,47 @@ func Serve(opts *ServeConfig) { } } + var serverCert string + clientCert := os.Getenv("PLUGIN_CLIENT_CERT") + // If the client is configured using AutoMTLS, the certificate will be here, + // and we need to generate our own in response. + if tlsConfig == nil && clientCert != "" { + logger.Info("configuring server automatic mTLS") + clientCertPool := x509.NewCertPool() + if !clientCertPool.AppendCertsFromPEM([]byte(clientCert)) { + logger.Error("client cert provided but failed to parse", "cert", clientCert) + } + + certPEM, keyPEM, err := generateCert() + if err != nil { + logger.Error("failed to generate client certificate", "error", err) + panic(err) + } + + cert, err := tls.X509KeyPair(certPEM, keyPEM) + if err != nil { + logger.Error("failed to parse client certificate", "error", err) + panic(err) + } + + tlsConfig = &tls.Config{ + Certificates: []tls.Certificate{cert}, + ClientAuth: tls.RequireAndVerifyClientCert, + ClientCAs: clientCertPool, + MinVersion: tls.VersionTLS12, + } + + // We send back the raw leaf cert data for the client rather than the + // PEM, since the protocol can't handle newlines. + serverCert = base64.RawStdEncoding.EncodeToString(cert.Certificate[0]) + } + // Create the channel to tell us when we're done doneCh := make(chan struct{}) // Build the server type var server ServerProtocol - switch opts.Protocol() { + switch protoType { case ProtocolNetRPC: // If we have a TLS configuration then we wrap the listener // ourselves and do it at that level. @@ -170,7 +293,7 @@ func Serve(opts *ServeConfig) { // Create the RPC server to dispense server = &RPCServer{ - Plugins: opts.Plugins, + Plugins: pluginSet, Stdout: stdout_r, Stderr: stderr_r, DoneCh: doneCh, @@ -179,16 +302,17 @@ func Serve(opts *ServeConfig) { case ProtocolGRPC: // Create the gRPC server server = &GRPCServer{ - Plugins: opts.Plugins, + Plugins: pluginSet, Server: opts.GRPCServer, TLS: tlsConfig, Stdout: stdout_r, Stderr: stderr_r, DoneCh: doneCh, + logger: logger, } default: - panic("unknown server protocol: " + opts.Protocol()) + panic("unknown server protocol: " + protoType) } // Initialize the servers @@ -197,25 +321,16 @@ func Serve(opts *ServeConfig) { return } - // Build the extra configuration - extra := "" - if v := server.Config(); v != "" { - extra = base64.StdEncoding.EncodeToString([]byte(v)) - } - if extra != "" { - extra = "|" + extra - } - logger.Debug("plugin address", "network", listener.Addr().Network(), "address", listener.Addr().String()) - // Output the address and service name to stdout so that core can bring it up. - fmt.Printf("%d|%d|%s|%s|%s%s\n", + // Output the address and service name to stdout so that the client can bring it up. + fmt.Printf("%d|%d|%s|%s|%s|%s\n", CoreProtocolVersion, - opts.ProtocolVersion, + protoVersion, listener.Addr().Network(), listener.Addr().String(), - opts.Protocol(), - extra) + protoType, + serverCert) os.Stdout.Sync() // Eat the interrupts -- cgit v1.2.3