]>
Commit | Line | Data |
---|---|---|
9b12e4fe JC |
1 | // Copyright 2015 go-dockerclient authors. All rights reserved. |
2 | // Use of this source code is governed by a BSD-style | |
3 | // license that can be found in the LICENSE file. | |
4 | ||
5 | // Package docker provides a client for the Docker remote API. | |
6 | // | |
7 | // See https://goo.gl/G3plxW for more details on the remote API. | |
8 | package docker | |
9 | ||
10 | import ( | |
11 | "bufio" | |
12 | "bytes" | |
13 | "crypto/tls" | |
14 | "crypto/x509" | |
15 | "encoding/json" | |
16 | "errors" | |
17 | "fmt" | |
18 | "io" | |
19 | "io/ioutil" | |
20 | "net" | |
21 | "net/http" | |
22 | "net/http/httputil" | |
23 | "net/url" | |
24 | "os" | |
25 | "path/filepath" | |
26 | "reflect" | |
27 | "runtime" | |
28 | "strconv" | |
29 | "strings" | |
30 | "sync/atomic" | |
31 | "time" | |
32 | ||
33 | "github.com/fsouza/go-dockerclient/external/github.com/docker/docker/opts" | |
34 | "github.com/fsouza/go-dockerclient/external/github.com/docker/docker/pkg/homedir" | |
35 | "github.com/fsouza/go-dockerclient/external/github.com/docker/docker/pkg/stdcopy" | |
36 | "github.com/fsouza/go-dockerclient/external/github.com/hashicorp/go-cleanhttp" | |
37 | ) | |
38 | ||
39 | const userAgent = "go-dockerclient" | |
40 | ||
41 | var ( | |
42 | // ErrInvalidEndpoint is returned when the endpoint is not a valid HTTP URL. | |
43 | ErrInvalidEndpoint = errors.New("invalid endpoint") | |
44 | ||
45 | // ErrConnectionRefused is returned when the client cannot connect to the given endpoint. | |
46 | ErrConnectionRefused = errors.New("cannot connect to Docker endpoint") | |
47 | ||
48 | // ErrInactivityTimeout is returned when a streamable call has been inactive for some time. | |
49 | ErrInactivityTimeout = errors.New("inactivity time exceeded timeout") | |
50 | ||
51 | apiVersion112, _ = NewAPIVersion("1.12") | |
52 | ||
53 | apiVersion119, _ = NewAPIVersion("1.19") | |
54 | ) | |
55 | ||
56 | // APIVersion is an internal representation of a version of the Remote API. | |
57 | type APIVersion []int | |
58 | ||
59 | // NewAPIVersion returns an instance of APIVersion for the given string. | |
60 | // | |
61 | // The given string must be in the form <major>.<minor>.<patch>, where <major>, | |
62 | // <minor> and <patch> are integer numbers. | |
63 | func NewAPIVersion(input string) (APIVersion, error) { | |
64 | if !strings.Contains(input, ".") { | |
65 | return nil, fmt.Errorf("Unable to parse version %q", input) | |
66 | } | |
67 | raw := strings.Split(input, "-") | |
68 | arr := strings.Split(raw[0], ".") | |
69 | ret := make(APIVersion, len(arr)) | |
70 | var err error | |
71 | for i, val := range arr { | |
72 | ret[i], err = strconv.Atoi(val) | |
73 | if err != nil { | |
74 | return nil, fmt.Errorf("Unable to parse version %q: %q is not an integer", input, val) | |
75 | } | |
76 | } | |
77 | return ret, nil | |
78 | } | |
79 | ||
80 | func (version APIVersion) String() string { | |
81 | var str string | |
82 | for i, val := range version { | |
83 | str += strconv.Itoa(val) | |
84 | if i < len(version)-1 { | |
85 | str += "." | |
86 | } | |
87 | } | |
88 | return str | |
89 | } | |
90 | ||
91 | // LessThan is a function for comparing APIVersion structs | |
92 | func (version APIVersion) LessThan(other APIVersion) bool { | |
93 | return version.compare(other) < 0 | |
94 | } | |
95 | ||
96 | // LessThanOrEqualTo is a function for comparing APIVersion structs | |
97 | func (version APIVersion) LessThanOrEqualTo(other APIVersion) bool { | |
98 | return version.compare(other) <= 0 | |
99 | } | |
100 | ||
101 | // GreaterThan is a function for comparing APIVersion structs | |
102 | func (version APIVersion) GreaterThan(other APIVersion) bool { | |
103 | return version.compare(other) > 0 | |
104 | } | |
105 | ||
106 | // GreaterThanOrEqualTo is a function for comparing APIVersion structs | |
107 | func (version APIVersion) GreaterThanOrEqualTo(other APIVersion) bool { | |
108 | return version.compare(other) >= 0 | |
109 | } | |
110 | ||
111 | func (version APIVersion) compare(other APIVersion) int { | |
112 | for i, v := range version { | |
113 | if i <= len(other)-1 { | |
114 | otherVersion := other[i] | |
115 | ||
116 | if v < otherVersion { | |
117 | return -1 | |
118 | } else if v > otherVersion { | |
119 | return 1 | |
120 | } | |
121 | } | |
122 | } | |
123 | if len(version) > len(other) { | |
124 | return 1 | |
125 | } | |
126 | if len(version) < len(other) { | |
127 | return -1 | |
128 | } | |
129 | return 0 | |
130 | } | |
131 | ||
132 | // Client is the basic type of this package. It provides methods for | |
133 | // interaction with the API. | |
134 | type Client struct { | |
135 | SkipServerVersionCheck bool | |
136 | HTTPClient *http.Client | |
137 | TLSConfig *tls.Config | |
138 | Dialer *net.Dialer | |
139 | ||
140 | endpoint string | |
141 | endpointURL *url.URL | |
142 | eventMonitor *eventMonitoringState | |
143 | requestedAPIVersion APIVersion | |
144 | serverAPIVersion APIVersion | |
145 | expectedAPIVersion APIVersion | |
146 | unixHTTPClient *http.Client | |
147 | } | |
148 | ||
149 | // NewClient returns a Client instance ready for communication with the given | |
150 | // server endpoint. It will use the latest remote API version available in the | |
151 | // server. | |
152 | func NewClient(endpoint string) (*Client, error) { | |
153 | client, err := NewVersionedClient(endpoint, "") | |
154 | if err != nil { | |
155 | return nil, err | |
156 | } | |
157 | client.SkipServerVersionCheck = true | |
158 | return client, nil | |
159 | } | |
160 | ||
161 | // NewTLSClient returns a Client instance ready for TLS communications with the givens | |
162 | // server endpoint, key and certificates . It will use the latest remote API version | |
163 | // available in the server. | |
164 | func NewTLSClient(endpoint string, cert, key, ca string) (*Client, error) { | |
165 | client, err := NewVersionedTLSClient(endpoint, cert, key, ca, "") | |
166 | if err != nil { | |
167 | return nil, err | |
168 | } | |
169 | client.SkipServerVersionCheck = true | |
170 | return client, nil | |
171 | } | |
172 | ||
173 | // NewTLSClientFromBytes returns a Client instance ready for TLS communications with the givens | |
174 | // server endpoint, key and certificates (passed inline to the function as opposed to being | |
175 | // read from a local file). It will use the latest remote API version available in the server. | |
176 | func NewTLSClientFromBytes(endpoint string, certPEMBlock, keyPEMBlock, caPEMCert []byte) (*Client, error) { | |
177 | client, err := NewVersionedTLSClientFromBytes(endpoint, certPEMBlock, keyPEMBlock, caPEMCert, "") | |
178 | if err != nil { | |
179 | return nil, err | |
180 | } | |
181 | client.SkipServerVersionCheck = true | |
182 | return client, nil | |
183 | } | |
184 | ||
185 | // NewVersionedClient returns a Client instance ready for communication with | |
186 | // the given server endpoint, using a specific remote API version. | |
187 | func NewVersionedClient(endpoint string, apiVersionString string) (*Client, error) { | |
188 | u, err := parseEndpoint(endpoint, false) | |
189 | if err != nil { | |
190 | return nil, err | |
191 | } | |
192 | var requestedAPIVersion APIVersion | |
193 | if strings.Contains(apiVersionString, ".") { | |
194 | requestedAPIVersion, err = NewAPIVersion(apiVersionString) | |
195 | if err != nil { | |
196 | return nil, err | |
197 | } | |
198 | } | |
199 | return &Client{ | |
200 | HTTPClient: cleanhttp.DefaultClient(), | |
201 | Dialer: &net.Dialer{}, | |
202 | endpoint: endpoint, | |
203 | endpointURL: u, | |
204 | eventMonitor: new(eventMonitoringState), | |
205 | requestedAPIVersion: requestedAPIVersion, | |
206 | }, nil | |
207 | } | |
208 | ||
209 | // NewVersionnedTLSClient has been DEPRECATED, please use NewVersionedTLSClient. | |
210 | func NewVersionnedTLSClient(endpoint string, cert, key, ca, apiVersionString string) (*Client, error) { | |
211 | return NewVersionedTLSClient(endpoint, cert, key, ca, apiVersionString) | |
212 | } | |
213 | ||
214 | // NewVersionedTLSClient returns a Client instance ready for TLS communications with the givens | |
215 | // server endpoint, key and certificates, using a specific remote API version. | |
216 | func NewVersionedTLSClient(endpoint string, cert, key, ca, apiVersionString string) (*Client, error) { | |
217 | certPEMBlock, err := ioutil.ReadFile(cert) | |
218 | if err != nil { | |
219 | return nil, err | |
220 | } | |
221 | keyPEMBlock, err := ioutil.ReadFile(key) | |
222 | if err != nil { | |
223 | return nil, err | |
224 | } | |
225 | caPEMCert, err := ioutil.ReadFile(ca) | |
226 | if err != nil { | |
227 | return nil, err | |
228 | } | |
229 | return NewVersionedTLSClientFromBytes(endpoint, certPEMBlock, keyPEMBlock, caPEMCert, apiVersionString) | |
230 | } | |
231 | ||
232 | // NewClientFromEnv returns a Client instance ready for communication created from | |
233 | // Docker's default logic for the environment variables DOCKER_HOST, DOCKER_TLS_VERIFY, and DOCKER_CERT_PATH. | |
234 | // | |
235 | // See https://github.com/docker/docker/blob/1f963af697e8df3a78217f6fdbf67b8123a7db94/docker/docker.go#L68. | |
236 | // See https://github.com/docker/compose/blob/81707ef1ad94403789166d2fe042c8a718a4c748/compose/cli/docker_client.py#L7. | |
237 | func NewClientFromEnv() (*Client, error) { | |
238 | client, err := NewVersionedClientFromEnv("") | |
239 | if err != nil { | |
240 | return nil, err | |
241 | } | |
242 | client.SkipServerVersionCheck = true | |
243 | return client, nil | |
244 | } | |
245 | ||
246 | // NewVersionedClientFromEnv returns a Client instance ready for TLS communications created from | |
247 | // Docker's default logic for the environment variables DOCKER_HOST, DOCKER_TLS_VERIFY, and DOCKER_CERT_PATH, | |
248 | // and using a specific remote API version. | |
249 | // | |
250 | // See https://github.com/docker/docker/blob/1f963af697e8df3a78217f6fdbf67b8123a7db94/docker/docker.go#L68. | |
251 | // See https://github.com/docker/compose/blob/81707ef1ad94403789166d2fe042c8a718a4c748/compose/cli/docker_client.py#L7. | |
252 | func NewVersionedClientFromEnv(apiVersionString string) (*Client, error) { | |
253 | dockerEnv, err := getDockerEnv() | |
254 | if err != nil { | |
255 | return nil, err | |
256 | } | |
257 | dockerHost := dockerEnv.dockerHost | |
258 | if dockerEnv.dockerTLSVerify { | |
259 | parts := strings.SplitN(dockerEnv.dockerHost, "://", 2) | |
260 | if len(parts) != 2 { | |
261 | return nil, fmt.Errorf("could not split %s into two parts by ://", dockerHost) | |
262 | } | |
263 | cert := filepath.Join(dockerEnv.dockerCertPath, "cert.pem") | |
264 | key := filepath.Join(dockerEnv.dockerCertPath, "key.pem") | |
265 | ca := filepath.Join(dockerEnv.dockerCertPath, "ca.pem") | |
266 | return NewVersionedTLSClient(dockerEnv.dockerHost, cert, key, ca, apiVersionString) | |
267 | } | |
268 | return NewVersionedClient(dockerEnv.dockerHost, apiVersionString) | |
269 | } | |
270 | ||
271 | // NewVersionedTLSClientFromBytes returns a Client instance ready for TLS communications with the givens | |
272 | // server endpoint, key and certificates (passed inline to the function as opposed to being | |
273 | // read from a local file), using a specific remote API version. | |
274 | func NewVersionedTLSClientFromBytes(endpoint string, certPEMBlock, keyPEMBlock, caPEMCert []byte, apiVersionString string) (*Client, error) { | |
275 | u, err := parseEndpoint(endpoint, true) | |
276 | if err != nil { | |
277 | return nil, err | |
278 | } | |
279 | var requestedAPIVersion APIVersion | |
280 | if strings.Contains(apiVersionString, ".") { | |
281 | requestedAPIVersion, err = NewAPIVersion(apiVersionString) | |
282 | if err != nil { | |
283 | return nil, err | |
284 | } | |
285 | } | |
286 | if certPEMBlock == nil || keyPEMBlock == nil { | |
287 | return nil, errors.New("Both cert and key are required") | |
288 | } | |
289 | tlsCert, err := tls.X509KeyPair(certPEMBlock, keyPEMBlock) | |
290 | if err != nil { | |
291 | return nil, err | |
292 | } | |
293 | tlsConfig := &tls.Config{Certificates: []tls.Certificate{tlsCert}} | |
294 | if caPEMCert == nil { | |
295 | tlsConfig.InsecureSkipVerify = true | |
296 | } else { | |
297 | caPool := x509.NewCertPool() | |
298 | if !caPool.AppendCertsFromPEM(caPEMCert) { | |
299 | return nil, errors.New("Could not add RootCA pem") | |
300 | } | |
301 | tlsConfig.RootCAs = caPool | |
302 | } | |
303 | tr := cleanhttp.DefaultTransport() | |
304 | tr.TLSClientConfig = tlsConfig | |
305 | if err != nil { | |
306 | return nil, err | |
307 | } | |
308 | return &Client{ | |
309 | HTTPClient: &http.Client{Transport: tr}, | |
310 | TLSConfig: tlsConfig, | |
311 | Dialer: &net.Dialer{}, | |
312 | endpoint: endpoint, | |
313 | endpointURL: u, | |
314 | eventMonitor: new(eventMonitoringState), | |
315 | requestedAPIVersion: requestedAPIVersion, | |
316 | }, nil | |
317 | } | |
318 | ||
319 | func (c *Client) checkAPIVersion() error { | |
320 | serverAPIVersionString, err := c.getServerAPIVersionString() | |
321 | if err != nil { | |
322 | return err | |
323 | } | |
324 | c.serverAPIVersion, err = NewAPIVersion(serverAPIVersionString) | |
325 | if err != nil { | |
326 | return err | |
327 | } | |
328 | if c.requestedAPIVersion == nil { | |
329 | c.expectedAPIVersion = c.serverAPIVersion | |
330 | } else { | |
331 | c.expectedAPIVersion = c.requestedAPIVersion | |
332 | } | |
333 | return nil | |
334 | } | |
335 | ||
336 | // Endpoint returns the current endpoint. It's useful for getting the endpoint | |
337 | // when using functions that get this data from the environment (like | |
338 | // NewClientFromEnv. | |
339 | func (c *Client) Endpoint() string { | |
340 | return c.endpoint | |
341 | } | |
342 | ||
343 | // Ping pings the docker server | |
344 | // | |
345 | // See https://goo.gl/kQCfJj for more details. | |
346 | func (c *Client) Ping() error { | |
347 | path := "/_ping" | |
348 | resp, err := c.do("GET", path, doOptions{}) | |
349 | if err != nil { | |
350 | return err | |
351 | } | |
352 | if resp.StatusCode != http.StatusOK { | |
353 | return newError(resp) | |
354 | } | |
355 | resp.Body.Close() | |
356 | return nil | |
357 | } | |
358 | ||
359 | func (c *Client) getServerAPIVersionString() (version string, err error) { | |
360 | resp, err := c.do("GET", "/version", doOptions{}) | |
361 | if err != nil { | |
362 | return "", err | |
363 | } | |
364 | defer resp.Body.Close() | |
365 | if resp.StatusCode != http.StatusOK { | |
366 | return "", fmt.Errorf("Received unexpected status %d while trying to retrieve the server version", resp.StatusCode) | |
367 | } | |
368 | var versionResponse map[string]interface{} | |
369 | if err := json.NewDecoder(resp.Body).Decode(&versionResponse); err != nil { | |
370 | return "", err | |
371 | } | |
372 | if version, ok := (versionResponse["ApiVersion"]).(string); ok { | |
373 | return version, nil | |
374 | } | |
375 | return "", nil | |
376 | } | |
377 | ||
378 | type doOptions struct { | |
379 | data interface{} | |
380 | forceJSON bool | |
381 | headers map[string]string | |
382 | } | |
383 | ||
384 | func (c *Client) do(method, path string, doOptions doOptions) (*http.Response, error) { | |
385 | var params io.Reader | |
386 | if doOptions.data != nil || doOptions.forceJSON { | |
387 | buf, err := json.Marshal(doOptions.data) | |
388 | if err != nil { | |
389 | return nil, err | |
390 | } | |
391 | params = bytes.NewBuffer(buf) | |
392 | } | |
393 | if path != "/version" && !c.SkipServerVersionCheck && c.expectedAPIVersion == nil { | |
394 | err := c.checkAPIVersion() | |
395 | if err != nil { | |
396 | return nil, err | |
397 | } | |
398 | } | |
399 | httpClient := c.HTTPClient | |
400 | protocol := c.endpointURL.Scheme | |
401 | var u string | |
402 | if protocol == "unix" { | |
403 | httpClient = c.unixClient() | |
404 | u = c.getFakeUnixURL(path) | |
405 | } else { | |
406 | u = c.getURL(path) | |
407 | } | |
408 | req, err := http.NewRequest(method, u, params) | |
409 | if err != nil { | |
410 | return nil, err | |
411 | } | |
412 | req.Header.Set("User-Agent", userAgent) | |
413 | if doOptions.data != nil { | |
414 | req.Header.Set("Content-Type", "application/json") | |
415 | } else if method == "POST" { | |
416 | req.Header.Set("Content-Type", "plain/text") | |
417 | } | |
418 | ||
419 | for k, v := range doOptions.headers { | |
420 | req.Header.Set(k, v) | |
421 | } | |
422 | resp, err := httpClient.Do(req) | |
423 | if err != nil { | |
424 | if strings.Contains(err.Error(), "connection refused") { | |
425 | return nil, ErrConnectionRefused | |
426 | } | |
427 | return nil, err | |
428 | } | |
429 | if resp.StatusCode < 200 || resp.StatusCode >= 400 { | |
430 | return nil, newError(resp) | |
431 | } | |
432 | return resp, nil | |
433 | } | |
434 | ||
435 | type streamOptions struct { | |
436 | setRawTerminal bool | |
437 | rawJSONStream bool | |
438 | useJSONDecoder bool | |
439 | headers map[string]string | |
440 | in io.Reader | |
441 | stdout io.Writer | |
442 | stderr io.Writer | |
443 | // timeout is the initial connection timeout | |
444 | timeout time.Duration | |
445 | // Timeout with no data is received, it's reset every time new data | |
446 | // arrives | |
447 | inactivityTimeout time.Duration | |
448 | } | |
449 | ||
450 | func (c *Client) stream(method, path string, streamOptions streamOptions) error { | |
451 | if (method == "POST" || method == "PUT") && streamOptions.in == nil { | |
452 | streamOptions.in = bytes.NewReader(nil) | |
453 | } | |
454 | if path != "/version" && !c.SkipServerVersionCheck && c.expectedAPIVersion == nil { | |
455 | err := c.checkAPIVersion() | |
456 | if err != nil { | |
457 | return err | |
458 | } | |
459 | } | |
460 | req, err := http.NewRequest(method, c.getURL(path), streamOptions.in) | |
461 | if err != nil { | |
462 | return err | |
463 | } | |
464 | req.Header.Set("User-Agent", userAgent) | |
465 | if method == "POST" { | |
466 | req.Header.Set("Content-Type", "plain/text") | |
467 | } | |
468 | for key, val := range streamOptions.headers { | |
469 | req.Header.Set(key, val) | |
470 | } | |
471 | var resp *http.Response | |
472 | protocol := c.endpointURL.Scheme | |
473 | address := c.endpointURL.Path | |
474 | if streamOptions.stdout == nil { | |
475 | streamOptions.stdout = ioutil.Discard | |
476 | } | |
477 | if streamOptions.stderr == nil { | |
478 | streamOptions.stderr = ioutil.Discard | |
479 | } | |
480 | cancelRequest := cancelable(c.HTTPClient, req) | |
481 | if protocol == "unix" { | |
482 | dial, err := c.Dialer.Dial(protocol, address) | |
483 | if err != nil { | |
484 | return err | |
485 | } | |
486 | cancelRequest = func() { dial.Close() } | |
487 | defer dial.Close() | |
488 | breader := bufio.NewReader(dial) | |
489 | err = req.Write(dial) | |
490 | if err != nil { | |
491 | return err | |
492 | } | |
493 | ||
494 | // ReadResponse may hang if server does not replay | |
495 | if streamOptions.timeout > 0 { | |
496 | dial.SetDeadline(time.Now().Add(streamOptions.timeout)) | |
497 | } | |
498 | ||
499 | if resp, err = http.ReadResponse(breader, req); err != nil { | |
500 | // Cancel timeout for future I/O operations | |
501 | if streamOptions.timeout > 0 { | |
502 | dial.SetDeadline(time.Time{}) | |
503 | } | |
504 | if strings.Contains(err.Error(), "connection refused") { | |
505 | return ErrConnectionRefused | |
506 | } | |
507 | return err | |
508 | } | |
509 | } else { | |
510 | if resp, err = c.HTTPClient.Do(req); err != nil { | |
511 | if strings.Contains(err.Error(), "connection refused") { | |
512 | return ErrConnectionRefused | |
513 | } | |
514 | return err | |
515 | } | |
516 | } | |
517 | defer resp.Body.Close() | |
518 | if resp.StatusCode < 200 || resp.StatusCode >= 400 { | |
519 | return newError(resp) | |
520 | } | |
521 | var canceled uint32 | |
522 | if streamOptions.inactivityTimeout > 0 { | |
523 | ch := handleInactivityTimeout(&streamOptions, cancelRequest, &canceled) | |
524 | defer close(ch) | |
525 | } | |
526 | err = handleStreamResponse(resp, &streamOptions) | |
527 | if err != nil { | |
528 | if atomic.LoadUint32(&canceled) != 0 { | |
529 | return ErrInactivityTimeout | |
530 | } | |
531 | return err | |
532 | } | |
533 | return nil | |
534 | } | |
535 | ||
536 | func handleStreamResponse(resp *http.Response, streamOptions *streamOptions) error { | |
537 | var err error | |
538 | if !streamOptions.useJSONDecoder && resp.Header.Get("Content-Type") != "application/json" { | |
539 | if streamOptions.setRawTerminal { | |
540 | _, err = io.Copy(streamOptions.stdout, resp.Body) | |
541 | } else { | |
542 | _, err = stdcopy.StdCopy(streamOptions.stdout, streamOptions.stderr, resp.Body) | |
543 | } | |
544 | return err | |
545 | } | |
546 | // if we want to get raw json stream, just copy it back to output | |
547 | // without decoding it | |
548 | if streamOptions.rawJSONStream { | |
549 | _, err = io.Copy(streamOptions.stdout, resp.Body) | |
550 | return err | |
551 | } | |
552 | dec := json.NewDecoder(resp.Body) | |
553 | for { | |
554 | var m jsonMessage | |
555 | if err := dec.Decode(&m); err == io.EOF { | |
556 | break | |
557 | } else if err != nil { | |
558 | return err | |
559 | } | |
560 | if m.Stream != "" { | |
561 | fmt.Fprint(streamOptions.stdout, m.Stream) | |
562 | } else if m.Progress != "" { | |
563 | fmt.Fprintf(streamOptions.stdout, "%s %s\r", m.Status, m.Progress) | |
564 | } else if m.Error != "" { | |
565 | return errors.New(m.Error) | |
566 | } | |
567 | if m.Status != "" { | |
568 | fmt.Fprintln(streamOptions.stdout, m.Status) | |
569 | } | |
570 | } | |
571 | return nil | |
572 | } | |
573 | ||
574 | type proxyWriter struct { | |
575 | io.Writer | |
576 | calls uint64 | |
577 | } | |
578 | ||
579 | func (p *proxyWriter) callCount() uint64 { | |
580 | return atomic.LoadUint64(&p.calls) | |
581 | } | |
582 | ||
583 | func (p *proxyWriter) Write(data []byte) (int, error) { | |
584 | atomic.AddUint64(&p.calls, 1) | |
585 | return p.Writer.Write(data) | |
586 | } | |
587 | ||
588 | func handleInactivityTimeout(options *streamOptions, cancelRequest func(), canceled *uint32) chan<- struct{} { | |
589 | done := make(chan struct{}) | |
590 | proxyStdout := &proxyWriter{Writer: options.stdout} | |
591 | proxyStderr := &proxyWriter{Writer: options.stderr} | |
592 | options.stdout = proxyStdout | |
593 | options.stderr = proxyStderr | |
594 | go func() { | |
595 | var lastCallCount uint64 | |
596 | for { | |
597 | select { | |
598 | case <-time.After(options.inactivityTimeout): | |
599 | case <-done: | |
600 | return | |
601 | } | |
602 | curCallCount := proxyStdout.callCount() + proxyStderr.callCount() | |
603 | if curCallCount == lastCallCount { | |
604 | atomic.AddUint32(canceled, 1) | |
605 | cancelRequest() | |
606 | return | |
607 | } | |
608 | lastCallCount = curCallCount | |
609 | } | |
610 | }() | |
611 | return done | |
612 | } | |
613 | ||
614 | type hijackOptions struct { | |
615 | success chan struct{} | |
616 | setRawTerminal bool | |
617 | in io.Reader | |
618 | stdout io.Writer | |
619 | stderr io.Writer | |
620 | data interface{} | |
621 | } | |
622 | ||
623 | // CloseWaiter is an interface with methods for closing the underlying resource | |
624 | // and then waiting for it to finish processing. | |
625 | type CloseWaiter interface { | |
626 | io.Closer | |
627 | Wait() error | |
628 | } | |
629 | ||
630 | type waiterFunc func() error | |
631 | ||
632 | func (w waiterFunc) Wait() error { return w() } | |
633 | ||
634 | type closerFunc func() error | |
635 | ||
636 | func (c closerFunc) Close() error { return c() } | |
637 | ||
638 | func (c *Client) hijack(method, path string, hijackOptions hijackOptions) (CloseWaiter, error) { | |
639 | if path != "/version" && !c.SkipServerVersionCheck && c.expectedAPIVersion == nil { | |
640 | err := c.checkAPIVersion() | |
641 | if err != nil { | |
642 | return nil, err | |
643 | } | |
644 | } | |
645 | var params io.Reader | |
646 | if hijackOptions.data != nil { | |
647 | buf, err := json.Marshal(hijackOptions.data) | |
648 | if err != nil { | |
649 | return nil, err | |
650 | } | |
651 | params = bytes.NewBuffer(buf) | |
652 | } | |
653 | req, err := http.NewRequest(method, c.getURL(path), params) | |
654 | if err != nil { | |
655 | return nil, err | |
656 | } | |
657 | req.Header.Set("Content-Type", "application/json") | |
658 | req.Header.Set("Connection", "Upgrade") | |
659 | req.Header.Set("Upgrade", "tcp") | |
660 | protocol := c.endpointURL.Scheme | |
661 | address := c.endpointURL.Path | |
662 | if protocol != "unix" { | |
663 | protocol = "tcp" | |
664 | address = c.endpointURL.Host | |
665 | } | |
666 | var dial net.Conn | |
667 | if c.TLSConfig != nil && protocol != "unix" { | |
668 | dial, err = tlsDialWithDialer(c.Dialer, protocol, address, c.TLSConfig) | |
669 | if err != nil { | |
670 | return nil, err | |
671 | } | |
672 | } else { | |
673 | dial, err = c.Dialer.Dial(protocol, address) | |
674 | if err != nil { | |
675 | return nil, err | |
676 | } | |
677 | } | |
678 | ||
679 | errs := make(chan error) | |
680 | quit := make(chan struct{}) | |
681 | go func() { | |
682 | clientconn := httputil.NewClientConn(dial, nil) | |
683 | defer clientconn.Close() | |
684 | clientconn.Do(req) | |
685 | if hijackOptions.success != nil { | |
686 | hijackOptions.success <- struct{}{} | |
687 | <-hijackOptions.success | |
688 | } | |
689 | rwc, br := clientconn.Hijack() | |
690 | defer rwc.Close() | |
691 | ||
692 | errChanOut := make(chan error, 1) | |
693 | errChanIn := make(chan error, 1) | |
694 | if hijackOptions.stdout == nil && hijackOptions.stderr == nil { | |
695 | close(errChanOut) | |
696 | } else { | |
697 | // Only copy if hijackOptions.stdout and/or hijackOptions.stderr is actually set. | |
698 | // Otherwise, if the only stream you care about is stdin, your attach session | |
699 | // will "hang" until the container terminates, even though you're not reading | |
700 | // stdout/stderr | |
701 | if hijackOptions.stdout == nil { | |
702 | hijackOptions.stdout = ioutil.Discard | |
703 | } | |
704 | if hijackOptions.stderr == nil { | |
705 | hijackOptions.stderr = ioutil.Discard | |
706 | } | |
707 | ||
708 | go func() { | |
709 | defer func() { | |
710 | if hijackOptions.in != nil { | |
711 | if closer, ok := hijackOptions.in.(io.Closer); ok { | |
712 | closer.Close() | |
713 | } | |
714 | errChanIn <- nil | |
715 | } | |
716 | }() | |
717 | ||
718 | var err error | |
719 | if hijackOptions.setRawTerminal { | |
720 | _, err = io.Copy(hijackOptions.stdout, br) | |
721 | } else { | |
722 | _, err = stdcopy.StdCopy(hijackOptions.stdout, hijackOptions.stderr, br) | |
723 | } | |
724 | errChanOut <- err | |
725 | }() | |
726 | } | |
727 | ||
728 | go func() { | |
729 | var err error | |
730 | if hijackOptions.in != nil { | |
731 | _, err = io.Copy(rwc, hijackOptions.in) | |
732 | } | |
733 | errChanIn <- err | |
734 | rwc.(interface { | |
735 | CloseWrite() error | |
736 | }).CloseWrite() | |
737 | }() | |
738 | ||
739 | var errIn error | |
740 | select { | |
741 | case errIn = <-errChanIn: | |
742 | case <-quit: | |
743 | return | |
744 | } | |
745 | ||
746 | var errOut error | |
747 | select { | |
748 | case errOut = <-errChanOut: | |
749 | case <-quit: | |
750 | return | |
751 | } | |
752 | ||
753 | if errIn != nil { | |
754 | errs <- errIn | |
755 | } else { | |
756 | errs <- errOut | |
757 | } | |
758 | }() | |
759 | ||
760 | return struct { | |
761 | closerFunc | |
762 | waiterFunc | |
763 | }{ | |
764 | closerFunc(func() error { close(quit); return nil }), | |
765 | waiterFunc(func() error { return <-errs }), | |
766 | }, nil | |
767 | } | |
768 | ||
769 | func (c *Client) getURL(path string) string { | |
770 | urlStr := strings.TrimRight(c.endpointURL.String(), "/") | |
771 | if c.endpointURL.Scheme == "unix" { | |
772 | urlStr = "" | |
773 | } | |
774 | if c.requestedAPIVersion != nil { | |
775 | return fmt.Sprintf("%s/v%s%s", urlStr, c.requestedAPIVersion, path) | |
776 | } | |
777 | return fmt.Sprintf("%s%s", urlStr, path) | |
778 | } | |
779 | ||
780 | // getFakeUnixURL returns the URL needed to make an HTTP request over a UNIX | |
781 | // domain socket to the given path. | |
782 | func (c *Client) getFakeUnixURL(path string) string { | |
783 | u := *c.endpointURL // Copy. | |
784 | ||
785 | // Override URL so that net/http will not complain. | |
786 | u.Scheme = "http" | |
787 | u.Host = "unix.sock" // Doesn't matter what this is - it's not used. | |
788 | u.Path = "" | |
789 | urlStr := strings.TrimRight(u.String(), "/") | |
790 | if c.requestedAPIVersion != nil { | |
791 | return fmt.Sprintf("%s/v%s%s", urlStr, c.requestedAPIVersion, path) | |
792 | } | |
793 | return fmt.Sprintf("%s%s", urlStr, path) | |
794 | } | |
795 | ||
796 | func (c *Client) unixClient() *http.Client { | |
797 | if c.unixHTTPClient != nil { | |
798 | return c.unixHTTPClient | |
799 | } | |
800 | socketPath := c.endpointURL.Path | |
801 | tr := &http.Transport{ | |
802 | Dial: func(network, addr string) (net.Conn, error) { | |
803 | return c.Dialer.Dial("unix", socketPath) | |
804 | }, | |
805 | } | |
806 | cleanhttp.SetTransportFinalizer(tr) | |
807 | c.unixHTTPClient = &http.Client{Transport: tr} | |
808 | return c.unixHTTPClient | |
809 | } | |
810 | ||
811 | type jsonMessage struct { | |
812 | Status string `json:"status,omitempty"` | |
813 | Progress string `json:"progress,omitempty"` | |
814 | Error string `json:"error,omitempty"` | |
815 | Stream string `json:"stream,omitempty"` | |
816 | } | |
817 | ||
818 | func queryString(opts interface{}) string { | |
819 | if opts == nil { | |
820 | return "" | |
821 | } | |
822 | value := reflect.ValueOf(opts) | |
823 | if value.Kind() == reflect.Ptr { | |
824 | value = value.Elem() | |
825 | } | |
826 | if value.Kind() != reflect.Struct { | |
827 | return "" | |
828 | } | |
829 | items := url.Values(map[string][]string{}) | |
830 | for i := 0; i < value.NumField(); i++ { | |
831 | field := value.Type().Field(i) | |
832 | if field.PkgPath != "" { | |
833 | continue | |
834 | } | |
835 | key := field.Tag.Get("qs") | |
836 | if key == "" { | |
837 | key = strings.ToLower(field.Name) | |
838 | } else if key == "-" { | |
839 | continue | |
840 | } | |
841 | addQueryStringValue(items, key, value.Field(i)) | |
842 | } | |
843 | return items.Encode() | |
844 | } | |
845 | ||
846 | func addQueryStringValue(items url.Values, key string, v reflect.Value) { | |
847 | switch v.Kind() { | |
848 | case reflect.Bool: | |
849 | if v.Bool() { | |
850 | items.Add(key, "1") | |
851 | } | |
852 | case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: | |
853 | if v.Int() > 0 { | |
854 | items.Add(key, strconv.FormatInt(v.Int(), 10)) | |
855 | } | |
856 | case reflect.Float32, reflect.Float64: | |
857 | if v.Float() > 0 { | |
858 | items.Add(key, strconv.FormatFloat(v.Float(), 'f', -1, 64)) | |
859 | } | |
860 | case reflect.String: | |
861 | if v.String() != "" { | |
862 | items.Add(key, v.String()) | |
863 | } | |
864 | case reflect.Ptr: | |
865 | if !v.IsNil() { | |
866 | if b, err := json.Marshal(v.Interface()); err == nil { | |
867 | items.Add(key, string(b)) | |
868 | } | |
869 | } | |
870 | case reflect.Map: | |
871 | if len(v.MapKeys()) > 0 { | |
872 | if b, err := json.Marshal(v.Interface()); err == nil { | |
873 | items.Add(key, string(b)) | |
874 | } | |
875 | } | |
876 | case reflect.Array, reflect.Slice: | |
877 | vLen := v.Len() | |
878 | if vLen > 0 { | |
879 | for i := 0; i < vLen; i++ { | |
880 | addQueryStringValue(items, key, v.Index(i)) | |
881 | } | |
882 | } | |
883 | } | |
884 | } | |
885 | ||
886 | // Error represents failures in the API. It represents a failure from the API. | |
887 | type Error struct { | |
888 | Status int | |
889 | Message string | |
890 | } | |
891 | ||
892 | func newError(resp *http.Response) *Error { | |
893 | defer resp.Body.Close() | |
894 | data, err := ioutil.ReadAll(resp.Body) | |
895 | if err != nil { | |
896 | return &Error{Status: resp.StatusCode, Message: fmt.Sprintf("cannot read body, err: %v", err)} | |
897 | } | |
898 | return &Error{Status: resp.StatusCode, Message: string(data)} | |
899 | } | |
900 | ||
901 | func (e *Error) Error() string { | |
902 | return fmt.Sprintf("API error (%d): %s", e.Status, e.Message) | |
903 | } | |
904 | ||
905 | func parseEndpoint(endpoint string, tls bool) (*url.URL, error) { | |
906 | if endpoint != "" && !strings.Contains(endpoint, "://") { | |
907 | endpoint = "tcp://" + endpoint | |
908 | } | |
909 | u, err := url.Parse(endpoint) | |
910 | if err != nil { | |
911 | return nil, ErrInvalidEndpoint | |
912 | } | |
913 | if tls { | |
914 | u.Scheme = "https" | |
915 | } | |
916 | switch u.Scheme { | |
917 | case "unix": | |
918 | return u, nil | |
919 | case "http", "https", "tcp": | |
920 | _, port, err := net.SplitHostPort(u.Host) | |
921 | if err != nil { | |
922 | if e, ok := err.(*net.AddrError); ok { | |
923 | if e.Err == "missing port in address" { | |
924 | return u, nil | |
925 | } | |
926 | } | |
927 | return nil, ErrInvalidEndpoint | |
928 | } | |
929 | number, err := strconv.ParseInt(port, 10, 64) | |
930 | if err == nil && number > 0 && number < 65536 { | |
931 | if u.Scheme == "tcp" { | |
932 | if tls { | |
933 | u.Scheme = "https" | |
934 | } else { | |
935 | u.Scheme = "http" | |
936 | } | |
937 | } | |
938 | return u, nil | |
939 | } | |
940 | return nil, ErrInvalidEndpoint | |
941 | default: | |
942 | return nil, ErrInvalidEndpoint | |
943 | } | |
944 | } | |
945 | ||
946 | type dockerEnv struct { | |
947 | dockerHost string | |
948 | dockerTLSVerify bool | |
949 | dockerCertPath string | |
950 | } | |
951 | ||
952 | func getDockerEnv() (*dockerEnv, error) { | |
953 | dockerHost := os.Getenv("DOCKER_HOST") | |
954 | var err error | |
955 | if dockerHost == "" { | |
956 | dockerHost, err = DefaultDockerHost() | |
957 | if err != nil { | |
958 | return nil, err | |
959 | } | |
960 | } | |
961 | dockerTLSVerify := os.Getenv("DOCKER_TLS_VERIFY") != "" | |
962 | var dockerCertPath string | |
963 | if dockerTLSVerify { | |
964 | dockerCertPath = os.Getenv("DOCKER_CERT_PATH") | |
965 | if dockerCertPath == "" { | |
966 | home := homedir.Get() | |
967 | if home == "" { | |
968 | return nil, errors.New("environment variable HOME must be set if DOCKER_CERT_PATH is not set") | |
969 | } | |
970 | dockerCertPath = filepath.Join(home, ".docker") | |
971 | dockerCertPath, err = filepath.Abs(dockerCertPath) | |
972 | if err != nil { | |
973 | return nil, err | |
974 | } | |
975 | } | |
976 | } | |
977 | return &dockerEnv{ | |
978 | dockerHost: dockerHost, | |
979 | dockerTLSVerify: dockerTLSVerify, | |
980 | dockerCertPath: dockerCertPath, | |
981 | }, nil | |
982 | } | |
983 | ||
984 | // DefaultDockerHost returns the default docker socket for the current OS | |
985 | func DefaultDockerHost() (string, error) { | |
986 | var defaultHost string | |
987 | if runtime.GOOS == "windows" { | |
988 | // If we do not have a host, default to TCP socket on Windows | |
989 | defaultHost = fmt.Sprintf("tcp://%s:%d", opts.DefaultHTTPHost, opts.DefaultHTTPPort) | |
990 | } else { | |
991 | // If we do not have a host, default to unix socket | |
992 | defaultHost = fmt.Sprintf("unix://%s", opts.DefaultUnixSocket) | |
993 | } | |
994 | return opts.ValidateHost(defaultHost) | |
995 | } |