aboutsummaryrefslogtreecommitdiffhomepage
path: root/vendor/google.golang.org/grpc/proxy.go
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/google.golang.org/grpc/proxy.go')
-rw-r--r--vendor/google.golang.org/grpc/proxy.go52
1 files changed, 37 insertions, 15 deletions
diff --git a/vendor/google.golang.org/grpc/proxy.go b/vendor/google.golang.org/grpc/proxy.go
index 2d40236..f8f69bf 100644
--- a/vendor/google.golang.org/grpc/proxy.go
+++ b/vendor/google.golang.org/grpc/proxy.go
@@ -20,6 +20,8 @@ package grpc
20 20
21import ( 21import (
22 "bufio" 22 "bufio"
23 "context"
24 "encoding/base64"
23 "errors" 25 "errors"
24 "fmt" 26 "fmt"
25 "io" 27 "io"
@@ -27,10 +29,10 @@ import (
27 "net/http" 29 "net/http"
28 "net/http/httputil" 30 "net/http/httputil"
29 "net/url" 31 "net/url"
30
31 "golang.org/x/net/context"
32) 32)
33 33
34const proxyAuthHeaderKey = "Proxy-Authorization"
35
34var ( 36var (
35 // errDisabled indicates that proxy is disabled for the address. 37 // errDisabled indicates that proxy is disabled for the address.
36 errDisabled = errors.New("proxy is disabled for the address") 38 errDisabled = errors.New("proxy is disabled for the address")
@@ -38,7 +40,7 @@ var (
38 httpProxyFromEnvironment = http.ProxyFromEnvironment 40 httpProxyFromEnvironment = http.ProxyFromEnvironment
39) 41)
40 42
41func mapAddress(ctx context.Context, address string) (string, error) { 43func mapAddress(ctx context.Context, address string) (*url.URL, error) {
42 req := &http.Request{ 44 req := &http.Request{
43 URL: &url.URL{ 45 URL: &url.URL{
44 Scheme: "https", 46 Scheme: "https",
@@ -47,12 +49,12 @@ func mapAddress(ctx context.Context, address string) (string, error) {
47 } 49 }
48 url, err := httpProxyFromEnvironment(req) 50 url, err := httpProxyFromEnvironment(req)
49 if err != nil { 51 if err != nil {
50 return "", err 52 return nil, err
51 } 53 }
52 if url == nil { 54 if url == nil {
53 return "", errDisabled 55 return nil, errDisabled
54 } 56 }
55 return url.Host, nil 57 return url, nil
56} 58}
57 59
58// To read a response from a net.Conn, http.ReadResponse() takes a bufio.Reader. 60// To read a response from a net.Conn, http.ReadResponse() takes a bufio.Reader.
@@ -69,18 +71,28 @@ func (c *bufConn) Read(b []byte) (int, error) {
69 return c.r.Read(b) 71 return c.r.Read(b)
70} 72}
71 73
72func doHTTPConnectHandshake(ctx context.Context, conn net.Conn, addr string) (_ net.Conn, err error) { 74func basicAuth(username, password string) string {
75 auth := username + ":" + password
76 return base64.StdEncoding.EncodeToString([]byte(auth))
77}
78
79func doHTTPConnectHandshake(ctx context.Context, conn net.Conn, backendAddr string, proxyURL *url.URL) (_ net.Conn, err error) {
73 defer func() { 80 defer func() {
74 if err != nil { 81 if err != nil {
75 conn.Close() 82 conn.Close()
76 } 83 }
77 }() 84 }()
78 85
79 req := (&http.Request{ 86 req := &http.Request{
80 Method: http.MethodConnect, 87 Method: http.MethodConnect,
81 URL: &url.URL{Host: addr}, 88 URL: &url.URL{Host: backendAddr},
82 Header: map[string][]string{"User-Agent": {grpcUA}}, 89 Header: map[string][]string{"User-Agent": {grpcUA}},
83 }) 90 }
91 if t := proxyURL.User; t != nil {
92 u := t.Username()
93 p, _ := t.Password()
94 req.Header.Add(proxyAuthHeaderKey, "Basic "+basicAuth(u, p))
95 }
84 96
85 if err := sendHTTPRequest(ctx, req, conn); err != nil { 97 if err := sendHTTPRequest(ctx, req, conn); err != nil {
86 return nil, fmt.Errorf("failed to write the HTTP request: %v", err) 98 return nil, fmt.Errorf("failed to write the HTTP request: %v", err)
@@ -108,23 +120,33 @@ func doHTTPConnectHandshake(ctx context.Context, conn net.Conn, addr string) (_
108// provided dialer, does HTTP CONNECT handshake and returns the connection. 120// provided dialer, does HTTP CONNECT handshake and returns the connection.
109func newProxyDialer(dialer func(context.Context, string) (net.Conn, error)) func(context.Context, string) (net.Conn, error) { 121func newProxyDialer(dialer func(context.Context, string) (net.Conn, error)) func(context.Context, string) (net.Conn, error) {
110 return func(ctx context.Context, addr string) (conn net.Conn, err error) { 122 return func(ctx context.Context, addr string) (conn net.Conn, err error) {
111 var skipHandshake bool 123 var newAddr string
112 newAddr, err := mapAddress(ctx, addr) 124 proxyURL, err := mapAddress(ctx, addr)
113 if err != nil { 125 if err != nil {
114 if err != errDisabled { 126 if err != errDisabled {
115 return nil, err 127 return nil, err
116 } 128 }
117 skipHandshake = true
118 newAddr = addr 129 newAddr = addr
130 } else {
131 newAddr = proxyURL.Host
119 } 132 }
120 133
121 conn, err = dialer(ctx, newAddr) 134 conn, err = dialer(ctx, newAddr)
122 if err != nil { 135 if err != nil {
123 return 136 return
124 } 137 }
125 if !skipHandshake { 138 if proxyURL != nil {
126 conn, err = doHTTPConnectHandshake(ctx, conn, addr) 139 // proxy is disabled if proxyURL is nil.
140 conn, err = doHTTPConnectHandshake(ctx, conn, addr, proxyURL)
127 } 141 }
128 return 142 return
129 } 143 }
130} 144}
145
146func sendHTTPRequest(ctx context.Context, req *http.Request, conn net.Conn) error {
147 req = req.WithContext(ctx)
148 if err := req.Write(conn); err != nil {
149 return fmt.Errorf("failed to write the HTTP request: %v", err)
150 }
151 return nil
152}