]>
Commit | Line | Data |
---|---|---|
bae9f6d2 JC |
1 | // Copyright 2013 The Go Authors. All rights reserved. |
2 | // Use of this source code is governed by a BSD-style | |
3 | // license that can be found in the LICENSE file. | |
4 | ||
5 | package ssh | |
6 | ||
7 | import ( | |
8 | "encoding/binary" | |
9 | "fmt" | |
10 | "io" | |
11 | "log" | |
12 | "sync" | |
13 | "sync/atomic" | |
14 | ) | |
15 | ||
16 | // debugMux, if set, causes messages in the connection protocol to be | |
17 | // logged. | |
18 | const debugMux = false | |
19 | ||
20 | // chanList is a thread safe channel list. | |
21 | type chanList struct { | |
22 | // protects concurrent access to chans | |
23 | sync.Mutex | |
24 | ||
25 | // chans are indexed by the local id of the channel, which the | |
26 | // other side should send in the PeersId field. | |
27 | chans []*channel | |
28 | ||
29 | // This is a debugging aid: it offsets all IDs by this | |
30 | // amount. This helps distinguish otherwise identical | |
31 | // server/client muxes | |
32 | offset uint32 | |
33 | } | |
34 | ||
35 | // Assigns a channel ID to the given channel. | |
36 | func (c *chanList) add(ch *channel) uint32 { | |
37 | c.Lock() | |
38 | defer c.Unlock() | |
39 | for i := range c.chans { | |
40 | if c.chans[i] == nil { | |
41 | c.chans[i] = ch | |
42 | return uint32(i) + c.offset | |
43 | } | |
44 | } | |
45 | c.chans = append(c.chans, ch) | |
46 | return uint32(len(c.chans)-1) + c.offset | |
47 | } | |
48 | ||
49 | // getChan returns the channel for the given ID. | |
50 | func (c *chanList) getChan(id uint32) *channel { | |
51 | id -= c.offset | |
52 | ||
53 | c.Lock() | |
54 | defer c.Unlock() | |
55 | if id < uint32(len(c.chans)) { | |
56 | return c.chans[id] | |
57 | } | |
58 | return nil | |
59 | } | |
60 | ||
61 | func (c *chanList) remove(id uint32) { | |
62 | id -= c.offset | |
63 | c.Lock() | |
64 | if id < uint32(len(c.chans)) { | |
65 | c.chans[id] = nil | |
66 | } | |
67 | c.Unlock() | |
68 | } | |
69 | ||
70 | // dropAll forgets all channels it knows, returning them in a slice. | |
71 | func (c *chanList) dropAll() []*channel { | |
72 | c.Lock() | |
73 | defer c.Unlock() | |
74 | var r []*channel | |
75 | ||
76 | for _, ch := range c.chans { | |
77 | if ch == nil { | |
78 | continue | |
79 | } | |
80 | r = append(r, ch) | |
81 | } | |
82 | c.chans = nil | |
83 | return r | |
84 | } | |
85 | ||
86 | // mux represents the state for the SSH connection protocol, which | |
87 | // multiplexes many channels onto a single packet transport. | |
88 | type mux struct { | |
89 | conn packetConn | |
90 | chanList chanList | |
91 | ||
92 | incomingChannels chan NewChannel | |
93 | ||
94 | globalSentMu sync.Mutex | |
95 | globalResponses chan interface{} | |
96 | incomingRequests chan *Request | |
97 | ||
98 | errCond *sync.Cond | |
99 | err error | |
100 | } | |
101 | ||
102 | // When debugging, each new chanList instantiation has a different | |
103 | // offset. | |
104 | var globalOff uint32 | |
105 | ||
106 | func (m *mux) Wait() error { | |
107 | m.errCond.L.Lock() | |
108 | defer m.errCond.L.Unlock() | |
109 | for m.err == nil { | |
110 | m.errCond.Wait() | |
111 | } | |
112 | return m.err | |
113 | } | |
114 | ||
115 | // newMux returns a mux that runs over the given connection. | |
116 | func newMux(p packetConn) *mux { | |
117 | m := &mux{ | |
118 | conn: p, | |
119 | incomingChannels: make(chan NewChannel, chanSize), | |
120 | globalResponses: make(chan interface{}, 1), | |
121 | incomingRequests: make(chan *Request, chanSize), | |
122 | errCond: newCond(), | |
123 | } | |
124 | if debugMux { | |
125 | m.chanList.offset = atomic.AddUint32(&globalOff, 1) | |
126 | } | |
127 | ||
128 | go m.loop() | |
129 | return m | |
130 | } | |
131 | ||
132 | func (m *mux) sendMessage(msg interface{}) error { | |
133 | p := Marshal(msg) | |
134 | if debugMux { | |
135 | log.Printf("send global(%d): %#v", m.chanList.offset, msg) | |
136 | } | |
137 | return m.conn.writePacket(p) | |
138 | } | |
139 | ||
140 | func (m *mux) SendRequest(name string, wantReply bool, payload []byte) (bool, []byte, error) { | |
141 | if wantReply { | |
142 | m.globalSentMu.Lock() | |
143 | defer m.globalSentMu.Unlock() | |
144 | } | |
145 | ||
146 | if err := m.sendMessage(globalRequestMsg{ | |
147 | Type: name, | |
148 | WantReply: wantReply, | |
149 | Data: payload, | |
150 | }); err != nil { | |
151 | return false, nil, err | |
152 | } | |
153 | ||
154 | if !wantReply { | |
155 | return false, nil, nil | |
156 | } | |
157 | ||
158 | msg, ok := <-m.globalResponses | |
159 | if !ok { | |
160 | return false, nil, io.EOF | |
161 | } | |
162 | switch msg := msg.(type) { | |
163 | case *globalRequestFailureMsg: | |
164 | return false, msg.Data, nil | |
165 | case *globalRequestSuccessMsg: | |
166 | return true, msg.Data, nil | |
167 | default: | |
168 | return false, nil, fmt.Errorf("ssh: unexpected response to request: %#v", msg) | |
169 | } | |
170 | } | |
171 | ||
172 | // ackRequest must be called after processing a global request that | |
173 | // has WantReply set. | |
174 | func (m *mux) ackRequest(ok bool, data []byte) error { | |
175 | if ok { | |
176 | return m.sendMessage(globalRequestSuccessMsg{Data: data}) | |
177 | } | |
178 | return m.sendMessage(globalRequestFailureMsg{Data: data}) | |
179 | } | |
180 | ||
181 | func (m *mux) Close() error { | |
182 | return m.conn.Close() | |
183 | } | |
184 | ||
185 | // loop runs the connection machine. It will process packets until an | |
186 | // error is encountered. To synchronize on loop exit, use mux.Wait. | |
187 | func (m *mux) loop() { | |
188 | var err error | |
189 | for err == nil { | |
190 | err = m.onePacket() | |
191 | } | |
192 | ||
193 | for _, ch := range m.chanList.dropAll() { | |
194 | ch.close() | |
195 | } | |
196 | ||
197 | close(m.incomingChannels) | |
198 | close(m.incomingRequests) | |
199 | close(m.globalResponses) | |
200 | ||
201 | m.conn.Close() | |
202 | ||
203 | m.errCond.L.Lock() | |
204 | m.err = err | |
205 | m.errCond.Broadcast() | |
206 | m.errCond.L.Unlock() | |
207 | ||
208 | if debugMux { | |
209 | log.Println("loop exit", err) | |
210 | } | |
211 | } | |
212 | ||
213 | // onePacket reads and processes one packet. | |
214 | func (m *mux) onePacket() error { | |
215 | packet, err := m.conn.readPacket() | |
216 | if err != nil { | |
217 | return err | |
218 | } | |
219 | ||
220 | if debugMux { | |
221 | if packet[0] == msgChannelData || packet[0] == msgChannelExtendedData { | |
222 | log.Printf("decoding(%d): data packet - %d bytes", m.chanList.offset, len(packet)) | |
223 | } else { | |
224 | p, _ := decode(packet) | |
225 | log.Printf("decoding(%d): %d %#v - %d bytes", m.chanList.offset, packet[0], p, len(packet)) | |
226 | } | |
227 | } | |
228 | ||
229 | switch packet[0] { | |
230 | case msgChannelOpen: | |
231 | return m.handleChannelOpen(packet) | |
232 | case msgGlobalRequest, msgRequestSuccess, msgRequestFailure: | |
233 | return m.handleGlobalPacket(packet) | |
234 | } | |
235 | ||
236 | // assume a channel packet. | |
237 | if len(packet) < 5 { | |
238 | return parseError(packet[0]) | |
239 | } | |
240 | id := binary.BigEndian.Uint32(packet[1:]) | |
241 | ch := m.chanList.getChan(id) | |
242 | if ch == nil { | |
243 | return fmt.Errorf("ssh: invalid channel %d", id) | |
244 | } | |
245 | ||
246 | return ch.handlePacket(packet) | |
247 | } | |
248 | ||
249 | func (m *mux) handleGlobalPacket(packet []byte) error { | |
250 | msg, err := decode(packet) | |
251 | if err != nil { | |
252 | return err | |
253 | } | |
254 | ||
255 | switch msg := msg.(type) { | |
256 | case *globalRequestMsg: | |
257 | m.incomingRequests <- &Request{ | |
258 | Type: msg.Type, | |
259 | WantReply: msg.WantReply, | |
260 | Payload: msg.Data, | |
261 | mux: m, | |
262 | } | |
263 | case *globalRequestSuccessMsg, *globalRequestFailureMsg: | |
264 | m.globalResponses <- msg | |
265 | default: | |
266 | panic(fmt.Sprintf("not a global message %#v", msg)) | |
267 | } | |
268 | ||
269 | return nil | |
270 | } | |
271 | ||
272 | // handleChannelOpen schedules a channel to be Accept()ed. | |
273 | func (m *mux) handleChannelOpen(packet []byte) error { | |
274 | var msg channelOpenMsg | |
275 | if err := Unmarshal(packet, &msg); err != nil { | |
276 | return err | |
277 | } | |
278 | ||
279 | if msg.MaxPacketSize < minPacketLength || msg.MaxPacketSize > 1<<31 { | |
280 | failMsg := channelOpenFailureMsg{ | |
281 | PeersId: msg.PeersId, | |
282 | Reason: ConnectionFailed, | |
283 | Message: "invalid request", | |
284 | Language: "en_US.UTF-8", | |
285 | } | |
286 | return m.sendMessage(failMsg) | |
287 | } | |
288 | ||
289 | c := m.newChannel(msg.ChanType, channelInbound, msg.TypeSpecificData) | |
290 | c.remoteId = msg.PeersId | |
291 | c.maxRemotePayload = msg.MaxPacketSize | |
292 | c.remoteWin.add(msg.PeersWindow) | |
293 | m.incomingChannels <- c | |
294 | return nil | |
295 | } | |
296 | ||
297 | func (m *mux) OpenChannel(chanType string, extra []byte) (Channel, <-chan *Request, error) { | |
298 | ch, err := m.openChannel(chanType, extra) | |
299 | if err != nil { | |
300 | return nil, nil, err | |
301 | } | |
302 | ||
303 | return ch, ch.incomingRequests, nil | |
304 | } | |
305 | ||
306 | func (m *mux) openChannel(chanType string, extra []byte) (*channel, error) { | |
307 | ch := m.newChannel(chanType, channelOutbound, extra) | |
308 | ||
309 | ch.maxIncomingPayload = channelMaxPacket | |
310 | ||
311 | open := channelOpenMsg{ | |
312 | ChanType: chanType, | |
313 | PeersWindow: ch.myWindow, | |
314 | MaxPacketSize: ch.maxIncomingPayload, | |
315 | TypeSpecificData: extra, | |
316 | PeersId: ch.localId, | |
317 | } | |
318 | if err := m.sendMessage(open); err != nil { | |
319 | return nil, err | |
320 | } | |
321 | ||
322 | switch msg := (<-ch.msg).(type) { | |
323 | case *channelOpenConfirmMsg: | |
324 | return ch, nil | |
325 | case *channelOpenFailureMsg: | |
326 | return nil, &OpenChannelError{msg.Reason, msg.Message} | |
327 | default: | |
328 | return nil, fmt.Errorf("ssh: unexpected packet in response to channel open: %T", msg) | |
329 | } | |
330 | } |