]> git.immae.eu Git - github/fretlink/terraform-provider-statuscake.git/blobdiff - vendor/google.golang.org/grpc/internal/transport/handler_server.go
Upgrade to 0.12
[github/fretlink/terraform-provider-statuscake.git] / vendor / google.golang.org / grpc / internal / transport / handler_server.go
similarity index 75%
rename from vendor/google.golang.org/grpc/transport/handler_server.go
rename to vendor/google.golang.org/grpc/internal/transport/handler_server.go
index 27372b50894c577516a914e4a76f89f13c826e99..73b41ea7e0b0fff4b18b817e9ffc7127ad725dfd 100644 (file)
@@ -24,6 +24,7 @@
 package transport
 
 import (
+       "context"
        "errors"
        "fmt"
        "io"
@@ -33,26 +34,30 @@ import (
        "sync"
        "time"
 
-       "golang.org/x/net/context"
+       "github.com/golang/protobuf/proto"
        "golang.org/x/net/http2"
        "google.golang.org/grpc/codes"
        "google.golang.org/grpc/credentials"
        "google.golang.org/grpc/metadata"
        "google.golang.org/grpc/peer"
+       "google.golang.org/grpc/stats"
        "google.golang.org/grpc/status"
 )
 
 // NewServerHandlerTransport returns a ServerTransport handling gRPC
 // from inside an http.Handler. It requires that the http Server
 // supports HTTP/2.
-func NewServerHandlerTransport(w http.ResponseWriter, r *http.Request) (ServerTransport, error) {
+func NewServerHandlerTransport(w http.ResponseWriter, r *http.Request, stats stats.Handler) (ServerTransport, error) {
        if r.ProtoMajor != 2 {
                return nil, errors.New("gRPC requires HTTP/2")
        }
        if r.Method != "POST" {
                return nil, errors.New("invalid gRPC request method")
        }
-       if !validContentType(r.Header.Get("Content-Type")) {
+       contentType := r.Header.Get("Content-Type")
+       // TODO: do we assume contentType is lowercase? we did before
+       contentSubtype, validContentType := contentSubtype(contentType)
+       if !validContentType {
                return nil, errors.New("invalid gRPC request content-type")
        }
        if _, ok := w.(http.Flusher); !ok {
@@ -63,34 +68,37 @@ func NewServerHandlerTransport(w http.ResponseWriter, r *http.Request) (ServerTr
        }
 
        st := &serverHandlerTransport{
-               rw:       w,
-               req:      r,
-               closedCh: make(chan struct{}),
-               writes:   make(chan func()),
+               rw:             w,
+               req:            r,
+               closedCh:       make(chan struct{}),
+               writes:         make(chan func()),
+               contentType:    contentType,
+               contentSubtype: contentSubtype,
+               stats:          stats,
        }
 
        if v := r.Header.Get("grpc-timeout"); v != "" {
                to, err := decodeTimeout(v)
                if err != nil {
-                       return nil, streamErrorf(codes.Internal, "malformed time-out: %v", err)
+                       return nil, status.Errorf(codes.Internal, "malformed time-out: %v", err)
                }
                st.timeoutSet = true
                st.timeout = to
        }
 
-       var metakv []string
+       metakv := []string{"content-type", contentType}
        if r.Host != "" {
                metakv = append(metakv, ":authority", r.Host)
        }
        for k, vv := range r.Header {
                k = strings.ToLower(k)
-               if isReservedHeader(k) && !isWhitelistedPseudoHeader(k) {
+               if isReservedHeader(k) && !isWhitelistedHeader(k) {
                        continue
                }
                for _, v := range vv {
                        v, err := decodeMetadataHeader(k, v)
                        if err != nil {
-                               return nil, streamErrorf(codes.InvalidArgument, "malformed binary metadata: %v", err)
+                               return nil, status.Errorf(codes.Internal, "malformed binary metadata: %v", err)
                        }
                        metakv = append(metakv, k, v)
                }
@@ -121,6 +129,18 @@ type serverHandlerTransport struct {
        // ServeHTTP (HandleStreams) goroutine. The channel is closed
        // when WriteStatus is called.
        writes chan func()
+
+       // block concurrent WriteStatus calls
+       // e.g. grpc/(*serverStream).SendMsg/RecvMsg
+       writeStatusMu sync.Mutex
+
+       // we just mirror the request content-type
+       contentType string
+       // we store both contentType and contentSubtype so we don't keep recreating them
+       // TODO make sure this is consistent across handler_server and http2_server
+       contentSubtype string
+
+       stats stats.Handler
 }
 
 func (ht *serverHandlerTransport) Close() error {
@@ -167,11 +187,13 @@ func (ht *serverHandlerTransport) do(fn func()) error {
                case <-ht.closedCh:
                        return ErrConnClosing
                }
-
        }
 }
 
 func (ht *serverHandlerTransport) WriteStatus(s *Stream, st *status.Status) error {
+       ht.writeStatusMu.Lock()
+       defer ht.writeStatusMu.Unlock()
+
        err := ht.do(func() {
                ht.writeCommonHeaders(s)
 
@@ -186,7 +208,15 @@ func (ht *serverHandlerTransport) WriteStatus(s *Stream, st *status.Status) erro
                        h.Set("Grpc-Message", encodeGrpcMessage(m))
                }
 
-               // TODO: Support Grpc-Status-Details-Bin
+               if p := st.Proto(); p != nil && len(p.Details) > 0 {
+                       stBytes, err := proto.Marshal(p)
+                       if err != nil {
+                               // TODO: return error instead, when callers are able to handle it.
+                               panic(err)
+                       }
+
+                       h.Set("Grpc-Status-Details-Bin", encodeBinHeader(stBytes))
+               }
 
                if md := s.Trailer(); len(md) > 0 {
                        for k, vv := range md {
@@ -202,7 +232,14 @@ func (ht *serverHandlerTransport) WriteStatus(s *Stream, st *status.Status) erro
                        }
                }
        })
-       close(ht.writes)
+
+       if err == nil { // transport has not been closed
+               if ht.stats != nil {
+                       ht.stats.HandleRPC(s.Context(), &stats.OutTrailer{})
+               }
+               close(ht.writes)
+       }
+       ht.Close()
        return err
 }
 
@@ -216,7 +253,7 @@ func (ht *serverHandlerTransport) writeCommonHeaders(s *Stream) {
 
        h := ht.rw.Header()
        h["Date"] = nil // suppress Date to make tests happy; TODO: restore
-       h.Set("Content-Type", "application/grpc")
+       h.Set("Content-Type", ht.contentType)
 
        // Predeclare trailers we'll set later in WriteStatus (after the body).
        // This is a SHOULD in the HTTP RFC, and the way you add (known)
@@ -225,25 +262,24 @@ func (ht *serverHandlerTransport) writeCommonHeaders(s *Stream) {
        // and https://golang.org/pkg/net/http/#example_ResponseWriter_trailers
        h.Add("Trailer", "Grpc-Status")
        h.Add("Trailer", "Grpc-Message")
-       // TODO: Support Grpc-Status-Details-Bin
+       h.Add("Trailer", "Grpc-Status-Details-Bin")
 
        if s.sendCompress != "" {
                h.Set("Grpc-Encoding", s.sendCompress)
        }
 }
 
-func (ht *serverHandlerTransport) Write(s *Stream, data []byte, opts *Options) error {
+func (ht *serverHandlerTransport) Write(s *Stream, hdr []byte, data []byte, opts *Options) error {
        return ht.do(func() {
                ht.writeCommonHeaders(s)
+               ht.rw.Write(hdr)
                ht.rw.Write(data)
-               if !opts.Delay {
-                       ht.rw.(http.Flusher).Flush()
-               }
+               ht.rw.(http.Flusher).Flush()
        })
 }
 
 func (ht *serverHandlerTransport) WriteHeader(s *Stream, md metadata.MD) error {
-       return ht.do(func() {
+       err := ht.do(func() {
                ht.writeCommonHeaders(s)
                h := ht.rw.Header()
                for k, vv := range md {
@@ -259,17 +295,24 @@ func (ht *serverHandlerTransport) WriteHeader(s *Stream, md metadata.MD) error {
                ht.rw.WriteHeader(200)
                ht.rw.(http.Flusher).Flush()
        })
+
+       if err == nil {
+               if ht.stats != nil {
+                       ht.stats.HandleRPC(s.Context(), &stats.OutHeader{})
+               }
+       }
+       return err
 }
 
 func (ht *serverHandlerTransport) HandleStreams(startStream func(*Stream), traceCtx func(context.Context, string) context.Context) {
        // With this transport type there will be exactly 1 stream: this HTTP request.
 
-       var ctx context.Context
+       ctx := ht.req.Context()
        var cancel context.CancelFunc
        if ht.timeoutSet {
-               ctx, cancel = context.WithTimeout(context.Background(), ht.timeout)
+               ctx, cancel = context.WithTimeout(ctx, ht.timeout)
        } else {
-               ctx, cancel = context.WithCancel(context.Background())
+               ctx, cancel = context.WithCancel(ctx)
        }
 
        // requestOver is closed when either the request's context is done
@@ -283,23 +326,24 @@ func (ht *serverHandlerTransport) HandleStreams(startStream func(*Stream), trace
        go func() {
                select {
                case <-requestOver:
-                       return
                case <-ht.closedCh:
                case <-clientGone:
                }
                cancel()
+               ht.Close()
        }()
 
        req := ht.req
 
        s := &Stream{
-               id:           0, // irrelevant
-               requestRead:  func(int) {},
-               cancel:       cancel,
-               buf:          newRecvBuffer(),
-               st:           ht,
-               method:       req.URL.Path,
-               recvCompress: req.Header.Get("grpc-encoding"),
+               id:             0, // irrelevant
+               requestRead:    func(int) {},
+               cancel:         cancel,
+               buf:            newRecvBuffer(),
+               st:             ht,
+               method:         req.URL.Path,
+               recvCompress:   req.Header.Get("grpc-encoding"),
+               contentSubtype: ht.contentSubtype,
        }
        pr := &peer.Peer{
                Addr: ht.RemoteAddr(),
@@ -308,10 +352,18 @@ func (ht *serverHandlerTransport) HandleStreams(startStream func(*Stream), trace
                pr.AuthInfo = credentials.TLSInfo{State: *req.TLS}
        }
        ctx = metadata.NewIncomingContext(ctx, ht.headerMD)
-       ctx = peer.NewContext(ctx, pr)
-       s.ctx = newContextWithStream(ctx, s)
+       s.ctx = peer.NewContext(ctx, pr)
+       if ht.stats != nil {
+               s.ctx = ht.stats.TagRPC(s.ctx, &stats.RPCTagInfo{FullMethodName: s.method})
+               inHeader := &stats.InHeader{
+                       FullMethod:  s.method,
+                       RemoteAddr:  ht.RemoteAddr(),
+                       Compression: s.recvCompress,
+               }
+               ht.stats.HandleRPC(s.ctx, inHeader)
+       }
        s.trReader = &transportReader{
-               reader:        &recvBufferReader{ctx: s.ctx, recv: s.buf},
+               reader:        &recvBufferReader{ctx: s.ctx, ctxDone: s.ctx.Done(), recv: s.buf},
                windowHandler: func(int) {},
        }
 
@@ -366,6 +418,10 @@ func (ht *serverHandlerTransport) runStream() {
        }
 }
 
+func (ht *serverHandlerTransport) IncrMsgSent() {}
+
+func (ht *serverHandlerTransport) IncrMsgRecv() {}
+
 func (ht *serverHandlerTransport) Drain() {
        panic("Drain() is not implemented")
 }
@@ -376,18 +432,18 @@ func (ht *serverHandlerTransport) Drain() {
 //   * io.EOF
 //   * io.ErrUnexpectedEOF
 //   * of type transport.ConnectionError
-//   * of type transport.StreamError
+//   * an error from the status package
 func mapRecvMsgError(err error) error {
        if err == io.EOF || err == io.ErrUnexpectedEOF {
                return err
        }
        if se, ok := err.(http2.StreamError); ok {
                if code, ok := http2ErrConvTab[se.Code]; ok {
-                       return StreamError{
-                               Code: code,
-                               Desc: se.Error(),
-                       }
+                       return status.Error(code, se.Error())
                }
        }
+       if strings.Contains(err.Error(), "body closed by handler") {
+               return status.Error(codes.Canceled, err.Error())
+       }
        return connectionErrorf(true, err, err.Error())
 }