mirror of
https://github.com/coredns/coredns.git
synced 2026-06-15 13:40:11 -04:00
feat(forward): add doh support (#8004)
* chore(pkg/proxy): prepare for DoH implementation Signed-off-by: Thomas Gosteli <thomas.gosteli@protonmail.ch> * chore(pkg/proxy): prepare for DoH implementation Signed-off-by: Thomas Gosteli <thomas.gosteli@protonmail.ch> * feat(proxy): implement basic DoH resolution Signed-off-by: Thomas Gosteli <thomas.gosteli@protonmail.ch> * feat(forward): implement DoH forwarding Signed-off-by: Thomas Gosteli <thomas.gosteli@protonmail.ch> * feat(proxy): add basic DoH health checker Signed-off-by: Thomas Gosteli <thomas.gosteli@protonmail.ch> * chore: align http transport with Go's DefaultTransport and resolve some of the TODOs Signed-off-by: Thomas Gosteli <thomas.gosteli@protonmail.ch> * docs(forward): add basic documentation for DoH Signed-off-by: Thomas Gosteli <thomas.gosteli@protonmail.ch> * chore: add basic tests to cover DoH Signed-off-by: Thomas Gosteli <thomas.gosteli@protonmail.ch> * chore(health): unify default timeout to 1s Signed-off-by: Thomas Gosteli <thomas.gosteli@protonmail.ch> * feat(forward): make doh method configurable Signed-off-by: Thomas Gosteli <thomas.gosteli@protonmail.ch> * chore: remove maxIdleConnsPerHost setting & update docs Signed-off-by: Thomas Gosteli <thomas.gosteli@protonmail.ch> * chore(forward): reject https upstreams with path Signed-off-by: Thomas Gosteli <thomas.gosteli@protonmail.ch> --------- Signed-off-by: Thomas Gosteli <thomas.gosteli@protonmail.ch>
This commit is contained in:
@@ -2,6 +2,7 @@ package doh
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"io"
|
||||
@@ -23,6 +24,10 @@ const Path = "/dns-query"
|
||||
// be prefixed with https:// by default, unless it's already prefixed with
|
||||
// either http:// or https://.
|
||||
func NewRequest(method, url string, m *dns.Msg) (*http.Request, error) {
|
||||
return NewRequestWithContext(context.Background(), method, url, m)
|
||||
}
|
||||
|
||||
func NewRequestWithContext(ctx context.Context, method, url string, m *dns.Msg) (*http.Request, error) {
|
||||
buf, err := m.Pack()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -36,7 +41,8 @@ func NewRequest(method, url string, m *dns.Msg) (*http.Request, error) {
|
||||
case http.MethodGet:
|
||||
b64 := base64.RawURLEncoding.EncodeToString(buf)
|
||||
|
||||
req, err := http.NewRequest(
|
||||
req, err := http.NewRequestWithContext(
|
||||
ctx,
|
||||
http.MethodGet,
|
||||
fmt.Sprintf("%s%s?dns=%s", url, Path, b64),
|
||||
nil,
|
||||
@@ -50,7 +56,8 @@ func NewRequest(method, url string, m *dns.Msg) (*http.Request, error) {
|
||||
return req, nil
|
||||
|
||||
case http.MethodPost:
|
||||
req, err := http.NewRequest(
|
||||
req, err := http.NewRequestWithContext(
|
||||
ctx,
|
||||
http.MethodPost,
|
||||
fmt.Sprintf("%s%s", url, Path),
|
||||
bytes.NewReader(buf),
|
||||
|
||||
@@ -6,12 +6,15 @@ package proxy
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/coredns/coredns/plugin/pkg/doh"
|
||||
"github.com/coredns/coredns/plugin/pkg/transport"
|
||||
"github.com/coredns/coredns/request"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
@@ -102,10 +105,7 @@ func (t *Transport) Dial(proto string) (*persistConn, bool, error) {
|
||||
return &persistConn{c: conn, created: time.Now()}, false, err
|
||||
}
|
||||
|
||||
// Connect selects an upstream, sends the request and waits for a response.
|
||||
func (p *Proxy) Connect(_ctx context.Context, state request.Request, opts Options) (*dns.Msg, error) {
|
||||
start := time.Now()
|
||||
|
||||
func (p *Proxy) lookupDNS(_ctx context.Context, state request.Request, opts Options) (*dns.Msg, error) {
|
||||
var proto string
|
||||
switch {
|
||||
case opts.ForceTCP: // TCP flag has precedence over UDP flag
|
||||
@@ -172,11 +172,55 @@ func (p *Proxy) Connect(_ctx context.Context, state request.Request, opts Option
|
||||
break
|
||||
}
|
||||
}
|
||||
p.transport.Yield(pc)
|
||||
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
func (p *Proxy) lookupDoH(ctx context.Context, state request.Request, _ Options) (*dns.Msg, error) {
|
||||
req, err := doh.NewRequestWithContext(ctx, p.dohMethod, p.addr, state.Req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
resp, err := p.transport.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// ResponseToMsg always closes the body via defer resp.Body.Close().
|
||||
ret, err := doh.ResponseToMsg(resp)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
// Connect selects an upstream, sends the request and waits for a response.
|
||||
func (p *Proxy) Connect(ctx context.Context, state request.Request, opts Options) (*dns.Msg, error) {
|
||||
start := time.Now()
|
||||
originId := state.Req.Id
|
||||
|
||||
var (
|
||||
ret *dns.Msg
|
||||
err error
|
||||
)
|
||||
switch p.protocol {
|
||||
case transport.HTTPS:
|
||||
ret, err = p.lookupDoH(ctx, state, opts)
|
||||
case transport.DNS, transport.TLS:
|
||||
ret, err = p.lookupDNS(ctx, state, opts)
|
||||
default:
|
||||
return nil, fmt.Errorf("transport %s not supported to proxy", p.protocol)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// recovery the origin Id after upstream.
|
||||
ret.Id = originId
|
||||
|
||||
p.transport.Yield(pc)
|
||||
|
||||
rc, ok := dns.RcodeToString[ret.Rcode]
|
||||
if !ok {
|
||||
rc = strconv.Itoa(ret.Rcode)
|
||||
|
||||
@@ -1,10 +1,13 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"net/http"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/coredns/coredns/plugin/pkg/doh"
|
||||
"github.com/coredns/coredns/plugin/pkg/log"
|
||||
"github.com/coredns/coredns/plugin/pkg/transport"
|
||||
|
||||
@@ -36,14 +39,16 @@ type dnsHc struct {
|
||||
proxyName string
|
||||
}
|
||||
|
||||
const defaultTimeout = 1 * time.Second
|
||||
|
||||
// NewHealthChecker returns a new HealthChecker based on transport.
|
||||
func NewHealthChecker(proxyName, trans string, recursionDesired bool, domain string) HealthChecker {
|
||||
switch trans {
|
||||
func NewHealthChecker(proxyName, protocol string, recursionDesired bool, domain string) HealthChecker {
|
||||
switch protocol {
|
||||
case transport.DNS, transport.TLS:
|
||||
c := new(dns.Client)
|
||||
c.Net = "udp"
|
||||
c.ReadTimeout = 1 * time.Second
|
||||
c.WriteTimeout = 1 * time.Second
|
||||
c.ReadTimeout = defaultTimeout
|
||||
c.WriteTimeout = defaultTimeout
|
||||
|
||||
return &dnsHc{
|
||||
c: c,
|
||||
@@ -51,9 +56,22 @@ func NewHealthChecker(proxyName, trans string, recursionDesired bool, domain str
|
||||
domain: domain,
|
||||
proxyName: proxyName,
|
||||
}
|
||||
case transport.HTTPS:
|
||||
httpTransport := http.DefaultTransport.(*http.Transport).Clone()
|
||||
httpTransport.TLSClientConfig = new(tls.Config)
|
||||
|
||||
return &dohHc{
|
||||
client: &http.Client{
|
||||
Transport: httpTransport,
|
||||
Timeout: defaultTimeout,
|
||||
},
|
||||
recursionDesired: recursionDesired,
|
||||
domain: domain,
|
||||
proxyName: proxyName,
|
||||
}
|
||||
}
|
||||
|
||||
log.Warningf("No healthchecker for transport %q", trans)
|
||||
log.Warningf("No healthchecker for transport %q", protocol)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -132,3 +150,97 @@ func (h *dnsHc) send(addr string) error {
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// dohHc is a health checker for a DNS-over-HTTPS (DoH) endpoint.
|
||||
type dohHc struct {
|
||||
client *http.Client
|
||||
recursionDesired bool
|
||||
domain string
|
||||
proxyName string
|
||||
}
|
||||
|
||||
func (h *dohHc) Check(p *Proxy) error {
|
||||
err := h.send(p.addr)
|
||||
if err != nil {
|
||||
healthcheckFailureCount.WithLabelValues(p.proxyName, p.addr).Add(1)
|
||||
p.incrementFails()
|
||||
return err
|
||||
}
|
||||
|
||||
atomic.StoreUint32(&p.fails, 0)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *dohHc) send(addr string) error {
|
||||
ping := new(dns.Msg)
|
||||
ping.SetQuestion(h.domain, dns.TypeNS)
|
||||
ping.RecursionDesired = h.recursionDesired
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), h.client.Timeout)
|
||||
defer cancel()
|
||||
|
||||
req, err := doh.NewRequestWithContext(ctx, http.MethodPost, addr, ping)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
resp, err := h.client.Do(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// ResponseToMsg always closes the body via defer resp.Body.Close().
|
||||
m, err := doh.ResponseToMsg(resp)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// If we got a header, we're alright.
|
||||
if m.Response || m.Opcode == dns.OpcodeQuery {
|
||||
return nil
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *dohHc) SetTLSConfig(cfg *tls.Config) {
|
||||
h.client.Transport.(*http.Transport).TLSClientConfig = cfg
|
||||
}
|
||||
|
||||
func (h *dohHc) GetTLSConfig() *tls.Config {
|
||||
return h.client.Transport.(*http.Transport).TLSClientConfig
|
||||
}
|
||||
|
||||
func (h *dohHc) SetRecursionDesired(recursionDesired bool) {
|
||||
h.recursionDesired = recursionDesired
|
||||
}
|
||||
func (h *dohHc) GetRecursionDesired() bool {
|
||||
return h.recursionDesired
|
||||
}
|
||||
|
||||
func (h *dohHc) SetDomain(domain string) {
|
||||
h.domain = domain
|
||||
}
|
||||
func (h *dohHc) GetDomain() string {
|
||||
return h.domain
|
||||
}
|
||||
|
||||
func (h *dohHc) SetTCPTransport() {
|
||||
// no-op for DoH
|
||||
}
|
||||
|
||||
func (h *dohHc) GetReadTimeout() time.Duration {
|
||||
return h.client.Transport.(*http.Transport).ResponseHeaderTimeout
|
||||
}
|
||||
|
||||
func (h *dohHc) SetReadTimeout(t time.Duration) {
|
||||
h.client.Transport.(*http.Transport).ResponseHeaderTimeout = t
|
||||
}
|
||||
|
||||
func (h *dohHc) GetWriteTimeout() time.Duration {
|
||||
return h.client.Timeout
|
||||
}
|
||||
|
||||
func (h *dohHc) SetWriteTimeout(t time.Duration) {
|
||||
h.client.Timeout = t
|
||||
}
|
||||
|
||||
@@ -1,11 +1,14 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/coredns/coredns/plugin/pkg/dnstest"
|
||||
"github.com/coredns/coredns/plugin/pkg/doh"
|
||||
"github.com/coredns/coredns/plugin/pkg/transport"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
@@ -72,6 +75,52 @@ func TestHealthTCP(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestHealthHTTPS(t *testing.T) {
|
||||
i := uint32(0)
|
||||
s := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
msg, err := doh.RequestToMsg(r)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
if msg.Question[0].Name == "." && msg.RecursionDesired == true {
|
||||
atomic.AddUint32(&i, 1)
|
||||
}
|
||||
|
||||
ret := new(dns.Msg)
|
||||
ret.SetReply(msg)
|
||||
|
||||
buf, err := ret.Pack()
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", doh.MimeType)
|
||||
w.Write(buf)
|
||||
}))
|
||||
defer s.Close()
|
||||
|
||||
hc := NewHealthChecker("TestHealthHTTPS", transport.HTTPS, true, ".")
|
||||
hc.SetTLSConfig(s.Client().Transport.(*http.Transport).TLSClientConfig)
|
||||
hc.SetReadTimeout(10 * time.Millisecond)
|
||||
hc.SetWriteTimeout(10 * time.Millisecond)
|
||||
|
||||
p := NewProxy("TestHealthHTTPS", s.URL, transport.HTTPS)
|
||||
p.readTimeout = 10 * time.Millisecond
|
||||
err := hc.Check(p)
|
||||
if err != nil {
|
||||
t.Fatalf("check failed: %v", err)
|
||||
}
|
||||
|
||||
time.Sleep(20 * time.Millisecond)
|
||||
i1 := atomic.LoadUint32(&i)
|
||||
if i1 != 1 {
|
||||
t.Errorf("Expected number of health checks with RecursionDesired==true to be %d, got %d", 1, i1)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHealthNoRecursion(t *testing.T) {
|
||||
i := uint32(0)
|
||||
s := dnstest.NewServer(func(w dns.ResponseWriter, r *dns.Msg) {
|
||||
|
||||
@@ -2,6 +2,7 @@ package proxy
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"net/http"
|
||||
"sort"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -19,12 +20,13 @@ type persistConn struct {
|
||||
// Transport hold the persistent cache.
|
||||
type Transport struct {
|
||||
avgDialTime int64 // kind of average time of dial time
|
||||
conns [typeTotalCount][]*persistConn // Buckets for udp, tcp and tcp-tls.
|
||||
conns [typeTotalCount][]*persistConn // Buckets for udp and tcp connections
|
||||
expire time.Duration // After this duration an idle connection is expired.
|
||||
maxAge time.Duration // After this duration a connection is closed regardless of activity; 0 means unlimited.
|
||||
maxIdleConns int // Max idle connections per transport type; 0 means unlimited.
|
||||
maxIdleConns int // Max idle connections per protocol type; 0 means unlimited.
|
||||
addr string
|
||||
tlsConfig *tls.Config
|
||||
httpClient *http.Client
|
||||
proxyName string
|
||||
|
||||
mu sync.Mutex
|
||||
@@ -40,6 +42,7 @@ func newTransport(proxyName, addr string) *Transport {
|
||||
stop: make(chan struct{}),
|
||||
proxyName: proxyName,
|
||||
}
|
||||
|
||||
return t
|
||||
}
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@ package proxy
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"net/http"
|
||||
"runtime"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
@@ -17,6 +18,9 @@ type Proxy struct {
|
||||
proxyName string
|
||||
|
||||
transport *Transport
|
||||
protocol string
|
||||
|
||||
dohMethod string
|
||||
|
||||
readTimeout time.Duration
|
||||
|
||||
@@ -26,14 +30,16 @@ type Proxy struct {
|
||||
}
|
||||
|
||||
// NewProxy returns a new proxy.
|
||||
func NewProxy(proxyName, addr, trans string) *Proxy {
|
||||
func NewProxy(proxyName, addr, protocol string) *Proxy {
|
||||
p := &Proxy{
|
||||
addr: addr,
|
||||
fails: 0,
|
||||
probe: up.New(),
|
||||
readTimeout: 2 * time.Second,
|
||||
transport: newTransport(proxyName, addr),
|
||||
health: NewHealthChecker(proxyName, trans, true, "."),
|
||||
protocol: protocol,
|
||||
dohMethod: http.MethodPost,
|
||||
health: NewHealthChecker(proxyName, protocol, true, "."),
|
||||
proxyName: proxyName,
|
||||
}
|
||||
|
||||
@@ -47,6 +53,9 @@ func (p *Proxy) Addr() string { return p.addr }
|
||||
func (p *Proxy) SetTLSConfig(cfg *tls.Config) {
|
||||
p.transport.SetTLSConfig(cfg)
|
||||
p.health.SetTLSConfig(cfg)
|
||||
if p.transport.httpClient != nil {
|
||||
p.transport.httpClient.Transport.(*http.Transport).TLSClientConfig = cfg
|
||||
}
|
||||
}
|
||||
|
||||
// SetExpire sets the expire duration in the lower p.transport.
|
||||
@@ -60,6 +69,14 @@ func (p *Proxy) SetMaxAge(maxAge time.Duration) { p.transport.SetMaxAge(maxAge)
|
||||
// A value of 0 means unlimited (default).
|
||||
func (p *Proxy) SetMaxIdleConns(n int) { p.transport.SetMaxIdleConns(n) }
|
||||
|
||||
func (p *Proxy) SetHTTPClient(client *http.Client) {
|
||||
p.transport.httpClient = client
|
||||
}
|
||||
|
||||
func (p *Proxy) SetDOHRequestOptions(method string) {
|
||||
p.dohMethod = method
|
||||
}
|
||||
|
||||
func (p *Proxy) GetHealthchecker() HealthChecker {
|
||||
return p.health
|
||||
}
|
||||
|
||||
@@ -9,7 +9,6 @@ type transportType int
|
||||
const (
|
||||
typeUDP transportType = iota
|
||||
typeTCP
|
||||
typeTLS
|
||||
typeTotalCount // keep this last
|
||||
)
|
||||
|
||||
@@ -17,13 +16,11 @@ func stringToTransportType(s string) transportType {
|
||||
switch s {
|
||||
case "udp":
|
||||
return typeUDP
|
||||
case "tcp":
|
||||
case "tcp", "tcp-tls":
|
||||
return typeTCP
|
||||
case "tcp-tls":
|
||||
return typeTLS
|
||||
default:
|
||||
return typeUDP
|
||||
}
|
||||
|
||||
return typeUDP
|
||||
}
|
||||
|
||||
func (t *Transport) transportTypeFromConn(pc *persistConn) transportType {
|
||||
@@ -31,9 +28,5 @@ func (t *Transport) transportTypeFromConn(pc *persistConn) transportType {
|
||||
return typeUDP
|
||||
}
|
||||
|
||||
if t.tlsConfig == nil {
|
||||
return typeTCP
|
||||
}
|
||||
|
||||
return typeTLS
|
||||
return typeTCP
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user