]>
Commit | Line | Data |
---|---|---|
bae9f6d2 JC |
1 | package yamux |
2 | ||
3 | import ( | |
4 | "bufio" | |
5 | "fmt" | |
6 | "io" | |
7 | "io/ioutil" | |
8 | "log" | |
9 | "math" | |
10 | "net" | |
11 | "strings" | |
12 | "sync" | |
13 | "sync/atomic" | |
14 | "time" | |
15 | ) | |
16 | ||
17 | // Session is used to wrap a reliable ordered connection and to | |
18 | // multiplex it into multiple streams. | |
19 | type Session struct { | |
20 | // remoteGoAway indicates the remote side does | |
21 | // not want futher connections. Must be first for alignment. | |
22 | remoteGoAway int32 | |
23 | ||
24 | // localGoAway indicates that we should stop | |
25 | // accepting futher connections. Must be first for alignment. | |
26 | localGoAway int32 | |
27 | ||
28 | // nextStreamID is the next stream we should | |
29 | // send. This depends if we are a client/server. | |
30 | nextStreamID uint32 | |
31 | ||
32 | // config holds our configuration | |
33 | config *Config | |
34 | ||
35 | // logger is used for our logs | |
36 | logger *log.Logger | |
37 | ||
38 | // conn is the underlying connection | |
39 | conn io.ReadWriteCloser | |
40 | ||
41 | // bufRead is a buffered reader | |
42 | bufRead *bufio.Reader | |
43 | ||
44 | // pings is used to track inflight pings | |
45 | pings map[uint32]chan struct{} | |
46 | pingID uint32 | |
47 | pingLock sync.Mutex | |
48 | ||
49 | // streams maps a stream id to a stream, and inflight has an entry | |
50 | // for any outgoing stream that has not yet been established. Both are | |
51 | // protected by streamLock. | |
52 | streams map[uint32]*Stream | |
53 | inflight map[uint32]struct{} | |
54 | streamLock sync.Mutex | |
55 | ||
56 | // synCh acts like a semaphore. It is sized to the AcceptBacklog which | |
57 | // is assumed to be symmetric between the client and server. This allows | |
58 | // the client to avoid exceeding the backlog and instead blocks the open. | |
59 | synCh chan struct{} | |
60 | ||
61 | // acceptCh is used to pass ready streams to the client | |
62 | acceptCh chan *Stream | |
63 | ||
64 | // sendCh is used to mark a stream as ready to send, | |
65 | // or to send a header out directly. | |
66 | sendCh chan sendReady | |
67 | ||
68 | // recvDoneCh is closed when recv() exits to avoid a race | |
69 | // between stream registration and stream shutdown | |
70 | recvDoneCh chan struct{} | |
71 | ||
72 | // shutdown is used to safely close a session | |
73 | shutdown bool | |
74 | shutdownErr error | |
75 | shutdownCh chan struct{} | |
76 | shutdownLock sync.Mutex | |
77 | } | |
78 | ||
79 | // sendReady is used to either mark a stream as ready | |
80 | // or to directly send a header | |
81 | type sendReady struct { | |
82 | Hdr []byte | |
83 | Body io.Reader | |
84 | Err chan error | |
85 | } | |
86 | ||
87 | // newSession is used to construct a new session | |
88 | func newSession(config *Config, conn io.ReadWriteCloser, client bool) *Session { | |
89 | s := &Session{ | |
90 | config: config, | |
91 | logger: log.New(config.LogOutput, "", log.LstdFlags), | |
92 | conn: conn, | |
93 | bufRead: bufio.NewReader(conn), | |
94 | pings: make(map[uint32]chan struct{}), | |
95 | streams: make(map[uint32]*Stream), | |
96 | inflight: make(map[uint32]struct{}), | |
97 | synCh: make(chan struct{}, config.AcceptBacklog), | |
98 | acceptCh: make(chan *Stream, config.AcceptBacklog), | |
99 | sendCh: make(chan sendReady, 64), | |
100 | recvDoneCh: make(chan struct{}), | |
101 | shutdownCh: make(chan struct{}), | |
102 | } | |
103 | if client { | |
104 | s.nextStreamID = 1 | |
105 | } else { | |
106 | s.nextStreamID = 2 | |
107 | } | |
108 | go s.recv() | |
109 | go s.send() | |
110 | if config.EnableKeepAlive { | |
111 | go s.keepalive() | |
112 | } | |
113 | return s | |
114 | } | |
115 | ||
116 | // IsClosed does a safe check to see if we have shutdown | |
117 | func (s *Session) IsClosed() bool { | |
118 | select { | |
119 | case <-s.shutdownCh: | |
120 | return true | |
121 | default: | |
122 | return false | |
123 | } | |
124 | } | |
125 | ||
126 | // NumStreams returns the number of currently open streams | |
127 | func (s *Session) NumStreams() int { | |
128 | s.streamLock.Lock() | |
129 | num := len(s.streams) | |
130 | s.streamLock.Unlock() | |
131 | return num | |
132 | } | |
133 | ||
134 | // Open is used to create a new stream as a net.Conn | |
135 | func (s *Session) Open() (net.Conn, error) { | |
136 | conn, err := s.OpenStream() | |
137 | if err != nil { | |
138 | return nil, err | |
139 | } | |
140 | return conn, nil | |
141 | } | |
142 | ||
143 | // OpenStream is used to create a new stream | |
144 | func (s *Session) OpenStream() (*Stream, error) { | |
145 | if s.IsClosed() { | |
146 | return nil, ErrSessionShutdown | |
147 | } | |
148 | if atomic.LoadInt32(&s.remoteGoAway) == 1 { | |
149 | return nil, ErrRemoteGoAway | |
150 | } | |
151 | ||
152 | // Block if we have too many inflight SYNs | |
153 | select { | |
154 | case s.synCh <- struct{}{}: | |
155 | case <-s.shutdownCh: | |
156 | return nil, ErrSessionShutdown | |
157 | } | |
158 | ||
159 | GET_ID: | |
160 | // Get an ID, and check for stream exhaustion | |
161 | id := atomic.LoadUint32(&s.nextStreamID) | |
162 | if id >= math.MaxUint32-1 { | |
163 | return nil, ErrStreamsExhausted | |
164 | } | |
165 | if !atomic.CompareAndSwapUint32(&s.nextStreamID, id, id+2) { | |
166 | goto GET_ID | |
167 | } | |
168 | ||
169 | // Register the stream | |
170 | stream := newStream(s, id, streamInit) | |
171 | s.streamLock.Lock() | |
172 | s.streams[id] = stream | |
173 | s.inflight[id] = struct{}{} | |
174 | s.streamLock.Unlock() | |
175 | ||
176 | // Send the window update to create | |
177 | if err := stream.sendWindowUpdate(); err != nil { | |
178 | select { | |
179 | case <-s.synCh: | |
180 | default: | |
181 | s.logger.Printf("[ERR] yamux: aborted stream open without inflight syn semaphore") | |
182 | } | |
183 | return nil, err | |
184 | } | |
185 | return stream, nil | |
186 | } | |
187 | ||
188 | // Accept is used to block until the next available stream | |
189 | // is ready to be accepted. | |
190 | func (s *Session) Accept() (net.Conn, error) { | |
191 | conn, err := s.AcceptStream() | |
192 | if err != nil { | |
193 | return nil, err | |
194 | } | |
195 | return conn, err | |
196 | } | |
197 | ||
198 | // AcceptStream is used to block until the next available stream | |
199 | // is ready to be accepted. | |
200 | func (s *Session) AcceptStream() (*Stream, error) { | |
201 | select { | |
202 | case stream := <-s.acceptCh: | |
203 | if err := stream.sendWindowUpdate(); err != nil { | |
204 | return nil, err | |
205 | } | |
206 | return stream, nil | |
207 | case <-s.shutdownCh: | |
208 | return nil, s.shutdownErr | |
209 | } | |
210 | } | |
211 | ||
212 | // Close is used to close the session and all streams. | |
213 | // Attempts to send a GoAway before closing the connection. | |
214 | func (s *Session) Close() error { | |
215 | s.shutdownLock.Lock() | |
216 | defer s.shutdownLock.Unlock() | |
217 | ||
218 | if s.shutdown { | |
219 | return nil | |
220 | } | |
221 | s.shutdown = true | |
222 | if s.shutdownErr == nil { | |
223 | s.shutdownErr = ErrSessionShutdown | |
224 | } | |
225 | close(s.shutdownCh) | |
226 | s.conn.Close() | |
227 | <-s.recvDoneCh | |
228 | ||
229 | s.streamLock.Lock() | |
230 | defer s.streamLock.Unlock() | |
231 | for _, stream := range s.streams { | |
232 | stream.forceClose() | |
233 | } | |
234 | return nil | |
235 | } | |
236 | ||
237 | // exitErr is used to handle an error that is causing the | |
238 | // session to terminate. | |
239 | func (s *Session) exitErr(err error) { | |
240 | s.shutdownLock.Lock() | |
241 | if s.shutdownErr == nil { | |
242 | s.shutdownErr = err | |
243 | } | |
244 | s.shutdownLock.Unlock() | |
245 | s.Close() | |
246 | } | |
247 | ||
248 | // GoAway can be used to prevent accepting further | |
249 | // connections. It does not close the underlying conn. | |
250 | func (s *Session) GoAway() error { | |
251 | return s.waitForSend(s.goAway(goAwayNormal), nil) | |
252 | } | |
253 | ||
254 | // goAway is used to send a goAway message | |
255 | func (s *Session) goAway(reason uint32) header { | |
256 | atomic.SwapInt32(&s.localGoAway, 1) | |
257 | hdr := header(make([]byte, headerSize)) | |
258 | hdr.encode(typeGoAway, 0, 0, reason) | |
259 | return hdr | |
260 | } | |
261 | ||
262 | // Ping is used to measure the RTT response time | |
263 | func (s *Session) Ping() (time.Duration, error) { | |
264 | // Get a channel for the ping | |
265 | ch := make(chan struct{}) | |
266 | ||
267 | // Get a new ping id, mark as pending | |
268 | s.pingLock.Lock() | |
269 | id := s.pingID | |
270 | s.pingID++ | |
271 | s.pings[id] = ch | |
272 | s.pingLock.Unlock() | |
273 | ||
274 | // Send the ping request | |
275 | hdr := header(make([]byte, headerSize)) | |
276 | hdr.encode(typePing, flagSYN, 0, id) | |
277 | if err := s.waitForSend(hdr, nil); err != nil { | |
278 | return 0, err | |
279 | } | |
280 | ||
281 | // Wait for a response | |
282 | start := time.Now() | |
283 | select { | |
284 | case <-ch: | |
285 | case <-time.After(s.config.ConnectionWriteTimeout): | |
286 | s.pingLock.Lock() | |
287 | delete(s.pings, id) // Ignore it if a response comes later. | |
288 | s.pingLock.Unlock() | |
289 | return 0, ErrTimeout | |
290 | case <-s.shutdownCh: | |
291 | return 0, ErrSessionShutdown | |
292 | } | |
293 | ||
294 | // Compute the RTT | |
295 | return time.Now().Sub(start), nil | |
296 | } | |
297 | ||
298 | // keepalive is a long running goroutine that periodically does | |
299 | // a ping to keep the connection alive. | |
300 | func (s *Session) keepalive() { | |
301 | for { | |
302 | select { | |
303 | case <-time.After(s.config.KeepAliveInterval): | |
304 | _, err := s.Ping() | |
305 | if err != nil { | |
306 | s.logger.Printf("[ERR] yamux: keepalive failed: %v", err) | |
307 | s.exitErr(ErrKeepAliveTimeout) | |
308 | return | |
309 | } | |
310 | case <-s.shutdownCh: | |
311 | return | |
312 | } | |
313 | } | |
314 | } | |
315 | ||
316 | // waitForSendErr waits to send a header, checking for a potential shutdown | |
317 | func (s *Session) waitForSend(hdr header, body io.Reader) error { | |
318 | errCh := make(chan error, 1) | |
319 | return s.waitForSendErr(hdr, body, errCh) | |
320 | } | |
321 | ||
322 | // waitForSendErr waits to send a header with optional data, checking for a | |
323 | // potential shutdown. Since there's the expectation that sends can happen | |
324 | // in a timely manner, we enforce the connection write timeout here. | |
325 | func (s *Session) waitForSendErr(hdr header, body io.Reader, errCh chan error) error { | |
326 | timer := time.NewTimer(s.config.ConnectionWriteTimeout) | |
327 | defer timer.Stop() | |
328 | ||
329 | ready := sendReady{Hdr: hdr, Body: body, Err: errCh} | |
330 | select { | |
331 | case s.sendCh <- ready: | |
332 | case <-s.shutdownCh: | |
333 | return ErrSessionShutdown | |
334 | case <-timer.C: | |
335 | return ErrConnectionWriteTimeout | |
336 | } | |
337 | ||
338 | select { | |
339 | case err := <-errCh: | |
340 | return err | |
341 | case <-s.shutdownCh: | |
342 | return ErrSessionShutdown | |
343 | case <-timer.C: | |
344 | return ErrConnectionWriteTimeout | |
345 | } | |
346 | } | |
347 | ||
348 | // sendNoWait does a send without waiting. Since there's the expectation that | |
349 | // the send happens right here, we enforce the connection write timeout if we | |
350 | // can't queue the header to be sent. | |
351 | func (s *Session) sendNoWait(hdr header) error { | |
352 | timer := time.NewTimer(s.config.ConnectionWriteTimeout) | |
353 | defer timer.Stop() | |
354 | ||
355 | select { | |
356 | case s.sendCh <- sendReady{Hdr: hdr}: | |
357 | return nil | |
358 | case <-s.shutdownCh: | |
359 | return ErrSessionShutdown | |
360 | case <-timer.C: | |
361 | return ErrConnectionWriteTimeout | |
362 | } | |
363 | } | |
364 | ||
365 | // send is a long running goroutine that sends data | |
366 | func (s *Session) send() { | |
367 | for { | |
368 | select { | |
369 | case ready := <-s.sendCh: | |
370 | // Send a header if ready | |
371 | if ready.Hdr != nil { | |
372 | sent := 0 | |
373 | for sent < len(ready.Hdr) { | |
374 | n, err := s.conn.Write(ready.Hdr[sent:]) | |
375 | if err != nil { | |
376 | s.logger.Printf("[ERR] yamux: Failed to write header: %v", err) | |
377 | asyncSendErr(ready.Err, err) | |
378 | s.exitErr(err) | |
379 | return | |
380 | } | |
381 | sent += n | |
382 | } | |
383 | } | |
384 | ||
385 | // Send data from a body if given | |
386 | if ready.Body != nil { | |
387 | _, err := io.Copy(s.conn, ready.Body) | |
388 | if err != nil { | |
389 | s.logger.Printf("[ERR] yamux: Failed to write body: %v", err) | |
390 | asyncSendErr(ready.Err, err) | |
391 | s.exitErr(err) | |
392 | return | |
393 | } | |
394 | } | |
395 | ||
396 | // No error, successful send | |
397 | asyncSendErr(ready.Err, nil) | |
398 | case <-s.shutdownCh: | |
399 | return | |
400 | } | |
401 | } | |
402 | } | |
403 | ||
404 | // recv is a long running goroutine that accepts new data | |
405 | func (s *Session) recv() { | |
406 | if err := s.recvLoop(); err != nil { | |
407 | s.exitErr(err) | |
408 | } | |
409 | } | |
410 | ||
411 | // recvLoop continues to receive data until a fatal error is encountered | |
412 | func (s *Session) recvLoop() error { | |
413 | defer close(s.recvDoneCh) | |
414 | hdr := header(make([]byte, headerSize)) | |
415 | var handler func(header) error | |
416 | for { | |
417 | // Read the header | |
418 | if _, err := io.ReadFull(s.bufRead, hdr); err != nil { | |
419 | if err != io.EOF && !strings.Contains(err.Error(), "closed") && !strings.Contains(err.Error(), "reset by peer") { | |
420 | s.logger.Printf("[ERR] yamux: Failed to read header: %v", err) | |
421 | } | |
422 | return err | |
423 | } | |
424 | ||
425 | // Verify the version | |
426 | if hdr.Version() != protoVersion { | |
427 | s.logger.Printf("[ERR] yamux: Invalid protocol version: %d", hdr.Version()) | |
428 | return ErrInvalidVersion | |
429 | } | |
430 | ||
431 | // Switch on the type | |
432 | switch hdr.MsgType() { | |
433 | case typeData: | |
434 | handler = s.handleStreamMessage | |
435 | case typeWindowUpdate: | |
436 | handler = s.handleStreamMessage | |
437 | case typeGoAway: | |
438 | handler = s.handleGoAway | |
439 | case typePing: | |
440 | handler = s.handlePing | |
441 | default: | |
442 | return ErrInvalidMsgType | |
443 | } | |
444 | ||
445 | // Invoke the handler | |
446 | if err := handler(hdr); err != nil { | |
447 | return err | |
448 | } | |
449 | } | |
450 | } | |
451 | ||
452 | // handleStreamMessage handles either a data or window update frame | |
453 | func (s *Session) handleStreamMessage(hdr header) error { | |
454 | // Check for a new stream creation | |
455 | id := hdr.StreamID() | |
456 | flags := hdr.Flags() | |
457 | if flags&flagSYN == flagSYN { | |
458 | if err := s.incomingStream(id); err != nil { | |
459 | return err | |
460 | } | |
461 | } | |
462 | ||
463 | // Get the stream | |
464 | s.streamLock.Lock() | |
465 | stream := s.streams[id] | |
466 | s.streamLock.Unlock() | |
467 | ||
468 | // If we do not have a stream, likely we sent a RST | |
469 | if stream == nil { | |
470 | // Drain any data on the wire | |
471 | if hdr.MsgType() == typeData && hdr.Length() > 0 { | |
472 | s.logger.Printf("[WARN] yamux: Discarding data for stream: %d", id) | |
473 | if _, err := io.CopyN(ioutil.Discard, s.bufRead, int64(hdr.Length())); err != nil { | |
474 | s.logger.Printf("[ERR] yamux: Failed to discard data: %v", err) | |
475 | return nil | |
476 | } | |
477 | } else { | |
478 | s.logger.Printf("[WARN] yamux: frame for missing stream: %v", hdr) | |
479 | } | |
480 | return nil | |
481 | } | |
482 | ||
483 | // Check if this is a window update | |
484 | if hdr.MsgType() == typeWindowUpdate { | |
485 | if err := stream.incrSendWindow(hdr, flags); err != nil { | |
486 | if sendErr := s.sendNoWait(s.goAway(goAwayProtoErr)); sendErr != nil { | |
487 | s.logger.Printf("[WARN] yamux: failed to send go away: %v", sendErr) | |
488 | } | |
489 | return err | |
490 | } | |
491 | return nil | |
492 | } | |
493 | ||
494 | // Read the new data | |
495 | if err := stream.readData(hdr, flags, s.bufRead); err != nil { | |
496 | if sendErr := s.sendNoWait(s.goAway(goAwayProtoErr)); sendErr != nil { | |
497 | s.logger.Printf("[WARN] yamux: failed to send go away: %v", sendErr) | |
498 | } | |
499 | return err | |
500 | } | |
501 | return nil | |
502 | } | |
503 | ||
504 | // handlePing is invokde for a typePing frame | |
505 | func (s *Session) handlePing(hdr header) error { | |
506 | flags := hdr.Flags() | |
507 | pingID := hdr.Length() | |
508 | ||
509 | // Check if this is a query, respond back in a separate context so we | |
510 | // don't interfere with the receiving thread blocking for the write. | |
511 | if flags&flagSYN == flagSYN { | |
512 | go func() { | |
513 | hdr := header(make([]byte, headerSize)) | |
514 | hdr.encode(typePing, flagACK, 0, pingID) | |
515 | if err := s.sendNoWait(hdr); err != nil { | |
516 | s.logger.Printf("[WARN] yamux: failed to send ping reply: %v", err) | |
517 | } | |
518 | }() | |
519 | return nil | |
520 | } | |
521 | ||
522 | // Handle a response | |
523 | s.pingLock.Lock() | |
524 | ch := s.pings[pingID] | |
525 | if ch != nil { | |
526 | delete(s.pings, pingID) | |
527 | close(ch) | |
528 | } | |
529 | s.pingLock.Unlock() | |
530 | return nil | |
531 | } | |
532 | ||
533 | // handleGoAway is invokde for a typeGoAway frame | |
534 | func (s *Session) handleGoAway(hdr header) error { | |
535 | code := hdr.Length() | |
536 | switch code { | |
537 | case goAwayNormal: | |
538 | atomic.SwapInt32(&s.remoteGoAway, 1) | |
539 | case goAwayProtoErr: | |
540 | s.logger.Printf("[ERR] yamux: received protocol error go away") | |
541 | return fmt.Errorf("yamux protocol error") | |
542 | case goAwayInternalErr: | |
543 | s.logger.Printf("[ERR] yamux: received internal error go away") | |
544 | return fmt.Errorf("remote yamux internal error") | |
545 | default: | |
546 | s.logger.Printf("[ERR] yamux: received unexpected go away") | |
547 | return fmt.Errorf("unexpected go away received") | |
548 | } | |
549 | return nil | |
550 | } | |
551 | ||
552 | // incomingStream is used to create a new incoming stream | |
553 | func (s *Session) incomingStream(id uint32) error { | |
554 | // Reject immediately if we are doing a go away | |
555 | if atomic.LoadInt32(&s.localGoAway) == 1 { | |
556 | hdr := header(make([]byte, headerSize)) | |
557 | hdr.encode(typeWindowUpdate, flagRST, id, 0) | |
558 | return s.sendNoWait(hdr) | |
559 | } | |
560 | ||
561 | // Allocate a new stream | |
562 | stream := newStream(s, id, streamSYNReceived) | |
563 | ||
564 | s.streamLock.Lock() | |
565 | defer s.streamLock.Unlock() | |
566 | ||
567 | // Check if stream already exists | |
568 | if _, ok := s.streams[id]; ok { | |
569 | s.logger.Printf("[ERR] yamux: duplicate stream declared") | |
570 | if sendErr := s.sendNoWait(s.goAway(goAwayProtoErr)); sendErr != nil { | |
571 | s.logger.Printf("[WARN] yamux: failed to send go away: %v", sendErr) | |
572 | } | |
573 | return ErrDuplicateStream | |
574 | } | |
575 | ||
576 | // Register the stream | |
577 | s.streams[id] = stream | |
578 | ||
579 | // Check if we've exceeded the backlog | |
580 | select { | |
581 | case s.acceptCh <- stream: | |
582 | return nil | |
583 | default: | |
584 | // Backlog exceeded! RST the stream | |
585 | s.logger.Printf("[WARN] yamux: backlog exceeded, forcing connection reset") | |
586 | delete(s.streams, id) | |
587 | stream.sendHdr.encode(typeWindowUpdate, flagRST, id, 0) | |
588 | return s.sendNoWait(stream.sendHdr) | |
589 | } | |
590 | } | |
591 | ||
592 | // closeStream is used to close a stream once both sides have | |
593 | // issued a close. If there was an in-flight SYN and the stream | |
594 | // was not yet established, then this will give the credit back. | |
595 | func (s *Session) closeStream(id uint32) { | |
596 | s.streamLock.Lock() | |
597 | if _, ok := s.inflight[id]; ok { | |
598 | select { | |
599 | case <-s.synCh: | |
600 | default: | |
601 | s.logger.Printf("[ERR] yamux: SYN tracking out of sync") | |
602 | } | |
603 | } | |
604 | delete(s.streams, id) | |
605 | s.streamLock.Unlock() | |
606 | } | |
607 | ||
608 | // establishStream is used to mark a stream that was in the | |
609 | // SYN Sent state as established. | |
610 | func (s *Session) establishStream(id uint32) { | |
611 | s.streamLock.Lock() | |
612 | if _, ok := s.inflight[id]; ok { | |
613 | delete(s.inflight, id) | |
614 | } else { | |
615 | s.logger.Printf("[ERR] yamux: established stream without inflight SYN (no tracking entry)") | |
616 | } | |
617 | select { | |
618 | case <-s.synCh: | |
619 | default: | |
620 | s.logger.Printf("[ERR] yamux: established stream without inflight SYN (didn't have semaphore)") | |
621 | } | |
622 | s.streamLock.Unlock() | |
623 | } |