]>
Commit | Line | Data |
---|---|---|
bae9f6d2 JC |
1 | package corehandlers |
2 | ||
3 | import ( | |
4 | "bytes" | |
5 | "fmt" | |
6 | "io" | |
7 | "io/ioutil" | |
8 | "net/http" | |
9 | "net/url" | |
10 | "regexp" | |
11 | "runtime" | |
12 | "strconv" | |
13 | "time" | |
14 | ||
15 | "github.com/aws/aws-sdk-go/aws" | |
16 | "github.com/aws/aws-sdk-go/aws/awserr" | |
17 | "github.com/aws/aws-sdk-go/aws/credentials" | |
18 | "github.com/aws/aws-sdk-go/aws/request" | |
19 | ) | |
20 | ||
21 | // Interface for matching types which also have a Len method. | |
22 | type lener interface { | |
23 | Len() int | |
24 | } | |
25 | ||
26 | // BuildContentLengthHandler builds the content length of a request based on the body, | |
27 | // or will use the HTTPRequest.Header's "Content-Length" if defined. If unable | |
28 | // to determine request body length and no "Content-Length" was specified it will panic. | |
29 | // | |
30 | // The Content-Length will only be added to the request if the length of the body | |
31 | // is greater than 0. If the body is empty or the current `Content-Length` | |
32 | // header is <= 0, the header will also be stripped. | |
33 | var BuildContentLengthHandler = request.NamedHandler{Name: "core.BuildContentLengthHandler", Fn: func(r *request.Request) { | |
34 | var length int64 | |
35 | ||
36 | if slength := r.HTTPRequest.Header.Get("Content-Length"); slength != "" { | |
37 | length, _ = strconv.ParseInt(slength, 10, 64) | |
38 | } else { | |
39 | switch body := r.Body.(type) { | |
40 | case nil: | |
41 | length = 0 | |
42 | case lener: | |
43 | length = int64(body.Len()) | |
44 | case io.Seeker: | |
45 | r.BodyStart, _ = body.Seek(0, 1) | |
46 | end, _ := body.Seek(0, 2) | |
47 | body.Seek(r.BodyStart, 0) // make sure to seek back to original location | |
48 | length = end - r.BodyStart | |
49 | default: | |
50 | panic("Cannot get length of body, must provide `ContentLength`") | |
51 | } | |
52 | } | |
53 | ||
54 | if length > 0 { | |
55 | r.HTTPRequest.ContentLength = length | |
56 | r.HTTPRequest.Header.Set("Content-Length", fmt.Sprintf("%d", length)) | |
57 | } else { | |
58 | r.HTTPRequest.ContentLength = 0 | |
59 | r.HTTPRequest.Header.Del("Content-Length") | |
60 | } | |
61 | }} | |
62 | ||
63 | // SDKVersionUserAgentHandler is a request handler for adding the SDK Version to the user agent. | |
64 | var SDKVersionUserAgentHandler = request.NamedHandler{ | |
65 | Name: "core.SDKVersionUserAgentHandler", | |
66 | Fn: request.MakeAddToUserAgentHandler(aws.SDKName, aws.SDKVersion, | |
67 | runtime.Version(), runtime.GOOS, runtime.GOARCH), | |
68 | } | |
69 | ||
70 | var reStatusCode = regexp.MustCompile(`^(\d{3})`) | |
71 | ||
72 | // ValidateReqSigHandler is a request handler to ensure that the request's | |
73 | // signature doesn't expire before it is sent. This can happen when a request | |
74 | // is built and signed significantly before it is sent. Or significant delays | |
75 | // occur when retrying requests that would cause the signature to expire. | |
76 | var ValidateReqSigHandler = request.NamedHandler{ | |
77 | Name: "core.ValidateReqSigHandler", | |
78 | Fn: func(r *request.Request) { | |
79 | // Unsigned requests are not signed | |
80 | if r.Config.Credentials == credentials.AnonymousCredentials { | |
81 | return | |
82 | } | |
83 | ||
84 | signedTime := r.Time | |
85 | if !r.LastSignedAt.IsZero() { | |
86 | signedTime = r.LastSignedAt | |
87 | } | |
88 | ||
89 | // 10 minutes to allow for some clock skew/delays in transmission. | |
90 | // Would be improved with aws/aws-sdk-go#423 | |
91 | if signedTime.Add(10 * time.Minute).After(time.Now()) { | |
92 | return | |
93 | } | |
94 | ||
95 | fmt.Println("request expired, resigning") | |
96 | r.Sign() | |
97 | }, | |
98 | } | |
99 | ||
100 | // SendHandler is a request handler to send service request using HTTP client. | |
101 | var SendHandler = request.NamedHandler{ | |
102 | Name: "core.SendHandler", | |
103 | Fn: func(r *request.Request) { | |
104 | sender := sendFollowRedirects | |
105 | if r.DisableFollowRedirects { | |
106 | sender = sendWithoutFollowRedirects | |
107 | } | |
108 | ||
109 | var err error | |
110 | r.HTTPResponse, err = sender(r) | |
111 | if err != nil { | |
112 | handleSendError(r, err) | |
113 | } | |
114 | }, | |
115 | } | |
116 | ||
117 | func sendFollowRedirects(r *request.Request) (*http.Response, error) { | |
118 | return r.Config.HTTPClient.Do(r.HTTPRequest) | |
119 | } | |
120 | ||
121 | func sendWithoutFollowRedirects(r *request.Request) (*http.Response, error) { | |
122 | transport := r.Config.HTTPClient.Transport | |
123 | if transport == nil { | |
124 | transport = http.DefaultTransport | |
125 | } | |
126 | ||
127 | return transport.RoundTrip(r.HTTPRequest) | |
128 | } | |
129 | ||
130 | func handleSendError(r *request.Request, err error) { | |
131 | // Prevent leaking if an HTTPResponse was returned. Clean up | |
132 | // the body. | |
133 | if r.HTTPResponse != nil { | |
134 | r.HTTPResponse.Body.Close() | |
135 | } | |
136 | // Capture the case where url.Error is returned for error processing | |
137 | // response. e.g. 301 without location header comes back as string | |
138 | // error and r.HTTPResponse is nil. Other URL redirect errors will | |
139 | // comeback in a similar method. | |
140 | if e, ok := err.(*url.Error); ok && e.Err != nil { | |
141 | if s := reStatusCode.FindStringSubmatch(e.Err.Error()); s != nil { | |
142 | code, _ := strconv.ParseInt(s[1], 10, 64) | |
143 | r.HTTPResponse = &http.Response{ | |
144 | StatusCode: int(code), | |
145 | Status: http.StatusText(int(code)), | |
146 | Body: ioutil.NopCloser(bytes.NewReader([]byte{})), | |
147 | } | |
148 | return | |
149 | } | |
150 | } | |
151 | if r.HTTPResponse == nil { | |
152 | // Add a dummy request response object to ensure the HTTPResponse | |
153 | // value is consistent. | |
154 | r.HTTPResponse = &http.Response{ | |
155 | StatusCode: int(0), | |
156 | Status: http.StatusText(int(0)), | |
157 | Body: ioutil.NopCloser(bytes.NewReader([]byte{})), | |
158 | } | |
159 | } | |
160 | // Catch all other request errors. | |
161 | r.Error = awserr.New("RequestError", "send request failed", err) | |
162 | r.Retryable = aws.Bool(true) // network errors are retryable | |
163 | ||
164 | // Override the error with a context canceled error, if that was canceled. | |
165 | ctx := r.Context() | |
166 | select { | |
167 | case <-ctx.Done(): | |
168 | r.Error = awserr.New(request.CanceledErrorCode, | |
169 | "request context canceled", ctx.Err()) | |
170 | r.Retryable = aws.Bool(false) | |
171 | default: | |
172 | } | |
173 | } | |
174 | ||
175 | // ValidateResponseHandler is a request handler to validate service response. | |
176 | var ValidateResponseHandler = request.NamedHandler{Name: "core.ValidateResponseHandler", Fn: func(r *request.Request) { | |
177 | if r.HTTPResponse.StatusCode == 0 || r.HTTPResponse.StatusCode >= 300 { | |
178 | // this may be replaced by an UnmarshalError handler | |
179 | r.Error = awserr.New("UnknownError", "unknown error", nil) | |
180 | } | |
181 | }} | |
182 | ||
183 | // AfterRetryHandler performs final checks to determine if the request should | |
184 | // be retried and how long to delay. | |
185 | var AfterRetryHandler = request.NamedHandler{Name: "core.AfterRetryHandler", Fn: func(r *request.Request) { | |
186 | // If one of the other handlers already set the retry state | |
187 | // we don't want to override it based on the service's state | |
188 | if r.Retryable == nil || aws.BoolValue(r.Config.EnforceShouldRetryCheck) { | |
189 | r.Retryable = aws.Bool(r.ShouldRetry(r)) | |
190 | } | |
191 | ||
192 | if r.WillRetry() { | |
193 | r.RetryDelay = r.RetryRules(r) | |
194 | ||
195 | if sleepFn := r.Config.SleepDelay; sleepFn != nil { | |
196 | // Support SleepDelay for backwards compatibility and testing | |
197 | sleepFn(r.RetryDelay) | |
198 | } else if err := aws.SleepWithContext(r.Context(), r.RetryDelay); err != nil { | |
199 | r.Error = awserr.New(request.CanceledErrorCode, | |
200 | "request context canceled", err) | |
201 | r.Retryable = aws.Bool(false) | |
202 | return | |
203 | } | |
204 | ||
205 | // when the expired token exception occurs the credentials | |
206 | // need to be expired locally so that the next request to | |
207 | // get credentials will trigger a credentials refresh. | |
208 | if r.IsErrorExpired() { | |
209 | r.Config.Credentials.Expire() | |
210 | } | |
211 | ||
212 | r.RetryCount++ | |
213 | r.Error = nil | |
214 | } | |
215 | }} | |
216 | ||
217 | // ValidateEndpointHandler is a request handler to validate a request had the | |
218 | // appropriate Region and Endpoint set. Will set r.Error if the endpoint or | |
219 | // region is not valid. | |
220 | var ValidateEndpointHandler = request.NamedHandler{Name: "core.ValidateEndpointHandler", Fn: func(r *request.Request) { | |
221 | if r.ClientInfo.SigningRegion == "" && aws.StringValue(r.Config.Region) == "" { | |
222 | r.Error = aws.ErrMissingRegion | |
223 | } else if r.ClientInfo.Endpoint == "" { | |
224 | r.Error = aws.ErrMissingEndpoint | |
225 | } | |
226 | }} |