]>
Commit | Line | Data |
---|---|---|
bae9f6d2 JC |
1 | package plugin |
2 | ||
3 | import ( | |
4 | "encoding/binary" | |
5 | "fmt" | |
6 | "log" | |
7 | "net" | |
8 | "sync" | |
9 | "sync/atomic" | |
10 | "time" | |
11 | ||
12 | "github.com/hashicorp/yamux" | |
13 | ) | |
14 | ||
15 | // MuxBroker is responsible for brokering multiplexed connections by unique ID. | |
16 | // | |
17 | // It is used by plugins to multiplex multiple RPC connections and data | |
18 | // streams on top of a single connection between the plugin process and the | |
19 | // host process. | |
20 | // | |
21 | // This allows a plugin to request a channel with a specific ID to connect to | |
22 | // or accept a connection from, and the broker handles the details of | |
23 | // holding these channels open while they're being negotiated. | |
24 | // | |
25 | // The Plugin interface has access to these for both Server and Client. | |
26 | // The broker can be used by either (optionally) to reserve and connect to | |
27 | // new multiplexed streams. This is useful for complex args and return values, | |
28 | // or anything else you might need a data stream for. | |
29 | type MuxBroker struct { | |
30 | nextId uint32 | |
31 | session *yamux.Session | |
32 | streams map[uint32]*muxBrokerPending | |
33 | ||
34 | sync.Mutex | |
35 | } | |
36 | ||
37 | type muxBrokerPending struct { | |
38 | ch chan net.Conn | |
39 | doneCh chan struct{} | |
40 | } | |
41 | ||
42 | func newMuxBroker(s *yamux.Session) *MuxBroker { | |
43 | return &MuxBroker{ | |
44 | session: s, | |
45 | streams: make(map[uint32]*muxBrokerPending), | |
46 | } | |
47 | } | |
48 | ||
49 | // Accept accepts a connection by ID. | |
50 | // | |
51 | // This should not be called multiple times with the same ID at one time. | |
52 | func (m *MuxBroker) Accept(id uint32) (net.Conn, error) { | |
53 | var c net.Conn | |
54 | p := m.getStream(id) | |
55 | select { | |
56 | case c = <-p.ch: | |
57 | close(p.doneCh) | |
58 | case <-time.After(5 * time.Second): | |
59 | m.Lock() | |
60 | defer m.Unlock() | |
61 | delete(m.streams, id) | |
62 | ||
63 | return nil, fmt.Errorf("timeout waiting for accept") | |
64 | } | |
65 | ||
66 | // Ack our connection | |
67 | if err := binary.Write(c, binary.LittleEndian, id); err != nil { | |
68 | c.Close() | |
69 | return nil, err | |
70 | } | |
71 | ||
72 | return c, nil | |
73 | } | |
74 | ||
75 | // AcceptAndServe is used to accept a specific stream ID and immediately | |
76 | // serve an RPC server on that stream ID. This is used to easily serve | |
77 | // complex arguments. | |
78 | // | |
79 | // The served interface is always registered to the "Plugin" name. | |
80 | func (m *MuxBroker) AcceptAndServe(id uint32, v interface{}) { | |
81 | conn, err := m.Accept(id) | |
82 | if err != nil { | |
83 | log.Printf("[ERR] plugin: plugin acceptAndServe error: %s", err) | |
84 | return | |
85 | } | |
86 | ||
87 | serve(conn, "Plugin", v) | |
88 | } | |
89 | ||
90 | // Close closes the connection and all sub-connections. | |
91 | func (m *MuxBroker) Close() error { | |
92 | return m.session.Close() | |
93 | } | |
94 | ||
95 | // Dial opens a connection by ID. | |
96 | func (m *MuxBroker) Dial(id uint32) (net.Conn, error) { | |
97 | // Open the stream | |
98 | stream, err := m.session.OpenStream() | |
99 | if err != nil { | |
100 | return nil, err | |
101 | } | |
102 | ||
103 | // Write the stream ID onto the wire. | |
104 | if err := binary.Write(stream, binary.LittleEndian, id); err != nil { | |
105 | stream.Close() | |
106 | return nil, err | |
107 | } | |
108 | ||
109 | // Read the ack that we connected. Then we're off! | |
110 | var ack uint32 | |
111 | if err := binary.Read(stream, binary.LittleEndian, &ack); err != nil { | |
112 | stream.Close() | |
113 | return nil, err | |
114 | } | |
115 | if ack != id { | |
116 | stream.Close() | |
117 | return nil, fmt.Errorf("bad ack: %d (expected %d)", ack, id) | |
118 | } | |
119 | ||
120 | return stream, nil | |
121 | } | |
122 | ||
123 | // NextId returns a unique ID to use next. | |
124 | // | |
125 | // It is possible for very long-running plugin hosts to wrap this value, | |
126 | // though it would require a very large amount of RPC calls. In practice | |
127 | // we've never seen it happen. | |
128 | func (m *MuxBroker) NextId() uint32 { | |
129 | return atomic.AddUint32(&m.nextId, 1) | |
130 | } | |
131 | ||
132 | // Run starts the brokering and should be executed in a goroutine, since it | |
133 | // blocks forever, or until the session closes. | |
134 | // | |
135 | // Uses of MuxBroker never need to call this. It is called internally by | |
136 | // the plugin host/client. | |
137 | func (m *MuxBroker) Run() { | |
138 | for { | |
139 | stream, err := m.session.AcceptStream() | |
140 | if err != nil { | |
141 | // Once we receive an error, just exit | |
142 | break | |
143 | } | |
144 | ||
145 | // Read the stream ID from the stream | |
146 | var id uint32 | |
147 | if err := binary.Read(stream, binary.LittleEndian, &id); err != nil { | |
148 | stream.Close() | |
149 | continue | |
150 | } | |
151 | ||
152 | // Initialize the waiter | |
153 | p := m.getStream(id) | |
154 | select { | |
155 | case p.ch <- stream: | |
156 | default: | |
157 | } | |
158 | ||
159 | // Wait for a timeout | |
160 | go m.timeoutWait(id, p) | |
161 | } | |
162 | } | |
163 | ||
164 | func (m *MuxBroker) getStream(id uint32) *muxBrokerPending { | |
165 | m.Lock() | |
166 | defer m.Unlock() | |
167 | ||
168 | p, ok := m.streams[id] | |
169 | if ok { | |
170 | return p | |
171 | } | |
172 | ||
173 | m.streams[id] = &muxBrokerPending{ | |
174 | ch: make(chan net.Conn, 1), | |
175 | doneCh: make(chan struct{}), | |
176 | } | |
177 | return m.streams[id] | |
178 | } | |
179 | ||
180 | func (m *MuxBroker) timeoutWait(id uint32, p *muxBrokerPending) { | |
181 | // Wait for the stream to either be picked up and connected, or | |
182 | // for a timeout. | |
183 | timeout := false | |
184 | select { | |
185 | case <-p.doneCh: | |
186 | case <-time.After(5 * time.Second): | |
187 | timeout = true | |
188 | } | |
189 | ||
190 | m.Lock() | |
191 | defer m.Unlock() | |
192 | ||
193 | // Delete the stream so no one else can grab it | |
194 | delete(m.streams, id) | |
195 | ||
196 | // If we timed out, then check if we have a channel in the buffer, | |
197 | // and if so, close it. | |
198 | if timeout { | |
199 | select { | |
200 | case s := <-p.ch: | |
201 | s.Close() | |
202 | } | |
203 | } | |
204 | } |