mirror of
https://github.com/coredns/coredns.git
synced 2026-04-09 13:35:33 -04:00
core: Add full TSIG verification in DoH transport (#8013)
* core: Add full TSIG verification in DoH transport This PR add full TSIG verification in DoH using dns.TsigVerify() 7943 --------- Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
This commit is contained in:
@@ -17,6 +17,9 @@ type DoHWriter struct {
|
|||||||
// request is the HTTP request we're currently handling.
|
// request is the HTTP request we're currently handling.
|
||||||
request *http.Request
|
request *http.Request
|
||||||
|
|
||||||
|
// tsigStatus stores the TSIG verification result for the request.
|
||||||
|
tsigStatus error
|
||||||
|
|
||||||
// Msg is a response to be written to the client.
|
// Msg is a response to be written to the client.
|
||||||
Msg *dns.Msg
|
Msg *dns.Msg
|
||||||
}
|
}
|
||||||
@@ -58,9 +61,9 @@ func (d *DoHWriter) Close() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// TsigStatus no-op implementation.
|
// TsigStatus returns the TSIG verification status for this request.
|
||||||
func (d *DoHWriter) TsigStatus() error {
|
func (d *DoHWriter) TsigStatus() error {
|
||||||
return nil
|
return d.tsigStatus
|
||||||
}
|
}
|
||||||
|
|
||||||
// TsigTimersOnly no-op implementation.
|
// TsigTimersOnly no-op implementation.
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ import (
|
|||||||
"github.com/coredns/coredns/plugin/pkg/reuseport"
|
"github.com/coredns/coredns/plugin/pkg/reuseport"
|
||||||
"github.com/coredns/coredns/plugin/pkg/transport"
|
"github.com/coredns/coredns/plugin/pkg/transport"
|
||||||
|
|
||||||
|
"github.com/miekg/dns"
|
||||||
"github.com/pires/go-proxyproto"
|
"github.com/pires/go-proxyproto"
|
||||||
"golang.org/x/net/netutil"
|
"golang.org/x/net/netutil"
|
||||||
)
|
)
|
||||||
@@ -192,7 +193,7 @@ func (s *ServerHTTPS) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
msg, err := doh.RequestToMsg(r)
|
msg, raw, err := doh.RequestToMsgWire(r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||||
s.countResponse(http.StatusBadRequest)
|
s.countResponse(http.StatusBadRequest)
|
||||||
@@ -208,6 +209,16 @@ func (s *ServerHTTPS) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|||||||
request: r,
|
request: r,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if tsig := msg.IsTsig(); tsig != nil {
|
||||||
|
if s.tsigSecret == nil {
|
||||||
|
dw.tsigStatus = dns.ErrSecret
|
||||||
|
} else if secret, ok := s.tsigSecret[tsig.Hdr.Name]; !ok {
|
||||||
|
dw.tsigStatus = dns.ErrSecret
|
||||||
|
} else {
|
||||||
|
dw.tsigStatus = dns.TsigVerify(raw, secret, "", false)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// We just call the normal chain handler - all error handling is done there.
|
// We just call the normal chain handler - all error handling is done there.
|
||||||
// We should expect a packet to be returned that we can send to the client.
|
// We should expect a packet to be returned that we can send to the client.
|
||||||
|
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"regexp"
|
"regexp"
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -22,6 +23,12 @@ var (
|
|||||||
validator = func(r *http.Request) bool { return validPath.MatchString(r.URL.Path) }
|
validator = func(r *http.Request) bool { return validPath.MatchString(r.URL.Path) }
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
testTSIGKeyName = "tsig-key."
|
||||||
|
testTSIGSecret = "MTIzNA=="
|
||||||
|
testTSIGWrongSecret = "NTY3OA=="
|
||||||
|
)
|
||||||
|
|
||||||
func testServerHTTPS(t *testing.T, path string, validator func(*http.Request) bool) *http.Response {
|
func testServerHTTPS(t *testing.T, path string, validator func(*http.Request) bool) *http.Response {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
c := Config{
|
c := Config{
|
||||||
@@ -396,3 +403,183 @@ func TestHTTPRequestContextPropagation(t *testing.T) {
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type tsigStatusPlugin struct{}
|
||||||
|
|
||||||
|
func (p *tsigStatusPlugin) ServeDNS(_ context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
|
||||||
|
m := new(dns.Msg)
|
||||||
|
m.SetReply(r)
|
||||||
|
m.Authoritative = true
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case r.IsTsig() == nil:
|
||||||
|
m.Rcode = dns.RcodeRefused
|
||||||
|
case w.TsigStatus() != nil:
|
||||||
|
m.Rcode = dns.RcodeNotAuth
|
||||||
|
default:
|
||||||
|
m.Rcode = dns.RcodeSuccess
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := w.WriteMsg(m); err != nil {
|
||||||
|
return dns.RcodeServerFailure, err
|
||||||
|
}
|
||||||
|
return dns.RcodeSuccess, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *tsigStatusPlugin) Name() string { return "tsig_status" }
|
||||||
|
|
||||||
|
func testConfigWithTSIGStatusPlugin() *Config {
|
||||||
|
c := &Config{
|
||||||
|
Zone: "example.com.",
|
||||||
|
Transport: "https",
|
||||||
|
TLSConfig: &tls.Config{},
|
||||||
|
ListenHosts: []string{"127.0.0.1"},
|
||||||
|
Port: "443",
|
||||||
|
TsigSecret: map[string]string{
|
||||||
|
testTSIGKeyName: testTSIGSecret,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
c.AddPlugin(func(_next plugin.Handler) plugin.Handler { return &tsigStatusPlugin{} })
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
|
||||||
|
func testServerHTTPSMsg(t *testing.T, cfg *Config, req *dns.Msg) *dns.Msg {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
s, err := NewServerHTTPS("127.0.0.1:443", []*Config{cfg})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal("could not create HTTPS server:", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
buf, err := req.Pack()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
r := httptest.NewRequest(http.MethodPost, "/dns-query", bytes.NewReader(buf))
|
||||||
|
r.RemoteAddr = "127.0.0.1:12345"
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
s.ServeHTTP(w, r)
|
||||||
|
|
||||||
|
res := w.Result()
|
||||||
|
defer res.Body.Close()
|
||||||
|
|
||||||
|
if res.StatusCode != http.StatusOK {
|
||||||
|
t.Fatalf("unexpected HTTP status: got %d", res.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
body, err := io.ReadAll(res.Body)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
m := new(dns.Msg)
|
||||||
|
if err := m.Unpack(body); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
return m
|
||||||
|
}
|
||||||
|
|
||||||
|
func testServerHTTPSRaw(t *testing.T, cfg *Config, buf []byte) *dns.Msg {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
s, err := NewServerHTTPS("127.0.0.1:443", []*Config{cfg})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal("could not create HTTPS server:", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
r := httptest.NewRequest(http.MethodPost, "/dns-query", bytes.NewReader(buf))
|
||||||
|
r.RemoteAddr = "127.0.0.1:12345"
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
s.ServeHTTP(w, r)
|
||||||
|
|
||||||
|
res := w.Result()
|
||||||
|
defer res.Body.Close()
|
||||||
|
|
||||||
|
if res.StatusCode != http.StatusOK {
|
||||||
|
t.Fatalf("unexpected HTTP status: got %d", res.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
body, err := io.ReadAll(res.Body)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
m := new(dns.Msg)
|
||||||
|
if err := m.Unpack(body); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
return m
|
||||||
|
}
|
||||||
|
|
||||||
|
func forgedTSIGMsg() *dns.Msg {
|
||||||
|
m := new(dns.Msg)
|
||||||
|
m.SetQuestion("example.com.", dns.TypeA)
|
||||||
|
|
||||||
|
m.Extra = append(m.Extra, &dns.TSIG{
|
||||||
|
Hdr: dns.RR_Header{
|
||||||
|
Name: "bogus-key.",
|
||||||
|
Rrtype: dns.TypeTSIG,
|
||||||
|
Class: dns.ClassANY,
|
||||||
|
Ttl: 0,
|
||||||
|
},
|
||||||
|
Algorithm: dns.HmacSHA256,
|
||||||
|
TimeSigned: uint64(time.Now().Unix()),
|
||||||
|
Fudge: 300,
|
||||||
|
MACSize: 32,
|
||||||
|
MAC: strings.Repeat("00", 32),
|
||||||
|
OrigId: m.Id,
|
||||||
|
Error: dns.RcodeSuccess,
|
||||||
|
})
|
||||||
|
return m
|
||||||
|
}
|
||||||
|
|
||||||
|
func mustSignedTSIGQueryBytes(t *testing.T, keyName, secret string) []byte {
|
||||||
|
t.Helper()
|
||||||
|
return mustPackSignedTSIGQuery(t, keyName, secret, time.Now().Unix())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestServeHTTPRejectsUnsignedTSIGRequiredRequest(t *testing.T) {
|
||||||
|
m := new(dns.Msg)
|
||||||
|
m.SetQuestion("example.com.", dns.TypeA)
|
||||||
|
|
||||||
|
resp := testServerHTTPSMsg(t, testConfigWithTSIGStatusPlugin(), m)
|
||||||
|
if resp.Rcode != dns.RcodeRefused {
|
||||||
|
t.Fatalf("expected REFUSED for unsigned request, got %s", dns.RcodeToString[resp.Rcode])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestServeHTTPRejectsTSIGWithUnknownKey(t *testing.T) {
|
||||||
|
resp := testServerHTTPSMsg(t, testConfigWithTSIGStatusPlugin(), forgedTSIGMsg())
|
||||||
|
|
||||||
|
if resp.Rcode != dns.RcodeNotAuth {
|
||||||
|
t.Fatalf("expected NOTAUTH for unknown TSIG key, got %s", dns.RcodeToString[resp.Rcode])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestServeHTTPRejectsTSIGWithBadMAC(t *testing.T) {
|
||||||
|
buf := mustSignedTSIGQueryBytes(t, testTSIGKeyName, testTSIGWrongSecret)
|
||||||
|
|
||||||
|
resp := testServerHTTPSRaw(t, testConfigWithTSIGStatusPlugin(), buf)
|
||||||
|
if resp.Rcode != dns.RcodeNotAuth {
|
||||||
|
t.Fatalf("expected NOTAUTH for bad TSIG MAC, got %s", dns.RcodeToString[resp.Rcode])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestServeHTTPAcceptsValidTSIG(t *testing.T) {
|
||||||
|
buf := mustSignedTSIGQueryBytes(t, testTSIGKeyName, testTSIGSecret)
|
||||||
|
|
||||||
|
resp := testServerHTTPSRaw(t, testConfigWithTSIGStatusPlugin(), buf)
|
||||||
|
if resp.Rcode != dns.RcodeSuccess {
|
||||||
|
t.Fatalf("expected NOERROR for valid TSIG, got %s", dns.RcodeToString[resp.Rcode])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDoHWriterTsigStatusReturnsStoredStatus(t *testing.T) {
|
||||||
|
dw := &DoHWriter{tsigStatus: dns.ErrSecret}
|
||||||
|
if dw.TsigStatus() != dns.ErrSecret {
|
||||||
|
t.Fatal("expected TsigStatus to return stored tsigStatus")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -77,6 +77,13 @@ func ResponseToMsg(resp *http.Response) (*dns.Msg, error) {
|
|||||||
|
|
||||||
// RequestToMsg converts a http.Request to a dns message.
|
// RequestToMsg converts a http.Request to a dns message.
|
||||||
func RequestToMsg(req *http.Request) (*dns.Msg, error) {
|
func RequestToMsg(req *http.Request) (*dns.Msg, error) {
|
||||||
|
msg, _, err := RequestToMsgWire(req)
|
||||||
|
return msg, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// RequestToMsgWire converts a http.Request to a dns message and returns the
|
||||||
|
// original DNS wire bytes from the request.
|
||||||
|
func RequestToMsgWire(req *http.Request) (*dns.Msg, []byte, error) {
|
||||||
switch req.Method {
|
switch req.Method {
|
||||||
case http.MethodGet:
|
case http.MethodGet:
|
||||||
return requestToMsgGet(req)
|
return requestToMsgGet(req)
|
||||||
@@ -85,55 +92,60 @@ func RequestToMsg(req *http.Request) (*dns.Msg, error) {
|
|||||||
return requestToMsgPost(req)
|
return requestToMsgPost(req)
|
||||||
|
|
||||||
default:
|
default:
|
||||||
return nil, fmt.Errorf("method not allowed: %s", req.Method)
|
return nil, nil, fmt.Errorf("method not allowed: %s", req.Method)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// requestToMsgPost extracts the dns message from the request body.
|
// requestToMsgPost extracts the dns message from the request body.
|
||||||
func requestToMsgPost(req *http.Request) (*dns.Msg, error) {
|
func requestToMsgPost(req *http.Request) (*dns.Msg, []byte, error) {
|
||||||
defer req.Body.Close()
|
defer req.Body.Close()
|
||||||
return toMsg(req.Body)
|
return toMsgWire(req.Body)
|
||||||
}
|
}
|
||||||
|
|
||||||
const maxDNSQuerySize = 65536
|
const maxDNSQuerySize = 65536
|
||||||
const maxBase64Len = (maxDNSQuerySize*8 + 5) / 6
|
const maxBase64Len = (maxDNSQuerySize*8 + 5) / 6
|
||||||
|
|
||||||
// requestToMsgGet extract the dns message from the GET request.
|
// requestToMsgGet extract the dns message from the GET request.
|
||||||
func requestToMsgGet(req *http.Request) (*dns.Msg, error) {
|
func requestToMsgGet(req *http.Request) (*dns.Msg, []byte, error) {
|
||||||
values := req.URL.Query()
|
values := req.URL.Query()
|
||||||
b64, ok := values["dns"]
|
b64, ok := values["dns"]
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, fmt.Errorf("no 'dns' query parameter found")
|
return nil, nil, fmt.Errorf("no 'dns' query parameter found")
|
||||||
}
|
}
|
||||||
if len(b64) != 1 {
|
if len(b64) != 1 {
|
||||||
return nil, fmt.Errorf("multiple 'dns' query values found")
|
return nil, nil, fmt.Errorf("multiple 'dns' query values found")
|
||||||
}
|
}
|
||||||
if len(b64[0]) > maxBase64Len {
|
if len(b64[0]) > maxBase64Len {
|
||||||
return nil, fmt.Errorf("dns query too large")
|
return nil, nil, fmt.Errorf("dns query too large")
|
||||||
}
|
}
|
||||||
return base64ToMsg(b64[0])
|
return base64ToMsgWire(b64[0])
|
||||||
}
|
}
|
||||||
|
|
||||||
func toMsg(r io.ReadCloser) (*dns.Msg, error) {
|
func toMsg(r io.ReadCloser) (*dns.Msg, error) {
|
||||||
buf, err := io.ReadAll(http.MaxBytesReader(nil, r, maxDNSQuerySize))
|
m, _, err := toMsgWire(r)
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
m := new(dns.Msg)
|
|
||||||
err = m.Unpack(buf)
|
|
||||||
return m, err
|
return m, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func base64ToMsg(b64 string) (*dns.Msg, error) {
|
func toMsgWire(r io.ReadCloser) (*dns.Msg, []byte, error) {
|
||||||
|
buf, err := io.ReadAll(http.MaxBytesReader(nil, r, maxDNSQuerySize))
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
m := new(dns.Msg)
|
||||||
|
err = m.Unpack(buf)
|
||||||
|
return m, buf, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func base64ToMsgWire(b64 string) (*dns.Msg, []byte, error) {
|
||||||
buf, err := b64Enc.DecodeString(b64)
|
buf, err := b64Enc.DecodeString(b64)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
m := new(dns.Msg)
|
m := new(dns.Msg)
|
||||||
err = m.Unpack(buf)
|
err = m.Unpack(buf)
|
||||||
|
|
||||||
return m, err
|
return m, buf, err
|
||||||
}
|
}
|
||||||
|
|
||||||
var b64Enc = base64.RawURLEncoding
|
var b64Enc = base64.RawURLEncoding
|
||||||
|
|||||||
Reference in New Issue
Block a user