]>
Commit | Line | Data |
---|---|---|
bae9f6d2 JC |
1 | package yamux |
2 | ||
3 | import ( | |
4 | "bytes" | |
5 | "io" | |
6 | "sync" | |
7 | "sync/atomic" | |
8 | "time" | |
9 | ) | |
10 | ||
11 | type streamState int | |
12 | ||
13 | const ( | |
14 | streamInit streamState = iota | |
15 | streamSYNSent | |
16 | streamSYNReceived | |
17 | streamEstablished | |
18 | streamLocalClose | |
19 | streamRemoteClose | |
20 | streamClosed | |
21 | streamReset | |
22 | ) | |
23 | ||
24 | // Stream is used to represent a logical stream | |
25 | // within a session. | |
26 | type Stream struct { | |
27 | recvWindow uint32 | |
28 | sendWindow uint32 | |
29 | ||
30 | id uint32 | |
31 | session *Session | |
32 | ||
33 | state streamState | |
34 | stateLock sync.Mutex | |
35 | ||
36 | recvBuf *bytes.Buffer | |
37 | recvLock sync.Mutex | |
38 | ||
39 | controlHdr header | |
40 | controlErr chan error | |
41 | controlHdrLock sync.Mutex | |
42 | ||
43 | sendHdr header | |
44 | sendErr chan error | |
45 | sendLock sync.Mutex | |
46 | ||
47 | recvNotifyCh chan struct{} | |
48 | sendNotifyCh chan struct{} | |
49 | ||
50 | readDeadline time.Time | |
51 | writeDeadline time.Time | |
52 | } | |
53 | ||
54 | // newStream is used to construct a new stream within | |
55 | // a given session for an ID | |
56 | func newStream(session *Session, id uint32, state streamState) *Stream { | |
57 | s := &Stream{ | |
58 | id: id, | |
59 | session: session, | |
60 | state: state, | |
61 | controlHdr: header(make([]byte, headerSize)), | |
62 | controlErr: make(chan error, 1), | |
63 | sendHdr: header(make([]byte, headerSize)), | |
64 | sendErr: make(chan error, 1), | |
65 | recvWindow: initialStreamWindow, | |
66 | sendWindow: initialStreamWindow, | |
67 | recvNotifyCh: make(chan struct{}, 1), | |
68 | sendNotifyCh: make(chan struct{}, 1), | |
69 | } | |
70 | return s | |
71 | } | |
72 | ||
73 | // Session returns the associated stream session | |
74 | func (s *Stream) Session() *Session { | |
75 | return s.session | |
76 | } | |
77 | ||
78 | // StreamID returns the ID of this stream | |
79 | func (s *Stream) StreamID() uint32 { | |
80 | return s.id | |
81 | } | |
82 | ||
83 | // Read is used to read from the stream | |
84 | func (s *Stream) Read(b []byte) (n int, err error) { | |
85 | defer asyncNotify(s.recvNotifyCh) | |
86 | START: | |
87 | s.stateLock.Lock() | |
88 | switch s.state { | |
89 | case streamLocalClose: | |
90 | fallthrough | |
91 | case streamRemoteClose: | |
92 | fallthrough | |
93 | case streamClosed: | |
94 | s.recvLock.Lock() | |
95 | if s.recvBuf == nil || s.recvBuf.Len() == 0 { | |
96 | s.recvLock.Unlock() | |
97 | s.stateLock.Unlock() | |
98 | return 0, io.EOF | |
99 | } | |
100 | s.recvLock.Unlock() | |
101 | case streamReset: | |
102 | s.stateLock.Unlock() | |
103 | return 0, ErrConnectionReset | |
104 | } | |
105 | s.stateLock.Unlock() | |
106 | ||
107 | // If there is no data available, block | |
108 | s.recvLock.Lock() | |
109 | if s.recvBuf == nil || s.recvBuf.Len() == 0 { | |
110 | s.recvLock.Unlock() | |
111 | goto WAIT | |
112 | } | |
113 | ||
114 | // Read any bytes | |
115 | n, _ = s.recvBuf.Read(b) | |
116 | s.recvLock.Unlock() | |
117 | ||
118 | // Send a window update potentially | |
119 | err = s.sendWindowUpdate() | |
120 | return n, err | |
121 | ||
122 | WAIT: | |
123 | var timeout <-chan time.Time | |
124 | var timer *time.Timer | |
125 | if !s.readDeadline.IsZero() { | |
126 | delay := s.readDeadline.Sub(time.Now()) | |
127 | timer = time.NewTimer(delay) | |
128 | timeout = timer.C | |
129 | } | |
130 | select { | |
131 | case <-s.recvNotifyCh: | |
132 | if timer != nil { | |
133 | timer.Stop() | |
134 | } | |
135 | goto START | |
136 | case <-timeout: | |
137 | return 0, ErrTimeout | |
138 | } | |
139 | } | |
140 | ||
141 | // Write is used to write to the stream | |
142 | func (s *Stream) Write(b []byte) (n int, err error) { | |
143 | s.sendLock.Lock() | |
144 | defer s.sendLock.Unlock() | |
145 | total := 0 | |
146 | for total < len(b) { | |
147 | n, err := s.write(b[total:]) | |
148 | total += n | |
149 | if err != nil { | |
150 | return total, err | |
151 | } | |
152 | } | |
153 | return total, nil | |
154 | } | |
155 | ||
156 | // write is used to write to the stream, may return on | |
157 | // a short write. | |
158 | func (s *Stream) write(b []byte) (n int, err error) { | |
159 | var flags uint16 | |
160 | var max uint32 | |
161 | var body io.Reader | |
162 | START: | |
163 | s.stateLock.Lock() | |
164 | switch s.state { | |
165 | case streamLocalClose: | |
166 | fallthrough | |
167 | case streamClosed: | |
168 | s.stateLock.Unlock() | |
169 | return 0, ErrStreamClosed | |
170 | case streamReset: | |
171 | s.stateLock.Unlock() | |
172 | return 0, ErrConnectionReset | |
173 | } | |
174 | s.stateLock.Unlock() | |
175 | ||
176 | // If there is no data available, block | |
177 | window := atomic.LoadUint32(&s.sendWindow) | |
178 | if window == 0 { | |
179 | goto WAIT | |
180 | } | |
181 | ||
182 | // Determine the flags if any | |
183 | flags = s.sendFlags() | |
184 | ||
185 | // Send up to our send window | |
186 | max = min(window, uint32(len(b))) | |
187 | body = bytes.NewReader(b[:max]) | |
188 | ||
189 | // Send the header | |
190 | s.sendHdr.encode(typeData, flags, s.id, max) | |
191 | if err := s.session.waitForSendErr(s.sendHdr, body, s.sendErr); err != nil { | |
192 | return 0, err | |
193 | } | |
194 | ||
195 | // Reduce our send window | |
196 | atomic.AddUint32(&s.sendWindow, ^uint32(max-1)) | |
197 | ||
198 | // Unlock | |
199 | return int(max), err | |
200 | ||
201 | WAIT: | |
202 | var timeout <-chan time.Time | |
203 | if !s.writeDeadline.IsZero() { | |
204 | delay := s.writeDeadline.Sub(time.Now()) | |
205 | timeout = time.After(delay) | |
206 | } | |
207 | select { | |
208 | case <-s.sendNotifyCh: | |
209 | goto START | |
210 | case <-timeout: | |
211 | return 0, ErrTimeout | |
212 | } | |
213 | return 0, nil | |
214 | } | |
215 | ||
216 | // sendFlags determines any flags that are appropriate | |
217 | // based on the current stream state | |
218 | func (s *Stream) sendFlags() uint16 { | |
219 | s.stateLock.Lock() | |
220 | defer s.stateLock.Unlock() | |
221 | var flags uint16 | |
222 | switch s.state { | |
223 | case streamInit: | |
224 | flags |= flagSYN | |
225 | s.state = streamSYNSent | |
226 | case streamSYNReceived: | |
227 | flags |= flagACK | |
228 | s.state = streamEstablished | |
229 | } | |
230 | return flags | |
231 | } | |
232 | ||
233 | // sendWindowUpdate potentially sends a window update enabling | |
234 | // further writes to take place. Must be invoked with the lock. | |
235 | func (s *Stream) sendWindowUpdate() error { | |
236 | s.controlHdrLock.Lock() | |
237 | defer s.controlHdrLock.Unlock() | |
238 | ||
239 | // Determine the delta update | |
240 | max := s.session.config.MaxStreamWindowSize | |
241 | delta := max - atomic.LoadUint32(&s.recvWindow) | |
242 | ||
243 | // Determine the flags if any | |
244 | flags := s.sendFlags() | |
245 | ||
246 | // Check if we can omit the update | |
247 | if delta < (max/2) && flags == 0 { | |
248 | return nil | |
249 | } | |
250 | ||
251 | // Update our window | |
252 | atomic.AddUint32(&s.recvWindow, delta) | |
253 | ||
254 | // Send the header | |
255 | s.controlHdr.encode(typeWindowUpdate, flags, s.id, delta) | |
256 | if err := s.session.waitForSendErr(s.controlHdr, nil, s.controlErr); err != nil { | |
257 | return err | |
258 | } | |
259 | return nil | |
260 | } | |
261 | ||
262 | // sendClose is used to send a FIN | |
263 | func (s *Stream) sendClose() error { | |
264 | s.controlHdrLock.Lock() | |
265 | defer s.controlHdrLock.Unlock() | |
266 | ||
267 | flags := s.sendFlags() | |
268 | flags |= flagFIN | |
269 | s.controlHdr.encode(typeWindowUpdate, flags, s.id, 0) | |
270 | if err := s.session.waitForSendErr(s.controlHdr, nil, s.controlErr); err != nil { | |
271 | return err | |
272 | } | |
273 | return nil | |
274 | } | |
275 | ||
276 | // Close is used to close the stream | |
277 | func (s *Stream) Close() error { | |
278 | closeStream := false | |
279 | s.stateLock.Lock() | |
280 | switch s.state { | |
281 | // Opened means we need to signal a close | |
282 | case streamSYNSent: | |
283 | fallthrough | |
284 | case streamSYNReceived: | |
285 | fallthrough | |
286 | case streamEstablished: | |
287 | s.state = streamLocalClose | |
288 | goto SEND_CLOSE | |
289 | ||
290 | case streamLocalClose: | |
291 | case streamRemoteClose: | |
292 | s.state = streamClosed | |
293 | closeStream = true | |
294 | goto SEND_CLOSE | |
295 | ||
296 | case streamClosed: | |
297 | case streamReset: | |
298 | default: | |
299 | panic("unhandled state") | |
300 | } | |
301 | s.stateLock.Unlock() | |
302 | return nil | |
303 | SEND_CLOSE: | |
304 | s.stateLock.Unlock() | |
305 | s.sendClose() | |
306 | s.notifyWaiting() | |
307 | if closeStream { | |
308 | s.session.closeStream(s.id) | |
309 | } | |
310 | return nil | |
311 | } | |
312 | ||
313 | // forceClose is used for when the session is exiting | |
314 | func (s *Stream) forceClose() { | |
315 | s.stateLock.Lock() | |
316 | s.state = streamClosed | |
317 | s.stateLock.Unlock() | |
318 | s.notifyWaiting() | |
319 | } | |
320 | ||
321 | // processFlags is used to update the state of the stream | |
322 | // based on set flags, if any. Lock must be held | |
323 | func (s *Stream) processFlags(flags uint16) error { | |
324 | // Close the stream without holding the state lock | |
325 | closeStream := false | |
326 | defer func() { | |
327 | if closeStream { | |
328 | s.session.closeStream(s.id) | |
329 | } | |
330 | }() | |
331 | ||
332 | s.stateLock.Lock() | |
333 | defer s.stateLock.Unlock() | |
334 | if flags&flagACK == flagACK { | |
335 | if s.state == streamSYNSent { | |
336 | s.state = streamEstablished | |
337 | } | |
338 | s.session.establishStream(s.id) | |
339 | } | |
340 | if flags&flagFIN == flagFIN { | |
341 | switch s.state { | |
342 | case streamSYNSent: | |
343 | fallthrough | |
344 | case streamSYNReceived: | |
345 | fallthrough | |
346 | case streamEstablished: | |
347 | s.state = streamRemoteClose | |
348 | s.notifyWaiting() | |
349 | case streamLocalClose: | |
350 | s.state = streamClosed | |
351 | closeStream = true | |
352 | s.notifyWaiting() | |
353 | default: | |
354 | s.session.logger.Printf("[ERR] yamux: unexpected FIN flag in state %d", s.state) | |
355 | return ErrUnexpectedFlag | |
356 | } | |
357 | } | |
358 | if flags&flagRST == flagRST { | |
359 | s.state = streamReset | |
360 | closeStream = true | |
361 | s.notifyWaiting() | |
362 | } | |
363 | return nil | |
364 | } | |
365 | ||
366 | // notifyWaiting notifies all the waiting channels | |
367 | func (s *Stream) notifyWaiting() { | |
368 | asyncNotify(s.recvNotifyCh) | |
369 | asyncNotify(s.sendNotifyCh) | |
370 | } | |
371 | ||
372 | // incrSendWindow updates the size of our send window | |
373 | func (s *Stream) incrSendWindow(hdr header, flags uint16) error { | |
374 | if err := s.processFlags(flags); err != nil { | |
375 | return err | |
376 | } | |
377 | ||
378 | // Increase window, unblock a sender | |
379 | atomic.AddUint32(&s.sendWindow, hdr.Length()) | |
380 | asyncNotify(s.sendNotifyCh) | |
381 | return nil | |
382 | } | |
383 | ||
384 | // readData is used to handle a data frame | |
385 | func (s *Stream) readData(hdr header, flags uint16, conn io.Reader) error { | |
386 | if err := s.processFlags(flags); err != nil { | |
387 | return err | |
388 | } | |
389 | ||
390 | // Check that our recv window is not exceeded | |
391 | length := hdr.Length() | |
392 | if length == 0 { | |
393 | return nil | |
394 | } | |
395 | if remain := atomic.LoadUint32(&s.recvWindow); length > remain { | |
396 | s.session.logger.Printf("[ERR] yamux: receive window exceeded (stream: %d, remain: %d, recv: %d)", s.id, remain, length) | |
397 | return ErrRecvWindowExceeded | |
398 | } | |
399 | ||
400 | // Wrap in a limited reader | |
401 | conn = &io.LimitedReader{R: conn, N: int64(length)} | |
402 | ||
403 | // Copy into buffer | |
404 | s.recvLock.Lock() | |
405 | if s.recvBuf == nil { | |
406 | // Allocate the receive buffer just-in-time to fit the full data frame. | |
407 | // This way we can read in the whole packet without further allocations. | |
408 | s.recvBuf = bytes.NewBuffer(make([]byte, 0, length)) | |
409 | } | |
410 | if _, err := io.Copy(s.recvBuf, conn); err != nil { | |
411 | s.session.logger.Printf("[ERR] yamux: Failed to read stream data: %v", err) | |
412 | s.recvLock.Unlock() | |
413 | return err | |
414 | } | |
415 | ||
416 | // Decrement the receive window | |
417 | atomic.AddUint32(&s.recvWindow, ^uint32(length-1)) | |
418 | s.recvLock.Unlock() | |
419 | ||
420 | // Unblock any readers | |
421 | asyncNotify(s.recvNotifyCh) | |
422 | return nil | |
423 | } | |
424 | ||
425 | // SetDeadline sets the read and write deadlines | |
426 | func (s *Stream) SetDeadline(t time.Time) error { | |
427 | if err := s.SetReadDeadline(t); err != nil { | |
428 | return err | |
429 | } | |
430 | if err := s.SetWriteDeadline(t); err != nil { | |
431 | return err | |
432 | } | |
433 | return nil | |
434 | } | |
435 | ||
436 | // SetReadDeadline sets the deadline for future Read calls. | |
437 | func (s *Stream) SetReadDeadline(t time.Time) error { | |
438 | s.readDeadline = t | |
439 | return nil | |
440 | } | |
441 | ||
442 | // SetWriteDeadline sets the deadline for future Write calls | |
443 | func (s *Stream) SetWriteDeadline(t time.Time) error { | |
444 | s.writeDeadline = t | |
445 | return nil | |
446 | } | |
447 | ||
448 | // Shrink is used to compact the amount of buffers utilized | |
449 | // This is useful when using Yamux in a connection pool to reduce | |
450 | // the idle memory utilization. | |
451 | func (s *Stream) Shrink() { | |
452 | s.recvLock.Lock() | |
453 | if s.recvBuf != nil && s.recvBuf.Len() == 0 { | |
454 | s.recvBuf = nil | |
455 | } | |
456 | s.recvLock.Unlock() | |
457 | } |