]>
Commit | Line | Data |
---|---|---|
bae9f6d2 JC |
1 | package yamux |
2 | ||
3 | import ( | |
4 | "bytes" | |
5 | "io" | |
6 | "sync" | |
7 | "sync/atomic" | |
8 | "time" | |
9 | ) | |
10 | ||
11 | type streamState int | |
12 | ||
13 | const ( | |
14 | streamInit streamState = iota | |
15 | streamSYNSent | |
16 | streamSYNReceived | |
17 | streamEstablished | |
18 | streamLocalClose | |
19 | streamRemoteClose | |
20 | streamClosed | |
21 | streamReset | |
22 | ) | |
23 | ||
24 | // Stream is used to represent a logical stream | |
25 | // within a session. | |
26 | type Stream struct { | |
27 | recvWindow uint32 | |
28 | sendWindow uint32 | |
29 | ||
30 | id uint32 | |
31 | session *Session | |
32 | ||
33 | state streamState | |
34 | stateLock sync.Mutex | |
35 | ||
36 | recvBuf *bytes.Buffer | |
37 | recvLock sync.Mutex | |
38 | ||
39 | controlHdr header | |
40 | controlErr chan error | |
41 | controlHdrLock sync.Mutex | |
42 | ||
43 | sendHdr header | |
44 | sendErr chan error | |
45 | sendLock sync.Mutex | |
46 | ||
47 | recvNotifyCh chan struct{} | |
48 | sendNotifyCh chan struct{} | |
49 | ||
107c1cdb ND |
50 | readDeadline atomic.Value // time.Time |
51 | writeDeadline atomic.Value // time.Time | |
bae9f6d2 JC |
52 | } |
53 | ||
54 | // newStream is used to construct a new stream within | |
55 | // a given session for an ID | |
56 | func newStream(session *Session, id uint32, state streamState) *Stream { | |
57 | s := &Stream{ | |
58 | id: id, | |
59 | session: session, | |
60 | state: state, | |
61 | controlHdr: header(make([]byte, headerSize)), | |
62 | controlErr: make(chan error, 1), | |
63 | sendHdr: header(make([]byte, headerSize)), | |
64 | sendErr: make(chan error, 1), | |
65 | recvWindow: initialStreamWindow, | |
66 | sendWindow: initialStreamWindow, | |
67 | recvNotifyCh: make(chan struct{}, 1), | |
68 | sendNotifyCh: make(chan struct{}, 1), | |
69 | } | |
107c1cdb ND |
70 | s.readDeadline.Store(time.Time{}) |
71 | s.writeDeadline.Store(time.Time{}) | |
bae9f6d2 JC |
72 | return s |
73 | } | |
74 | ||
75 | // Session returns the associated stream session | |
76 | func (s *Stream) Session() *Session { | |
77 | return s.session | |
78 | } | |
79 | ||
80 | // StreamID returns the ID of this stream | |
81 | func (s *Stream) StreamID() uint32 { | |
82 | return s.id | |
83 | } | |
84 | ||
85 | // Read is used to read from the stream | |
86 | func (s *Stream) Read(b []byte) (n int, err error) { | |
87 | defer asyncNotify(s.recvNotifyCh) | |
88 | START: | |
89 | s.stateLock.Lock() | |
90 | switch s.state { | |
91 | case streamLocalClose: | |
92 | fallthrough | |
93 | case streamRemoteClose: | |
94 | fallthrough | |
95 | case streamClosed: | |
96 | s.recvLock.Lock() | |
97 | if s.recvBuf == nil || s.recvBuf.Len() == 0 { | |
98 | s.recvLock.Unlock() | |
99 | s.stateLock.Unlock() | |
100 | return 0, io.EOF | |
101 | } | |
102 | s.recvLock.Unlock() | |
103 | case streamReset: | |
104 | s.stateLock.Unlock() | |
105 | return 0, ErrConnectionReset | |
106 | } | |
107 | s.stateLock.Unlock() | |
108 | ||
109 | // If there is no data available, block | |
110 | s.recvLock.Lock() | |
111 | if s.recvBuf == nil || s.recvBuf.Len() == 0 { | |
112 | s.recvLock.Unlock() | |
113 | goto WAIT | |
114 | } | |
115 | ||
116 | // Read any bytes | |
117 | n, _ = s.recvBuf.Read(b) | |
118 | s.recvLock.Unlock() | |
119 | ||
120 | // Send a window update potentially | |
121 | err = s.sendWindowUpdate() | |
122 | return n, err | |
123 | ||
124 | WAIT: | |
125 | var timeout <-chan time.Time | |
126 | var timer *time.Timer | |
107c1cdb ND |
127 | readDeadline := s.readDeadline.Load().(time.Time) |
128 | if !readDeadline.IsZero() { | |
129 | delay := readDeadline.Sub(time.Now()) | |
bae9f6d2 JC |
130 | timer = time.NewTimer(delay) |
131 | timeout = timer.C | |
132 | } | |
133 | select { | |
134 | case <-s.recvNotifyCh: | |
135 | if timer != nil { | |
136 | timer.Stop() | |
137 | } | |
138 | goto START | |
139 | case <-timeout: | |
140 | return 0, ErrTimeout | |
141 | } | |
142 | } | |
143 | ||
144 | // Write is used to write to the stream | |
145 | func (s *Stream) Write(b []byte) (n int, err error) { | |
146 | s.sendLock.Lock() | |
147 | defer s.sendLock.Unlock() | |
148 | total := 0 | |
149 | for total < len(b) { | |
150 | n, err := s.write(b[total:]) | |
151 | total += n | |
152 | if err != nil { | |
153 | return total, err | |
154 | } | |
155 | } | |
156 | return total, nil | |
157 | } | |
158 | ||
159 | // write is used to write to the stream, may return on | |
160 | // a short write. | |
161 | func (s *Stream) write(b []byte) (n int, err error) { | |
162 | var flags uint16 | |
163 | var max uint32 | |
164 | var body io.Reader | |
165 | START: | |
166 | s.stateLock.Lock() | |
167 | switch s.state { | |
168 | case streamLocalClose: | |
169 | fallthrough | |
170 | case streamClosed: | |
171 | s.stateLock.Unlock() | |
172 | return 0, ErrStreamClosed | |
173 | case streamReset: | |
174 | s.stateLock.Unlock() | |
175 | return 0, ErrConnectionReset | |
176 | } | |
177 | s.stateLock.Unlock() | |
178 | ||
179 | // If there is no data available, block | |
180 | window := atomic.LoadUint32(&s.sendWindow) | |
181 | if window == 0 { | |
182 | goto WAIT | |
183 | } | |
184 | ||
185 | // Determine the flags if any | |
186 | flags = s.sendFlags() | |
187 | ||
188 | // Send up to our send window | |
189 | max = min(window, uint32(len(b))) | |
190 | body = bytes.NewReader(b[:max]) | |
191 | ||
192 | // Send the header | |
193 | s.sendHdr.encode(typeData, flags, s.id, max) | |
107c1cdb | 194 | if err = s.session.waitForSendErr(s.sendHdr, body, s.sendErr); err != nil { |
bae9f6d2 JC |
195 | return 0, err |
196 | } | |
197 | ||
198 | // Reduce our send window | |
199 | atomic.AddUint32(&s.sendWindow, ^uint32(max-1)) | |
200 | ||
201 | // Unlock | |
202 | return int(max), err | |
203 | ||
204 | WAIT: | |
205 | var timeout <-chan time.Time | |
107c1cdb ND |
206 | writeDeadline := s.writeDeadline.Load().(time.Time) |
207 | if !writeDeadline.IsZero() { | |
208 | delay := writeDeadline.Sub(time.Now()) | |
bae9f6d2 JC |
209 | timeout = time.After(delay) |
210 | } | |
211 | select { | |
212 | case <-s.sendNotifyCh: | |
213 | goto START | |
214 | case <-timeout: | |
215 | return 0, ErrTimeout | |
216 | } | |
217 | return 0, nil | |
218 | } | |
219 | ||
220 | // sendFlags determines any flags that are appropriate | |
221 | // based on the current stream state | |
222 | func (s *Stream) sendFlags() uint16 { | |
223 | s.stateLock.Lock() | |
224 | defer s.stateLock.Unlock() | |
225 | var flags uint16 | |
226 | switch s.state { | |
227 | case streamInit: | |
228 | flags |= flagSYN | |
229 | s.state = streamSYNSent | |
230 | case streamSYNReceived: | |
231 | flags |= flagACK | |
232 | s.state = streamEstablished | |
233 | } | |
234 | return flags | |
235 | } | |
236 | ||
237 | // sendWindowUpdate potentially sends a window update enabling | |
238 | // further writes to take place. Must be invoked with the lock. | |
239 | func (s *Stream) sendWindowUpdate() error { | |
240 | s.controlHdrLock.Lock() | |
241 | defer s.controlHdrLock.Unlock() | |
242 | ||
243 | // Determine the delta update | |
244 | max := s.session.config.MaxStreamWindowSize | |
107c1cdb ND |
245 | var bufLen uint32 |
246 | s.recvLock.Lock() | |
247 | if s.recvBuf != nil { | |
248 | bufLen = uint32(s.recvBuf.Len()) | |
249 | } | |
250 | delta := (max - bufLen) - s.recvWindow | |
bae9f6d2 JC |
251 | |
252 | // Determine the flags if any | |
253 | flags := s.sendFlags() | |
254 | ||
255 | // Check if we can omit the update | |
256 | if delta < (max/2) && flags == 0 { | |
107c1cdb | 257 | s.recvLock.Unlock() |
bae9f6d2 JC |
258 | return nil |
259 | } | |
260 | ||
261 | // Update our window | |
107c1cdb ND |
262 | s.recvWindow += delta |
263 | s.recvLock.Unlock() | |
bae9f6d2 JC |
264 | |
265 | // Send the header | |
266 | s.controlHdr.encode(typeWindowUpdate, flags, s.id, delta) | |
267 | if err := s.session.waitForSendErr(s.controlHdr, nil, s.controlErr); err != nil { | |
268 | return err | |
269 | } | |
270 | return nil | |
271 | } | |
272 | ||
273 | // sendClose is used to send a FIN | |
274 | func (s *Stream) sendClose() error { | |
275 | s.controlHdrLock.Lock() | |
276 | defer s.controlHdrLock.Unlock() | |
277 | ||
278 | flags := s.sendFlags() | |
279 | flags |= flagFIN | |
280 | s.controlHdr.encode(typeWindowUpdate, flags, s.id, 0) | |
281 | if err := s.session.waitForSendErr(s.controlHdr, nil, s.controlErr); err != nil { | |
282 | return err | |
283 | } | |
284 | return nil | |
285 | } | |
286 | ||
287 | // Close is used to close the stream | |
288 | func (s *Stream) Close() error { | |
289 | closeStream := false | |
290 | s.stateLock.Lock() | |
291 | switch s.state { | |
292 | // Opened means we need to signal a close | |
293 | case streamSYNSent: | |
294 | fallthrough | |
295 | case streamSYNReceived: | |
296 | fallthrough | |
297 | case streamEstablished: | |
298 | s.state = streamLocalClose | |
299 | goto SEND_CLOSE | |
300 | ||
301 | case streamLocalClose: | |
302 | case streamRemoteClose: | |
303 | s.state = streamClosed | |
304 | closeStream = true | |
305 | goto SEND_CLOSE | |
306 | ||
307 | case streamClosed: | |
308 | case streamReset: | |
309 | default: | |
310 | panic("unhandled state") | |
311 | } | |
312 | s.stateLock.Unlock() | |
313 | return nil | |
314 | SEND_CLOSE: | |
315 | s.stateLock.Unlock() | |
316 | s.sendClose() | |
317 | s.notifyWaiting() | |
318 | if closeStream { | |
319 | s.session.closeStream(s.id) | |
320 | } | |
321 | return nil | |
322 | } | |
323 | ||
324 | // forceClose is used for when the session is exiting | |
325 | func (s *Stream) forceClose() { | |
326 | s.stateLock.Lock() | |
327 | s.state = streamClosed | |
328 | s.stateLock.Unlock() | |
329 | s.notifyWaiting() | |
330 | } | |
331 | ||
332 | // processFlags is used to update the state of the stream | |
333 | // based on set flags, if any. Lock must be held | |
334 | func (s *Stream) processFlags(flags uint16) error { | |
335 | // Close the stream without holding the state lock | |
336 | closeStream := false | |
337 | defer func() { | |
338 | if closeStream { | |
339 | s.session.closeStream(s.id) | |
340 | } | |
341 | }() | |
342 | ||
343 | s.stateLock.Lock() | |
344 | defer s.stateLock.Unlock() | |
345 | if flags&flagACK == flagACK { | |
346 | if s.state == streamSYNSent { | |
347 | s.state = streamEstablished | |
348 | } | |
349 | s.session.establishStream(s.id) | |
350 | } | |
351 | if flags&flagFIN == flagFIN { | |
352 | switch s.state { | |
353 | case streamSYNSent: | |
354 | fallthrough | |
355 | case streamSYNReceived: | |
356 | fallthrough | |
357 | case streamEstablished: | |
358 | s.state = streamRemoteClose | |
359 | s.notifyWaiting() | |
360 | case streamLocalClose: | |
361 | s.state = streamClosed | |
362 | closeStream = true | |
363 | s.notifyWaiting() | |
364 | default: | |
365 | s.session.logger.Printf("[ERR] yamux: unexpected FIN flag in state %d", s.state) | |
366 | return ErrUnexpectedFlag | |
367 | } | |
368 | } | |
369 | if flags&flagRST == flagRST { | |
370 | s.state = streamReset | |
371 | closeStream = true | |
372 | s.notifyWaiting() | |
373 | } | |
374 | return nil | |
375 | } | |
376 | ||
377 | // notifyWaiting notifies all the waiting channels | |
378 | func (s *Stream) notifyWaiting() { | |
379 | asyncNotify(s.recvNotifyCh) | |
380 | asyncNotify(s.sendNotifyCh) | |
381 | } | |
382 | ||
383 | // incrSendWindow updates the size of our send window | |
384 | func (s *Stream) incrSendWindow(hdr header, flags uint16) error { | |
385 | if err := s.processFlags(flags); err != nil { | |
386 | return err | |
387 | } | |
388 | ||
389 | // Increase window, unblock a sender | |
390 | atomic.AddUint32(&s.sendWindow, hdr.Length()) | |
391 | asyncNotify(s.sendNotifyCh) | |
392 | return nil | |
393 | } | |
394 | ||
395 | // readData is used to handle a data frame | |
396 | func (s *Stream) readData(hdr header, flags uint16, conn io.Reader) error { | |
397 | if err := s.processFlags(flags); err != nil { | |
398 | return err | |
399 | } | |
400 | ||
401 | // Check that our recv window is not exceeded | |
402 | length := hdr.Length() | |
403 | if length == 0 { | |
404 | return nil | |
405 | } | |
bae9f6d2 JC |
406 | |
407 | // Wrap in a limited reader | |
408 | conn = &io.LimitedReader{R: conn, N: int64(length)} | |
409 | ||
410 | // Copy into buffer | |
411 | s.recvLock.Lock() | |
107c1cdb ND |
412 | |
413 | if length > s.recvWindow { | |
414 | s.session.logger.Printf("[ERR] yamux: receive window exceeded (stream: %d, remain: %d, recv: %d)", s.id, s.recvWindow, length) | |
415 | return ErrRecvWindowExceeded | |
416 | } | |
417 | ||
bae9f6d2 JC |
418 | if s.recvBuf == nil { |
419 | // Allocate the receive buffer just-in-time to fit the full data frame. | |
420 | // This way we can read in the whole packet without further allocations. | |
421 | s.recvBuf = bytes.NewBuffer(make([]byte, 0, length)) | |
422 | } | |
423 | if _, err := io.Copy(s.recvBuf, conn); err != nil { | |
424 | s.session.logger.Printf("[ERR] yamux: Failed to read stream data: %v", err) | |
425 | s.recvLock.Unlock() | |
426 | return err | |
427 | } | |
428 | ||
429 | // Decrement the receive window | |
107c1cdb | 430 | s.recvWindow -= length |
bae9f6d2 JC |
431 | s.recvLock.Unlock() |
432 | ||
433 | // Unblock any readers | |
434 | asyncNotify(s.recvNotifyCh) | |
435 | return nil | |
436 | } | |
437 | ||
438 | // SetDeadline sets the read and write deadlines | |
439 | func (s *Stream) SetDeadline(t time.Time) error { | |
440 | if err := s.SetReadDeadline(t); err != nil { | |
441 | return err | |
442 | } | |
443 | if err := s.SetWriteDeadline(t); err != nil { | |
444 | return err | |
445 | } | |
446 | return nil | |
447 | } | |
448 | ||
449 | // SetReadDeadline sets the deadline for future Read calls. | |
450 | func (s *Stream) SetReadDeadline(t time.Time) error { | |
107c1cdb | 451 | s.readDeadline.Store(t) |
bae9f6d2 JC |
452 | return nil |
453 | } | |
454 | ||
455 | // SetWriteDeadline sets the deadline for future Write calls | |
456 | func (s *Stream) SetWriteDeadline(t time.Time) error { | |
107c1cdb | 457 | s.writeDeadline.Store(t) |
bae9f6d2 JC |
458 | return nil |
459 | } | |
460 | ||
461 | // Shrink is used to compact the amount of buffers utilized | |
462 | // This is useful when using Yamux in a connection pool to reduce | |
463 | // the idle memory utilization. | |
464 | func (s *Stream) Shrink() { | |
465 | s.recvLock.Lock() | |
466 | if s.recvBuf != nil && s.recvBuf.Len() == 0 { | |
467 | s.recvBuf = nil | |
468 | } | |
469 | s.recvLock.Unlock() | |
470 | } |