From 98de336a2130aa102e52af4d8150cb3c05101d15 Mon Sep 17 00:00:00 2001 From: Tanmay Chaudhry Date: Tue, 17 Apr 2018 19:39:22 +0530 Subject: proxy: Enabled configurable timeout (#2070) * Enabled configurable Timeout for the proxy directive * Added Test for reverse for proxy timeout * Removed Duplication in proxy constructors * Remove indirection from multiple constructors and refactor into one * Fix inconsistent error message and refactor dialer initialization --- caddyhttp/proxy/proxy.go | 10 ++++- caddyhttp/proxy/proxy_test.go | 81 +++++++++++++++++++++++++----------- caddyhttp/proxy/reverseproxy.go | 34 ++++++++++----- caddyhttp/proxy/reverseproxy_test.go | 3 +- caddyhttp/proxy/upstream.go | 18 +++++++- 5 files changed, 107 insertions(+), 39 deletions(-) diff --git a/caddyhttp/proxy/proxy.go b/caddyhttp/proxy/proxy.go index e66ad776..c47a1597 100644 --- a/caddyhttp/proxy/proxy.go +++ b/caddyhttp/proxy/proxy.go @@ -58,6 +58,10 @@ type Upstream interface { // Gets the number of upstream hosts. GetHostCount() int + // Gets how long to wait before timing out + // the request + GetTimeout() time.Duration + // Stops the upstream from proxying requests to shutdown goroutines cleanly. Stop() error } @@ -187,7 +191,11 @@ func (p Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) { if nameURL, err := url.Parse(host.Name); err == nil { outreq.Host = nameURL.Host if proxy == nil { - proxy = NewSingleHostReverseProxy(nameURL, host.WithoutPathPrefix, http.DefaultMaxIdleConnsPerHost) + proxy = NewSingleHostReverseProxy(nameURL, + host.WithoutPathPrefix, + http.DefaultMaxIdleConnsPerHost, + upstream.GetTimeout(), + ) } // use upstream credentials by default diff --git a/caddyhttp/proxy/proxy_test.go b/caddyhttp/proxy/proxy_test.go index 213f8741..79c23887 100644 --- a/caddyhttp/proxy/proxy_test.go +++ b/caddyhttp/proxy/proxy_test.go @@ -122,7 +122,7 @@ func TestReverseProxy(t *testing.T) { // set up proxy p := &Proxy{ Next: httpserver.EmptyNext, // prevents panic in some cases when test fails - Upstreams: []Upstream{newFakeUpstream(backend.URL, false)}, + Upstreams: []Upstream{newFakeUpstream(backend.URL, false, 30*time.Second)}, } // Create the fake request body. @@ -202,7 +202,7 @@ func TestReverseProxyInsecureSkipVerify(t *testing.T) { // set up proxy p := &Proxy{ Next: httpserver.EmptyNext, // prevents panic in some cases when test fails - Upstreams: []Upstream{newFakeUpstream(backend.URL, true)}, + Upstreams: []Upstream{newFakeUpstream(backend.URL, true, 30*time.Second)}, } // create request and response recorder @@ -287,6 +287,31 @@ func TestReverseProxyMaxConnLimit(t *testing.T) { jobs.Wait() } +func TestReverseProxyTimeout(t *testing.T) { + timeout := 2 * time.Second + errorMargin := 100 * time.Millisecond + log.SetOutput(ioutil.Discard) + defer log.SetOutput(os.Stderr) + + // set up proxy + p := &Proxy{ + Next: httpserver.EmptyNext, // prevents panic in some cases when test fails + Upstreams: []Upstream{newFakeUpstream("https://8.8.8.8", true, timeout)}, + } + + // create request and response recorder + r := httptest.NewRequest("GET", "/", nil) + w := httptest.NewRecorder() + + start := time.Now() + p.ServeHTTP(w, r) + took := time.Since(start) + + if took > timeout+errorMargin { + t.Errorf("Expected timeout ~ %v but got %v", timeout, took) + } +} + func TestWebSocketReverseProxyNonHijackerPanic(t *testing.T) { // Capture the expected panic defer func() { @@ -301,7 +326,7 @@ func TestWebSocketReverseProxyNonHijackerPanic(t *testing.T) { defer wsNop.Close() // Get proxy to use for the test - p := newWebSocketTestProxy(wsNop.URL, false) + p := newWebSocketTestProxy(wsNop.URL, false, 30*time.Second) // Create client request r := httptest.NewRequest("GET", "/", nil) @@ -331,7 +356,7 @@ func TestWebSocketReverseProxyBackendShutDown(t *testing.T) { }() // Get proxy to use for the test - p := newWebSocketTestProxy(backend.URL, false) + p := newWebSocketTestProxy(backend.URL, false, 30*time.Second) backendProxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { p.ServeHTTP(w, r) })) @@ -360,7 +385,7 @@ func TestWebSocketReverseProxyServeHTTPHandler(t *testing.T) { defer wsNop.Close() // Get proxy to use for the test - p := newWebSocketTestProxy(wsNop.URL, false) + p := newWebSocketTestProxy(wsNop.URL, false, 30*time.Second) // Create client request r := httptest.NewRequest("GET", "/", nil) @@ -407,7 +432,7 @@ func TestWebSocketReverseProxyFromWSClient(t *testing.T) { defer wsEcho.Close() // Get proxy to use for the test - p := newWebSocketTestProxy(wsEcho.URL, false) + p := newWebSocketTestProxy(wsEcho.URL, false, 30*time.Second) // This is a full end-end test, so the proxy handler // has to be part of a server listening on a port. Our @@ -452,7 +477,7 @@ func TestWebSocketReverseProxyFromWSSClient(t *testing.T) { })) defer wsEcho.Close() - p := newWebSocketTestProxy(wsEcho.URL, true) + p := newWebSocketTestProxy(wsEcho.URL, true, 30*time.Second) echoProxy := newTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { p.ServeHTTP(w, r) @@ -528,7 +553,7 @@ func TestUnixSocketProxy(t *testing.T) { defer ts.Close() url := strings.Replace(ts.URL, "http://", "unix:", 1) - p := newWebSocketTestProxy(url, false) + p := newWebSocketTestProxy(url, false, 30*time.Second) echoProxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { p.ServeHTTP(w, r) @@ -686,7 +711,7 @@ func TestUpstreamHeadersUpdate(t *testing.T) { })) defer backend.Close() - upstream := newFakeUpstream(backend.URL, false) + upstream := newFakeUpstream(backend.URL, false, 30*time.Second) upstream.host.UpstreamHeaders = http.Header{ "Connection": {"{>Connection}"}, "Upgrade": {"{>Upgrade}"}, @@ -753,7 +778,7 @@ func TestDownstreamHeadersUpdate(t *testing.T) { })) defer backend.Close() - upstream := newFakeUpstream(backend.URL, false) + upstream := newFakeUpstream(backend.URL, false, 30*time.Second) upstream.host.DownstreamHeaders = http.Header{ "+Merge-Me": {"Merge-Value"}, "+Add-Me": {"Add-Value"}, @@ -893,7 +918,7 @@ func TestHostSimpleProxyNoHeaderForward(t *testing.T) { // set up proxy p := &Proxy{ Next: httpserver.EmptyNext, // prevents panic in some cases when test fails - Upstreams: []Upstream{newFakeUpstream(backend.URL, false)}, + Upstreams: []Upstream{newFakeUpstream(backend.URL, false, 30*time.Second)}, } r := httptest.NewRequest("GET", "/", nil) @@ -982,7 +1007,7 @@ func TestHostHeaderReplacedUsingForward(t *testing.T) { })) defer backend.Close() - upstream := newFakeUpstream(backend.URL, false) + upstream := newFakeUpstream(backend.URL, false, 30*time.Second) proxyHostHeader := "test2.com" upstream.host.UpstreamHeaders = http.Header{"Host": []string{proxyHostHeader}} // set up proxy @@ -1044,7 +1069,7 @@ func basicAuthTestcase(t *testing.T, upstreamUser, clientUser *url.Userinfo) { p := &Proxy{ Next: httpserver.EmptyNext, - Upstreams: []Upstream{newFakeUpstream(backURL.String(), false)}, + Upstreams: []Upstream{newFakeUpstream(backURL.String(), false, 30*time.Second)}, } r, err := http.NewRequest("GET", "/foo", nil) if err != nil { @@ -1179,7 +1204,7 @@ func TestProxyDirectorURL(t *testing.T) { continue } - NewSingleHostReverseProxy(targetURL, c.without, 0).Director(req) + NewSingleHostReverseProxy(targetURL, c.without, 0, 30*time.Second).Director(req) if expect, got := c.expectURL, req.URL.String(); expect != got { t.Errorf("case %d url not equal: expect %q, but got %q", i, expect, got) @@ -1326,7 +1351,7 @@ func TestCancelRequest(t *testing.T) { // set up proxy p := &Proxy{ Next: httpserver.EmptyNext, // prevents panic in some cases when test fails - Upstreams: []Upstream{newFakeUpstream(backend.URL, false)}, + Upstreams: []Upstream{newFakeUpstream(backend.URL, false, 30*time.Second)}, } // setup request with cancel ctx @@ -1375,14 +1400,15 @@ func (r *noopReader) Read(b []byte) (int, error) { return n, nil } -func newFakeUpstream(name string, insecure bool) *fakeUpstream { +func newFakeUpstream(name string, insecure bool, timeout time.Duration) *fakeUpstream { uri, _ := url.Parse(name) u := &fakeUpstream{ - name: name, - from: "/", + name: name, + from: "/", + timeout: timeout, host: &UpstreamHost{ Name: name, - ReverseProxy: NewSingleHostReverseProxy(uri, "", http.DefaultMaxIdleConnsPerHost), + ReverseProxy: NewSingleHostReverseProxy(uri, "", http.DefaultMaxIdleConnsPerHost, timeout), }, } if insecure { @@ -1396,6 +1422,7 @@ type fakeUpstream struct { host *UpstreamHost from string without string + timeout time.Duration } func (u *fakeUpstream) From() string { @@ -1410,7 +1437,7 @@ func (u *fakeUpstream) Select(r *http.Request) *UpstreamHost { } u.host = &UpstreamHost{ Name: u.name, - ReverseProxy: NewSingleHostReverseProxy(uri, u.without, http.DefaultMaxIdleConnsPerHost), + ReverseProxy: NewSingleHostReverseProxy(uri, u.without, http.DefaultMaxIdleConnsPerHost, u.GetTimeout()), } } return u.host @@ -1419,6 +1446,7 @@ func (u *fakeUpstream) Select(r *http.Request) *UpstreamHost { func (u *fakeUpstream) AllowedPath(requestPath string) bool { return true } func (u *fakeUpstream) GetTryDuration() time.Duration { return 1 * time.Second } func (u *fakeUpstream) GetTryInterval() time.Duration { return 250 * time.Millisecond } +func (u *fakeUpstream) GetTimeout() time.Duration { return u.timeout } func (u *fakeUpstream) GetHostCount() int { return 1 } func (u *fakeUpstream) Stop() error { return nil } @@ -1426,13 +1454,14 @@ func (u *fakeUpstream) Stop() error { return nil } // redirect to the specified backendAddr. The function // also sets up the rules/environment for testing WebSocket // proxy. -func newWebSocketTestProxy(backendAddr string, insecure bool) *Proxy { +func newWebSocketTestProxy(backendAddr string, insecure bool, timeout time.Duration) *Proxy { return &Proxy{ Next: httpserver.EmptyNext, // prevents panic in some cases when test fails Upstreams: []Upstream{&fakeWsUpstream{ name: backendAddr, without: "", insecure: insecure, + timeout: timeout, }}, } } @@ -1440,7 +1469,7 @@ func newWebSocketTestProxy(backendAddr string, insecure bool) *Proxy { func newPrefixedWebSocketTestProxy(backendAddr string, prefix string) *Proxy { return &Proxy{ Next: httpserver.EmptyNext, // prevents panic in some cases when test fails - Upstreams: []Upstream{&fakeWsUpstream{name: backendAddr, without: prefix}}, + Upstreams: []Upstream{&fakeWsUpstream{name: backendAddr, without: prefix, timeout: 30 * time.Second}}, } } @@ -1448,6 +1477,7 @@ type fakeWsUpstream struct { name string without string insecure bool + timeout time.Duration } func (u *fakeWsUpstream) From() string { @@ -1458,7 +1488,7 @@ func (u *fakeWsUpstream) Select(r *http.Request) *UpstreamHost { uri, _ := url.Parse(u.name) host := &UpstreamHost{ Name: u.name, - ReverseProxy: NewSingleHostReverseProxy(uri, u.without, http.DefaultMaxIdleConnsPerHost), + ReverseProxy: NewSingleHostReverseProxy(uri, u.without, http.DefaultMaxIdleConnsPerHost, u.GetTimeout()), UpstreamHeaders: http.Header{ "Connection": {"{>Connection}"}, "Upgrade": {"{>Upgrade}"}}, @@ -1472,6 +1502,7 @@ func (u *fakeWsUpstream) Select(r *http.Request) *UpstreamHost { func (u *fakeWsUpstream) AllowedPath(requestPath string) bool { return true } func (u *fakeWsUpstream) GetTryDuration() time.Duration { return 1 * time.Second } func (u *fakeWsUpstream) GetTryInterval() time.Duration { return 250 * time.Millisecond } +func (u *fakeWsUpstream) GetTimeout() time.Duration { return u.timeout } func (u *fakeWsUpstream) GetHostCount() int { return 1 } func (u *fakeWsUpstream) Stop() error { return nil } @@ -1517,7 +1548,7 @@ func BenchmarkProxy(b *testing.B) { })) defer backend.Close() - upstream := newFakeUpstream(backend.URL, false) + upstream := newFakeUpstream(backend.URL, false, 30*time.Second) upstream.host.UpstreamHeaders = http.Header{ "Hostname": {"{hostname}"}, "Host": {"{host}"}, @@ -1560,7 +1591,7 @@ func TestChunkedWebSocketReverseProxy(t *testing.T) { defer wsNop.Close() // Get proxy to use for the test - p := newWebSocketTestProxy(wsNop.URL, false) + p := newWebSocketTestProxy(wsNop.URL, false, 30*time.Second) // Create client request r := httptest.NewRequest("GET", "/", nil) diff --git a/caddyhttp/proxy/reverseproxy.go b/caddyhttp/proxy/reverseproxy.go index d48894ff..c528cf45 100644 --- a/caddyhttp/proxy/reverseproxy.go +++ b/caddyhttp/proxy/reverseproxy.go @@ -94,6 +94,10 @@ type ReverseProxy struct { // If zero, no periodic flushing is done. FlushInterval time.Duration + // dialer is used when values from the + // defaultDialer need to be overridden per Proxy + dialer *net.Dialer + srvResolver srvResolver } @@ -103,13 +107,13 @@ type ReverseProxy struct { // What we need is just the path, so if "unix:/var/run/www.socket" // was the proxy directive, the parsed hostName would be // "unix:///var/run/www.socket", hence the ambiguous trimming. -func socketDial(hostName string) func(network, addr string) (conn net.Conn, err error) { +func socketDial(hostName string, timeout time.Duration) func(network, addr string) (conn net.Conn, err error) { return func(network, addr string) (conn net.Conn, err error) { - return net.Dial("unix", hostName[len("unix://"):]) + return net.DialTimeout("unix", hostName[len("unix://"):], timeout) } } -func (rp *ReverseProxy) srvDialerFunc(locator string) func(network, addr string) (conn net.Conn, err error) { +func (rp *ReverseProxy) srvDialerFunc(locator string, timeout time.Duration) func(network, addr string) (conn net.Conn, err error) { service := locator if strings.HasPrefix(locator, "srv://") { service = locator[6:] @@ -122,7 +126,7 @@ func (rp *ReverseProxy) srvDialerFunc(locator string) func(network, addr string) if err != nil { return nil, err } - return net.Dial("tcp", fmt.Sprintf("%s:%d", addrs[0].Target, addrs[0].Port)) + return net.DialTimeout("tcp", fmt.Sprintf("%s:%d", addrs[0].Target, addrs[0].Port), timeout) } } @@ -144,7 +148,7 @@ func singleJoiningSlash(a, b string) string { // the target request will be for /base/dir. // Without logic: target's path is "/", incoming is "/api/messages", // without is "/api", then the target request will be for /messages. -func NewSingleHostReverseProxy(target *url.URL, without string, keepalive int) *ReverseProxy { +func NewSingleHostReverseProxy(target *url.URL, without string, keepalive int, timeout time.Duration) *ReverseProxy { targetQuery := target.RawQuery director := func(req *http.Request) { if target.Scheme == "unix" { @@ -226,15 +230,21 @@ func NewSingleHostReverseProxy(target *url.URL, without string, keepalive int) * } } + dialer := *defaultDialer + if timeout != defaultDialer.Timeout { + dialer.Timeout = timeout + } + rp := &ReverseProxy{ Director: director, FlushInterval: 250 * time.Millisecond, // flushing good for streaming & server-sent events srvResolver: net.DefaultResolver, + dialer: &dialer, } if target.Scheme == "unix" { rp.Transport = &http.Transport{ - Dial: socketDial(target.String()), + Dial: socketDial(target.String(), timeout), } } else if target.Scheme == "quic" { rp.Transport = &h2quic.RoundTripper{ @@ -244,9 +254,9 @@ func NewSingleHostReverseProxy(target *url.URL, without string, keepalive int) * }, } } else if keepalive != http.DefaultMaxIdleConnsPerHost || strings.HasPrefix(target.Scheme, "srv") { - dialFunc := defaultDialer.Dial + dialFunc := rp.dialer.Dial if strings.HasPrefix(target.Scheme, "srv") { - dialFunc = rp.srvDialerFunc(target.String()) + dialFunc = rp.srvDialerFunc(target.String(), timeout) } transport := &http.Transport{ @@ -275,7 +285,7 @@ func (rp *ReverseProxy) UseInsecureTransport() { if rp.Transport == nil { transport := &http.Transport{ Proxy: http.ProxyFromEnvironment, - Dial: defaultDialer.Dial, + Dial: rp.dialer.Dial, TLSHandshakeTimeout: defaultCryptoHandshakeTimeout, TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, } @@ -306,7 +316,9 @@ func (rp *ReverseProxy) ServeHTTP(rw http.ResponseWriter, outreq *http.Request, if requestIsWebsocket(outreq) { transport = newConnHijackerTransport(transport) } else if transport == nil { - transport = http.DefaultTransport + transport = &http.Transport{ + Dial: rp.dialer.Dial, + } } rp.Director(outreq) @@ -361,7 +373,7 @@ func (rp *ReverseProxy) ServeHTTP(rw http.ResponseWriter, outreq *http.Request, } bufferPool.Put(hj.Replay) } else { - backendConn, err = net.Dial("tcp", outreq.URL.Host) + backendConn, err = net.DialTimeout("tcp", outreq.URL.Host, rp.dialer.Timeout) if err != nil { return err } diff --git a/caddyhttp/proxy/reverseproxy_test.go b/caddyhttp/proxy/reverseproxy_test.go index 2d1d80df..8b01054e 100644 --- a/caddyhttp/proxy/reverseproxy_test.go +++ b/caddyhttp/proxy/reverseproxy_test.go @@ -21,6 +21,7 @@ import ( "net/url" "strconv" "testing" + "time" ) const ( @@ -66,7 +67,7 @@ func TestSingleSRVHostReverseProxy(t *testing.T) { } port := uint16(pp) - rp := NewSingleHostReverseProxy(target, "", http.DefaultMaxIdleConnsPerHost) + rp := NewSingleHostReverseProxy(target, "", http.DefaultMaxIdleConnsPerHost, 30*time.Second) rp.srvResolver = testResolver{ result: []*net.SRV{ {Target: upstream.Hostname(), Port: port, Priority: 1, Weight: 1}, diff --git a/caddyhttp/proxy/upstream.go b/caddyhttp/proxy/upstream.go index df93d390..8e5395c6 100644 --- a/caddyhttp/proxy/upstream.go +++ b/caddyhttp/proxy/upstream.go @@ -49,6 +49,7 @@ type staticUpstream struct { Hosts HostPool Policy Policy KeepAlive int + Timeout time.Duration FailTimeout time.Duration TryDuration time.Duration TryInterval time.Duration @@ -92,6 +93,7 @@ func NewStaticUpstreams(c caddyfile.Dispenser, host string) ([]Upstream, error) TryInterval: 250 * time.Millisecond, MaxConns: 0, KeepAlive: http.DefaultMaxIdleConnsPerHost, + Timeout: 30 * time.Second, resolver: net.DefaultResolver, } @@ -225,7 +227,7 @@ func (u *staticUpstream) NewHost(host string) (*UpstreamHost, error) { return nil, err } - uh.ReverseProxy = NewSingleHostReverseProxy(baseURL, uh.WithoutPathPrefix, u.KeepAlive) + uh.ReverseProxy = NewSingleHostReverseProxy(baseURL, uh.WithoutPathPrefix, u.KeepAlive, u.Timeout) if u.insecureSkipVerify { uh.ReverseProxy.UseInsecureTransport() } @@ -464,6 +466,15 @@ func parseBlock(c *caddyfile.Dispenser, u *staticUpstream, hasSrv bool) error { return c.ArgErr() } u.KeepAlive = n + case "timeout": + if !c.NextArg() { + return c.ArgErr() + } + dur, err := time.ParseDuration(c.Val()) + if err != nil { + return c.Errf("unable to parse timeout duration '%s'", c.Val()) + } + u.Timeout = dur default: return c.Errf("unknown property '%s'", c.Val()) } @@ -619,6 +630,11 @@ func (u *staticUpstream) GetTryInterval() time.Duration { return u.TryInterval } +// GetTimeout returns u.Timeout. +func (u *staticUpstream) GetTimeout() time.Duration { + return u.Timeout +} + func (u *staticUpstream) GetHostCount() int { return len(u.Hosts) } -- cgit v1.2.3-70-g09d2