diff options
Diffstat (limited to 'vendor/google.golang.org/grpc/server.go')
-rw-r--r-- | vendor/google.golang.org/grpc/server.go | 1159 |
1 files changed, 1159 insertions, 0 deletions
diff --git a/vendor/google.golang.org/grpc/server.go b/vendor/google.golang.org/grpc/server.go new file mode 100644 index 0000000..42733e2 --- /dev/null +++ b/vendor/google.golang.org/grpc/server.go | |||
@@ -0,0 +1,1159 @@ | |||
1 | /* | ||
2 | * | ||
3 | * Copyright 2014 gRPC authors. | ||
4 | * | ||
5 | * Licensed under the Apache License, Version 2.0 (the "License"); | ||
6 | * you may not use this file except in compliance with the License. | ||
7 | * You may obtain a copy of the License at | ||
8 | * | ||
9 | * http://www.apache.org/licenses/LICENSE-2.0 | ||
10 | * | ||
11 | * Unless required by applicable law or agreed to in writing, software | ||
12 | * distributed under the License is distributed on an "AS IS" BASIS, | ||
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
14 | * See the License for the specific language governing permissions and | ||
15 | * limitations under the License. | ||
16 | * | ||
17 | */ | ||
18 | |||
19 | package grpc | ||
20 | |||
21 | import ( | ||
22 | "bytes" | ||
23 | "errors" | ||
24 | "fmt" | ||
25 | "io" | ||
26 | "net" | ||
27 | "net/http" | ||
28 | "reflect" | ||
29 | "runtime" | ||
30 | "strings" | ||
31 | "sync" | ||
32 | "time" | ||
33 | |||
34 | "golang.org/x/net/context" | ||
35 | "golang.org/x/net/http2" | ||
36 | "golang.org/x/net/trace" | ||
37 | "google.golang.org/grpc/codes" | ||
38 | "google.golang.org/grpc/credentials" | ||
39 | "google.golang.org/grpc/grpclog" | ||
40 | "google.golang.org/grpc/internal" | ||
41 | "google.golang.org/grpc/keepalive" | ||
42 | "google.golang.org/grpc/metadata" | ||
43 | "google.golang.org/grpc/stats" | ||
44 | "google.golang.org/grpc/status" | ||
45 | "google.golang.org/grpc/tap" | ||
46 | "google.golang.org/grpc/transport" | ||
47 | ) | ||
48 | |||
49 | const ( | ||
50 | defaultServerMaxReceiveMessageSize = 1024 * 1024 * 4 | ||
51 | defaultServerMaxSendMessageSize = 1024 * 1024 * 4 | ||
52 | ) | ||
53 | |||
54 | type methodHandler func(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor UnaryServerInterceptor) (interface{}, error) | ||
55 | |||
56 | // MethodDesc represents an RPC service's method specification. | ||
57 | type MethodDesc struct { | ||
58 | MethodName string | ||
59 | Handler methodHandler | ||
60 | } | ||
61 | |||
62 | // ServiceDesc represents an RPC service's specification. | ||
63 | type ServiceDesc struct { | ||
64 | ServiceName string | ||
65 | // The pointer to the service interface. Used to check whether the user | ||
66 | // provided implementation satisfies the interface requirements. | ||
67 | HandlerType interface{} | ||
68 | Methods []MethodDesc | ||
69 | Streams []StreamDesc | ||
70 | Metadata interface{} | ||
71 | } | ||
72 | |||
73 | // service consists of the information of the server serving this service and | ||
74 | // the methods in this service. | ||
75 | type service struct { | ||
76 | server interface{} // the server for service methods | ||
77 | md map[string]*MethodDesc | ||
78 | sd map[string]*StreamDesc | ||
79 | mdata interface{} | ||
80 | } | ||
81 | |||
82 | // Server is a gRPC server to serve RPC requests. | ||
83 | type Server struct { | ||
84 | opts options | ||
85 | |||
86 | mu sync.Mutex // guards following | ||
87 | lis map[net.Listener]bool | ||
88 | conns map[io.Closer]bool | ||
89 | serve bool | ||
90 | drain bool | ||
91 | ctx context.Context | ||
92 | cancel context.CancelFunc | ||
93 | // A CondVar to let GracefulStop() blocks until all the pending RPCs are finished | ||
94 | // and all the transport goes away. | ||
95 | cv *sync.Cond | ||
96 | m map[string]*service // service name -> service info | ||
97 | events trace.EventLog | ||
98 | } | ||
99 | |||
100 | type options struct { | ||
101 | creds credentials.TransportCredentials | ||
102 | codec Codec | ||
103 | cp Compressor | ||
104 | dc Decompressor | ||
105 | unaryInt UnaryServerInterceptor | ||
106 | streamInt StreamServerInterceptor | ||
107 | inTapHandle tap.ServerInHandle | ||
108 | statsHandler stats.Handler | ||
109 | maxConcurrentStreams uint32 | ||
110 | maxReceiveMessageSize int | ||
111 | maxSendMessageSize int | ||
112 | useHandlerImpl bool // use http.Handler-based server | ||
113 | unknownStreamDesc *StreamDesc | ||
114 | keepaliveParams keepalive.ServerParameters | ||
115 | keepalivePolicy keepalive.EnforcementPolicy | ||
116 | initialWindowSize int32 | ||
117 | initialConnWindowSize int32 | ||
118 | } | ||
119 | |||
120 | var defaultServerOptions = options{ | ||
121 | maxReceiveMessageSize: defaultServerMaxReceiveMessageSize, | ||
122 | maxSendMessageSize: defaultServerMaxSendMessageSize, | ||
123 | } | ||
124 | |||
125 | // A ServerOption sets options such as credentials, codec and keepalive parameters, etc. | ||
126 | type ServerOption func(*options) | ||
127 | |||
128 | // InitialWindowSize returns a ServerOption that sets window size for stream. | ||
129 | // The lower bound for window size is 64K and any value smaller than that will be ignored. | ||
130 | func InitialWindowSize(s int32) ServerOption { | ||
131 | return func(o *options) { | ||
132 | o.initialWindowSize = s | ||
133 | } | ||
134 | } | ||
135 | |||
136 | // InitialConnWindowSize returns a ServerOption that sets window size for a connection. | ||
137 | // The lower bound for window size is 64K and any value smaller than that will be ignored. | ||
138 | func InitialConnWindowSize(s int32) ServerOption { | ||
139 | return func(o *options) { | ||
140 | o.initialConnWindowSize = s | ||
141 | } | ||
142 | } | ||
143 | |||
144 | // KeepaliveParams returns a ServerOption that sets keepalive and max-age parameters for the server. | ||
145 | func KeepaliveParams(kp keepalive.ServerParameters) ServerOption { | ||
146 | return func(o *options) { | ||
147 | o.keepaliveParams = kp | ||
148 | } | ||
149 | } | ||
150 | |||
151 | // KeepaliveEnforcementPolicy returns a ServerOption that sets keepalive enforcement policy for the server. | ||
152 | func KeepaliveEnforcementPolicy(kep keepalive.EnforcementPolicy) ServerOption { | ||
153 | return func(o *options) { | ||
154 | o.keepalivePolicy = kep | ||
155 | } | ||
156 | } | ||
157 | |||
158 | // CustomCodec returns a ServerOption that sets a codec for message marshaling and unmarshaling. | ||
159 | func CustomCodec(codec Codec) ServerOption { | ||
160 | return func(o *options) { | ||
161 | o.codec = codec | ||
162 | } | ||
163 | } | ||
164 | |||
165 | // RPCCompressor returns a ServerOption that sets a compressor for outbound messages. | ||
166 | func RPCCompressor(cp Compressor) ServerOption { | ||
167 | return func(o *options) { | ||
168 | o.cp = cp | ||
169 | } | ||
170 | } | ||
171 | |||
172 | // RPCDecompressor returns a ServerOption that sets a decompressor for inbound messages. | ||
173 | func RPCDecompressor(dc Decompressor) ServerOption { | ||
174 | return func(o *options) { | ||
175 | o.dc = dc | ||
176 | } | ||
177 | } | ||
178 | |||
179 | // MaxMsgSize returns a ServerOption to set the max message size in bytes the server can receive. | ||
180 | // If this is not set, gRPC uses the default limit. Deprecated: use MaxRecvMsgSize instead. | ||
181 | func MaxMsgSize(m int) ServerOption { | ||
182 | return MaxRecvMsgSize(m) | ||
183 | } | ||
184 | |||
185 | // MaxRecvMsgSize returns a ServerOption to set the max message size in bytes the server can receive. | ||
186 | // If this is not set, gRPC uses the default 4MB. | ||
187 | func MaxRecvMsgSize(m int) ServerOption { | ||
188 | return func(o *options) { | ||
189 | o.maxReceiveMessageSize = m | ||
190 | } | ||
191 | } | ||
192 | |||
193 | // MaxSendMsgSize returns a ServerOption to set the max message size in bytes the server can send. | ||
194 | // If this is not set, gRPC uses the default 4MB. | ||
195 | func MaxSendMsgSize(m int) ServerOption { | ||
196 | return func(o *options) { | ||
197 | o.maxSendMessageSize = m | ||
198 | } | ||
199 | } | ||
200 | |||
201 | // MaxConcurrentStreams returns a ServerOption that will apply a limit on the number | ||
202 | // of concurrent streams to each ServerTransport. | ||
203 | func MaxConcurrentStreams(n uint32) ServerOption { | ||
204 | return func(o *options) { | ||
205 | o.maxConcurrentStreams = n | ||
206 | } | ||
207 | } | ||
208 | |||
209 | // Creds returns a ServerOption that sets credentials for server connections. | ||
210 | func Creds(c credentials.TransportCredentials) ServerOption { | ||
211 | return func(o *options) { | ||
212 | o.creds = c | ||
213 | } | ||
214 | } | ||
215 | |||
216 | // UnaryInterceptor returns a ServerOption that sets the UnaryServerInterceptor for the | ||
217 | // server. Only one unary interceptor can be installed. The construction of multiple | ||
218 | // interceptors (e.g., chaining) can be implemented at the caller. | ||
219 | func UnaryInterceptor(i UnaryServerInterceptor) ServerOption { | ||
220 | return func(o *options) { | ||
221 | if o.unaryInt != nil { | ||
222 | panic("The unary server interceptor was already set and may not be reset.") | ||
223 | } | ||
224 | o.unaryInt = i | ||
225 | } | ||
226 | } | ||
227 | |||
228 | // StreamInterceptor returns a ServerOption that sets the StreamServerInterceptor for the | ||
229 | // server. Only one stream interceptor can be installed. | ||
230 | func StreamInterceptor(i StreamServerInterceptor) ServerOption { | ||
231 | return func(o *options) { | ||
232 | if o.streamInt != nil { | ||
233 | panic("The stream server interceptor was already set and may not be reset.") | ||
234 | } | ||
235 | o.streamInt = i | ||
236 | } | ||
237 | } | ||
238 | |||
239 | // InTapHandle returns a ServerOption that sets the tap handle for all the server | ||
240 | // transport to be created. Only one can be installed. | ||
241 | func InTapHandle(h tap.ServerInHandle) ServerOption { | ||
242 | return func(o *options) { | ||
243 | if o.inTapHandle != nil { | ||
244 | panic("The tap handle was already set and may not be reset.") | ||
245 | } | ||
246 | o.inTapHandle = h | ||
247 | } | ||
248 | } | ||
249 | |||
250 | // StatsHandler returns a ServerOption that sets the stats handler for the server. | ||
251 | func StatsHandler(h stats.Handler) ServerOption { | ||
252 | return func(o *options) { | ||
253 | o.statsHandler = h | ||
254 | } | ||
255 | } | ||
256 | |||
257 | // UnknownServiceHandler returns a ServerOption that allows for adding a custom | ||
258 | // unknown service handler. The provided method is a bidi-streaming RPC service | ||
259 | // handler that will be invoked instead of returning the "unimplemented" gRPC | ||
260 | // error whenever a request is received for an unregistered service or method. | ||
261 | // The handling function has full access to the Context of the request and the | ||
262 | // stream, and the invocation passes through interceptors. | ||
263 | func UnknownServiceHandler(streamHandler StreamHandler) ServerOption { | ||
264 | return func(o *options) { | ||
265 | o.unknownStreamDesc = &StreamDesc{ | ||
266 | StreamName: "unknown_service_handler", | ||
267 | Handler: streamHandler, | ||
268 | // We need to assume that the users of the streamHandler will want to use both. | ||
269 | ClientStreams: true, | ||
270 | ServerStreams: true, | ||
271 | } | ||
272 | } | ||
273 | } | ||
274 | |||
275 | // NewServer creates a gRPC server which has no service registered and has not | ||
276 | // started to accept requests yet. | ||
277 | func NewServer(opt ...ServerOption) *Server { | ||
278 | opts := defaultServerOptions | ||
279 | for _, o := range opt { | ||
280 | o(&opts) | ||
281 | } | ||
282 | if opts.codec == nil { | ||
283 | // Set the default codec. | ||
284 | opts.codec = protoCodec{} | ||
285 | } | ||
286 | s := &Server{ | ||
287 | lis: make(map[net.Listener]bool), | ||
288 | opts: opts, | ||
289 | conns: make(map[io.Closer]bool), | ||
290 | m: make(map[string]*service), | ||
291 | } | ||
292 | s.cv = sync.NewCond(&s.mu) | ||
293 | s.ctx, s.cancel = context.WithCancel(context.Background()) | ||
294 | if EnableTracing { | ||
295 | _, file, line, _ := runtime.Caller(1) | ||
296 | s.events = trace.NewEventLog("grpc.Server", fmt.Sprintf("%s:%d", file, line)) | ||
297 | } | ||
298 | return s | ||
299 | } | ||
300 | |||
301 | // printf records an event in s's event log, unless s has been stopped. | ||
302 | // REQUIRES s.mu is held. | ||
303 | func (s *Server) printf(format string, a ...interface{}) { | ||
304 | if s.events != nil { | ||
305 | s.events.Printf(format, a...) | ||
306 | } | ||
307 | } | ||
308 | |||
309 | // errorf records an error in s's event log, unless s has been stopped. | ||
310 | // REQUIRES s.mu is held. | ||
311 | func (s *Server) errorf(format string, a ...interface{}) { | ||
312 | if s.events != nil { | ||
313 | s.events.Errorf(format, a...) | ||
314 | } | ||
315 | } | ||
316 | |||
317 | // RegisterService registers a service and its implementation to the gRPC | ||
318 | // server. It is called from the IDL generated code. This must be called before | ||
319 | // invoking Serve. | ||
320 | func (s *Server) RegisterService(sd *ServiceDesc, ss interface{}) { | ||
321 | ht := reflect.TypeOf(sd.HandlerType).Elem() | ||
322 | st := reflect.TypeOf(ss) | ||
323 | if !st.Implements(ht) { | ||
324 | grpclog.Fatalf("grpc: Server.RegisterService found the handler of type %v that does not satisfy %v", st, ht) | ||
325 | } | ||
326 | s.register(sd, ss) | ||
327 | } | ||
328 | |||
329 | func (s *Server) register(sd *ServiceDesc, ss interface{}) { | ||
330 | s.mu.Lock() | ||
331 | defer s.mu.Unlock() | ||
332 | s.printf("RegisterService(%q)", sd.ServiceName) | ||
333 | if s.serve { | ||
334 | grpclog.Fatalf("grpc: Server.RegisterService after Server.Serve for %q", sd.ServiceName) | ||
335 | } | ||
336 | if _, ok := s.m[sd.ServiceName]; ok { | ||
337 | grpclog.Fatalf("grpc: Server.RegisterService found duplicate service registration for %q", sd.ServiceName) | ||
338 | } | ||
339 | srv := &service{ | ||
340 | server: ss, | ||
341 | md: make(map[string]*MethodDesc), | ||
342 | sd: make(map[string]*StreamDesc), | ||
343 | mdata: sd.Metadata, | ||
344 | } | ||
345 | for i := range sd.Methods { | ||
346 | d := &sd.Methods[i] | ||
347 | srv.md[d.MethodName] = d | ||
348 | } | ||
349 | for i := range sd.Streams { | ||
350 | d := &sd.Streams[i] | ||
351 | srv.sd[d.StreamName] = d | ||
352 | } | ||
353 | s.m[sd.ServiceName] = srv | ||
354 | } | ||
355 | |||
356 | // MethodInfo contains the information of an RPC including its method name and type. | ||
357 | type MethodInfo struct { | ||
358 | // Name is the method name only, without the service name or package name. | ||
359 | Name string | ||
360 | // IsClientStream indicates whether the RPC is a client streaming RPC. | ||
361 | IsClientStream bool | ||
362 | // IsServerStream indicates whether the RPC is a server streaming RPC. | ||
363 | IsServerStream bool | ||
364 | } | ||
365 | |||
366 | // ServiceInfo contains unary RPC method info, streaming RPC method info and metadata for a service. | ||
367 | type ServiceInfo struct { | ||
368 | Methods []MethodInfo | ||
369 | // Metadata is the metadata specified in ServiceDesc when registering service. | ||
370 | Metadata interface{} | ||
371 | } | ||
372 | |||
373 | // GetServiceInfo returns a map from service names to ServiceInfo. | ||
374 | // Service names include the package names, in the form of <package>.<service>. | ||
375 | func (s *Server) GetServiceInfo() map[string]ServiceInfo { | ||
376 | ret := make(map[string]ServiceInfo) | ||
377 | for n, srv := range s.m { | ||
378 | methods := make([]MethodInfo, 0, len(srv.md)+len(srv.sd)) | ||
379 | for m := range srv.md { | ||
380 | methods = append(methods, MethodInfo{ | ||
381 | Name: m, | ||
382 | IsClientStream: false, | ||
383 | IsServerStream: false, | ||
384 | }) | ||
385 | } | ||
386 | for m, d := range srv.sd { | ||
387 | methods = append(methods, MethodInfo{ | ||
388 | Name: m, | ||
389 | IsClientStream: d.ClientStreams, | ||
390 | IsServerStream: d.ServerStreams, | ||
391 | }) | ||
392 | } | ||
393 | |||
394 | ret[n] = ServiceInfo{ | ||
395 | Methods: methods, | ||
396 | Metadata: srv.mdata, | ||
397 | } | ||
398 | } | ||
399 | return ret | ||
400 | } | ||
401 | |||
402 | var ( | ||
403 | // ErrServerStopped indicates that the operation is now illegal because of | ||
404 | // the server being stopped. | ||
405 | ErrServerStopped = errors.New("grpc: the server has been stopped") | ||
406 | ) | ||
407 | |||
408 | func (s *Server) useTransportAuthenticator(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) { | ||
409 | if s.opts.creds == nil { | ||
410 | return rawConn, nil, nil | ||
411 | } | ||
412 | return s.opts.creds.ServerHandshake(rawConn) | ||
413 | } | ||
414 | |||
415 | // Serve accepts incoming connections on the listener lis, creating a new | ||
416 | // ServerTransport and service goroutine for each. The service goroutines | ||
417 | // read gRPC requests and then call the registered handlers to reply to them. | ||
418 | // Serve returns when lis.Accept fails with fatal errors. lis will be closed when | ||
419 | // this method returns. | ||
420 | // Serve always returns non-nil error. | ||
421 | func (s *Server) Serve(lis net.Listener) error { | ||
422 | s.mu.Lock() | ||
423 | s.printf("serving") | ||
424 | s.serve = true | ||
425 | if s.lis == nil { | ||
426 | s.mu.Unlock() | ||
427 | lis.Close() | ||
428 | return ErrServerStopped | ||
429 | } | ||
430 | s.lis[lis] = true | ||
431 | s.mu.Unlock() | ||
432 | defer func() { | ||
433 | s.mu.Lock() | ||
434 | if s.lis != nil && s.lis[lis] { | ||
435 | lis.Close() | ||
436 | delete(s.lis, lis) | ||
437 | } | ||
438 | s.mu.Unlock() | ||
439 | }() | ||
440 | |||
441 | var tempDelay time.Duration // how long to sleep on accept failure | ||
442 | |||
443 | for { | ||
444 | rawConn, err := lis.Accept() | ||
445 | if err != nil { | ||
446 | if ne, ok := err.(interface { | ||
447 | Temporary() bool | ||
448 | }); ok && ne.Temporary() { | ||
449 | if tempDelay == 0 { | ||
450 | tempDelay = 5 * time.Millisecond | ||
451 | } else { | ||
452 | tempDelay *= 2 | ||
453 | } | ||
454 | if max := 1 * time.Second; tempDelay > max { | ||
455 | tempDelay = max | ||
456 | } | ||
457 | s.mu.Lock() | ||
458 | s.printf("Accept error: %v; retrying in %v", err, tempDelay) | ||
459 | s.mu.Unlock() | ||
460 | timer := time.NewTimer(tempDelay) | ||
461 | select { | ||
462 | case <-timer.C: | ||
463 | case <-s.ctx.Done(): | ||
464 | } | ||
465 | timer.Stop() | ||
466 | continue | ||
467 | } | ||
468 | s.mu.Lock() | ||
469 | s.printf("done serving; Accept = %v", err) | ||
470 | s.mu.Unlock() | ||
471 | return err | ||
472 | } | ||
473 | tempDelay = 0 | ||
474 | // Start a new goroutine to deal with rawConn | ||
475 | // so we don't stall this Accept loop goroutine. | ||
476 | go s.handleRawConn(rawConn) | ||
477 | } | ||
478 | } | ||
479 | |||
480 | // handleRawConn is run in its own goroutine and handles a just-accepted | ||
481 | // connection that has not had any I/O performed on it yet. | ||
482 | func (s *Server) handleRawConn(rawConn net.Conn) { | ||
483 | conn, authInfo, err := s.useTransportAuthenticator(rawConn) | ||
484 | if err != nil { | ||
485 | s.mu.Lock() | ||
486 | s.errorf("ServerHandshake(%q) failed: %v", rawConn.RemoteAddr(), err) | ||
487 | s.mu.Unlock() | ||
488 | grpclog.Warningf("grpc: Server.Serve failed to complete security handshake from %q: %v", rawConn.RemoteAddr(), err) | ||
489 | // If serverHandShake returns ErrConnDispatched, keep rawConn open. | ||
490 | if err != credentials.ErrConnDispatched { | ||
491 | rawConn.Close() | ||
492 | } | ||
493 | return | ||
494 | } | ||
495 | |||
496 | s.mu.Lock() | ||
497 | if s.conns == nil { | ||
498 | s.mu.Unlock() | ||
499 | conn.Close() | ||
500 | return | ||
501 | } | ||
502 | s.mu.Unlock() | ||
503 | |||
504 | if s.opts.useHandlerImpl { | ||
505 | s.serveUsingHandler(conn) | ||
506 | } else { | ||
507 | s.serveHTTP2Transport(conn, authInfo) | ||
508 | } | ||
509 | } | ||
510 | |||
511 | // serveHTTP2Transport sets up a http/2 transport (using the | ||
512 | // gRPC http2 server transport in transport/http2_server.go) and | ||
513 | // serves streams on it. | ||
514 | // This is run in its own goroutine (it does network I/O in | ||
515 | // transport.NewServerTransport). | ||
516 | func (s *Server) serveHTTP2Transport(c net.Conn, authInfo credentials.AuthInfo) { | ||
517 | config := &transport.ServerConfig{ | ||
518 | MaxStreams: s.opts.maxConcurrentStreams, | ||
519 | AuthInfo: authInfo, | ||
520 | InTapHandle: s.opts.inTapHandle, | ||
521 | StatsHandler: s.opts.statsHandler, | ||
522 | KeepaliveParams: s.opts.keepaliveParams, | ||
523 | KeepalivePolicy: s.opts.keepalivePolicy, | ||
524 | InitialWindowSize: s.opts.initialWindowSize, | ||
525 | InitialConnWindowSize: s.opts.initialConnWindowSize, | ||
526 | } | ||
527 | st, err := transport.NewServerTransport("http2", c, config) | ||
528 | if err != nil { | ||
529 | s.mu.Lock() | ||
530 | s.errorf("NewServerTransport(%q) failed: %v", c.RemoteAddr(), err) | ||
531 | s.mu.Unlock() | ||
532 | c.Close() | ||
533 | grpclog.Warningln("grpc: Server.Serve failed to create ServerTransport: ", err) | ||
534 | return | ||
535 | } | ||
536 | if !s.addConn(st) { | ||
537 | st.Close() | ||
538 | return | ||
539 | } | ||
540 | s.serveStreams(st) | ||
541 | } | ||
542 | |||
543 | func (s *Server) serveStreams(st transport.ServerTransport) { | ||
544 | defer s.removeConn(st) | ||
545 | defer st.Close() | ||
546 | var wg sync.WaitGroup | ||
547 | st.HandleStreams(func(stream *transport.Stream) { | ||
548 | wg.Add(1) | ||
549 | go func() { | ||
550 | defer wg.Done() | ||
551 | s.handleStream(st, stream, s.traceInfo(st, stream)) | ||
552 | }() | ||
553 | }, func(ctx context.Context, method string) context.Context { | ||
554 | if !EnableTracing { | ||
555 | return ctx | ||
556 | } | ||
557 | tr := trace.New("grpc.Recv."+methodFamily(method), method) | ||
558 | return trace.NewContext(ctx, tr) | ||
559 | }) | ||
560 | wg.Wait() | ||
561 | } | ||
562 | |||
563 | var _ http.Handler = (*Server)(nil) | ||
564 | |||
565 | // serveUsingHandler is called from handleRawConn when s is configured | ||
566 | // to handle requests via the http.Handler interface. It sets up a | ||
567 | // net/http.Server to handle the just-accepted conn. The http.Server | ||
568 | // is configured to route all incoming requests (all HTTP/2 streams) | ||
569 | // to ServeHTTP, which creates a new ServerTransport for each stream. | ||
570 | // serveUsingHandler blocks until conn closes. | ||
571 | // | ||
572 | // This codepath is only used when Server.TestingUseHandlerImpl has | ||
573 | // been configured. This lets the end2end tests exercise the ServeHTTP | ||
574 | // method as one of the environment types. | ||
575 | // | ||
576 | // conn is the *tls.Conn that's already been authenticated. | ||
577 | func (s *Server) serveUsingHandler(conn net.Conn) { | ||
578 | if !s.addConn(conn) { | ||
579 | conn.Close() | ||
580 | return | ||
581 | } | ||
582 | defer s.removeConn(conn) | ||
583 | h2s := &http2.Server{ | ||
584 | MaxConcurrentStreams: s.opts.maxConcurrentStreams, | ||
585 | } | ||
586 | h2s.ServeConn(conn, &http2.ServeConnOpts{ | ||
587 | Handler: s, | ||
588 | }) | ||
589 | } | ||
590 | |||
591 | // ServeHTTP implements the Go standard library's http.Handler | ||
592 | // interface by responding to the gRPC request r, by looking up | ||
593 | // the requested gRPC method in the gRPC server s. | ||
594 | // | ||
595 | // The provided HTTP request must have arrived on an HTTP/2 | ||
596 | // connection. When using the Go standard library's server, | ||
597 | // practically this means that the Request must also have arrived | ||
598 | // over TLS. | ||
599 | // | ||
600 | // To share one port (such as 443 for https) between gRPC and an | ||
601 | // existing http.Handler, use a root http.Handler such as: | ||
602 | // | ||
603 | // if r.ProtoMajor == 2 && strings.HasPrefix( | ||
604 | // r.Header.Get("Content-Type"), "application/grpc") { | ||
605 | // grpcServer.ServeHTTP(w, r) | ||
606 | // } else { | ||
607 | // yourMux.ServeHTTP(w, r) | ||
608 | // } | ||
609 | // | ||
610 | // Note that ServeHTTP uses Go's HTTP/2 server implementation which is totally | ||
611 | // separate from grpc-go's HTTP/2 server. Performance and features may vary | ||
612 | // between the two paths. ServeHTTP does not support some gRPC features | ||
613 | // available through grpc-go's HTTP/2 server, and it is currently EXPERIMENTAL | ||
614 | // and subject to change. | ||
615 | func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { | ||
616 | st, err := transport.NewServerHandlerTransport(w, r) | ||
617 | if err != nil { | ||
618 | http.Error(w, err.Error(), http.StatusInternalServerError) | ||
619 | return | ||
620 | } | ||
621 | if !s.addConn(st) { | ||
622 | st.Close() | ||
623 | return | ||
624 | } | ||
625 | defer s.removeConn(st) | ||
626 | s.serveStreams(st) | ||
627 | } | ||
628 | |||
629 | // traceInfo returns a traceInfo and associates it with stream, if tracing is enabled. | ||
630 | // If tracing is not enabled, it returns nil. | ||
631 | func (s *Server) traceInfo(st transport.ServerTransport, stream *transport.Stream) (trInfo *traceInfo) { | ||
632 | tr, ok := trace.FromContext(stream.Context()) | ||
633 | if !ok { | ||
634 | return nil | ||
635 | } | ||
636 | |||
637 | trInfo = &traceInfo{ | ||
638 | tr: tr, | ||
639 | } | ||
640 | trInfo.firstLine.client = false | ||
641 | trInfo.firstLine.remoteAddr = st.RemoteAddr() | ||
642 | |||
643 | if dl, ok := stream.Context().Deadline(); ok { | ||
644 | trInfo.firstLine.deadline = dl.Sub(time.Now()) | ||
645 | } | ||
646 | return trInfo | ||
647 | } | ||
648 | |||
649 | func (s *Server) addConn(c io.Closer) bool { | ||
650 | s.mu.Lock() | ||
651 | defer s.mu.Unlock() | ||
652 | if s.conns == nil || s.drain { | ||
653 | return false | ||
654 | } | ||
655 | s.conns[c] = true | ||
656 | return true | ||
657 | } | ||
658 | |||
659 | func (s *Server) removeConn(c io.Closer) { | ||
660 | s.mu.Lock() | ||
661 | defer s.mu.Unlock() | ||
662 | if s.conns != nil { | ||
663 | delete(s.conns, c) | ||
664 | s.cv.Broadcast() | ||
665 | } | ||
666 | } | ||
667 | |||
668 | func (s *Server) sendResponse(t transport.ServerTransport, stream *transport.Stream, msg interface{}, cp Compressor, opts *transport.Options) error { | ||
669 | var ( | ||
670 | cbuf *bytes.Buffer | ||
671 | outPayload *stats.OutPayload | ||
672 | ) | ||
673 | if cp != nil { | ||
674 | cbuf = new(bytes.Buffer) | ||
675 | } | ||
676 | if s.opts.statsHandler != nil { | ||
677 | outPayload = &stats.OutPayload{} | ||
678 | } | ||
679 | p, err := encode(s.opts.codec, msg, cp, cbuf, outPayload) | ||
680 | if err != nil { | ||
681 | grpclog.Errorln("grpc: server failed to encode response: ", err) | ||
682 | return err | ||
683 | } | ||
684 | if len(p) > s.opts.maxSendMessageSize { | ||
685 | return status.Errorf(codes.ResourceExhausted, "grpc: trying to send message larger than max (%d vs. %d)", len(p), s.opts.maxSendMessageSize) | ||
686 | } | ||
687 | err = t.Write(stream, p, opts) | ||
688 | if err == nil && outPayload != nil { | ||
689 | outPayload.SentTime = time.Now() | ||
690 | s.opts.statsHandler.HandleRPC(stream.Context(), outPayload) | ||
691 | } | ||
692 | return err | ||
693 | } | ||
694 | |||
695 | func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport.Stream, srv *service, md *MethodDesc, trInfo *traceInfo) (err error) { | ||
696 | sh := s.opts.statsHandler | ||
697 | if sh != nil { | ||
698 | begin := &stats.Begin{ | ||
699 | BeginTime: time.Now(), | ||
700 | } | ||
701 | sh.HandleRPC(stream.Context(), begin) | ||
702 | defer func() { | ||
703 | end := &stats.End{ | ||
704 | EndTime: time.Now(), | ||
705 | } | ||
706 | if err != nil && err != io.EOF { | ||
707 | end.Error = toRPCErr(err) | ||
708 | } | ||
709 | sh.HandleRPC(stream.Context(), end) | ||
710 | }() | ||
711 | } | ||
712 | if trInfo != nil { | ||
713 | defer trInfo.tr.Finish() | ||
714 | trInfo.firstLine.client = false | ||
715 | trInfo.tr.LazyLog(&trInfo.firstLine, false) | ||
716 | defer func() { | ||
717 | if err != nil && err != io.EOF { | ||
718 | trInfo.tr.LazyLog(&fmtStringer{"%v", []interface{}{err}}, true) | ||
719 | trInfo.tr.SetError() | ||
720 | } | ||
721 | }() | ||
722 | } | ||
723 | if s.opts.cp != nil { | ||
724 | // NOTE: this needs to be ahead of all handling, https://github.com/grpc/grpc-go/issues/686. | ||
725 | stream.SetSendCompress(s.opts.cp.Type()) | ||
726 | } | ||
727 | p := &parser{r: stream} | ||
728 | pf, req, err := p.recvMsg(s.opts.maxReceiveMessageSize) | ||
729 | if err == io.EOF { | ||
730 | // The entire stream is done (for unary RPC only). | ||
731 | return err | ||
732 | } | ||
733 | if err == io.ErrUnexpectedEOF { | ||
734 | err = Errorf(codes.Internal, io.ErrUnexpectedEOF.Error()) | ||
735 | } | ||
736 | if err != nil { | ||
737 | if st, ok := status.FromError(err); ok { | ||
738 | if e := t.WriteStatus(stream, st); e != nil { | ||
739 | grpclog.Warningf("grpc: Server.processUnaryRPC failed to write status %v", e) | ||
740 | } | ||
741 | } else { | ||
742 | switch st := err.(type) { | ||
743 | case transport.ConnectionError: | ||
744 | // Nothing to do here. | ||
745 | case transport.StreamError: | ||
746 | if e := t.WriteStatus(stream, status.New(st.Code, st.Desc)); e != nil { | ||
747 | grpclog.Warningf("grpc: Server.processUnaryRPC failed to write status %v", e) | ||
748 | } | ||
749 | default: | ||
750 | panic(fmt.Sprintf("grpc: Unexpected error (%T) from recvMsg: %v", st, st)) | ||
751 | } | ||
752 | } | ||
753 | return err | ||
754 | } | ||
755 | |||
756 | if err := checkRecvPayload(pf, stream.RecvCompress(), s.opts.dc); err != nil { | ||
757 | if st, ok := status.FromError(err); ok { | ||
758 | if e := t.WriteStatus(stream, st); e != nil { | ||
759 | grpclog.Warningf("grpc: Server.processUnaryRPC failed to write status %v", e) | ||
760 | } | ||
761 | return err | ||
762 | } | ||
763 | if e := t.WriteStatus(stream, status.New(codes.Internal, err.Error())); e != nil { | ||
764 | grpclog.Warningf("grpc: Server.processUnaryRPC failed to write status %v", e) | ||
765 | } | ||
766 | |||
767 | // TODO checkRecvPayload always return RPC error. Add a return here if necessary. | ||
768 | } | ||
769 | var inPayload *stats.InPayload | ||
770 | if sh != nil { | ||
771 | inPayload = &stats.InPayload{ | ||
772 | RecvTime: time.Now(), | ||
773 | } | ||
774 | } | ||
775 | df := func(v interface{}) error { | ||
776 | if inPayload != nil { | ||
777 | inPayload.WireLength = len(req) | ||
778 | } | ||
779 | if pf == compressionMade { | ||
780 | var err error | ||
781 | req, err = s.opts.dc.Do(bytes.NewReader(req)) | ||
782 | if err != nil { | ||
783 | return Errorf(codes.Internal, err.Error()) | ||
784 | } | ||
785 | } | ||
786 | if len(req) > s.opts.maxReceiveMessageSize { | ||
787 | // TODO: Revisit the error code. Currently keep it consistent with | ||
788 | // java implementation. | ||
789 | return status.Errorf(codes.ResourceExhausted, "grpc: received message larger than max (%d vs. %d)", len(req), s.opts.maxReceiveMessageSize) | ||
790 | } | ||
791 | if err := s.opts.codec.Unmarshal(req, v); err != nil { | ||
792 | return status.Errorf(codes.Internal, "grpc: error unmarshalling request: %v", err) | ||
793 | } | ||
794 | if inPayload != nil { | ||
795 | inPayload.Payload = v | ||
796 | inPayload.Data = req | ||
797 | inPayload.Length = len(req) | ||
798 | sh.HandleRPC(stream.Context(), inPayload) | ||
799 | } | ||
800 | if trInfo != nil { | ||
801 | trInfo.tr.LazyLog(&payload{sent: false, msg: v}, true) | ||
802 | } | ||
803 | return nil | ||
804 | } | ||
805 | reply, appErr := md.Handler(srv.server, stream.Context(), df, s.opts.unaryInt) | ||
806 | if appErr != nil { | ||
807 | appStatus, ok := status.FromError(appErr) | ||
808 | if !ok { | ||
809 | // Convert appErr if it is not a grpc status error. | ||
810 | appErr = status.Error(convertCode(appErr), appErr.Error()) | ||
811 | appStatus, _ = status.FromError(appErr) | ||
812 | } | ||
813 | if trInfo != nil { | ||
814 | trInfo.tr.LazyLog(stringer(appStatus.Message()), true) | ||
815 | trInfo.tr.SetError() | ||
816 | } | ||
817 | if e := t.WriteStatus(stream, appStatus); e != nil { | ||
818 | grpclog.Warningf("grpc: Server.processUnaryRPC failed to write status: %v", e) | ||
819 | } | ||
820 | return appErr | ||
821 | } | ||
822 | if trInfo != nil { | ||
823 | trInfo.tr.LazyLog(stringer("OK"), false) | ||
824 | } | ||
825 | opts := &transport.Options{ | ||
826 | Last: true, | ||
827 | Delay: false, | ||
828 | } | ||
829 | if err := s.sendResponse(t, stream, reply, s.opts.cp, opts); err != nil { | ||
830 | if err == io.EOF { | ||
831 | // The entire stream is done (for unary RPC only). | ||
832 | return err | ||
833 | } | ||
834 | if s, ok := status.FromError(err); ok { | ||
835 | if e := t.WriteStatus(stream, s); e != nil { | ||
836 | grpclog.Warningf("grpc: Server.processUnaryRPC failed to write status: %v", e) | ||
837 | } | ||
838 | } else { | ||
839 | switch st := err.(type) { | ||
840 | case transport.ConnectionError: | ||
841 | // Nothing to do here. | ||
842 | case transport.StreamError: | ||
843 | if e := t.WriteStatus(stream, status.New(st.Code, st.Desc)); e != nil { | ||
844 | grpclog.Warningf("grpc: Server.processUnaryRPC failed to write status %v", e) | ||
845 | } | ||
846 | default: | ||
847 | panic(fmt.Sprintf("grpc: Unexpected error (%T) from sendResponse: %v", st, st)) | ||
848 | } | ||
849 | } | ||
850 | return err | ||
851 | } | ||
852 | if trInfo != nil { | ||
853 | trInfo.tr.LazyLog(&payload{sent: true, msg: reply}, true) | ||
854 | } | ||
855 | // TODO: Should we be logging if writing status failed here, like above? | ||
856 | // Should the logging be in WriteStatus? Should we ignore the WriteStatus | ||
857 | // error or allow the stats handler to see it? | ||
858 | return t.WriteStatus(stream, status.New(codes.OK, "")) | ||
859 | } | ||
860 | |||
861 | func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transport.Stream, srv *service, sd *StreamDesc, trInfo *traceInfo) (err error) { | ||
862 | sh := s.opts.statsHandler | ||
863 | if sh != nil { | ||
864 | begin := &stats.Begin{ | ||
865 | BeginTime: time.Now(), | ||
866 | } | ||
867 | sh.HandleRPC(stream.Context(), begin) | ||
868 | defer func() { | ||
869 | end := &stats.End{ | ||
870 | EndTime: time.Now(), | ||
871 | } | ||
872 | if err != nil && err != io.EOF { | ||
873 | end.Error = toRPCErr(err) | ||
874 | } | ||
875 | sh.HandleRPC(stream.Context(), end) | ||
876 | }() | ||
877 | } | ||
878 | if s.opts.cp != nil { | ||
879 | stream.SetSendCompress(s.opts.cp.Type()) | ||
880 | } | ||
881 | ss := &serverStream{ | ||
882 | t: t, | ||
883 | s: stream, | ||
884 | p: &parser{r: stream}, | ||
885 | codec: s.opts.codec, | ||
886 | cp: s.opts.cp, | ||
887 | dc: s.opts.dc, | ||
888 | maxReceiveMessageSize: s.opts.maxReceiveMessageSize, | ||
889 | maxSendMessageSize: s.opts.maxSendMessageSize, | ||
890 | trInfo: trInfo, | ||
891 | statsHandler: sh, | ||
892 | } | ||
893 | if ss.cp != nil { | ||
894 | ss.cbuf = new(bytes.Buffer) | ||
895 | } | ||
896 | if trInfo != nil { | ||
897 | trInfo.tr.LazyLog(&trInfo.firstLine, false) | ||
898 | defer func() { | ||
899 | ss.mu.Lock() | ||
900 | if err != nil && err != io.EOF { | ||
901 | ss.trInfo.tr.LazyLog(&fmtStringer{"%v", []interface{}{err}}, true) | ||
902 | ss.trInfo.tr.SetError() | ||
903 | } | ||
904 | ss.trInfo.tr.Finish() | ||
905 | ss.trInfo.tr = nil | ||
906 | ss.mu.Unlock() | ||
907 | }() | ||
908 | } | ||
909 | var appErr error | ||
910 | var server interface{} | ||
911 | if srv != nil { | ||
912 | server = srv.server | ||
913 | } | ||
914 | if s.opts.streamInt == nil { | ||
915 | appErr = sd.Handler(server, ss) | ||
916 | } else { | ||
917 | info := &StreamServerInfo{ | ||
918 | FullMethod: stream.Method(), | ||
919 | IsClientStream: sd.ClientStreams, | ||
920 | IsServerStream: sd.ServerStreams, | ||
921 | } | ||
922 | appErr = s.opts.streamInt(server, ss, info, sd.Handler) | ||
923 | } | ||
924 | if appErr != nil { | ||
925 | appStatus, ok := status.FromError(appErr) | ||
926 | if !ok { | ||
927 | switch err := appErr.(type) { | ||
928 | case transport.StreamError: | ||
929 | appStatus = status.New(err.Code, err.Desc) | ||
930 | default: | ||
931 | appStatus = status.New(convertCode(appErr), appErr.Error()) | ||
932 | } | ||
933 | appErr = appStatus.Err() | ||
934 | } | ||
935 | if trInfo != nil { | ||
936 | ss.mu.Lock() | ||
937 | ss.trInfo.tr.LazyLog(stringer(appStatus.Message()), true) | ||
938 | ss.trInfo.tr.SetError() | ||
939 | ss.mu.Unlock() | ||
940 | } | ||
941 | t.WriteStatus(ss.s, appStatus) | ||
942 | // TODO: Should we log an error from WriteStatus here and below? | ||
943 | return appErr | ||
944 | } | ||
945 | if trInfo != nil { | ||
946 | ss.mu.Lock() | ||
947 | ss.trInfo.tr.LazyLog(stringer("OK"), false) | ||
948 | ss.mu.Unlock() | ||
949 | } | ||
950 | return t.WriteStatus(ss.s, status.New(codes.OK, "")) | ||
951 | |||
952 | } | ||
953 | |||
954 | func (s *Server) handleStream(t transport.ServerTransport, stream *transport.Stream, trInfo *traceInfo) { | ||
955 | sm := stream.Method() | ||
956 | if sm != "" && sm[0] == '/' { | ||
957 | sm = sm[1:] | ||
958 | } | ||
959 | pos := strings.LastIndex(sm, "/") | ||
960 | if pos == -1 { | ||
961 | if trInfo != nil { | ||
962 | trInfo.tr.LazyLog(&fmtStringer{"Malformed method name %q", []interface{}{sm}}, true) | ||
963 | trInfo.tr.SetError() | ||
964 | } | ||
965 | errDesc := fmt.Sprintf("malformed method name: %q", stream.Method()) | ||
966 | if err := t.WriteStatus(stream, status.New(codes.ResourceExhausted, errDesc)); err != nil { | ||
967 | if trInfo != nil { | ||
968 | trInfo.tr.LazyLog(&fmtStringer{"%v", []interface{}{err}}, true) | ||
969 | trInfo.tr.SetError() | ||
970 | } | ||
971 | grpclog.Warningf("grpc: Server.handleStream failed to write status: %v", err) | ||
972 | } | ||
973 | if trInfo != nil { | ||
974 | trInfo.tr.Finish() | ||
975 | } | ||
976 | return | ||
977 | } | ||
978 | service := sm[:pos] | ||
979 | method := sm[pos+1:] | ||
980 | srv, ok := s.m[service] | ||
981 | if !ok { | ||
982 | if unknownDesc := s.opts.unknownStreamDesc; unknownDesc != nil { | ||
983 | s.processStreamingRPC(t, stream, nil, unknownDesc, trInfo) | ||
984 | return | ||
985 | } | ||
986 | if trInfo != nil { | ||
987 | trInfo.tr.LazyLog(&fmtStringer{"Unknown service %v", []interface{}{service}}, true) | ||
988 | trInfo.tr.SetError() | ||
989 | } | ||
990 | errDesc := fmt.Sprintf("unknown service %v", service) | ||
991 | if err := t.WriteStatus(stream, status.New(codes.Unimplemented, errDesc)); err != nil { | ||
992 | if trInfo != nil { | ||
993 | trInfo.tr.LazyLog(&fmtStringer{"%v", []interface{}{err}}, true) | ||
994 | trInfo.tr.SetError() | ||
995 | } | ||
996 | grpclog.Warningf("grpc: Server.handleStream failed to write status: %v", err) | ||
997 | } | ||
998 | if trInfo != nil { | ||
999 | trInfo.tr.Finish() | ||
1000 | } | ||
1001 | return | ||
1002 | } | ||
1003 | // Unary RPC or Streaming RPC? | ||
1004 | if md, ok := srv.md[method]; ok { | ||
1005 | s.processUnaryRPC(t, stream, srv, md, trInfo) | ||
1006 | return | ||
1007 | } | ||
1008 | if sd, ok := srv.sd[method]; ok { | ||
1009 | s.processStreamingRPC(t, stream, srv, sd, trInfo) | ||
1010 | return | ||
1011 | } | ||
1012 | if trInfo != nil { | ||
1013 | trInfo.tr.LazyLog(&fmtStringer{"Unknown method %v", []interface{}{method}}, true) | ||
1014 | trInfo.tr.SetError() | ||
1015 | } | ||
1016 | if unknownDesc := s.opts.unknownStreamDesc; unknownDesc != nil { | ||
1017 | s.processStreamingRPC(t, stream, nil, unknownDesc, trInfo) | ||
1018 | return | ||
1019 | } | ||
1020 | errDesc := fmt.Sprintf("unknown method %v", method) | ||
1021 | if err := t.WriteStatus(stream, status.New(codes.Unimplemented, errDesc)); err != nil { | ||
1022 | if trInfo != nil { | ||
1023 | trInfo.tr.LazyLog(&fmtStringer{"%v", []interface{}{err}}, true) | ||
1024 | trInfo.tr.SetError() | ||
1025 | } | ||
1026 | grpclog.Warningf("grpc: Server.handleStream failed to write status: %v", err) | ||
1027 | } | ||
1028 | if trInfo != nil { | ||
1029 | trInfo.tr.Finish() | ||
1030 | } | ||
1031 | } | ||
1032 | |||
1033 | // Stop stops the gRPC server. It immediately closes all open | ||
1034 | // connections and listeners. | ||
1035 | // It cancels all active RPCs on the server side and the corresponding | ||
1036 | // pending RPCs on the client side will get notified by connection | ||
1037 | // errors. | ||
1038 | func (s *Server) Stop() { | ||
1039 | s.mu.Lock() | ||
1040 | listeners := s.lis | ||
1041 | s.lis = nil | ||
1042 | st := s.conns | ||
1043 | s.conns = nil | ||
1044 | // interrupt GracefulStop if Stop and GracefulStop are called concurrently. | ||
1045 | s.cv.Broadcast() | ||
1046 | s.mu.Unlock() | ||
1047 | |||
1048 | for lis := range listeners { | ||
1049 | lis.Close() | ||
1050 | } | ||
1051 | for c := range st { | ||
1052 | c.Close() | ||
1053 | } | ||
1054 | |||
1055 | s.mu.Lock() | ||
1056 | s.cancel() | ||
1057 | if s.events != nil { | ||
1058 | s.events.Finish() | ||
1059 | s.events = nil | ||
1060 | } | ||
1061 | s.mu.Unlock() | ||
1062 | } | ||
1063 | |||
1064 | // GracefulStop stops the gRPC server gracefully. It stops the server from | ||
1065 | // accepting new connections and RPCs and blocks until all the pending RPCs are | ||
1066 | // finished. | ||
1067 | func (s *Server) GracefulStop() { | ||
1068 | s.mu.Lock() | ||
1069 | defer s.mu.Unlock() | ||
1070 | if s.conns == nil { | ||
1071 | return | ||
1072 | } | ||
1073 | for lis := range s.lis { | ||
1074 | lis.Close() | ||
1075 | } | ||
1076 | s.lis = nil | ||
1077 | s.cancel() | ||
1078 | if !s.drain { | ||
1079 | for c := range s.conns { | ||
1080 | c.(transport.ServerTransport).Drain() | ||
1081 | } | ||
1082 | s.drain = true | ||
1083 | } | ||
1084 | for len(s.conns) != 0 { | ||
1085 | s.cv.Wait() | ||
1086 | } | ||
1087 | s.conns = nil | ||
1088 | if s.events != nil { | ||
1089 | s.events.Finish() | ||
1090 | s.events = nil | ||
1091 | } | ||
1092 | } | ||
1093 | |||
1094 | func init() { | ||
1095 | internal.TestingCloseConns = func(arg interface{}) { | ||
1096 | arg.(*Server).testingCloseConns() | ||
1097 | } | ||
1098 | internal.TestingUseHandlerImpl = func(arg interface{}) { | ||
1099 | arg.(*Server).opts.useHandlerImpl = true | ||
1100 | } | ||
1101 | } | ||
1102 | |||
1103 | // testingCloseConns closes all existing transports but keeps s.lis | ||
1104 | // accepting new connections. | ||
1105 | func (s *Server) testingCloseConns() { | ||
1106 | s.mu.Lock() | ||
1107 | for c := range s.conns { | ||
1108 | c.Close() | ||
1109 | delete(s.conns, c) | ||
1110 | } | ||
1111 | s.mu.Unlock() | ||
1112 | } | ||
1113 | |||
1114 | // SetHeader sets the header metadata. | ||
1115 | // When called multiple times, all the provided metadata will be merged. | ||
1116 | // All the metadata will be sent out when one of the following happens: | ||
1117 | // - grpc.SendHeader() is called; | ||
1118 | // - The first response is sent out; | ||
1119 | // - An RPC status is sent out (error or success). | ||
1120 | func SetHeader(ctx context.Context, md metadata.MD) error { | ||
1121 | if md.Len() == 0 { | ||
1122 | return nil | ||
1123 | } | ||
1124 | stream, ok := transport.StreamFromContext(ctx) | ||
1125 | if !ok { | ||
1126 | return Errorf(codes.Internal, "grpc: failed to fetch the stream from the context %v", ctx) | ||
1127 | } | ||
1128 | return stream.SetHeader(md) | ||
1129 | } | ||
1130 | |||
1131 | // SendHeader sends header metadata. It may be called at most once. | ||
1132 | // The provided md and headers set by SetHeader() will be sent. | ||
1133 | func SendHeader(ctx context.Context, md metadata.MD) error { | ||
1134 | stream, ok := transport.StreamFromContext(ctx) | ||
1135 | if !ok { | ||
1136 | return Errorf(codes.Internal, "grpc: failed to fetch the stream from the context %v", ctx) | ||
1137 | } | ||
1138 | t := stream.ServerTransport() | ||
1139 | if t == nil { | ||
1140 | grpclog.Fatalf("grpc: SendHeader: %v has no ServerTransport to send header metadata.", stream) | ||
1141 | } | ||
1142 | if err := t.WriteHeader(stream, md); err != nil { | ||
1143 | return toRPCErr(err) | ||
1144 | } | ||
1145 | return nil | ||
1146 | } | ||
1147 | |||
1148 | // SetTrailer sets the trailer metadata that will be sent when an RPC returns. | ||
1149 | // When called more than once, all the provided metadata will be merged. | ||
1150 | func SetTrailer(ctx context.Context, md metadata.MD) error { | ||
1151 | if md.Len() == 0 { | ||
1152 | return nil | ||
1153 | } | ||
1154 | stream, ok := transport.StreamFromContext(ctx) | ||
1155 | if !ok { | ||
1156 | return Errorf(codes.Internal, "grpc: failed to fetch the stream from the context %v", ctx) | ||
1157 | } | ||
1158 | return stream.SetTrailer(md) | ||
1159 | } | ||