diff --git a/core/dnsserver/https.go b/core/dnsserver/https.go index 437c561f2..2430f20a7 100644 --- a/core/dnsserver/https.go +++ b/core/dnsserver/https.go @@ -17,6 +17,9 @@ type DoHWriter struct { // request is the HTTP request we're currently handling. request *http.Request + // tsigStatus stores the TSIG verification result for the request. + tsigStatus error + // Msg is a response to be written to the client. Msg *dns.Msg } @@ -58,9 +61,9 @@ func (d *DoHWriter) Close() error { return nil } -// TsigStatus no-op implementation. +// TsigStatus returns the TSIG verification status for this request. func (d *DoHWriter) TsigStatus() error { - return nil + return d.tsigStatus } // TsigTimersOnly no-op implementation. diff --git a/core/dnsserver/server_https.go b/core/dnsserver/server_https.go index 0df47c32b..2f0746b0d 100644 --- a/core/dnsserver/server_https.go +++ b/core/dnsserver/server_https.go @@ -19,6 +19,7 @@ import ( "github.com/coredns/coredns/plugin/pkg/reuseport" "github.com/coredns/coredns/plugin/pkg/transport" + "github.com/miekg/dns" "github.com/pires/go-proxyproto" "golang.org/x/net/netutil" ) @@ -192,7 +193,7 @@ func (s *ServerHTTPS) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } - msg, err := doh.RequestToMsg(r) + msg, raw, err := doh.RequestToMsgWire(r) if err != nil { http.Error(w, err.Error(), http.StatusBadRequest) s.countResponse(http.StatusBadRequest) @@ -208,6 +209,16 @@ func (s *ServerHTTPS) ServeHTTP(w http.ResponseWriter, r *http.Request) { request: r, } + if tsig := msg.IsTsig(); tsig != nil { + if s.tsigSecret == nil { + dw.tsigStatus = dns.ErrSecret + } else if secret, ok := s.tsigSecret[tsig.Hdr.Name]; !ok { + dw.tsigStatus = dns.ErrSecret + } else { + dw.tsigStatus = dns.TsigVerify(raw, secret, "", false) + } + } + // We just call the normal chain handler - all error handling is done there. // We should expect a packet to be returned that we can send to the client. diff --git a/core/dnsserver/server_https_test.go b/core/dnsserver/server_https_test.go index 21dbaa84b..e36c3d1d1 100644 --- a/core/dnsserver/server_https_test.go +++ b/core/dnsserver/server_https_test.go @@ -9,6 +9,7 @@ import ( "net/http" "net/http/httptest" "regexp" + "strings" "testing" "time" @@ -22,6 +23,12 @@ var ( validator = func(r *http.Request) bool { return validPath.MatchString(r.URL.Path) } ) +const ( + testTSIGKeyName = "tsig-key." + testTSIGSecret = "MTIzNA==" + testTSIGWrongSecret = "NTY3OA==" +) + func testServerHTTPS(t *testing.T, path string, validator func(*http.Request) bool) *http.Response { t.Helper() c := Config{ @@ -396,3 +403,183 @@ func TestHTTPRequestContextPropagation(t *testing.T) { } }) } + +type tsigStatusPlugin struct{} + +func (p *tsigStatusPlugin) ServeDNS(_ context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { + m := new(dns.Msg) + m.SetReply(r) + m.Authoritative = true + + switch { + case r.IsTsig() == nil: + m.Rcode = dns.RcodeRefused + case w.TsigStatus() != nil: + m.Rcode = dns.RcodeNotAuth + default: + m.Rcode = dns.RcodeSuccess + } + + if err := w.WriteMsg(m); err != nil { + return dns.RcodeServerFailure, err + } + return dns.RcodeSuccess, nil +} + +func (p *tsigStatusPlugin) Name() string { return "tsig_status" } + +func testConfigWithTSIGStatusPlugin() *Config { + c := &Config{ + Zone: "example.com.", + Transport: "https", + TLSConfig: &tls.Config{}, + ListenHosts: []string{"127.0.0.1"}, + Port: "443", + TsigSecret: map[string]string{ + testTSIGKeyName: testTSIGSecret, + }, + } + c.AddPlugin(func(_next plugin.Handler) plugin.Handler { return &tsigStatusPlugin{} }) + return c +} + +func testServerHTTPSMsg(t *testing.T, cfg *Config, req *dns.Msg) *dns.Msg { + t.Helper() + + s, err := NewServerHTTPS("127.0.0.1:443", []*Config{cfg}) + if err != nil { + t.Fatal("could not create HTTPS server:", err) + } + + buf, err := req.Pack() + if err != nil { + t.Fatal(err) + } + + r := httptest.NewRequest(http.MethodPost, "/dns-query", bytes.NewReader(buf)) + r.RemoteAddr = "127.0.0.1:12345" + w := httptest.NewRecorder() + + s.ServeHTTP(w, r) + + res := w.Result() + defer res.Body.Close() + + if res.StatusCode != http.StatusOK { + t.Fatalf("unexpected HTTP status: got %d", res.StatusCode) + } + + body, err := io.ReadAll(res.Body) + if err != nil { + t.Fatal(err) + } + + m := new(dns.Msg) + if err := m.Unpack(body); err != nil { + t.Fatal(err) + } + return m +} + +func testServerHTTPSRaw(t *testing.T, cfg *Config, buf []byte) *dns.Msg { + t.Helper() + + s, err := NewServerHTTPS("127.0.0.1:443", []*Config{cfg}) + if err != nil { + t.Fatal("could not create HTTPS server:", err) + } + + r := httptest.NewRequest(http.MethodPost, "/dns-query", bytes.NewReader(buf)) + r.RemoteAddr = "127.0.0.1:12345" + w := httptest.NewRecorder() + + s.ServeHTTP(w, r) + + res := w.Result() + defer res.Body.Close() + + if res.StatusCode != http.StatusOK { + t.Fatalf("unexpected HTTP status: got %d", res.StatusCode) + } + + body, err := io.ReadAll(res.Body) + if err != nil { + t.Fatal(err) + } + + m := new(dns.Msg) + if err := m.Unpack(body); err != nil { + t.Fatal(err) + } + return m +} + +func forgedTSIGMsg() *dns.Msg { + m := new(dns.Msg) + m.SetQuestion("example.com.", dns.TypeA) + + m.Extra = append(m.Extra, &dns.TSIG{ + Hdr: dns.RR_Header{ + Name: "bogus-key.", + Rrtype: dns.TypeTSIG, + Class: dns.ClassANY, + Ttl: 0, + }, + Algorithm: dns.HmacSHA256, + TimeSigned: uint64(time.Now().Unix()), + Fudge: 300, + MACSize: 32, + MAC: strings.Repeat("00", 32), + OrigId: m.Id, + Error: dns.RcodeSuccess, + }) + return m +} + +func mustSignedTSIGQueryBytes(t *testing.T, keyName, secret string) []byte { + t.Helper() + return mustPackSignedTSIGQuery(t, keyName, secret, time.Now().Unix()) +} + +func TestServeHTTPRejectsUnsignedTSIGRequiredRequest(t *testing.T) { + m := new(dns.Msg) + m.SetQuestion("example.com.", dns.TypeA) + + resp := testServerHTTPSMsg(t, testConfigWithTSIGStatusPlugin(), m) + if resp.Rcode != dns.RcodeRefused { + t.Fatalf("expected REFUSED for unsigned request, got %s", dns.RcodeToString[resp.Rcode]) + } +} + +func TestServeHTTPRejectsTSIGWithUnknownKey(t *testing.T) { + resp := testServerHTTPSMsg(t, testConfigWithTSIGStatusPlugin(), forgedTSIGMsg()) + + if resp.Rcode != dns.RcodeNotAuth { + t.Fatalf("expected NOTAUTH for unknown TSIG key, got %s", dns.RcodeToString[resp.Rcode]) + } +} + +func TestServeHTTPRejectsTSIGWithBadMAC(t *testing.T) { + buf := mustSignedTSIGQueryBytes(t, testTSIGKeyName, testTSIGWrongSecret) + + resp := testServerHTTPSRaw(t, testConfigWithTSIGStatusPlugin(), buf) + if resp.Rcode != dns.RcodeNotAuth { + t.Fatalf("expected NOTAUTH for bad TSIG MAC, got %s", dns.RcodeToString[resp.Rcode]) + } +} + +func TestServeHTTPAcceptsValidTSIG(t *testing.T) { + buf := mustSignedTSIGQueryBytes(t, testTSIGKeyName, testTSIGSecret) + + resp := testServerHTTPSRaw(t, testConfigWithTSIGStatusPlugin(), buf) + if resp.Rcode != dns.RcodeSuccess { + t.Fatalf("expected NOERROR for valid TSIG, got %s", dns.RcodeToString[resp.Rcode]) + } +} + +func TestDoHWriterTsigStatusReturnsStoredStatus(t *testing.T) { + dw := &DoHWriter{tsigStatus: dns.ErrSecret} + if dw.TsigStatus() != dns.ErrSecret { + t.Fatal("expected TsigStatus to return stored tsigStatus") + } +} diff --git a/plugin/pkg/doh/doh.go b/plugin/pkg/doh/doh.go index f9f4e8df8..94d18e7f4 100644 --- a/plugin/pkg/doh/doh.go +++ b/plugin/pkg/doh/doh.go @@ -77,6 +77,13 @@ func ResponseToMsg(resp *http.Response) (*dns.Msg, error) { // RequestToMsg converts a http.Request to a dns message. func RequestToMsg(req *http.Request) (*dns.Msg, error) { + msg, _, err := RequestToMsgWire(req) + return msg, err +} + +// RequestToMsgWire converts a http.Request to a dns message and returns the +// original DNS wire bytes from the request. +func RequestToMsgWire(req *http.Request) (*dns.Msg, []byte, error) { switch req.Method { case http.MethodGet: return requestToMsgGet(req) @@ -85,55 +92,60 @@ func RequestToMsg(req *http.Request) (*dns.Msg, error) { return requestToMsgPost(req) default: - return nil, fmt.Errorf("method not allowed: %s", req.Method) + return nil, nil, fmt.Errorf("method not allowed: %s", req.Method) } } // requestToMsgPost extracts the dns message from the request body. -func requestToMsgPost(req *http.Request) (*dns.Msg, error) { +func requestToMsgPost(req *http.Request) (*dns.Msg, []byte, error) { defer req.Body.Close() - return toMsg(req.Body) + return toMsgWire(req.Body) } const maxDNSQuerySize = 65536 const maxBase64Len = (maxDNSQuerySize*8 + 5) / 6 // requestToMsgGet extract the dns message from the GET request. -func requestToMsgGet(req *http.Request) (*dns.Msg, error) { +func requestToMsgGet(req *http.Request) (*dns.Msg, []byte, error) { values := req.URL.Query() b64, ok := values["dns"] if !ok { - return nil, fmt.Errorf("no 'dns' query parameter found") + return nil, nil, fmt.Errorf("no 'dns' query parameter found") } if len(b64) != 1 { - return nil, fmt.Errorf("multiple 'dns' query values found") + return nil, nil, fmt.Errorf("multiple 'dns' query values found") } if len(b64[0]) > maxBase64Len { - return nil, fmt.Errorf("dns query too large") + return nil, nil, fmt.Errorf("dns query too large") } - return base64ToMsg(b64[0]) + return base64ToMsgWire(b64[0]) } func toMsg(r io.ReadCloser) (*dns.Msg, error) { - buf, err := io.ReadAll(http.MaxBytesReader(nil, r, maxDNSQuerySize)) - if err != nil { - return nil, err - } - m := new(dns.Msg) - err = m.Unpack(buf) + m, _, err := toMsgWire(r) return m, err } -func base64ToMsg(b64 string) (*dns.Msg, error) { +func toMsgWire(r io.ReadCloser) (*dns.Msg, []byte, error) { + buf, err := io.ReadAll(http.MaxBytesReader(nil, r, maxDNSQuerySize)) + if err != nil { + return nil, nil, err + } + m := new(dns.Msg) + err = m.Unpack(buf) + return m, buf, err +} + +func base64ToMsgWire(b64 string) (*dns.Msg, []byte, error) { buf, err := b64Enc.DecodeString(b64) if err != nil { - return nil, err + return nil, nil, err } m := new(dns.Msg) err = m.Unpack(buf) - return m, err + return m, buf, err } var b64Enc = base64.RawURLEncoding