package transport
import (
+ "context"
"errors"
"fmt"
"io"
"sync"
"time"
- "golang.org/x/net/context"
+ "github.com/golang/protobuf/proto"
"golang.org/x/net/http2"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/peer"
+ "google.golang.org/grpc/stats"
"google.golang.org/grpc/status"
)
// NewServerHandlerTransport returns a ServerTransport handling gRPC
// from inside an http.Handler. It requires that the http Server
// supports HTTP/2.
-func NewServerHandlerTransport(w http.ResponseWriter, r *http.Request) (ServerTransport, error) {
+func NewServerHandlerTransport(w http.ResponseWriter, r *http.Request, stats stats.Handler) (ServerTransport, error) {
if r.ProtoMajor != 2 {
return nil, errors.New("gRPC requires HTTP/2")
}
if r.Method != "POST" {
return nil, errors.New("invalid gRPC request method")
}
- if !validContentType(r.Header.Get("Content-Type")) {
+ contentType := r.Header.Get("Content-Type")
+ // TODO: do we assume contentType is lowercase? we did before
+ contentSubtype, validContentType := contentSubtype(contentType)
+ if !validContentType {
return nil, errors.New("invalid gRPC request content-type")
}
if _, ok := w.(http.Flusher); !ok {
}
st := &serverHandlerTransport{
- rw: w,
- req: r,
- closedCh: make(chan struct{}),
- writes: make(chan func()),
+ rw: w,
+ req: r,
+ closedCh: make(chan struct{}),
+ writes: make(chan func()),
+ contentType: contentType,
+ contentSubtype: contentSubtype,
+ stats: stats,
}
if v := r.Header.Get("grpc-timeout"); v != "" {
to, err := decodeTimeout(v)
if err != nil {
- return nil, streamErrorf(codes.Internal, "malformed time-out: %v", err)
+ return nil, status.Errorf(codes.Internal, "malformed time-out: %v", err)
}
st.timeoutSet = true
st.timeout = to
}
- var metakv []string
+ metakv := []string{"content-type", contentType}
if r.Host != "" {
metakv = append(metakv, ":authority", r.Host)
}
for k, vv := range r.Header {
k = strings.ToLower(k)
- if isReservedHeader(k) && !isWhitelistedPseudoHeader(k) {
+ if isReservedHeader(k) && !isWhitelistedHeader(k) {
continue
}
for _, v := range vv {
v, err := decodeMetadataHeader(k, v)
if err != nil {
- return nil, streamErrorf(codes.InvalidArgument, "malformed binary metadata: %v", err)
+ return nil, status.Errorf(codes.Internal, "malformed binary metadata: %v", err)
}
metakv = append(metakv, k, v)
}
// ServeHTTP (HandleStreams) goroutine. The channel is closed
// when WriteStatus is called.
writes chan func()
+
+ // block concurrent WriteStatus calls
+ // e.g. grpc/(*serverStream).SendMsg/RecvMsg
+ writeStatusMu sync.Mutex
+
+ // we just mirror the request content-type
+ contentType string
+ // we store both contentType and contentSubtype so we don't keep recreating them
+ // TODO make sure this is consistent across handler_server and http2_server
+ contentSubtype string
+
+ stats stats.Handler
}
func (ht *serverHandlerTransport) Close() error {
case <-ht.closedCh:
return ErrConnClosing
}
-
}
}
func (ht *serverHandlerTransport) WriteStatus(s *Stream, st *status.Status) error {
+ ht.writeStatusMu.Lock()
+ defer ht.writeStatusMu.Unlock()
+
err := ht.do(func() {
ht.writeCommonHeaders(s)
h.Set("Grpc-Message", encodeGrpcMessage(m))
}
- // TODO: Support Grpc-Status-Details-Bin
+ if p := st.Proto(); p != nil && len(p.Details) > 0 {
+ stBytes, err := proto.Marshal(p)
+ if err != nil {
+ // TODO: return error instead, when callers are able to handle it.
+ panic(err)
+ }
+
+ h.Set("Grpc-Status-Details-Bin", encodeBinHeader(stBytes))
+ }
if md := s.Trailer(); len(md) > 0 {
for k, vv := range md {
}
}
})
- close(ht.writes)
+
+ if err == nil { // transport has not been closed
+ if ht.stats != nil {
+ ht.stats.HandleRPC(s.Context(), &stats.OutTrailer{})
+ }
+ close(ht.writes)
+ }
+ ht.Close()
return err
}
h := ht.rw.Header()
h["Date"] = nil // suppress Date to make tests happy; TODO: restore
- h.Set("Content-Type", "application/grpc")
+ h.Set("Content-Type", ht.contentType)
// Predeclare trailers we'll set later in WriteStatus (after the body).
// This is a SHOULD in the HTTP RFC, and the way you add (known)
// and https://golang.org/pkg/net/http/#example_ResponseWriter_trailers
h.Add("Trailer", "Grpc-Status")
h.Add("Trailer", "Grpc-Message")
- // TODO: Support Grpc-Status-Details-Bin
+ h.Add("Trailer", "Grpc-Status-Details-Bin")
if s.sendCompress != "" {
h.Set("Grpc-Encoding", s.sendCompress)
}
}
-func (ht *serverHandlerTransport) Write(s *Stream, data []byte, opts *Options) error {
+func (ht *serverHandlerTransport) Write(s *Stream, hdr []byte, data []byte, opts *Options) error {
return ht.do(func() {
ht.writeCommonHeaders(s)
+ ht.rw.Write(hdr)
ht.rw.Write(data)
- if !opts.Delay {
- ht.rw.(http.Flusher).Flush()
- }
+ ht.rw.(http.Flusher).Flush()
})
}
func (ht *serverHandlerTransport) WriteHeader(s *Stream, md metadata.MD) error {
- return ht.do(func() {
+ err := ht.do(func() {
ht.writeCommonHeaders(s)
h := ht.rw.Header()
for k, vv := range md {
ht.rw.WriteHeader(200)
ht.rw.(http.Flusher).Flush()
})
+
+ if err == nil {
+ if ht.stats != nil {
+ ht.stats.HandleRPC(s.Context(), &stats.OutHeader{})
+ }
+ }
+ return err
}
func (ht *serverHandlerTransport) HandleStreams(startStream func(*Stream), traceCtx func(context.Context, string) context.Context) {
// With this transport type there will be exactly 1 stream: this HTTP request.
- var ctx context.Context
+ ctx := ht.req.Context()
var cancel context.CancelFunc
if ht.timeoutSet {
- ctx, cancel = context.WithTimeout(context.Background(), ht.timeout)
+ ctx, cancel = context.WithTimeout(ctx, ht.timeout)
} else {
- ctx, cancel = context.WithCancel(context.Background())
+ ctx, cancel = context.WithCancel(ctx)
}
// requestOver is closed when either the request's context is done
go func() {
select {
case <-requestOver:
- return
case <-ht.closedCh:
case <-clientGone:
}
cancel()
+ ht.Close()
}()
req := ht.req
s := &Stream{
- id: 0, // irrelevant
- requestRead: func(int) {},
- cancel: cancel,
- buf: newRecvBuffer(),
- st: ht,
- method: req.URL.Path,
- recvCompress: req.Header.Get("grpc-encoding"),
+ id: 0, // irrelevant
+ requestRead: func(int) {},
+ cancel: cancel,
+ buf: newRecvBuffer(),
+ st: ht,
+ method: req.URL.Path,
+ recvCompress: req.Header.Get("grpc-encoding"),
+ contentSubtype: ht.contentSubtype,
}
pr := &peer.Peer{
Addr: ht.RemoteAddr(),
pr.AuthInfo = credentials.TLSInfo{State: *req.TLS}
}
ctx = metadata.NewIncomingContext(ctx, ht.headerMD)
- ctx = peer.NewContext(ctx, pr)
- s.ctx = newContextWithStream(ctx, s)
+ s.ctx = peer.NewContext(ctx, pr)
+ if ht.stats != nil {
+ s.ctx = ht.stats.TagRPC(s.ctx, &stats.RPCTagInfo{FullMethodName: s.method})
+ inHeader := &stats.InHeader{
+ FullMethod: s.method,
+ RemoteAddr: ht.RemoteAddr(),
+ Compression: s.recvCompress,
+ }
+ ht.stats.HandleRPC(s.ctx, inHeader)
+ }
s.trReader = &transportReader{
- reader: &recvBufferReader{ctx: s.ctx, recv: s.buf},
+ reader: &recvBufferReader{ctx: s.ctx, ctxDone: s.ctx.Done(), recv: s.buf},
windowHandler: func(int) {},
}
}
}
+func (ht *serverHandlerTransport) IncrMsgSent() {}
+
+func (ht *serverHandlerTransport) IncrMsgRecv() {}
+
func (ht *serverHandlerTransport) Drain() {
panic("Drain() is not implemented")
}
// * io.EOF
// * io.ErrUnexpectedEOF
// * of type transport.ConnectionError
-// * of type transport.StreamError
+// * an error from the status package
func mapRecvMsgError(err error) error {
if err == io.EOF || err == io.ErrUnexpectedEOF {
return err
}
if se, ok := err.(http2.StreamError); ok {
if code, ok := http2ErrConvTab[se.Code]; ok {
- return StreamError{
- Code: code,
- Desc: se.Error(),
- }
+ return status.Error(code, se.Error())
}
}
+ if strings.Contains(err.Error(), "body closed by handler") {
+ return status.Error(codes.Canceled, err.Error())
+ }
return connectionErrorf(true, err, err.Error())
}