17 // Session is used to wrap a reliable ordered connection and to
18 // multiplex it into multiple streams.
20 // remoteGoAway indicates the remote side does
21 // not want futher connections. Must be first for alignment.
24 // localGoAway indicates that we should stop
25 // accepting futher connections. Must be first for alignment.
28 // nextStreamID is the next stream we should
29 // send. This depends if we are a client/server.
32 // config holds our configuration
35 // logger is used for our logs
38 // conn is the underlying connection
39 conn io.ReadWriteCloser
41 // bufRead is a buffered reader
44 // pings is used to track inflight pings
45 pings map[uint32]chan struct{}
49 // streams maps a stream id to a stream, and inflight has an entry
50 // for any outgoing stream that has not yet been established. Both are
51 // protected by streamLock.
52 streams map[uint32]*Stream
53 inflight map[uint32]struct{}
56 // synCh acts like a semaphore. It is sized to the AcceptBacklog which
57 // is assumed to be symmetric between the client and server. This allows
58 // the client to avoid exceeding the backlog and instead blocks the open.
61 // acceptCh is used to pass ready streams to the client
64 // sendCh is used to mark a stream as ready to send,
65 // or to send a header out directly.
68 // recvDoneCh is closed when recv() exits to avoid a race
69 // between stream registration and stream shutdown
70 recvDoneCh chan struct{}
72 // shutdown is used to safely close a session
75 shutdownCh chan struct{}
76 shutdownLock sync.Mutex
79 // sendReady is used to either mark a stream as ready
80 // or to directly send a header
81 type sendReady struct {
87 // newSession is used to construct a new session
88 func newSession(config *Config, conn io.ReadWriteCloser, client bool) *Session {
91 logger: log.New(config.LogOutput, "", log.LstdFlags),
93 bufRead: bufio.NewReader(conn),
94 pings: make(map[uint32]chan struct{}),
95 streams: make(map[uint32]*Stream),
96 inflight: make(map[uint32]struct{}),
97 synCh: make(chan struct{}, config.AcceptBacklog),
98 acceptCh: make(chan *Stream, config.AcceptBacklog),
99 sendCh: make(chan sendReady, 64),
100 recvDoneCh: make(chan struct{}),
101 shutdownCh: make(chan struct{}),
110 if config.EnableKeepAlive {
116 // IsClosed does a safe check to see if we have shutdown
117 func (s *Session) IsClosed() bool {
126 // NumStreams returns the number of currently open streams
127 func (s *Session) NumStreams() int {
129 num := len(s.streams)
130 s.streamLock.Unlock()
134 // Open is used to create a new stream as a net.Conn
135 func (s *Session) Open() (net.Conn, error) {
136 conn, err := s.OpenStream()
143 // OpenStream is used to create a new stream
144 func (s *Session) OpenStream() (*Stream, error) {
146 return nil, ErrSessionShutdown
148 if atomic.LoadInt32(&s.remoteGoAway) == 1 {
149 return nil, ErrRemoteGoAway
152 // Block if we have too many inflight SYNs
154 case s.synCh <- struct{}{}:
156 return nil, ErrSessionShutdown
160 // Get an ID, and check for stream exhaustion
161 id := atomic.LoadUint32(&s.nextStreamID)
162 if id >= math.MaxUint32-1 {
163 return nil, ErrStreamsExhausted
165 if !atomic.CompareAndSwapUint32(&s.nextStreamID, id, id+2) {
169 // Register the stream
170 stream := newStream(s, id, streamInit)
172 s.streams[id] = stream
173 s.inflight[id] = struct{}{}
174 s.streamLock.Unlock()
176 // Send the window update to create
177 if err := stream.sendWindowUpdate(); err != nil {
181 s.logger.Printf("[ERR] yamux: aborted stream open without inflight syn semaphore")
188 // Accept is used to block until the next available stream
189 // is ready to be accepted.
190 func (s *Session) Accept() (net.Conn, error) {
191 conn, err := s.AcceptStream()
198 // AcceptStream is used to block until the next available stream
199 // is ready to be accepted.
200 func (s *Session) AcceptStream() (*Stream, error) {
202 case stream := <-s.acceptCh:
203 if err := stream.sendWindowUpdate(); err != nil {
208 return nil, s.shutdownErr
212 // Close is used to close the session and all streams.
213 // Attempts to send a GoAway before closing the connection.
214 func (s *Session) Close() error {
215 s.shutdownLock.Lock()
216 defer s.shutdownLock.Unlock()
222 if s.shutdownErr == nil {
223 s.shutdownErr = ErrSessionShutdown
230 defer s.streamLock.Unlock()
231 for _, stream := range s.streams {
237 // exitErr is used to handle an error that is causing the
238 // session to terminate.
239 func (s *Session) exitErr(err error) {
240 s.shutdownLock.Lock()
241 if s.shutdownErr == nil {
244 s.shutdownLock.Unlock()
248 // GoAway can be used to prevent accepting further
249 // connections. It does not close the underlying conn.
250 func (s *Session) GoAway() error {
251 return s.waitForSend(s.goAway(goAwayNormal), nil)
254 // goAway is used to send a goAway message
255 func (s *Session) goAway(reason uint32) header {
256 atomic.SwapInt32(&s.localGoAway, 1)
257 hdr := header(make([]byte, headerSize))
258 hdr.encode(typeGoAway, 0, 0, reason)
262 // Ping is used to measure the RTT response time
263 func (s *Session) Ping() (time.Duration, error) {
264 // Get a channel for the ping
265 ch := make(chan struct{})
267 // Get a new ping id, mark as pending
274 // Send the ping request
275 hdr := header(make([]byte, headerSize))
276 hdr.encode(typePing, flagSYN, 0, id)
277 if err := s.waitForSend(hdr, nil); err != nil {
281 // Wait for a response
285 case <-time.After(s.config.ConnectionWriteTimeout):
287 delete(s.pings, id) // Ignore it if a response comes later.
291 return 0, ErrSessionShutdown
295 return time.Now().Sub(start), nil
298 // keepalive is a long running goroutine that periodically does
299 // a ping to keep the connection alive.
300 func (s *Session) keepalive() {
303 case <-time.After(s.config.KeepAliveInterval):
306 s.logger.Printf("[ERR] yamux: keepalive failed: %v", err)
307 s.exitErr(ErrKeepAliveTimeout)
316 // waitForSendErr waits to send a header, checking for a potential shutdown
317 func (s *Session) waitForSend(hdr header, body io.Reader) error {
318 errCh := make(chan error, 1)
319 return s.waitForSendErr(hdr, body, errCh)
322 // waitForSendErr waits to send a header with optional data, checking for a
323 // potential shutdown. Since there's the expectation that sends can happen
324 // in a timely manner, we enforce the connection write timeout here.
325 func (s *Session) waitForSendErr(hdr header, body io.Reader, errCh chan error) error {
326 timer := time.NewTimer(s.config.ConnectionWriteTimeout)
329 ready := sendReady{Hdr: hdr, Body: body, Err: errCh}
331 case s.sendCh <- ready:
333 return ErrSessionShutdown
335 return ErrConnectionWriteTimeout
342 return ErrSessionShutdown
344 return ErrConnectionWriteTimeout
348 // sendNoWait does a send without waiting. Since there's the expectation that
349 // the send happens right here, we enforce the connection write timeout if we
350 // can't queue the header to be sent.
351 func (s *Session) sendNoWait(hdr header) error {
352 timer := time.NewTimer(s.config.ConnectionWriteTimeout)
356 case s.sendCh <- sendReady{Hdr: hdr}:
359 return ErrSessionShutdown
361 return ErrConnectionWriteTimeout
365 // send is a long running goroutine that sends data
366 func (s *Session) send() {
369 case ready := <-s.sendCh:
370 // Send a header if ready
371 if ready.Hdr != nil {
373 for sent < len(ready.Hdr) {
374 n, err := s.conn.Write(ready.Hdr[sent:])
376 s.logger.Printf("[ERR] yamux: Failed to write header: %v", err)
377 asyncSendErr(ready.Err, err)
385 // Send data from a body if given
386 if ready.Body != nil {
387 _, err := io.Copy(s.conn, ready.Body)
389 s.logger.Printf("[ERR] yamux: Failed to write body: %v", err)
390 asyncSendErr(ready.Err, err)
396 // No error, successful send
397 asyncSendErr(ready.Err, nil)
404 // recv is a long running goroutine that accepts new data
405 func (s *Session) recv() {
406 if err := s.recvLoop(); err != nil {
411 // recvLoop continues to receive data until a fatal error is encountered
412 func (s *Session) recvLoop() error {
413 defer close(s.recvDoneCh)
414 hdr := header(make([]byte, headerSize))
415 var handler func(header) error
418 if _, err := io.ReadFull(s.bufRead, hdr); err != nil {
419 if err != io.EOF && !strings.Contains(err.Error(), "closed") && !strings.Contains(err.Error(), "reset by peer") {
420 s.logger.Printf("[ERR] yamux: Failed to read header: %v", err)
425 // Verify the version
426 if hdr.Version() != protoVersion {
427 s.logger.Printf("[ERR] yamux: Invalid protocol version: %d", hdr.Version())
428 return ErrInvalidVersion
431 // Switch on the type
432 switch hdr.MsgType() {
434 handler = s.handleStreamMessage
435 case typeWindowUpdate:
436 handler = s.handleStreamMessage
438 handler = s.handleGoAway
440 handler = s.handlePing
442 return ErrInvalidMsgType
445 // Invoke the handler
446 if err := handler(hdr); err != nil {
452 // handleStreamMessage handles either a data or window update frame
453 func (s *Session) handleStreamMessage(hdr header) error {
454 // Check for a new stream creation
457 if flags&flagSYN == flagSYN {
458 if err := s.incomingStream(id); err != nil {
465 stream := s.streams[id]
466 s.streamLock.Unlock()
468 // If we do not have a stream, likely we sent a RST
470 // Drain any data on the wire
471 if hdr.MsgType() == typeData && hdr.Length() > 0 {
472 s.logger.Printf("[WARN] yamux: Discarding data for stream: %d", id)
473 if _, err := io.CopyN(ioutil.Discard, s.bufRead, int64(hdr.Length())); err != nil {
474 s.logger.Printf("[ERR] yamux: Failed to discard data: %v", err)
478 s.logger.Printf("[WARN] yamux: frame for missing stream: %v", hdr)
483 // Check if this is a window update
484 if hdr.MsgType() == typeWindowUpdate {
485 if err := stream.incrSendWindow(hdr, flags); err != nil {
486 if sendErr := s.sendNoWait(s.goAway(goAwayProtoErr)); sendErr != nil {
487 s.logger.Printf("[WARN] yamux: failed to send go away: %v", sendErr)
495 if err := stream.readData(hdr, flags, s.bufRead); err != nil {
496 if sendErr := s.sendNoWait(s.goAway(goAwayProtoErr)); sendErr != nil {
497 s.logger.Printf("[WARN] yamux: failed to send go away: %v", sendErr)
504 // handlePing is invokde for a typePing frame
505 func (s *Session) handlePing(hdr header) error {
507 pingID := hdr.Length()
509 // Check if this is a query, respond back in a separate context so we
510 // don't interfere with the receiving thread blocking for the write.
511 if flags&flagSYN == flagSYN {
513 hdr := header(make([]byte, headerSize))
514 hdr.encode(typePing, flagACK, 0, pingID)
515 if err := s.sendNoWait(hdr); err != nil {
516 s.logger.Printf("[WARN] yamux: failed to send ping reply: %v", err)
524 ch := s.pings[pingID]
526 delete(s.pings, pingID)
533 // handleGoAway is invokde for a typeGoAway frame
534 func (s *Session) handleGoAway(hdr header) error {
538 atomic.SwapInt32(&s.remoteGoAway, 1)
540 s.logger.Printf("[ERR] yamux: received protocol error go away")
541 return fmt.Errorf("yamux protocol error")
542 case goAwayInternalErr:
543 s.logger.Printf("[ERR] yamux: received internal error go away")
544 return fmt.Errorf("remote yamux internal error")
546 s.logger.Printf("[ERR] yamux: received unexpected go away")
547 return fmt.Errorf("unexpected go away received")
552 // incomingStream is used to create a new incoming stream
553 func (s *Session) incomingStream(id uint32) error {
554 // Reject immediately if we are doing a go away
555 if atomic.LoadInt32(&s.localGoAway) == 1 {
556 hdr := header(make([]byte, headerSize))
557 hdr.encode(typeWindowUpdate, flagRST, id, 0)
558 return s.sendNoWait(hdr)
561 // Allocate a new stream
562 stream := newStream(s, id, streamSYNReceived)
565 defer s.streamLock.Unlock()
567 // Check if stream already exists
568 if _, ok := s.streams[id]; ok {
569 s.logger.Printf("[ERR] yamux: duplicate stream declared")
570 if sendErr := s.sendNoWait(s.goAway(goAwayProtoErr)); sendErr != nil {
571 s.logger.Printf("[WARN] yamux: failed to send go away: %v", sendErr)
573 return ErrDuplicateStream
576 // Register the stream
577 s.streams[id] = stream
579 // Check if we've exceeded the backlog
581 case s.acceptCh <- stream:
584 // Backlog exceeded! RST the stream
585 s.logger.Printf("[WARN] yamux: backlog exceeded, forcing connection reset")
586 delete(s.streams, id)
587 stream.sendHdr.encode(typeWindowUpdate, flagRST, id, 0)
588 return s.sendNoWait(stream.sendHdr)
592 // closeStream is used to close a stream once both sides have
593 // issued a close. If there was an in-flight SYN and the stream
594 // was not yet established, then this will give the credit back.
595 func (s *Session) closeStream(id uint32) {
597 if _, ok := s.inflight[id]; ok {
601 s.logger.Printf("[ERR] yamux: SYN tracking out of sync")
604 delete(s.streams, id)
605 s.streamLock.Unlock()
608 // establishStream is used to mark a stream that was in the
609 // SYN Sent state as established.
610 func (s *Session) establishStream(id uint32) {
612 if _, ok := s.inflight[id]; ok {
613 delete(s.inflight, id)
615 s.logger.Printf("[ERR] yamux: established stream without inflight SYN (no tracking entry)")
620 s.logger.Printf("[ERR] yamux: established stream without inflight SYN (didn't have semaphore)")
622 s.streamLock.Unlock()