]>
Commit | Line | Data |
---|---|---|
15c0b25d AP |
1 | // Copyright 2015 The Go Authors. All rights reserved. |
2 | // Use of this source code is governed by a BSD-style | |
3 | // license that can be found in the LICENSE file. | |
4 | ||
5 | // +build go1.6 | |
6 | ||
7 | package http2 | |
8 | ||
9 | import ( | |
10 | "crypto/tls" | |
11 | "fmt" | |
12 | "net/http" | |
13 | ) | |
14 | ||
15 | func configureTransport(t1 *http.Transport) (*Transport, error) { | |
16 | connPool := new(clientConnPool) | |
17 | t2 := &Transport{ | |
18 | ConnPool: noDialClientConnPool{connPool}, | |
19 | t1: t1, | |
20 | } | |
21 | connPool.t = t2 | |
22 | if err := registerHTTPSProtocol(t1, noDialH2RoundTripper{t2}); err != nil { | |
23 | return nil, err | |
24 | } | |
25 | if t1.TLSClientConfig == nil { | |
26 | t1.TLSClientConfig = new(tls.Config) | |
27 | } | |
28 | if !strSliceContains(t1.TLSClientConfig.NextProtos, "h2") { | |
29 | t1.TLSClientConfig.NextProtos = append([]string{"h2"}, t1.TLSClientConfig.NextProtos...) | |
30 | } | |
31 | if !strSliceContains(t1.TLSClientConfig.NextProtos, "http/1.1") { | |
32 | t1.TLSClientConfig.NextProtos = append(t1.TLSClientConfig.NextProtos, "http/1.1") | |
33 | } | |
34 | upgradeFn := func(authority string, c *tls.Conn) http.RoundTripper { | |
35 | addr := authorityAddr("https", authority) | |
36 | if used, err := connPool.addConnIfNeeded(addr, t2, c); err != nil { | |
37 | go c.Close() | |
38 | return erringRoundTripper{err} | |
39 | } else if !used { | |
40 | // Turns out we don't need this c. | |
41 | // For example, two goroutines made requests to the same host | |
42 | // at the same time, both kicking off TCP dials. (since protocol | |
43 | // was unknown) | |
44 | go c.Close() | |
45 | } | |
46 | return t2 | |
47 | } | |
48 | if m := t1.TLSNextProto; len(m) == 0 { | |
49 | t1.TLSNextProto = map[string]func(string, *tls.Conn) http.RoundTripper{ | |
50 | "h2": upgradeFn, | |
51 | } | |
52 | } else { | |
53 | m["h2"] = upgradeFn | |
54 | } | |
55 | return t2, nil | |
56 | } | |
57 | ||
58 | // registerHTTPSProtocol calls Transport.RegisterProtocol but | |
59 | // converting panics into errors. | |
60 | func registerHTTPSProtocol(t *http.Transport, rt http.RoundTripper) (err error) { | |
61 | defer func() { | |
62 | if e := recover(); e != nil { | |
63 | err = fmt.Errorf("%v", e) | |
64 | } | |
65 | }() | |
66 | t.RegisterProtocol("https", rt) | |
67 | return nil | |
68 | } | |
69 | ||
70 | // noDialH2RoundTripper is a RoundTripper which only tries to complete the request | |
71 | // if there's already has a cached connection to the host. | |
72 | type noDialH2RoundTripper struct{ t *Transport } | |
73 | ||
74 | func (rt noDialH2RoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { | |
75 | res, err := rt.t.RoundTrip(req) | |
76 | if err == ErrNoCachedConn { | |
77 | return nil, http.ErrSkipAltProtocol | |
78 | } | |
79 | return res, err | |
80 | } |