diff options
Diffstat (limited to 'vendor/github.com/hashicorp/yamux/stream.go')
-rw-r--r-- | vendor/github.com/hashicorp/yamux/stream.go | 457 |
1 files changed, 457 insertions, 0 deletions
diff --git a/vendor/github.com/hashicorp/yamux/stream.go b/vendor/github.com/hashicorp/yamux/stream.go new file mode 100644 index 0000000..d216e28 --- /dev/null +++ b/vendor/github.com/hashicorp/yamux/stream.go | |||
@@ -0,0 +1,457 @@ | |||
1 | package yamux | ||
2 | |||
3 | import ( | ||
4 | "bytes" | ||
5 | "io" | ||
6 | "sync" | ||
7 | "sync/atomic" | ||
8 | "time" | ||
9 | ) | ||
10 | |||
11 | type streamState int | ||
12 | |||
13 | const ( | ||
14 | streamInit streamState = iota | ||
15 | streamSYNSent | ||
16 | streamSYNReceived | ||
17 | streamEstablished | ||
18 | streamLocalClose | ||
19 | streamRemoteClose | ||
20 | streamClosed | ||
21 | streamReset | ||
22 | ) | ||
23 | |||
24 | // Stream is used to represent a logical stream | ||
25 | // within a session. | ||
26 | type Stream struct { | ||
27 | recvWindow uint32 | ||
28 | sendWindow uint32 | ||
29 | |||
30 | id uint32 | ||
31 | session *Session | ||
32 | |||
33 | state streamState | ||
34 | stateLock sync.Mutex | ||
35 | |||
36 | recvBuf *bytes.Buffer | ||
37 | recvLock sync.Mutex | ||
38 | |||
39 | controlHdr header | ||
40 | controlErr chan error | ||
41 | controlHdrLock sync.Mutex | ||
42 | |||
43 | sendHdr header | ||
44 | sendErr chan error | ||
45 | sendLock sync.Mutex | ||
46 | |||
47 | recvNotifyCh chan struct{} | ||
48 | sendNotifyCh chan struct{} | ||
49 | |||
50 | readDeadline time.Time | ||
51 | writeDeadline time.Time | ||
52 | } | ||
53 | |||
54 | // newStream is used to construct a new stream within | ||
55 | // a given session for an ID | ||
56 | func newStream(session *Session, id uint32, state streamState) *Stream { | ||
57 | s := &Stream{ | ||
58 | id: id, | ||
59 | session: session, | ||
60 | state: state, | ||
61 | controlHdr: header(make([]byte, headerSize)), | ||
62 | controlErr: make(chan error, 1), | ||
63 | sendHdr: header(make([]byte, headerSize)), | ||
64 | sendErr: make(chan error, 1), | ||
65 | recvWindow: initialStreamWindow, | ||
66 | sendWindow: initialStreamWindow, | ||
67 | recvNotifyCh: make(chan struct{}, 1), | ||
68 | sendNotifyCh: make(chan struct{}, 1), | ||
69 | } | ||
70 | return s | ||
71 | } | ||
72 | |||
73 | // Session returns the associated stream session | ||
74 | func (s *Stream) Session() *Session { | ||
75 | return s.session | ||
76 | } | ||
77 | |||
78 | // StreamID returns the ID of this stream | ||
79 | func (s *Stream) StreamID() uint32 { | ||
80 | return s.id | ||
81 | } | ||
82 | |||
83 | // Read is used to read from the stream | ||
84 | func (s *Stream) Read(b []byte) (n int, err error) { | ||
85 | defer asyncNotify(s.recvNotifyCh) | ||
86 | START: | ||
87 | s.stateLock.Lock() | ||
88 | switch s.state { | ||
89 | case streamLocalClose: | ||
90 | fallthrough | ||
91 | case streamRemoteClose: | ||
92 | fallthrough | ||
93 | case streamClosed: | ||
94 | s.recvLock.Lock() | ||
95 | if s.recvBuf == nil || s.recvBuf.Len() == 0 { | ||
96 | s.recvLock.Unlock() | ||
97 | s.stateLock.Unlock() | ||
98 | return 0, io.EOF | ||
99 | } | ||
100 | s.recvLock.Unlock() | ||
101 | case streamReset: | ||
102 | s.stateLock.Unlock() | ||
103 | return 0, ErrConnectionReset | ||
104 | } | ||
105 | s.stateLock.Unlock() | ||
106 | |||
107 | // If there is no data available, block | ||
108 | s.recvLock.Lock() | ||
109 | if s.recvBuf == nil || s.recvBuf.Len() == 0 { | ||
110 | s.recvLock.Unlock() | ||
111 | goto WAIT | ||
112 | } | ||
113 | |||
114 | // Read any bytes | ||
115 | n, _ = s.recvBuf.Read(b) | ||
116 | s.recvLock.Unlock() | ||
117 | |||
118 | // Send a window update potentially | ||
119 | err = s.sendWindowUpdate() | ||
120 | return n, err | ||
121 | |||
122 | WAIT: | ||
123 | var timeout <-chan time.Time | ||
124 | var timer *time.Timer | ||
125 | if !s.readDeadline.IsZero() { | ||
126 | delay := s.readDeadline.Sub(time.Now()) | ||
127 | timer = time.NewTimer(delay) | ||
128 | timeout = timer.C | ||
129 | } | ||
130 | select { | ||
131 | case <-s.recvNotifyCh: | ||
132 | if timer != nil { | ||
133 | timer.Stop() | ||
134 | } | ||
135 | goto START | ||
136 | case <-timeout: | ||
137 | return 0, ErrTimeout | ||
138 | } | ||
139 | } | ||
140 | |||
141 | // Write is used to write to the stream | ||
142 | func (s *Stream) Write(b []byte) (n int, err error) { | ||
143 | s.sendLock.Lock() | ||
144 | defer s.sendLock.Unlock() | ||
145 | total := 0 | ||
146 | for total < len(b) { | ||
147 | n, err := s.write(b[total:]) | ||
148 | total += n | ||
149 | if err != nil { | ||
150 | return total, err | ||
151 | } | ||
152 | } | ||
153 | return total, nil | ||
154 | } | ||
155 | |||
156 | // write is used to write to the stream, may return on | ||
157 | // a short write. | ||
158 | func (s *Stream) write(b []byte) (n int, err error) { | ||
159 | var flags uint16 | ||
160 | var max uint32 | ||
161 | var body io.Reader | ||
162 | START: | ||
163 | s.stateLock.Lock() | ||
164 | switch s.state { | ||
165 | case streamLocalClose: | ||
166 | fallthrough | ||
167 | case streamClosed: | ||
168 | s.stateLock.Unlock() | ||
169 | return 0, ErrStreamClosed | ||
170 | case streamReset: | ||
171 | s.stateLock.Unlock() | ||
172 | return 0, ErrConnectionReset | ||
173 | } | ||
174 | s.stateLock.Unlock() | ||
175 | |||
176 | // If there is no data available, block | ||
177 | window := atomic.LoadUint32(&s.sendWindow) | ||
178 | if window == 0 { | ||
179 | goto WAIT | ||
180 | } | ||
181 | |||
182 | // Determine the flags if any | ||
183 | flags = s.sendFlags() | ||
184 | |||
185 | // Send up to our send window | ||
186 | max = min(window, uint32(len(b))) | ||
187 | body = bytes.NewReader(b[:max]) | ||
188 | |||
189 | // Send the header | ||
190 | s.sendHdr.encode(typeData, flags, s.id, max) | ||
191 | if err := s.session.waitForSendErr(s.sendHdr, body, s.sendErr); err != nil { | ||
192 | return 0, err | ||
193 | } | ||
194 | |||
195 | // Reduce our send window | ||
196 | atomic.AddUint32(&s.sendWindow, ^uint32(max-1)) | ||
197 | |||
198 | // Unlock | ||
199 | return int(max), err | ||
200 | |||
201 | WAIT: | ||
202 | var timeout <-chan time.Time | ||
203 | if !s.writeDeadline.IsZero() { | ||
204 | delay := s.writeDeadline.Sub(time.Now()) | ||
205 | timeout = time.After(delay) | ||
206 | } | ||
207 | select { | ||
208 | case <-s.sendNotifyCh: | ||
209 | goto START | ||
210 | case <-timeout: | ||
211 | return 0, ErrTimeout | ||
212 | } | ||
213 | return 0, nil | ||
214 | } | ||
215 | |||
216 | // sendFlags determines any flags that are appropriate | ||
217 | // based on the current stream state | ||
218 | func (s *Stream) sendFlags() uint16 { | ||
219 | s.stateLock.Lock() | ||
220 | defer s.stateLock.Unlock() | ||
221 | var flags uint16 | ||
222 | switch s.state { | ||
223 | case streamInit: | ||
224 | flags |= flagSYN | ||
225 | s.state = streamSYNSent | ||
226 | case streamSYNReceived: | ||
227 | flags |= flagACK | ||
228 | s.state = streamEstablished | ||
229 | } | ||
230 | return flags | ||
231 | } | ||
232 | |||
233 | // sendWindowUpdate potentially sends a window update enabling | ||
234 | // further writes to take place. Must be invoked with the lock. | ||
235 | func (s *Stream) sendWindowUpdate() error { | ||
236 | s.controlHdrLock.Lock() | ||
237 | defer s.controlHdrLock.Unlock() | ||
238 | |||
239 | // Determine the delta update | ||
240 | max := s.session.config.MaxStreamWindowSize | ||
241 | delta := max - atomic.LoadUint32(&s.recvWindow) | ||
242 | |||
243 | // Determine the flags if any | ||
244 | flags := s.sendFlags() | ||
245 | |||
246 | // Check if we can omit the update | ||
247 | if delta < (max/2) && flags == 0 { | ||
248 | return nil | ||
249 | } | ||
250 | |||
251 | // Update our window | ||
252 | atomic.AddUint32(&s.recvWindow, delta) | ||
253 | |||
254 | // Send the header | ||
255 | s.controlHdr.encode(typeWindowUpdate, flags, s.id, delta) | ||
256 | if err := s.session.waitForSendErr(s.controlHdr, nil, s.controlErr); err != nil { | ||
257 | return err | ||
258 | } | ||
259 | return nil | ||
260 | } | ||
261 | |||
262 | // sendClose is used to send a FIN | ||
263 | func (s *Stream) sendClose() error { | ||
264 | s.controlHdrLock.Lock() | ||
265 | defer s.controlHdrLock.Unlock() | ||
266 | |||
267 | flags := s.sendFlags() | ||
268 | flags |= flagFIN | ||
269 | s.controlHdr.encode(typeWindowUpdate, flags, s.id, 0) | ||
270 | if err := s.session.waitForSendErr(s.controlHdr, nil, s.controlErr); err != nil { | ||
271 | return err | ||
272 | } | ||
273 | return nil | ||
274 | } | ||
275 | |||
276 | // Close is used to close the stream | ||
277 | func (s *Stream) Close() error { | ||
278 | closeStream := false | ||
279 | s.stateLock.Lock() | ||
280 | switch s.state { | ||
281 | // Opened means we need to signal a close | ||
282 | case streamSYNSent: | ||
283 | fallthrough | ||
284 | case streamSYNReceived: | ||
285 | fallthrough | ||
286 | case streamEstablished: | ||
287 | s.state = streamLocalClose | ||
288 | goto SEND_CLOSE | ||
289 | |||
290 | case streamLocalClose: | ||
291 | case streamRemoteClose: | ||
292 | s.state = streamClosed | ||
293 | closeStream = true | ||
294 | goto SEND_CLOSE | ||
295 | |||
296 | case streamClosed: | ||
297 | case streamReset: | ||
298 | default: | ||
299 | panic("unhandled state") | ||
300 | } | ||
301 | s.stateLock.Unlock() | ||
302 | return nil | ||
303 | SEND_CLOSE: | ||
304 | s.stateLock.Unlock() | ||
305 | s.sendClose() | ||
306 | s.notifyWaiting() | ||
307 | if closeStream { | ||
308 | s.session.closeStream(s.id) | ||
309 | } | ||
310 | return nil | ||
311 | } | ||
312 | |||
313 | // forceClose is used for when the session is exiting | ||
314 | func (s *Stream) forceClose() { | ||
315 | s.stateLock.Lock() | ||
316 | s.state = streamClosed | ||
317 | s.stateLock.Unlock() | ||
318 | s.notifyWaiting() | ||
319 | } | ||
320 | |||
321 | // processFlags is used to update the state of the stream | ||
322 | // based on set flags, if any. Lock must be held | ||
323 | func (s *Stream) processFlags(flags uint16) error { | ||
324 | // Close the stream without holding the state lock | ||
325 | closeStream := false | ||
326 | defer func() { | ||
327 | if closeStream { | ||
328 | s.session.closeStream(s.id) | ||
329 | } | ||
330 | }() | ||
331 | |||
332 | s.stateLock.Lock() | ||
333 | defer s.stateLock.Unlock() | ||
334 | if flags&flagACK == flagACK { | ||
335 | if s.state == streamSYNSent { | ||
336 | s.state = streamEstablished | ||
337 | } | ||
338 | s.session.establishStream(s.id) | ||
339 | } | ||
340 | if flags&flagFIN == flagFIN { | ||
341 | switch s.state { | ||
342 | case streamSYNSent: | ||
343 | fallthrough | ||
344 | case streamSYNReceived: | ||
345 | fallthrough | ||
346 | case streamEstablished: | ||
347 | s.state = streamRemoteClose | ||
348 | s.notifyWaiting() | ||
349 | case streamLocalClose: | ||
350 | s.state = streamClosed | ||
351 | closeStream = true | ||
352 | s.notifyWaiting() | ||
353 | default: | ||
354 | s.session.logger.Printf("[ERR] yamux: unexpected FIN flag in state %d", s.state) | ||
355 | return ErrUnexpectedFlag | ||
356 | } | ||
357 | } | ||
358 | if flags&flagRST == flagRST { | ||
359 | s.state = streamReset | ||
360 | closeStream = true | ||
361 | s.notifyWaiting() | ||
362 | } | ||
363 | return nil | ||
364 | } | ||
365 | |||
366 | // notifyWaiting notifies all the waiting channels | ||
367 | func (s *Stream) notifyWaiting() { | ||
368 | asyncNotify(s.recvNotifyCh) | ||
369 | asyncNotify(s.sendNotifyCh) | ||
370 | } | ||
371 | |||
372 | // incrSendWindow updates the size of our send window | ||
373 | func (s *Stream) incrSendWindow(hdr header, flags uint16) error { | ||
374 | if err := s.processFlags(flags); err != nil { | ||
375 | return err | ||
376 | } | ||
377 | |||
378 | // Increase window, unblock a sender | ||
379 | atomic.AddUint32(&s.sendWindow, hdr.Length()) | ||
380 | asyncNotify(s.sendNotifyCh) | ||
381 | return nil | ||
382 | } | ||
383 | |||
384 | // readData is used to handle a data frame | ||
385 | func (s *Stream) readData(hdr header, flags uint16, conn io.Reader) error { | ||
386 | if err := s.processFlags(flags); err != nil { | ||
387 | return err | ||
388 | } | ||
389 | |||
390 | // Check that our recv window is not exceeded | ||
391 | length := hdr.Length() | ||
392 | if length == 0 { | ||
393 | return nil | ||
394 | } | ||
395 | if remain := atomic.LoadUint32(&s.recvWindow); length > remain { | ||
396 | s.session.logger.Printf("[ERR] yamux: receive window exceeded (stream: %d, remain: %d, recv: %d)", s.id, remain, length) | ||
397 | return ErrRecvWindowExceeded | ||
398 | } | ||
399 | |||
400 | // Wrap in a limited reader | ||
401 | conn = &io.LimitedReader{R: conn, N: int64(length)} | ||
402 | |||
403 | // Copy into buffer | ||
404 | s.recvLock.Lock() | ||
405 | if s.recvBuf == nil { | ||
406 | // Allocate the receive buffer just-in-time to fit the full data frame. | ||
407 | // This way we can read in the whole packet without further allocations. | ||
408 | s.recvBuf = bytes.NewBuffer(make([]byte, 0, length)) | ||
409 | } | ||
410 | if _, err := io.Copy(s.recvBuf, conn); err != nil { | ||
411 | s.session.logger.Printf("[ERR] yamux: Failed to read stream data: %v", err) | ||
412 | s.recvLock.Unlock() | ||
413 | return err | ||
414 | } | ||
415 | |||
416 | // Decrement the receive window | ||
417 | atomic.AddUint32(&s.recvWindow, ^uint32(length-1)) | ||
418 | s.recvLock.Unlock() | ||
419 | |||
420 | // Unblock any readers | ||
421 | asyncNotify(s.recvNotifyCh) | ||
422 | return nil | ||
423 | } | ||
424 | |||
425 | // SetDeadline sets the read and write deadlines | ||
426 | func (s *Stream) SetDeadline(t time.Time) error { | ||
427 | if err := s.SetReadDeadline(t); err != nil { | ||
428 | return err | ||
429 | } | ||
430 | if err := s.SetWriteDeadline(t); err != nil { | ||
431 | return err | ||
432 | } | ||
433 | return nil | ||
434 | } | ||
435 | |||
436 | // SetReadDeadline sets the deadline for future Read calls. | ||
437 | func (s *Stream) SetReadDeadline(t time.Time) error { | ||
438 | s.readDeadline = t | ||
439 | return nil | ||
440 | } | ||
441 | |||
442 | // SetWriteDeadline sets the deadline for future Write calls | ||
443 | func (s *Stream) SetWriteDeadline(t time.Time) error { | ||
444 | s.writeDeadline = t | ||
445 | return nil | ||
446 | } | ||
447 | |||
448 | // Shrink is used to compact the amount of buffers utilized | ||
449 | // This is useful when using Yamux in a connection pool to reduce | ||
450 | // the idle memory utilization. | ||
451 | func (s *Stream) Shrink() { | ||
452 | s.recvLock.Lock() | ||
453 | if s.recvBuf != nil && s.recvBuf.Len() == 0 { | ||
454 | s.recvBuf = nil | ||
455 | } | ||
456 | s.recvLock.Unlock() | ||
457 | } | ||