From f2f5b5a1cc46ed16ae6dd3c0043303cdf6704b62 Mon Sep 17 00:00:00 2001 From: Thomas Gosteli Date: Mon, 15 Jun 2026 02:54:05 +0200 Subject: [PATCH] feat(forward): add doh support (#8004) * chore(pkg/proxy): prepare for DoH implementation Signed-off-by: Thomas Gosteli * chore(pkg/proxy): prepare for DoH implementation Signed-off-by: Thomas Gosteli * feat(proxy): implement basic DoH resolution Signed-off-by: Thomas Gosteli * feat(forward): implement DoH forwarding Signed-off-by: Thomas Gosteli * feat(proxy): add basic DoH health checker Signed-off-by: Thomas Gosteli * chore: align http transport with Go's DefaultTransport and resolve some of the TODOs Signed-off-by: Thomas Gosteli * docs(forward): add basic documentation for DoH Signed-off-by: Thomas Gosteli * chore: add basic tests to cover DoH Signed-off-by: Thomas Gosteli * chore(health): unify default timeout to 1s Signed-off-by: Thomas Gosteli * feat(forward): make doh method configurable Signed-off-by: Thomas Gosteli * chore: remove maxIdleConnsPerHost setting & update docs Signed-off-by: Thomas Gosteli * chore(forward): reject https upstreams with path Signed-off-by: Thomas Gosteli --------- Signed-off-by: Thomas Gosteli --- plugin/forward/README.md | 27 +++++-- plugin/forward/forward.go | 4 +- plugin/forward/proxy_test.go | 48 +++++++++++++ plugin/forward/setup.go | 36 +++++++++- plugin/forward/setup_test.go | 3 +- plugin/pkg/doh/doh.go | 11 ++- plugin/pkg/proxy/connect.go | 56 +++++++++++++-- plugin/pkg/proxy/health.go | 122 ++++++++++++++++++++++++++++++-- plugin/pkg/proxy/health_test.go | 49 +++++++++++++ plugin/pkg/proxy/persistent.go | 7 +- plugin/pkg/proxy/proxy.go | 21 +++++- plugin/pkg/proxy/type.go | 15 ++-- 12 files changed, 363 insertions(+), 36 deletions(-) diff --git a/plugin/forward/README.md b/plugin/forward/README.md index 74278414c..e900851d8 100644 --- a/plugin/forward/README.md +++ b/plugin/forward/README.md @@ -6,8 +6,8 @@ ## Description -The *forward* plugin re-uses already opened sockets to the upstreams. It supports UDP, TCP and -DNS-over-TLS and uses in band health checking. +The *forward* plugin re-uses already opened sockets to the upstreams. It supports UDP, TCP, +DNS-over-TLS, DNS-over-HTTPS and uses in band health checking. When it detects an error a health check is performed. This checks runs in a loop, performing each check at a *0.5s* interval for as long as the upstream reports unhealthy. Once healthy we stop @@ -30,8 +30,8 @@ forward FROM TO... * **FROM** is the base domain to match for the request to be forwarded. Domains using CIDR notation that expand to multiple reverse zones are not fully supported; only the first expanded zone is used. * **TO...** are the destination endpoints to forward to. The **TO** syntax allows you to specify - a protocol, `tls://9.9.9.9` or `dns://` (or no protocol) for plain DNS. The number of upstreams is - limited to 15. In addition to IP addresses and files (like `/etc/resolv.conf`), **TO** can also be + a protocol, `tls://9.9.9.9`, `https://9.9.9.9` (DoH defaults to `/dns-query` path) or `dns://` (or no protocol) + for plain DNS. The number of upstreams is limited to 15. In addition to IP addresses and files (like `/etc/resolv.conf`), **TO** can also be a hostname (e.g., `my-dns.svc.cluster.local`). Hostnames are resolved to IP addresses at startup. See the `resolver` option below. @@ -49,6 +49,7 @@ forward FROM TO... { max_idle_conns INTEGER max_fails INTEGER max_connect_attempts INTEGER + doh_method GET|POST tls CERT KEY CA tls_servername NAME policy random|round_robin|sequential @@ -75,6 +76,7 @@ forward FROM TO... { performed for a single incoming DNS request. Default value of 0 means no per-request cap. * `expire` **DURATION**, expire (cached) connections after this time, the default is 10s. +* `doh_method` **GET|POST**, whether to use GET or POST http method for DoH requests (defaults to POST). * `max_idle_conns` **INTEGER**, maximum number of idle connections to cache per upstream for reuse. Default is 0, which means unlimited. * `tls` **CERT** **KEY** **CA** define the TLS properties for TLS connection. From 0 to 3 arguments can be @@ -148,7 +150,7 @@ If monitoring is enabled (via the *prometheus* plugin) then the following metric * `coredns_proxy_conn_cache_misses_total{proxy_name="forward", to, proto}` - count of connection cache misses per upstream and protocol. Where `to` is one of the upstream servers (**TO** from the config), `rcode` is the returned RCODE -from the upstream, `proto` is the transport protocol like `udp`, `tcp`, `tcp-tls`. +from the upstream, `proto` is the transport protocol like `udp`, `tcp`, `tcp-tls`, `https`. The following metrics have recently been deprecated: * `coredns_forward_healthcheck_failures_total{to, rcode}` @@ -247,6 +249,19 @@ service with health checks. } ~~~ +The same configuration but using DNS-over-HTTPS (DoH) protocol. Note that the implementation uses the default `/dns-query` +path (custom paths are not supported). + +~~~ corefile +. { + forward . https://9.9.9.9 { + tls_servername dns.quad9.net + health_check 5s + } + cache 30 +} +~~~ + Or configure other domain name for health check requests ~~~ corefile @@ -330,3 +345,5 @@ Forward to an upstream identified by hostname, using a specific resolver to look ## See Also [RFC 7858](https://tools.ietf.org/html/rfc7858) for DNS over TLS. + +[RFC 8484](https://tools.ietf.org/html/rfc8484) for DNS over HTTPS. diff --git a/plugin/forward/forward.go b/plugin/forward/forward.go index 0b48da462..e97c4503c 100644 --- a/plugin/forward/forward.go +++ b/plugin/forward/forward.go @@ -8,6 +8,7 @@ import ( "context" "crypto/tls" "errors" + "net/http" "sync/atomic" "time" @@ -52,6 +53,7 @@ type Forward struct { expire time.Duration maxAge time.Duration maxIdleConns int + dohMethod string maxConcurrent int64 failfastUnhealthyUpstreams bool failoverRcodes []int @@ -74,7 +76,7 @@ type Forward struct { // New returns a new Forward. func New() *Forward { - f := &Forward{maxfails: 2, tlsConfig: new(tls.Config), expire: defaultExpire, p: new(random), from: ".", hcInterval: hcInterval, opts: proxyPkg.Options{ForceTCP: false, PreferUDP: false, HCRecursionDesired: true, HCDomain: "."}} + f := &Forward{maxfails: 2, tlsConfig: new(tls.Config), expire: defaultExpire, p: new(random), from: ".", hcInterval: hcInterval, dohMethod: http.MethodPost, opts: proxyPkg.Options{ForceTCP: false, PreferUDP: false, HCRecursionDesired: true, HCDomain: "."}} return f } diff --git a/plugin/forward/proxy_test.go b/plugin/forward/proxy_test.go index daf5f964c..86e3ff71e 100644 --- a/plugin/forward/proxy_test.go +++ b/plugin/forward/proxy_test.go @@ -2,10 +2,13 @@ package forward import ( "context" + "net/http" + "net/http/httptest" "testing" "github.com/coredns/caddy" "github.com/coredns/coredns/plugin/pkg/dnstest" + "github.com/coredns/coredns/plugin/pkg/doh" "github.com/coredns/coredns/plugin/test" "github.com/miekg/dns" @@ -68,3 +71,48 @@ func TestProxyTLSFail(t *testing.T) { t.Fatal("Expected *not* to receive reply, but got one") } } + +func TestProxyHTTPS(t *testing.T) { + s := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + msg, err := doh.RequestToMsg(r) + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + + ret := new(dns.Msg) + reply := ret.SetReply(msg) + reply.Answer = append(reply.Answer, test.A("example.org. IN A 127.0.0.1")) + + buf, err := reply.Pack() + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + w.Header().Set("Content-Type", doh.MimeType) + w.Write(buf) + })) + defer s.Close() + + c := caddy.NewTestController("dns", "forward . "+s.URL) + fs, err := parseForward(c) + if err != nil { + t.Errorf("Failed to create forwarder: %s", err) + } + f := fs[0] + f.proxies[0].SetHTTPClient(s.Client()) + f.OnStartup() + defer f.OnShutdown() + + m := new(dns.Msg) + m.SetQuestion("example.org.", dns.TypeA) + rec := dnstest.NewRecorder(&test.ResponseWriter{}) + + if _, err := f.ServeDNS(context.TODO(), rec, m); err != nil { + t.Fatal("Expected to receive reply, but didn't") + } + if x := rec.Msg.Answer[0].Header().Name; x != "example.org." { + t.Errorf("Expected %s, got %s", "example.org.", x) + } +} diff --git a/plugin/forward/setup.go b/plugin/forward/setup.go index 245aba0c0..2010131c2 100644 --- a/plugin/forward/setup.go +++ b/plugin/forward/setup.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "net" + "net/http" "path/filepath" "strconv" "strings" @@ -158,6 +159,14 @@ func parseStanza(c *caddy.Controller) (*Forward, error) { return f, fmt.Errorf("max_age (%s) must not be less than expire (%s)", f.maxAge, f.expire) } + // Reject HTTPS upstreams that include a path, the doh implementation default to /dns-query path. + for _, addr := range to { + trans, h := parse.Transport(addr) + if trans == transport.HTTPS && strings.Contains(h, "/") { + return f, fmt.Errorf("paths are not allowed in HTTPS upstream addresses (the /dns-query path is used by default): %s", addr) + } + } + // Classify TO addresses in order, preserving config ordering. entries, err := classifyToAddrs(to) if err != nil { @@ -177,7 +186,7 @@ func parseStanza(c *caddy.Controller) (*Forward, error) { tlsServerNames := make([]string, len(toHosts)) perServerNameProxyCount := make(map[string]int) transports := make([]string, len(toHosts)) - allowedTrans := map[string]bool{"dns": true, "tls": true} + allowedTrans := map[string]bool{"dns": true, "tls": true, "https": true} for i, hostWithZone := range toHosts { host, serverName := splitZone(hostWithZone) trans, h := parse.Transport(host) @@ -223,6 +232,21 @@ func parseStanza(c *caddy.Controller) (*Forward, error) { f.proxies[i].SetTLSConfig(f.tlsConfig) } } + + if transports[i] == transport.HTTPS { + httpTransport := http.DefaultTransport.(*http.Transport).Clone() + httpTransport.TLSClientConfig = f.tlsConfig + httpTransport.MaxIdleConns = f.maxIdleConns + httpTransport.MaxIdleConnsPerHost = f.maxIdleConns + + c := http.Client{ + Transport: httpTransport, + Timeout: 2 * time.Second, + } + f.proxies[i].SetHTTPClient(&c) + f.proxies[i].SetDOHRequestOptions(f.dohMethod) + } + f.proxies[i].SetExpire(f.expire) f.proxies[i].SetMaxAge(f.maxAge) f.proxies[i].SetMaxIdleConns(f.maxIdleConns) @@ -365,6 +389,16 @@ func parseBlock(c *caddy.Controller, f *Forward) error { return fmt.Errorf("max_idle_conns can't be negative: %d", n) } f.maxIdleConns = n + case "doh_method": + if !c.NextArg() { + return c.ArgErr() + } + switch c.Val() { + case http.MethodPost, http.MethodGet: + f.dohMethod = c.Val() + default: + return fmt.Errorf("doh_method must be either %s or %s", http.MethodPost, http.MethodGet) + } case "policy": if !c.NextArg() { return c.ArgErr() diff --git a/plugin/forward/setup_test.go b/plugin/forward/setup_test.go index 2b86f355d..b13150d52 100644 --- a/plugin/forward/setup_test.go +++ b/plugin/forward/setup_test.go @@ -45,11 +45,12 @@ func TestSetup(t *testing.T) { {`forward . ::1 forward com ::2`, false, ".", nil, 2, proxy.Options{HCRecursionDesired: true, HCDomain: "."}, "plugin"}, {"forward . tls://[2400:3200::1%dns.alidns.com]:853 {\ntls\n}\n", false, ".", nil, 2, proxy.Options{HCRecursionDesired: true, HCDomain: "."}, ""}, + {"forward . https://127.0.0.1 \n", false, ".", nil, 2, proxy.Options{HCRecursionDesired: true, HCDomain: "."}, ""}, // negative + {"forward . https://1.1.1.1/ \n", true, "", nil, 0, proxy.Options{HCRecursionDesired: true, HCDomain: "."}, "paths are not allowed in HTTPS upstream addresses"}, {"forward . a27.0.0.1", true, "", nil, 0, proxy.Options{HCRecursionDesired: true, HCDomain: "."}, "failed to resolve"}, {"forward . 127.0.0.1 {\nblaatl\n}\n", true, "", nil, 0, proxy.Options{HCRecursionDesired: true, HCDomain: "."}, "unknown property"}, {"forward . 127.0.0.1 {\nhealth_check 0.5s domain\n}\n", true, "", nil, 0, proxy.Options{HCRecursionDesired: true, HCDomain: "."}, "Wrong argument count or unexpected line ending after 'domain'"}, - {"forward . https://127.0.0.1 \n", true, ".", nil, 2, proxy.Options{HCRecursionDesired: true, HCDomain: "."}, "'https' is not supported as a destination protocol in forward: https://127.0.0.1"}, {"forward xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx 127.0.0.1 \n", true, ".", nil, 2, proxy.Options{HCRecursionDesired: true, HCDomain: "."}, "unable to normalize 'xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx'"}, } diff --git a/plugin/pkg/doh/doh.go b/plugin/pkg/doh/doh.go index 94d18e7f4..45d25a0f3 100644 --- a/plugin/pkg/doh/doh.go +++ b/plugin/pkg/doh/doh.go @@ -2,6 +2,7 @@ package doh import ( "bytes" + "context" "encoding/base64" "fmt" "io" @@ -23,6 +24,10 @@ const Path = "/dns-query" // be prefixed with https:// by default, unless it's already prefixed with // either http:// or https://. func NewRequest(method, url string, m *dns.Msg) (*http.Request, error) { + return NewRequestWithContext(context.Background(), method, url, m) +} + +func NewRequestWithContext(ctx context.Context, method, url string, m *dns.Msg) (*http.Request, error) { buf, err := m.Pack() if err != nil { return nil, err @@ -36,7 +41,8 @@ func NewRequest(method, url string, m *dns.Msg) (*http.Request, error) { case http.MethodGet: b64 := base64.RawURLEncoding.EncodeToString(buf) - req, err := http.NewRequest( + req, err := http.NewRequestWithContext( + ctx, http.MethodGet, fmt.Sprintf("%s%s?dns=%s", url, Path, b64), nil, @@ -50,7 +56,8 @@ func NewRequest(method, url string, m *dns.Msg) (*http.Request, error) { return req, nil case http.MethodPost: - req, err := http.NewRequest( + req, err := http.NewRequestWithContext( + ctx, http.MethodPost, fmt.Sprintf("%s%s", url, Path), bytes.NewReader(buf), diff --git a/plugin/pkg/proxy/connect.go b/plugin/pkg/proxy/connect.go index ec68debd8..cf788fdd3 100644 --- a/plugin/pkg/proxy/connect.go +++ b/plugin/pkg/proxy/connect.go @@ -6,12 +6,15 @@ package proxy import ( "context" "errors" + "fmt" "io" "strconv" "strings" "sync/atomic" "time" + "github.com/coredns/coredns/plugin/pkg/doh" + "github.com/coredns/coredns/plugin/pkg/transport" "github.com/coredns/coredns/request" "github.com/miekg/dns" @@ -102,10 +105,7 @@ func (t *Transport) Dial(proto string) (*persistConn, bool, error) { return &persistConn{c: conn, created: time.Now()}, false, err } -// Connect selects an upstream, sends the request and waits for a response. -func (p *Proxy) Connect(_ctx context.Context, state request.Request, opts Options) (*dns.Msg, error) { - start := time.Now() - +func (p *Proxy) lookupDNS(_ctx context.Context, state request.Request, opts Options) (*dns.Msg, error) { var proto string switch { case opts.ForceTCP: // TCP flag has precedence over UDP flag @@ -172,11 +172,55 @@ func (p *Proxy) Connect(_ctx context.Context, state request.Request, opts Option break } } + p.transport.Yield(pc) + + return ret, nil +} + +func (p *Proxy) lookupDoH(ctx context.Context, state request.Request, _ Options) (*dns.Msg, error) { + req, err := doh.NewRequestWithContext(ctx, p.dohMethod, p.addr, state.Req) + if err != nil { + return nil, err + } + + resp, err := p.transport.httpClient.Do(req) + if err != nil { + return nil, err + } + + // ResponseToMsg always closes the body via defer resp.Body.Close(). + ret, err := doh.ResponseToMsg(resp) + if err != nil { + return nil, err + } + + return ret, nil +} + +// Connect selects an upstream, sends the request and waits for a response. +func (p *Proxy) Connect(ctx context.Context, state request.Request, opts Options) (*dns.Msg, error) { + start := time.Now() + originId := state.Req.Id + + var ( + ret *dns.Msg + err error + ) + switch p.protocol { + case transport.HTTPS: + ret, err = p.lookupDoH(ctx, state, opts) + case transport.DNS, transport.TLS: + ret, err = p.lookupDNS(ctx, state, opts) + default: + return nil, fmt.Errorf("transport %s not supported to proxy", p.protocol) + } + if err != nil { + return nil, err + } + // recovery the origin Id after upstream. ret.Id = originId - p.transport.Yield(pc) - rc, ok := dns.RcodeToString[ret.Rcode] if !ok { rc = strconv.Itoa(ret.Rcode) diff --git a/plugin/pkg/proxy/health.go b/plugin/pkg/proxy/health.go index ff3dbd6ce..b8f9fb887 100644 --- a/plugin/pkg/proxy/health.go +++ b/plugin/pkg/proxy/health.go @@ -1,10 +1,13 @@ package proxy import ( + "context" "crypto/tls" + "net/http" "sync/atomic" "time" + "github.com/coredns/coredns/plugin/pkg/doh" "github.com/coredns/coredns/plugin/pkg/log" "github.com/coredns/coredns/plugin/pkg/transport" @@ -36,14 +39,16 @@ type dnsHc struct { proxyName string } +const defaultTimeout = 1 * time.Second + // NewHealthChecker returns a new HealthChecker based on transport. -func NewHealthChecker(proxyName, trans string, recursionDesired bool, domain string) HealthChecker { - switch trans { +func NewHealthChecker(proxyName, protocol string, recursionDesired bool, domain string) HealthChecker { + switch protocol { case transport.DNS, transport.TLS: c := new(dns.Client) c.Net = "udp" - c.ReadTimeout = 1 * time.Second - c.WriteTimeout = 1 * time.Second + c.ReadTimeout = defaultTimeout + c.WriteTimeout = defaultTimeout return &dnsHc{ c: c, @@ -51,9 +56,22 @@ func NewHealthChecker(proxyName, trans string, recursionDesired bool, domain str domain: domain, proxyName: proxyName, } + case transport.HTTPS: + httpTransport := http.DefaultTransport.(*http.Transport).Clone() + httpTransport.TLSClientConfig = new(tls.Config) + + return &dohHc{ + client: &http.Client{ + Transport: httpTransport, + Timeout: defaultTimeout, + }, + recursionDesired: recursionDesired, + domain: domain, + proxyName: proxyName, + } } - log.Warningf("No healthchecker for transport %q", trans) + log.Warningf("No healthchecker for transport %q", protocol) return nil } @@ -132,3 +150,97 @@ func (h *dnsHc) send(addr string) error { return err } + +// dohHc is a health checker for a DNS-over-HTTPS (DoH) endpoint. +type dohHc struct { + client *http.Client + recursionDesired bool + domain string + proxyName string +} + +func (h *dohHc) Check(p *Proxy) error { + err := h.send(p.addr) + if err != nil { + healthcheckFailureCount.WithLabelValues(p.proxyName, p.addr).Add(1) + p.incrementFails() + return err + } + + atomic.StoreUint32(&p.fails, 0) + return nil +} + +func (h *dohHc) send(addr string) error { + ping := new(dns.Msg) + ping.SetQuestion(h.domain, dns.TypeNS) + ping.RecursionDesired = h.recursionDesired + + ctx, cancel := context.WithTimeout(context.Background(), h.client.Timeout) + defer cancel() + + req, err := doh.NewRequestWithContext(ctx, http.MethodPost, addr, ping) + if err != nil { + return err + } + + resp, err := h.client.Do(req) + if err != nil { + return err + } + + // ResponseToMsg always closes the body via defer resp.Body.Close(). + m, err := doh.ResponseToMsg(resp) + if err != nil { + return err + } + + // If we got a header, we're alright. + if m.Response || m.Opcode == dns.OpcodeQuery { + return nil + } + + return nil +} + +func (h *dohHc) SetTLSConfig(cfg *tls.Config) { + h.client.Transport.(*http.Transport).TLSClientConfig = cfg +} + +func (h *dohHc) GetTLSConfig() *tls.Config { + return h.client.Transport.(*http.Transport).TLSClientConfig +} + +func (h *dohHc) SetRecursionDesired(recursionDesired bool) { + h.recursionDesired = recursionDesired +} +func (h *dohHc) GetRecursionDesired() bool { + return h.recursionDesired +} + +func (h *dohHc) SetDomain(domain string) { + h.domain = domain +} +func (h *dohHc) GetDomain() string { + return h.domain +} + +func (h *dohHc) SetTCPTransport() { + // no-op for DoH +} + +func (h *dohHc) GetReadTimeout() time.Duration { + return h.client.Transport.(*http.Transport).ResponseHeaderTimeout +} + +func (h *dohHc) SetReadTimeout(t time.Duration) { + h.client.Transport.(*http.Transport).ResponseHeaderTimeout = t +} + +func (h *dohHc) GetWriteTimeout() time.Duration { + return h.client.Timeout +} + +func (h *dohHc) SetWriteTimeout(t time.Duration) { + h.client.Timeout = t +} diff --git a/plugin/pkg/proxy/health_test.go b/plugin/pkg/proxy/health_test.go index 577139654..5e5bdb940 100644 --- a/plugin/pkg/proxy/health_test.go +++ b/plugin/pkg/proxy/health_test.go @@ -1,11 +1,14 @@ package proxy import ( + "net/http" + "net/http/httptest" "sync/atomic" "testing" "time" "github.com/coredns/coredns/plugin/pkg/dnstest" + "github.com/coredns/coredns/plugin/pkg/doh" "github.com/coredns/coredns/plugin/pkg/transport" "github.com/miekg/dns" @@ -72,6 +75,52 @@ func TestHealthTCP(t *testing.T) { } } +func TestHealthHTTPS(t *testing.T) { + i := uint32(0) + s := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + msg, err := doh.RequestToMsg(r) + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + + if msg.Question[0].Name == "." && msg.RecursionDesired == true { + atomic.AddUint32(&i, 1) + } + + ret := new(dns.Msg) + ret.SetReply(msg) + + buf, err := ret.Pack() + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + w.Header().Set("Content-Type", doh.MimeType) + w.Write(buf) + })) + defer s.Close() + + hc := NewHealthChecker("TestHealthHTTPS", transport.HTTPS, true, ".") + hc.SetTLSConfig(s.Client().Transport.(*http.Transport).TLSClientConfig) + hc.SetReadTimeout(10 * time.Millisecond) + hc.SetWriteTimeout(10 * time.Millisecond) + + p := NewProxy("TestHealthHTTPS", s.URL, transport.HTTPS) + p.readTimeout = 10 * time.Millisecond + err := hc.Check(p) + if err != nil { + t.Fatalf("check failed: %v", err) + } + + time.Sleep(20 * time.Millisecond) + i1 := atomic.LoadUint32(&i) + if i1 != 1 { + t.Errorf("Expected number of health checks with RecursionDesired==true to be %d, got %d", 1, i1) + } +} + func TestHealthNoRecursion(t *testing.T) { i := uint32(0) s := dnstest.NewServer(func(w dns.ResponseWriter, r *dns.Msg) { diff --git a/plugin/pkg/proxy/persistent.go b/plugin/pkg/proxy/persistent.go index 74c5d8d8d..64501654b 100644 --- a/plugin/pkg/proxy/persistent.go +++ b/plugin/pkg/proxy/persistent.go @@ -2,6 +2,7 @@ package proxy import ( "crypto/tls" + "net/http" "sort" "sync" "time" @@ -19,12 +20,13 @@ type persistConn struct { // Transport hold the persistent cache. type Transport struct { avgDialTime int64 // kind of average time of dial time - conns [typeTotalCount][]*persistConn // Buckets for udp, tcp and tcp-tls. + conns [typeTotalCount][]*persistConn // Buckets for udp and tcp connections expire time.Duration // After this duration an idle connection is expired. maxAge time.Duration // After this duration a connection is closed regardless of activity; 0 means unlimited. - maxIdleConns int // Max idle connections per transport type; 0 means unlimited. + maxIdleConns int // Max idle connections per protocol type; 0 means unlimited. addr string tlsConfig *tls.Config + httpClient *http.Client proxyName string mu sync.Mutex @@ -40,6 +42,7 @@ func newTransport(proxyName, addr string) *Transport { stop: make(chan struct{}), proxyName: proxyName, } + return t } diff --git a/plugin/pkg/proxy/proxy.go b/plugin/pkg/proxy/proxy.go index d455da2f9..676786068 100644 --- a/plugin/pkg/proxy/proxy.go +++ b/plugin/pkg/proxy/proxy.go @@ -2,6 +2,7 @@ package proxy import ( "crypto/tls" + "net/http" "runtime" "sync/atomic" "time" @@ -17,6 +18,9 @@ type Proxy struct { proxyName string transport *Transport + protocol string + + dohMethod string readTimeout time.Duration @@ -26,14 +30,16 @@ type Proxy struct { } // NewProxy returns a new proxy. -func NewProxy(proxyName, addr, trans string) *Proxy { +func NewProxy(proxyName, addr, protocol string) *Proxy { p := &Proxy{ addr: addr, fails: 0, probe: up.New(), readTimeout: 2 * time.Second, transport: newTransport(proxyName, addr), - health: NewHealthChecker(proxyName, trans, true, "."), + protocol: protocol, + dohMethod: http.MethodPost, + health: NewHealthChecker(proxyName, protocol, true, "."), proxyName: proxyName, } @@ -47,6 +53,9 @@ func (p *Proxy) Addr() string { return p.addr } func (p *Proxy) SetTLSConfig(cfg *tls.Config) { p.transport.SetTLSConfig(cfg) p.health.SetTLSConfig(cfg) + if p.transport.httpClient != nil { + p.transport.httpClient.Transport.(*http.Transport).TLSClientConfig = cfg + } } // SetExpire sets the expire duration in the lower p.transport. @@ -60,6 +69,14 @@ func (p *Proxy) SetMaxAge(maxAge time.Duration) { p.transport.SetMaxAge(maxAge) // A value of 0 means unlimited (default). func (p *Proxy) SetMaxIdleConns(n int) { p.transport.SetMaxIdleConns(n) } +func (p *Proxy) SetHTTPClient(client *http.Client) { + p.transport.httpClient = client +} + +func (p *Proxy) SetDOHRequestOptions(method string) { + p.dohMethod = method +} + func (p *Proxy) GetHealthchecker() HealthChecker { return p.health } diff --git a/plugin/pkg/proxy/type.go b/plugin/pkg/proxy/type.go index 10f3a4639..1bc249523 100644 --- a/plugin/pkg/proxy/type.go +++ b/plugin/pkg/proxy/type.go @@ -9,7 +9,6 @@ type transportType int const ( typeUDP transportType = iota typeTCP - typeTLS typeTotalCount // keep this last ) @@ -17,13 +16,11 @@ func stringToTransportType(s string) transportType { switch s { case "udp": return typeUDP - case "tcp": + case "tcp", "tcp-tls": return typeTCP - case "tcp-tls": - return typeTLS + default: + return typeUDP } - - return typeUDP } func (t *Transport) transportTypeFromConn(pc *persistConn) transportType { @@ -31,9 +28,5 @@ func (t *Transport) transportTypeFromConn(pc *persistConn) transportType { return typeUDP } - if t.tlsConfig == nil { - return typeTCP - } - - return typeTLS + return typeTCP }